@@ -223,20 +223,15 @@ def sliding_window_inference(
223223 for idx in slice_range
224224 ]
225225 if sw_batch_size > 1 :
226- win_data = torch .cat (
227- [inputs [tuple (win_slice ) if isinstance (win_slice , list ) else win_slice ] for win_slice in unravel_slice ]
228- ).to (sw_device )
226+ win_data = torch .cat ([inputs [ensure_tuple (win_slice )] for win_slice in unravel_slice ]).to (sw_device )
229227 if condition is not None :
230- win_condition = torch .cat (
231- [
232- condition [tuple (win_slice ) if isinstance (win_slice , list ) else win_slice ]
233- for win_slice in unravel_slice
234- ]
235- ).to (sw_device )
228+ win_condition = torch .cat ([condition [ensure_tuple (win_slice )] for win_slice in unravel_slice ]).to (
229+ sw_device
230+ )
236231 kwargs ["condition" ] = win_condition
237232 else :
238233 s0 = unravel_slice [0 ]
239- s0_idx = tuple (s0 ) if isinstance ( s0 , list ) else s0
234+ s0_idx = ensure_tuple (s0 )
240235
241236 win_data = inputs [s0_idx ].to (sw_device )
242237 if condition is not None :
@@ -267,7 +262,7 @@ def sliding_window_inference(
267262 offset = s [buffer_dim + 2 ].start - c_start
268263 s [buffer_dim + 2 ] = slice (offset , offset + roi_size [buffer_dim ])
269264 s [0 ] = slice (0 , 1 )
270- sw_device_buffer [0 ][tuple (s ) if isinstance ( s , list ) else s ] += p * w_t
265+ sw_device_buffer [0 ][ensure_tuple (s )] += p * w_t
271266 b_i += len (unravel_slice )
272267 if b_i < b_slices [b_s ][0 ]:
273268 continue
@@ -298,7 +293,7 @@ def sliding_window_inference(
298293 o_slice [buffer_dim + 2 ] = slice (c_start , c_end )
299294 img_b = b_s // n_per_batch # image batch index
300295 o_slice [0 ] = slice (img_b , img_b + 1 )
301- o_slice_idx = tuple (o_slice ) if isinstance ( o_slice , list ) else o_slice
296+ o_slice_idx = ensure_tuple (o_slice )
302297 if non_blocking :
303298 output_image_list [0 ][o_slice_idx ].copy_ (sw_device_buffer [0 ], non_blocking = non_blocking )
304299 else :
@@ -378,7 +373,7 @@ def _compute_coords(coords, z_scale, out, patch):
378373 idx_zm [axis ] = slice (
379374 int (original_idx [axis ].start * z_scale [axis - 2 ]), int (original_idx [axis ].stop * z_scale [axis - 2 ])
380375 )
381- out [tuple (idx_zm )] += p
376+ out [ensure_tuple (idx_zm )] += p
382377
383378
384379def _get_scan_interval (
0 commit comments