Skip to content

Inference Pipeline Refactor for Compile/Export Support #343

@gitttt-1234

Description

@gitttt-1234

Background

The current inference pipeline in sleap-nn has several architectural issues that limit performance and portability:

  • Performance: Inference could be made much faster with torch.compile, ONNX, and TensorRT support, but the current architecture has no compilation or export capabilities
  • Data processing: Preprocessing and postprocessing are tightly coupled to the overall pipeline, with logic broken down across multiple classes (data loading, preprocessing, model forward pass, postprocessing) and mixed across Predictor classes and underlying torch.nn.Modules, making it difficult to optimize the entire pipeline
  • APIs: The interface is not user-friendly. There should be a simple, lower-level API: Load a model → Pass in raw images → Get predictions back
  • Data structure organization: Currently using dictionaries for outputs, which could be made better with results data structures.

This issue builds upon the discussions in #77 and #46.

Proposed Solution

1. Structured Output Data Container

@attrs.define
class Outputs:
    """Outputs data structure for inference.

    This data structure is used to store the outputs of the inference model.
    """

    original_image: Optional[torch.Tensor]
    processed_image: Optional[torch.Tensor]
    pred_raw_keypoints: Optional[torch.Tensor] = None
    pred_keypoints: Optional[torch.Tensor] = None
    pred_confmaps: Optional[torch.Tensor] = None
    pred_pafs: Optional[torch.Tensor] = None
    pred_peak_values: Optional[torch.Tensor] = None
    pred_raw_centroids: Optional[torch.Tensor] = None
    pred_centroids: Optional[torch.Tensor] = None
    device: Optional[torch.device] = None
    frame_idx: Optional[int] = None
    video_idx: Optional[int] = None

    def as_dict(self) -> Dict[str, Optional[torch.Tensor]]:
        """Convert to a dictionary."""
        pass

    def to(self, device: torch.device) -> "Outputs":
        """Move all tensors to the given device."""
        pass

    def numpy(self) -> "Outputs":
        """Move all tensors to cpu and convert to numpy."""
        pass

    def plot(self):
        """Plot the results (overlay ground truth and predictions on the original image)."""
        pass

2. API Architecture

High-Level API: Predictor

  • Load preprocessing configs and models
  • Process data (filtering frames, batching)
  • Create sio.Labels objects from raw outputs
  • Simple predict() function that takes raw data or video/labels objects
  • Each model type would have its own Predictor sub class
@attrs.define
class Predictor:
    """Constructs pipelines (e.g., top-down: centroid + centered-instance)."""

    runners: Union[InferenceRunnerTorch, InferenceRunnerONNX]
    skeletons: List[sio.Skeleton] = attrs.Factory(list)
    videos: List[sio.Video] = attrs.Factory(list)
    batch_size: int = 4

    @classmethod
    def from_model_paths(cls, model_paths: List[Union[str, Path]], **kw) -> "Predictor":
        # Detect model type(s), build appropriate runners (Torch/ONNX), and return a Predictor instance
        raise NotImplementedError

    def make_pipeline(self, source: Union[str, Path, sio.Labels, sio.Video]):
        # Set up the thread-based pipeline to load frames from the source
        raise NotImplementedError

    def _to_labels(self, outputs: List[Outputs]) -> sio.Labels:
        # Convert list of Outputs → sio.Labels
        raise NotImplementedError

    def predict(
        self, source: Union[sio.Labels, sio.Video], as_labels: bool = True
    ) -> Union[sio.Labels, List[Outputs]]:
        if self.pipeline is None:
            self.make_pipeline(source)
        results = []
        for batch in self.pipeline:
            outputs = self.runners.predict(batch)
            results.extend(outputs)  # list[Outputs]
        return results if not as_labels else self._to_labels(results)

Low-Level API: BaseInferenceRunner

  • Framework-agnostic inference execution
  • Handles preprocessing (normalization, resizing, padding) and postprocessing (peak finding, coordinate adjustment)
  • Supports both PyTorch and ONNX backends
@attrs.define
class BaseInferenceRunner:
    # preprocessing
    input_scaling: float = 1.0
    max_height: Optional[int] = None
    max_width: Optional[int] = None
    pad_to_stride: Optional[int] = 32
    ensure_rgb: bool = False
    ensure_grayscale: bool = False
    max_stride: int = 32

    # postprocessing
    peak_threshold: float = 0.2
    refinement: Optional[str] = "integral"
    integral_patch_size: int = 5
    confmap_output_stride: int = 4

    def preprocess_image(
        self,
        imgs: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Preprocess the image(s)."""
        # normalize -> rgb/ grayscale -> size matching (resize + pad) -> pad to max stride
        # if numpy convert to torch tensor
        # if list, concatenate along batch axis
        preprocess_cfg = {
            "scale": self.input_scaling,
            "size_matcher_scale": None,
        }  # parameters used for postprocessing
        return imgs, preprocess_cfg

    def postprocess_results(
        self,
        result: Dict[str, torch.Tensor],
        preprocess_cfg: Dict[str, torch.Tensor],
    ) -> List[Outputs]:
        """Map confmap/peaks back to original image coords."""
        # peak finding, integral refinement, paf grouping, unpad/unscale.
        # return list of Outputs objects
        raise NotImplementedError

    def _predict(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Framework-specific forward (batched). Must be implemented in backends."""
        raise NotImplementedError

    def predict(
        self,
        x: Union[
            torch.Tensor,
            List[torch.Tensor],
            np.ndarray,
            List[np.ndarray],
            Dict[str, torch.Tensor],
        ],
    ) -> List[Outputs]:
        """Predict the keypoints from the image(s)."""
        if isinstance(x, Dict):
            imgs = x["image"]
        imgs, preprocess_cfg = self.preprocess_image(imgs)
        x["image"] = imgs
        result = self._predict(x)
        return self.postprocess_results(result, preprocess_cfg)

PyTorch inference runner

@attrs.define
class InferenceRunnerTorch(BaseInferenceRunner):
    module: nn.Module = attrs.field(repr=False) # also could use torch.compile?
    device: Union[str, torch.device] = "cpu"

    @torch.inference_mode()
    def _predict(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x = x.to(next(self.module.parameters()).device, non_blocking=True)
        y = self.module(x)
        # send to postprocess_results
        return y

ONNX Backend

@attrs.define
class InferenceRunnerONNX(BaseInferenceRunner):
    onnx_session: ort.InferenceSession = attrs.field(repr=False)
    input_name: Optional[str] = None
    providers: Tuple[str, ...] = ("CUDAExecutionProvider", "CPUExecutionProvider")

    def _predict(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x_np = x.detach().cpu().numpy()  # ORT needs numpy
        out_names = [o.name for o in self.onnx_session.get_outputs()]
        outs = self.onnx_session.run(out_names, {self.input_name: x_np})
        # these would be sent to postprocess_results
        return outs

Implementation Plan

PR 1: Core Modules

  • Implement Outputs data structure
  • Implement function to export model to ONNX

PR 2: Implement Inference runners

  • Create BaseInferenceRunner with preprocessing/postprocessing
  • Implement InferenceRunnerTorch
    • Benchmark with torch.compile support
  • Implement InferenceRunnerONNX for ONNX models

PR 3: Refactor peak finding and pafs grouping logic

  • Refactor peak finding and pafs grouping logic

PR 4: Implement Predictor classes

  • Add Predictor classes for each model type

Additional features:

  • Pass in list of videos to generate a single output slp file

Example Usage

# Simple inference
predictor = Predictor.from_model_paths(["model.ckpt"])
outputs = predictor.predict(imgs) # imgs: np.ndarray

# ONNX inference
predictor = SingleInstancePredictor.from_model_paths(
    ["model.onnx"], 
)
outputs = predictor.predict(labels)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions