|
3 | 3 | import logging |
4 | 4 | import typing |
5 | 5 | from abc import ABC, abstractmethod |
| 6 | +from platform import version |
6 | 7 | from typing import Union, Tuple, Callable, Optional, List, Sequence |
7 | 8 |
|
8 | 9 | import numpy as np |
|
21 | 22 | from ..util.dtype import to_float_array |
22 | 23 | from ..util.pickle import setstate |
23 | 24 | from ..util.string import ToStringMixin |
| 25 | +from ..util.version import Version |
24 | 26 | from ..vector_model import VectorRegressionModel, VectorClassificationModel, TrainingContext |
25 | 27 |
|
26 | 28 | log: logging.Logger = logging.getLogger(__name__) |
| 29 | +torch_version = Version(torch) |
27 | 30 |
|
28 | 31 |
|
29 | 32 | class MCDropoutCapableNNModule(nn.Module, ABC): |
@@ -139,9 +142,12 @@ def _set_cuda_enabled(self, is_cuda_enabled: bool) -> None: |
139 | 142 | def _is_cuda_enabled(self) -> bool: |
140 | 143 | return self.cuda |
141 | 144 |
|
142 | | - def _load_model(self, model_file) -> None: # TODO: complete type hints: what types are allowed for modelFile? |
| 145 | + def _load_model(self, model_file) -> None: |
143 | 146 | try: |
144 | | - self.module = torch.load(model_file) |
| 147 | + load_kwargs = {} |
| 148 | + if torch_version.is_at_least(2, 4): |
| 149 | + load_kwargs["weights_only"] = False |
| 150 | + self.module = torch.load(model_file, **load_kwargs) |
145 | 151 | self._gpu = self._get_gpu_from_model_parameter_device() |
146 | 152 | except: |
147 | 153 | if self._is_cuda_enabled(): |
|
0 commit comments