Skip to content

Commit 58f0a34

Browse files
authored
Merge branch 'dev' into docs-segresnet-ds-shape-constraints
2 parents 93c9437 + a537499 commit 58f0a34

36 files changed

Lines changed: 1048 additions & 82 deletions

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
8383
- name: Install the dependencies
8484
run: |
85-
python -m pip install --user --upgrade pip wheel
85+
python -m pip install --user --upgrade pip wheel pybind11
8686
python -m pip install torch==2.5.1 torchvision==0.20.1
8787
cat "requirements-dev.txt"
8888
python -m pip install --no-build-isolation -r requirements-dev.txt

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ include monai/_version.py
33

44
include README.md
55
include LICENSE
6+
7+
prune tests

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ Segmentation Losses
9898
.. autoclass:: NACLLoss
9999
:members:
100100

101+
`MCCLoss`
102+
~~~~~~~~~
103+
.. autoclass:: MCCLoss
104+
:members:
105+
101106
Registration Losses
102107
-------------------
103108

monai/apps/auto3dseg/auto_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __init__(
229229
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
230230
logger.info(f"Input config is not provided, using the default {input}")
231231

232-
self.data_src_cfg = dict()
232+
self.data_src_cfg = {}
233233
if isinstance(input, dict):
234234
self.data_src_cfg = input
235235
elif isinstance(input, str) and os.path.isfile(input):

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# If meta_key references a MetaTensor, extract from its .meta attribute;
86+
# otherwise treat it as a metadata dictionary directly.
87+
if isinstance(meta_obj, MetaTensor):
88+
meta_dict: dict = meta_obj.meta
89+
else:
90+
meta_dict = dict(meta_obj)
91+
6192
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
93+
if key in meta_dict:
94+
d[key] = meta_dict[key] # type: ignore
6495
elif not self.allow_missing_keys:
6596
raise KeyError(
6697
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"

monai/auto3dseg/analyzer.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Hashable, Mapping
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919

2020
import numpy as np
2121
import torch
@@ -468,21 +468,35 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
468468
"""
469469
d: dict[Hashable, MetaTensor] = dict(data)
470470
start = time.time()
471-
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
472-
using_cuda = True
473-
else:
474-
using_cuda = False
471+
image_tensor = d[self.image_key]
472+
label_tensor = d[self.label_key]
473+
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474+
using_cuda = any(
475+
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
476+
)
475477
restore_grad_state = torch.is_grad_enabled()
476478
torch.set_grad_enabled(False)
477479

478-
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
479-
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
480+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
481+
label_tensor, (MetaTensor, torch.Tensor)
482+
):
483+
if label_tensor.device != image_tensor.device:
484+
if using_cuda:
485+
# Move both tensors to CUDA when mixing devices
486+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489+
else:
490+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
491+
492+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
493+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
480494

481495
if ndas_label.shape != ndas[0].shape:
482496
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
483497

484498
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
485-
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
499+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
486500

487501
unique_label = unique(ndas_label)
488502
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def compute_shape_offset(
881881
Default is False, using option 1 to compute the shape and offset.
882882
883883
"""
884-
shape = np.array(spatial_shape, copy=True, dtype=float)
884+
shape = np.array(tuple(spatial_shape), copy=True, dtype=float)
885885
sr = len(shape)
886886
in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0]
887887
out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0]

monai/engines/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def __call__(
219219
`kwargs` supports other args for `Tensor.to()` API.
220220
"""
221221
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
222-
args_ = list()
223-
kwargs_ = dict()
222+
args_ = []
223+
kwargs_ = {}
224224

225225
def _get_data(key: str) -> torch.Tensor:
226226
data = batchdata[key]

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .giou_loss import BoxGIoULoss, giou
3737
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
3838
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
39+
from .mcc_loss import MCCLoss
3940
from .multi_scale import MultiScaleLoss
4041
from .nacl_loss import NACLLoss
4142
from .perceptual import PerceptualLoss

monai/losses/image_dissimilarity.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn import functional as F
1616
from torch.nn.modules.loss import _Loss
1717

18-
from monai.networks.layers import gaussian_1d, separable_filtering
18+
from monai.networks.layers import separable_filtering
1919
from monai.utils import LossReduction
2020
from monai.utils.module import look_up_option
2121

@@ -34,11 +34,11 @@ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
3434

3535

3636
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
37-
sigma = torch.tensor(kernel_size / 3.0)
38-
kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * (
39-
2.5066282 * sigma
40-
)
41-
return kernel[:kernel_size]
37+
sigma = kernel_size / 3.0
38+
half = kernel_size // 2
39+
x = torch.arange(-half, half + 1, dtype=torch.float)
40+
kernel = torch.exp(-0.5 / (sigma * sigma) * x**2)
41+
return kernel
4242

4343

4444
kernel_dict = {

0 commit comments

Comments
 (0)