Skip to content

Commit 07990cc

Browse files
aymuos15ericspod
andauthored
Fix make_gaussian_kernel truncated parameter unit mismatch (#8780) (#8781)
Fixes: #8780 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 24f4924 commit 07990cc

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

monai/losses/image_dissimilarity.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn import functional as F
1616
from torch.nn.modules.loss import _Loss
1717

18-
from monai.networks.layers import gaussian_1d, separable_filtering
18+
from monai.networks.layers import separable_filtering
1919
from monai.utils import LossReduction
2020
from monai.utils.module import look_up_option
2121

@@ -34,11 +34,11 @@ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
3434

3535

3636
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
37-
sigma = torch.tensor(kernel_size / 3.0)
38-
kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * (
39-
2.5066282 * sigma
40-
)
41-
return kernel[:kernel_size]
37+
sigma = kernel_size / 3.0
38+
half = kernel_size // 2
39+
x = torch.arange(-half, half + 1, dtype=torch.float)
40+
kernel = torch.exp(-0.5 / (sigma * sigma) * x**2)
41+
return kernel
4242

4343

4444
kernel_dict = {

tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from parameterized import parameterized
1919

20-
from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss
20+
from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss, make_gaussian_kernel
2121

2222
device = "cuda" if torch.cuda.is_available() else "cpu"
2323

@@ -113,6 +113,25 @@
113113
},
114114
-0.95406944,
115115
],
116+
# Regression tests for gh-8780: gaussian kernel_size > 3 was broken due to
117+
# truncated parameter being passed as pixel radius instead of sigma multiplier.
118+
# Identical images must yield loss == -1.0 for any kernel size.
119+
[
120+
{"spatial_dims": 1, "kernel_type": "gaussian", "kernel_size": 5},
121+
{
122+
"pred": torch.arange(0, 5).reshape(1, 1, -1).to(dtype=torch.float, device=device),
123+
"target": torch.arange(0, 5).reshape(1, 1, -1).to(dtype=torch.float, device=device),
124+
},
125+
-1.0,
126+
],
127+
[
128+
{"spatial_dims": 1, "kernel_type": "gaussian", "kernel_size": 9},
129+
{
130+
"pred": torch.arange(0, 9).reshape(1, 1, -1).to(dtype=torch.float, device=device),
131+
"target": torch.arange(0, 9).reshape(1, 1, -1).to(dtype=torch.float, device=device),
132+
},
133+
-1.0,
134+
],
116135
]
117136

118137

@@ -138,6 +157,15 @@ def test_ill_shape(self):
138157
torch.ones((1, 3, 4, 4, 4), dtype=torch.float, device=device),
139158
)
140159

160+
def test_gaussian_kernel_shape_and_symmetry(self):
161+
# gh-8780: kernel must have correct length, be symmetric, and peak at center
162+
for kernel_size in [3, 5, 7, 9, 11, 15]:
163+
k = make_gaussian_kernel(kernel_size)
164+
self.assertEqual(len(k), kernel_size)
165+
self.assertTrue(torch.allclose(k, k.flip(0)), f"kernel_size={kernel_size} not symmetric")
166+
self.assertEqual(k.argmax().item(), kernel_size // 2)
167+
np.testing.assert_allclose(k.max().item(), 1.0, rtol=1e-6)
168+
141169
def test_ill_opts(self):
142170
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)
143171
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)

0 commit comments

Comments
 (0)