We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7a2f49b commit 52e6d19Copy full SHA for 52e6d19
avg_checkpoints.py
@@ -17,6 +17,7 @@
17
import glob
18
import hashlib
19
from timm.models import load_state_dict
20
+from timm.models._helpers import _torch_load
21
try:
22
import safetensors.torch
23
_has_safetensors = True
@@ -47,7 +48,7 @@ def checkpoint_metric(checkpoint_path):
47
48
if not checkpoint_path or not os.path.isfile(checkpoint_path):
49
return {}
50
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
- checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
51
+ checkpoint = _torch_load(checkpoint_path, map_location='cpu', weights_only=True)
52
metric = None
53
if 'metric' in checkpoint:
54
metric = checkpoint['metric']
0 commit comments