Skip to content

Commit 021be9d

Browse files
committed
Infer single-animal mode and n_bodyparts from config metadata.
Instead using a single_animal parameter for PytorchRunner, which defaults to True, single_animal mode will be inferred from the models metadata configuration. This is useful for cases when you want to safely leave out the single_animal paramter, e.g. when running a multi-animal model in DeepLabCut-live-GUI, just passing a model configuration suffices.
1 parent a096c94 commit 021be9d

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

dlclive/pose_estimation_pytorch/runner.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataclasses import dataclass
1414
from pathlib import Path
1515
from typing import Literal
16+
import warnings
1617

1718
import numpy as np
1819
import torch
@@ -131,15 +132,25 @@ def __init__(
131132
path: str | Path,
132133
device: str = "auto",
133134
precision: Literal["FP16", "FP32"] = "FP32",
134-
single_animal: bool = True,
135+
single_animal: bool | None = None,
135136
dynamic: dict | dynamic_cropping.DynamicCropper | None = None,
136137
top_down_config: dict | TopDownConfig | None = None,
137138
) -> None:
138139
super().__init__(path)
139140
self.device = _parse_device(device)
140141
self.precision = precision
142+
if single_animal is not None:
143+
warnings.warn(
144+
"The `single_animal` parameter is deprecated and will be removed "
145+
"in a future version. The number of individuals will be automaticalliy inferred "
146+
"from the model configuration. Remove argument `single_animal` or set "
147+
"`single_animal=None` to accept the inferred value and silence this warning.",
148+
DeprecationWarning,
149+
stacklevel=2,
150+
)
141151
self.single_animal = single_animal
142-
152+
self.n_individuals = None
153+
self.n_bodyparts = None
143154
self.cfg = None
144155
self.detector = None
145156
self.model = None
@@ -260,6 +271,15 @@ def load_model(self) -> None:
260271
raw_data = torch.load(self.path, map_location="cpu", weights_only=True)
261272

262273
self.cfg = raw_data["config"]
274+
275+
# Infer single animal mode and n_bodyparts from model configuration
276+
individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1'])
277+
bodyparts = self.cfg.get("metadata", {}).get("bodyparts", [])
278+
self.n_individuals = len(individuals)
279+
self.n_bodyparts = len(bodyparts)
280+
if self.single_animal is None:
281+
self.single_animal = self.n_individuals == 1
282+
263283
self.model = models.PoseModel.build(self.cfg["model"])
264284
self.model.load_state_dict(raw_data["pose"])
265285
self.model = self.model.to(self.device)

0 commit comments

Comments
 (0)