-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Labels
Description
Background
The current inference pipeline in sleap-nn has several architectural issues that limit performance and portability:
- Performance: Inference could be made much faster with
torch.compile, ONNX, and TensorRT support, but the current architecture has no compilation or export capabilities - Data processing: Preprocessing and postprocessing are tightly coupled to the overall pipeline, with logic broken down across multiple classes (data loading, preprocessing, model forward pass, postprocessing) and mixed across
Predictorclasses and underlyingtorch.nn.Modules, making it difficult to optimize the entire pipeline - APIs: The interface is not user-friendly. There should be a simple, lower-level API: Load a model → Pass in raw images → Get predictions back
- Data structure organization: Currently using dictionaries for outputs, which could be made better with results data structures.
This issue builds upon the discussions in #77 and #46.
Proposed Solution
1. Structured Output Data Container
@attrs.define
class Outputs:
"""Outputs data structure for inference.
This data structure is used to store the outputs of the inference model.
"""
original_image: Optional[torch.Tensor]
processed_image: Optional[torch.Tensor]
pred_raw_keypoints: Optional[torch.Tensor] = None
pred_keypoints: Optional[torch.Tensor] = None
pred_confmaps: Optional[torch.Tensor] = None
pred_pafs: Optional[torch.Tensor] = None
pred_peak_values: Optional[torch.Tensor] = None
pred_raw_centroids: Optional[torch.Tensor] = None
pred_centroids: Optional[torch.Tensor] = None
device: Optional[torch.device] = None
frame_idx: Optional[int] = None
video_idx: Optional[int] = None
def as_dict(self) -> Dict[str, Optional[torch.Tensor]]:
"""Convert to a dictionary."""
pass
def to(self, device: torch.device) -> "Outputs":
"""Move all tensors to the given device."""
pass
def numpy(self) -> "Outputs":
"""Move all tensors to cpu and convert to numpy."""
pass
def plot(self):
"""Plot the results (overlay ground truth and predictions on the original image)."""
pass2. API Architecture
High-Level API: Predictor
- Load preprocessing configs and models
- Process data (filtering frames, batching)
- Create
sio.Labelsobjects from raw outputs - Simple
predict()function that takes raw data or video/labels objects - Each model type would have its own Predictor sub class
@attrs.define
class Predictor:
"""Constructs pipelines (e.g., top-down: centroid + centered-instance)."""
runners: Union[InferenceRunnerTorch, InferenceRunnerONNX]
skeletons: List[sio.Skeleton] = attrs.Factory(list)
videos: List[sio.Video] = attrs.Factory(list)
batch_size: int = 4
@classmethod
def from_model_paths(cls, model_paths: List[Union[str, Path]], **kw) -> "Predictor":
# Detect model type(s), build appropriate runners (Torch/ONNX), and return a Predictor instance
raise NotImplementedError
def make_pipeline(self, source: Union[str, Path, sio.Labels, sio.Video]):
# Set up the thread-based pipeline to load frames from the source
raise NotImplementedError
def _to_labels(self, outputs: List[Outputs]) -> sio.Labels:
# Convert list of Outputs → sio.Labels
raise NotImplementedError
def predict(
self, source: Union[sio.Labels, sio.Video], as_labels: bool = True
) -> Union[sio.Labels, List[Outputs]]:
if self.pipeline is None:
self.make_pipeline(source)
results = []
for batch in self.pipeline:
outputs = self.runners.predict(batch)
results.extend(outputs) # list[Outputs]
return results if not as_labels else self._to_labels(results)Low-Level API: BaseInferenceRunner
- Framework-agnostic inference execution
- Handles preprocessing (normalization, resizing, padding) and postprocessing (peak finding, coordinate adjustment)
- Supports both PyTorch and ONNX backends
@attrs.define
class BaseInferenceRunner:
# preprocessing
input_scaling: float = 1.0
max_height: Optional[int] = None
max_width: Optional[int] = None
pad_to_stride: Optional[int] = 32
ensure_rgb: bool = False
ensure_grayscale: bool = False
max_stride: int = 32
# postprocessing
peak_threshold: float = 0.2
refinement: Optional[str] = "integral"
integral_patch_size: int = 5
confmap_output_stride: int = 4
def preprocess_image(
self,
imgs: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Preprocess the image(s)."""
# normalize -> rgb/ grayscale -> size matching (resize + pad) -> pad to max stride
# if numpy convert to torch tensor
# if list, concatenate along batch axis
preprocess_cfg = {
"scale": self.input_scaling,
"size_matcher_scale": None,
} # parameters used for postprocessing
return imgs, preprocess_cfg
def postprocess_results(
self,
result: Dict[str, torch.Tensor],
preprocess_cfg: Dict[str, torch.Tensor],
) -> List[Outputs]:
"""Map confmap/peaks back to original image coords."""
# peak finding, integral refinement, paf grouping, unpad/unscale.
# return list of Outputs objects
raise NotImplementedError
def _predict(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Framework-specific forward (batched). Must be implemented in backends."""
raise NotImplementedError
def predict(
self,
x: Union[
torch.Tensor,
List[torch.Tensor],
np.ndarray,
List[np.ndarray],
Dict[str, torch.Tensor],
],
) -> List[Outputs]:
"""Predict the keypoints from the image(s)."""
if isinstance(x, Dict):
imgs = x["image"]
imgs, preprocess_cfg = self.preprocess_image(imgs)
x["image"] = imgs
result = self._predict(x)
return self.postprocess_results(result, preprocess_cfg)
PyTorch inference runner
@attrs.define
class InferenceRunnerTorch(BaseInferenceRunner):
module: nn.Module = attrs.field(repr=False) # also could use torch.compile?
device: Union[str, torch.device] = "cpu"
@torch.inference_mode()
def _predict(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x = x.to(next(self.module.parameters()).device, non_blocking=True)
y = self.module(x)
# send to postprocess_results
return yONNX Backend
@attrs.define
class InferenceRunnerONNX(BaseInferenceRunner):
onnx_session: ort.InferenceSession = attrs.field(repr=False)
input_name: Optional[str] = None
providers: Tuple[str, ...] = ("CUDAExecutionProvider", "CPUExecutionProvider")
def _predict(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x_np = x.detach().cpu().numpy() # ORT needs numpy
out_names = [o.name for o in self.onnx_session.get_outputs()]
outs = self.onnx_session.run(out_names, {self.input_name: x_np})
# these would be sent to postprocess_results
return outsImplementation Plan
PR 1: Core Modules
- Implement
Outputsdata structure - Implement function to export model to ONNX
PR 2: Implement Inference runners
- Create
BaseInferenceRunnerwith preprocessing/postprocessing - Implement
InferenceRunnerTorch- Benchmark with
torch.compilesupport
- Benchmark with
- Implement
InferenceRunnerONNXfor ONNX models
PR 3: Refactor peak finding and pafs grouping logic
- Refactor peak finding and pafs grouping logic
PR 4: Implement Predictor classes
- Add Predictor classes for each model type
Additional features:
- Pass in list of videos to generate a single output slp file
Example Usage
# Simple inference
predictor = Predictor.from_model_paths(["model.ckpt"])
outputs = predictor.predict(imgs) # imgs: np.ndarray
# ONNX inference
predictor = SingleInstancePredictor.from_model_paths(
["model.onnx"],
)
outputs = predictor.predict(labels)Reactions are currently unavailable