Skip to content

Commit 59a2da9

Browse files
committed
Explicitly set weights_only=False for torch version 2.4+
1 parent a0115bc commit 59a2da9

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/sensai/torch/torch_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import typing
55
from abc import ABC, abstractmethod
6+
from platform import version
67
from typing import Union, Tuple, Callable, Optional, List, Sequence
78

89
import numpy as np
@@ -21,9 +22,11 @@
2122
from ..util.dtype import to_float_array
2223
from ..util.pickle import setstate
2324
from ..util.string import ToStringMixin
25+
from ..util.version import Version
2426
from ..vector_model import VectorRegressionModel, VectorClassificationModel, TrainingContext
2527

2628
log: logging.Logger = logging.getLogger(__name__)
29+
torch_version = Version(torch)
2730

2831

2932
class MCDropoutCapableNNModule(nn.Module, ABC):
@@ -139,9 +142,12 @@ def _set_cuda_enabled(self, is_cuda_enabled: bool) -> None:
139142
def _is_cuda_enabled(self) -> bool:
140143
return self.cuda
141144

142-
def _load_model(self, model_file) -> None: # TODO: complete type hints: what types are allowed for modelFile?
145+
def _load_model(self, model_file) -> None:
143146
try:
144-
self.module = torch.load(model_file)
147+
load_kwargs = {}
148+
if torch_version.is_at_least(2, 4):
149+
load_kwargs["weights_only"] = False
150+
self.module = torch.load(model_file, **load_kwargs)
145151
self._gpu = self._get_gpu_from_model_parameter_device()
146152
except:
147153
if self._is_cuda_enabled():

0 commit comments

Comments
 (0)