Skip to content

Commit 939a5bb

Browse files
committed
Use monai.utils.ensure_tuple instead of explicit casting.
Signed-off-by: Aaron Ponti <aaron@aaronponti.ch>
1 parent 57ffb77 commit 939a5bb

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

monai/inferers/utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,15 @@ def sliding_window_inference(
223223
for idx in slice_range
224224
]
225225
if sw_batch_size > 1:
226-
win_data = torch.cat(
227-
[inputs[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in unravel_slice]
228-
).to(sw_device)
226+
win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device)
229227
if condition is not None:
230-
win_condition = torch.cat(
231-
[
232-
condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice]
233-
for win_slice in unravel_slice
234-
]
235-
).to(sw_device)
228+
win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(
229+
sw_device
230+
)
236231
kwargs["condition"] = win_condition
237232
else:
238233
s0 = unravel_slice[0]
239-
s0_idx = tuple(s0) if isinstance(s0, list) else s0
234+
s0_idx = ensure_tuple(s0)
240235

241236
win_data = inputs[s0_idx].to(sw_device)
242237
if condition is not None:
@@ -267,7 +262,7 @@ def sliding_window_inference(
267262
offset = s[buffer_dim + 2].start - c_start
268263
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
269264
s[0] = slice(0, 1)
270-
sw_device_buffer[0][tuple(s) if isinstance(s, list) else s] += p * w_t
265+
sw_device_buffer[0][ensure_tuple(s)] += p * w_t
271266
b_i += len(unravel_slice)
272267
if b_i < b_slices[b_s][0]:
273268
continue
@@ -298,7 +293,7 @@ def sliding_window_inference(
298293
o_slice[buffer_dim + 2] = slice(c_start, c_end)
299294
img_b = b_s // n_per_batch # image batch index
300295
o_slice[0] = slice(img_b, img_b + 1)
301-
o_slice_idx = tuple(o_slice) if isinstance(o_slice, list) else o_slice
296+
o_slice_idx = ensure_tuple(o_slice)
302297
if non_blocking:
303298
output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)
304299
else:
@@ -378,7 +373,7 @@ def _compute_coords(coords, z_scale, out, patch):
378373
idx_zm[axis] = slice(
379374
int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])
380375
)
381-
out[tuple(idx_zm)] += p
376+
out[ensure_tuple(idx_zm)] += p
382377

383378

384379
def _get_scan_interval(

0 commit comments

Comments
 (0)