Skip to content

Commit 252d26e

Browse files
aymuos15ericspodpre-commit-ci[bot]
authored
Fix memory leak in optional_import traceback handling (#8782)
## Summary - Fix memory leak in `optional_import` where stored traceback objects retain references to entire call stacks, preventing garbage collection - Format traceback to string immediately and clear `import_exception.__traceback__ = None` - Hoist 3 `optional_import` calls for cucim to module level in `monai/transforms/utils.py` Fixes #7480, #7727 ## Details **`monai/utils/module.py`:** - Replace storing live `__traceback__` object with `traceback.format_exception()` string - Clear `import_exception.__traceback__ = None` to break reference chain - Embed formatted traceback in error message under "Original traceback:" section **`monai/transforms/utils.py`:** - Move cucim `optional_import` calls to module level (consistent with existing skimage/scipy/cupy pattern) - Fixes incidental name shadowing in `distance_transform_edt` ## Test plan - [x] `test_no_traceback_leak` — weakref regression test for the leak - [x] `test_failed_import_shows_traceback_string` — verifies traceback in error message - [x] All 12 `test_optional_import.py` tests pass - [x] pre-commit, black, ruff, pytype, mypy — all clean --------- Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Signed-off-by: Soumya Snigdha Kundu <soumyawork15@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 089f02d commit 252d26e

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

monai/transforms/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
cp, has_cp = optional_import("cupy")
8686
cp_ndarray, _ = optional_import("cupy", name="ndarray")
8787
exposure, has_skimage = optional_import("skimage.exposure")
88+
# NOTE: cucim is deliberately NOT imported at module level.
89+
# Module-level cucim imports caused very slow import times and other buggy behaviour.
90+
# Keep cucim imports inside the functions that need them.
8891

8992
__all__ = [
9093
"allow_missing_keys_mode",

monai/utils/module.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pdb
1818
import re
1919
import sys
20+
import traceback as traceback_mod
2021
import warnings
2122
from collections.abc import Callable, Collection, Hashable, Iterable, Mapping
2223
from functools import partial, wraps
@@ -368,8 +369,9 @@ def optional_import(
368369
OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
369370
"""
370371

371-
tb = None
372+
had_exception = False
372373
exception_str = ""
374+
tb_str = ""
373375
if name:
374376
actual_cmd = f"from {module} import {name}"
375377
else:
@@ -384,8 +386,12 @@ def optional_import(
384386
if name: # user specified to load class/function/... from the module
385387
the_module = getattr(the_module, name)
386388
except Exception as import_exception: # any exceptions during import
387-
tb = import_exception.__traceback__
389+
tb_str = "".join(
390+
traceback_mod.format_exception(type(import_exception), import_exception, import_exception.__traceback__)
391+
)
392+
import_exception.__traceback__ = None
388393
exception_str = f"{import_exception}"
394+
had_exception = True
389395
else: # found the module
390396
if version_args and version_checker(pkg, f"{version}", version_args):
391397
return the_module, True
@@ -394,7 +400,7 @@ def optional_import(
394400

395401
# preparing lazy error message
396402
msg = descriptor.format(actual_cmd)
397-
if version and tb is None: # a pure version issue
403+
if version and not had_exception: # a pure version issue
398404
msg += f" (requires '{module} {version}' by '{version_checker.__name__}')"
399405
if exception_str:
400406
msg += f" ({exception_str})"
@@ -407,10 +413,9 @@ def __init__(self, *_args, **_kwargs):
407413
+ "\n\nFor details about installing the optional dependencies, please visit:"
408414
+ "\n https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies"
409415
)
410-
if tb is None:
411-
self._exception = OptionalImportError(_default_msg)
412-
else:
413-
self._exception = OptionalImportError(_default_msg).with_traceback(tb)
416+
if tb_str:
417+
_default_msg += f"\n\nOriginal traceback:\n{tb_str}"
418+
self._exception = OptionalImportError(_default_msg)
414419

415420
def __getattr__(self, name):
416421
"""

tests/utils/test_optional_import.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
import gc
1415
import unittest
16+
import weakref
1517

1618
from parameterized import parameterized
1719

@@ -75,6 +77,34 @@ def versioning(module, ver, a):
7577
nn, flag = optional_import("torch", "1.1", version_checker=versioning, name="nn", version_args=test_args)
7678
self.assertTrue(flag)
7779

80+
def test_no_traceback_leak(self):
81+
"""Verify optional_import does not retain references to stack frames (issue #7480)."""
82+
83+
class _Marker:
84+
pass
85+
86+
def _do_import():
87+
marker = _Marker()
88+
ref = weakref.ref(marker)
89+
# Call optional_import for a module that does not exist.
90+
# If the traceback is leaked, `marker` stays alive via frame references.
91+
_mod, flag = optional_import("nonexistent_module_for_leak_test")
92+
self.assertFalse(flag)
93+
return ref
94+
95+
ref = _do_import()
96+
gc.collect()
97+
self.assertIsNone(ref(), "optional_import is leaking frame references via traceback")
98+
99+
def test_failed_import_shows_traceback_string(self):
100+
"""Verify the error message includes the original traceback as a string."""
101+
mod, flag = optional_import("nonexistent_module_for_tb_test")
102+
self.assertFalse(flag)
103+
with self.assertRaises(OptionalImportError) as ctx:
104+
_ = mod.something
105+
self.assertIn("Original traceback", str(ctx.exception))
106+
self.assertIn("ModuleNotFoundError", str(ctx.exception))
107+
78108

79109
if __name__ == "__main__":
80110
unittest.main()

0 commit comments

Comments
 (0)