Skip to content

Commit 52e6d19

Browse files
committed
Change avg_checkpoints.py to use more secure load helper
1 parent 7a2f49b commit 52e6d19

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

avg_checkpoints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import glob
1818
import hashlib
1919
from timm.models import load_state_dict
20+
from timm.models._helpers import _torch_load
2021
try:
2122
import safetensors.torch
2223
_has_safetensors = True
@@ -47,7 +48,7 @@ def checkpoint_metric(checkpoint_path):
4748
if not checkpoint_path or not os.path.isfile(checkpoint_path):
4849
return {}
4950
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
50-
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
51+
checkpoint = _torch_load(checkpoint_path, map_location='cpu', weights_only=True)
5152
metric = None
5253
if 'metric' in checkpoint:
5354
metric = checkpoint['metric']

0 commit comments

Comments
 (0)