Skip to content

Commit 20f7d7f

Browse files
authored
Merge branch 'dev' into cicd_test_update
2 parents 4af7d58 + 1dc47a5 commit 20f7d7f

22 files changed

+429
-57
lines changed

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ include monai/_version.py
33

44
include README.md
55
include LICENSE
6+
7+
prune tests

monai/losses/image_dissimilarity.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn import functional as F
1616
from torch.nn.modules.loss import _Loss
1717

18-
from monai.networks.layers import gaussian_1d, separable_filtering
18+
from monai.networks.layers import separable_filtering
1919
from monai.utils import LossReduction
2020
from monai.utils.module import look_up_option
2121

@@ -34,11 +34,11 @@ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
3434

3535

3636
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
37-
sigma = torch.tensor(kernel_size / 3.0)
38-
kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * (
39-
2.5066282 * sigma
40-
)
41-
return kernel[:kernel_size]
37+
sigma = kernel_size / 3.0
38+
half = kernel_size // 2
39+
x = torch.arange(-half, half + 1, dtype=torch.float)
40+
kernel = torch.exp(-0.5 / (sigma * sigma) * x**2)
41+
return kernel
4242

4343

4444
kernel_dict = {

monai/losses/spectral_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
self.fft_norm = fft_norm
5656

5757
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
58-
input_amplitude = self._get_fft_amplitude(target)
59-
target_amplitude = self._get_fft_amplitude(input)
58+
input_amplitude = self._get_fft_amplitude(input)
59+
target_amplitude = self._get_fft_amplitude(target)
6060

6161
# Compute distance between amplitude of frequency components
6262
# See Section 3.3 from https://arxiv.org/abs/2005.00341

monai/losses/ssim_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
111111
# 2D data
112112
x = torch.ones([1,1,10,10])/2
113113
y = torch.ones([1,1,10,10])/2
114-
print(1-SSIMLoss(spatial_dims=2)(x,y))
114+
print(SSIMLoss(spatial_dims=2)(x,y))
115115
116116
# pseudo-3D data
117117
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
118118
y = torch.ones([1,5,10,10])/2
119-
print(1-SSIMLoss(spatial_dims=2)(x,y))
119+
print(SSIMLoss(spatial_dims=2)(x,y))
120120
121121
# 3D data
122122
x = torch.ones([1,1,10,10,10])/2
123123
y = torch.ones([1,1,10,10,10])/2
124-
print(1-SSIMLoss(spatial_dims=3)(x,y))
124+
print(SSIMLoss(spatial_dims=3)(x,y))
125125
"""
126126
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
127127
loss: torch.Tensor = 1 - ssim_value

monai/networks/layers/filtering.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def __init__(self, spatial_sigma, color_sigma):
221221
self.len_spatial_sigma = 3
222222
else:
223223
raise ValueError(
224-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
224+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)"
225+
f"or be a single float value ({spatial_sigma=})."
225226
)
226227

227228
# Register sigmas as trainable parameters.
@@ -231,6 +232,10 @@ def __init__(self, spatial_sigma, color_sigma):
231232
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
232233

233234
def forward(self, input_tensor):
235+
if len(input_tensor.shape) < 3:
236+
raise ValueError(
237+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
238+
)
234239
if input_tensor.shape[1] != 1:
235240
raise ValueError(
236241
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
@@ -239,24 +244,27 @@ def forward(self, input_tensor):
239244
)
240245

241246
len_input = len(input_tensor.shape)
247+
spatial_dims = len_input - 2
242248

243249
# C++ extension so far only supports 5-dim inputs.
244-
if len_input == 3:
250+
if spatial_dims == 1:
245251
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
246-
elif len_input == 4:
252+
elif spatial_dims == 2:
247253
input_tensor = input_tensor.unsqueeze(4)
248254

249-
if self.len_spatial_sigma != len_input:
250-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
255+
if self.len_spatial_sigma != spatial_dims:
256+
raise ValueError(
257+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
258+
)
251259

252260
prediction = TrainableBilateralFilterFunction.apply(
253261
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
254262
)
255263

256264
# Make sure to return tensor of the same shape as the input.
257-
if len_input == 3:
265+
if spatial_dims == 1:
258266
prediction = prediction.squeeze(4).squeeze(3)
259-
elif len_input == 4:
267+
elif spatial_dims == 2:
260268
prediction = prediction.squeeze(4)
261269

262270
return prediction
@@ -389,7 +397,8 @@ def __init__(self, spatial_sigma, color_sigma):
389397
self.len_spatial_sigma = 3
390398
else:
391399
raise ValueError(
392-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
400+
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n"
401+
f"or be a single float value ({spatial_sigma=})."
393402
)
394403

395404
# Register sigmas as trainable parameters.
@@ -399,39 +408,45 @@ def __init__(self, spatial_sigma, color_sigma):
399408
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
400409

401410
def forward(self, input_tensor, guidance_tensor):
411+
if len(input_tensor.shape) < 3:
412+
raise ValueError(
413+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
414+
)
402415
if input_tensor.shape[1] != 1:
403416
raise ValueError(
404-
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
417+
f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. "
405418
"Please use multiple parallel filter layers if you want "
406419
"to filter multiple channels."
407420
)
408421
if input_tensor.shape != guidance_tensor.shape:
409422
raise ValueError(
410-
"Shape of input image must equal shape of guidance image."
411-
f"Got {input_tensor.shape} and {guidance_tensor.shape}."
423+
f"Shape of input image must equal shape of guidance image, got {input_tensor.shape} and {guidance_tensor.shape}."
412424
)
413425

414426
len_input = len(input_tensor.shape)
427+
spatial_dims = len_input - 2
415428

416429
# C++ extension so far only supports 5-dim inputs.
417-
if len_input == 3:
430+
if spatial_dims == 1:
418431
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
419432
guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
420-
elif len_input == 4:
433+
elif spatial_dims == 2:
421434
input_tensor = input_tensor.unsqueeze(4)
422435
guidance_tensor = guidance_tensor.unsqueeze(4)
423436

424-
if self.len_spatial_sigma != len_input:
425-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
437+
if self.len_spatial_sigma != spatial_dims:
438+
raise ValueError(
439+
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
440+
)
426441

427442
prediction = TrainableJointBilateralFilterFunction.apply(
428443
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
429444
)
430445

431446
# Make sure to return tensor of the same shape as the input.
432-
if len_input == 3:
447+
if spatial_dims == 1:
433448
prediction = prediction.squeeze(4).squeeze(3)
434-
elif len_input == 4:
449+
elif spatial_dims == 2:
435450
prediction = prediction.squeeze(4)
436451

437452
return prediction

monai/networks/nets/autoencoderkl.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
680680
681681
Args:
682682
old_state_dict: state dict from the old AutoencoderKL model.
683+
verbose: if True, print diagnostic information about key mismatches.
683684
"""
684685

685686
new_state_dict = self.state_dict()
@@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715716
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
716717
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
717718

718-
# old version did not have a projection so set these to the identity
719-
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
720-
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
721-
)
722-
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
723-
new_state_dict[f"{block}.attn.out_proj.bias"].shape
724-
)
719+
out_w = f"{block}.attn.out_proj.weight"
720+
out_b = f"{block}.attn.out_proj.bias"
721+
proj_w = f"{block}.proj_attn.weight"
722+
proj_b = f"{block}.proj_attn.bias"
723+
724+
if out_w in new_state_dict:
725+
if proj_w in old_state_dict:
726+
new_state_dict[out_w] = old_state_dict.pop(proj_w)
727+
if proj_b in old_state_dict:
728+
new_state_dict[out_b] = old_state_dict.pop(proj_b)
729+
else:
730+
new_state_dict[out_b] = torch.zeros(
731+
new_state_dict[out_b].shape,
732+
dtype=new_state_dict[out_b].dtype,
733+
device=new_state_dict[out_b].device,
734+
)
735+
else:
736+
# No legacy proj_attn - initialize out_proj to identity/zero
737+
new_state_dict[out_w] = torch.eye(
738+
new_state_dict[out_w].shape[0],
739+
dtype=new_state_dict[out_w].dtype,
740+
device=new_state_dict[out_w].device,
741+
)
742+
new_state_dict[out_b] = torch.zeros(
743+
new_state_dict[out_b].shape,
744+
dtype=new_state_dict[out_b].dtype,
745+
device=new_state_dict[out_b].device,
746+
)
747+
elif proj_w in old_state_dict:
748+
# new model has no out_proj at all - discard the legacy keys so they
749+
# don't surface as "unexpected keys" during load_state_dict
750+
old_state_dict.pop(proj_w)
751+
old_state_dict.pop(proj_b, None)
725752

726753
# fix the upsample conv blocks which were renamed postconv
727754
for k in new_state_dict:

monai/networks/nets/segresnet_ds.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,20 @@ class SegResNetDS(nn.Module):
254254
image spacing into an approximately isotropic space.
255255
Otherwise, by default, the kernel size and downsampling is always isotropic.
256256
257+
**Spatial shape constraints**: If ``resolution`` is ``None`` (isotropic mode),
258+
each spatial dimension must be divisible by ``2 ** (len(blocks_down) - 1)``.
259+
With the default ``blocks_down=(1, 2, 2, 4)``, each dimension must be
260+
divisible by 8. If ``resolution`` is provided (anisotropic mode),
261+
divisibility can differ per dimension; use :py:meth:`shape_factor` for
262+
the exact required factors and :py:meth:`is_valid_shape` to verify a shape.
263+
264+
Example::
265+
266+
model = SegResNetDS(spatial_dims=3, blocks_down=(1, 2, 2, 4))
267+
print(model.shape_factor()) # [8, 8, 8]
268+
print(model.is_valid_shape((1, 1, 128, 128, 128))) # True
269+
print(model.is_valid_shape((1, 1, 100, 100, 100))) # False
270+
257271
"""
258272

259273
def __init__(

monai/networks/nets/swin_unetr.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ class SwinUNETR(nn.Module):
4747
Swin UNETR based on: "Hatamizadeh et al.,
4848
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
4949
<https://arxiv.org/abs/2201.01266>"
50+
51+
Spatial Shape Constraints:
52+
Each spatial dimension of the input must be divisible by ``patch_size ** 5``.
53+
With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32**
54+
(i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages
55+
of PatchMerging downsampling, each halving the spatial resolution.
56+
57+
For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``.
58+
59+
Examples of valid 3D input sizes (with default ``patch_size=2``):
60+
``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``.
61+
62+
A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.
5063
"""
5164

5265
def __init__(
@@ -76,7 +89,8 @@ def __init__(
7689
Args:
7790
in_channels: dimension of input channels.
7891
out_channels: dimension of output channels.
79-
patch_size: size of the patch token.
92+
patch_size: size of the patch token. Input spatial dimensions must be divisible by
93+
``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``).
8094
feature_size: dimension of network feature size.
8195
depths: number of layers in each stage.
8296
num_heads: number of attention heads.
@@ -108,6 +122,10 @@ def __init__(
108122
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
109123
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110124
125+
Raises:
126+
ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``.
127+
Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference.
128+
111129
"""
112130

113131
super().__init__()

monai/transforms/signal/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
273273
data = convert_to_tensor(self.freqs * time_partial)
274274
sine_partial = self.magnitude * torch.sin(data)
275275

276-
loc = np.random.choice(range(length))
276+
loc = self.R.choice(range(length))
277277
signal = paste(signal, sine_partial, (loc,))
278278

279279
return signal
@@ -354,7 +354,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
354354
time_partial = np.arange(0, round(self.fracs * length), 1)
355355
squaredpulse_partial = self.magnitude * squarepulse(self.freqs * time_partial)
356356

357-
loc = np.random.choice(range(length))
357+
loc = self.R.choice(range(length))
358358
signal = paste(signal, squaredpulse_partial, (loc,))
359359

360360
return signal

monai/transforms/utility/array.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
10491049
which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
10501050
label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
10511051
label 2 is the peritumoral edema, which is counted only under WT subregion,
1052-
label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1052+
the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1053+
1054+
Args:
1055+
et_label: the label used for the GD-enhancing tumor (ET).
1056+
- Use 4 for BraTS 2018-2022.
1057+
- Use 3 for BraTS 2023.
1058+
Defaults to 4.
10531059
"""
10541060

10551061
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
10561062

1063+
def __init__(self, et_label: int = 4) -> None:
1064+
if et_label in (1, 2):
1065+
raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.")
1066+
self.et_label = et_label
1067+
10571068
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
10581069
# if img has channel dim, squeeze it
10591070
if img.ndim == 4 and img.shape[0] == 1:
10601071
img = img.squeeze(0)
10611072

1062-
result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]
1063-
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
1064-
# label 4 is ET
1073+
result = [
1074+
(img == 1) | (img == self.et_label),
1075+
(img == 1) | (img == self.et_label) | (img == 2),
1076+
img == self.et_label,
1077+
]
1078+
# merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
1079+
# self.et_label is ET (4 or 3)
10651080
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)
10661081

10671082

0 commit comments

Comments
 (0)