|
13 | 13 | from dataclasses import dataclass |
14 | 14 | from pathlib import Path |
15 | 15 | from typing import Literal |
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import torch |
@@ -131,15 +132,25 @@ def __init__( |
131 | 132 | path: str | Path, |
132 | 133 | device: str = "auto", |
133 | 134 | precision: Literal["FP16", "FP32"] = "FP32", |
134 | | - single_animal: bool = True, |
| 135 | + single_animal: bool | None = None, |
135 | 136 | dynamic: dict | dynamic_cropping.DynamicCropper | None = None, |
136 | 137 | top_down_config: dict | TopDownConfig | None = None, |
137 | 138 | ) -> None: |
138 | 139 | super().__init__(path) |
139 | 140 | self.device = _parse_device(device) |
140 | 141 | self.precision = precision |
| 142 | + if single_animal is not None: |
| 143 | + warnings.warn( |
| 144 | + "The `single_animal` parameter is deprecated and will be removed " |
| 145 | + "in a future version. The number of individuals will be automaticalliy inferred " |
| 146 | + "from the model configuration. Remove argument `single_animal` or set " |
| 147 | + "`single_animal=None` to accept the inferred value and silence this warning.", |
| 148 | + DeprecationWarning, |
| 149 | + stacklevel=2, |
| 150 | + ) |
141 | 151 | self.single_animal = single_animal |
142 | | - |
| 152 | + self.n_individuals = None |
| 153 | + self.n_bodyparts = None |
143 | 154 | self.cfg = None |
144 | 155 | self.detector = None |
145 | 156 | self.model = None |
@@ -260,6 +271,15 @@ def load_model(self) -> None: |
260 | 271 | raw_data = torch.load(self.path, map_location="cpu", weights_only=True) |
261 | 272 |
|
262 | 273 | self.cfg = raw_data["config"] |
| 274 | + |
| 275 | + # Infer single animal mode and n_bodyparts from model configuration |
| 276 | + individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1']) |
| 277 | + bodyparts = self.cfg.get("metadata", {}).get("bodyparts", []) |
| 278 | + self.n_individuals = len(individuals) |
| 279 | + self.n_bodyparts = len(bodyparts) |
| 280 | + if self.single_animal is None: |
| 281 | + self.single_animal = self.n_individuals == 1 |
| 282 | + |
263 | 283 | self.model = models.PoseModel.build(self.cfg["model"]) |
264 | 284 | self.model.load_state_dict(raw_data["pose"]) |
265 | 285 | self.model = self.model.to(self.device) |
|
0 commit comments