Skip to content

nickjanssen544/seqjax

 
 

Repository files navigation

SeqJAX

SeqJAX provides utilities for building sequential probabilistic models with JAX. The library encourages a functional style: models are composed of Prior, Transition and Emission classes which operate on simple dataclasses for particles, observations and parameters. Runtime interface checks ensure that these components implement the correct methods and signatures, reducing boilerplate errors. The three components are grouped together in a SequentialModel definition.

SeqJAX ships with a handful of toy models demonstrating this approach: an AR(1) process, a stochastic volatility model and a multidimensional linear Gaussian state space model.

Installation

SeqJAX requires Python 3.13 or later. Only installation from source is currently available:

pip install git+https://github.com/bayesianshift/seqjax.git

Example

The seqjax.model.ar module contains a small autoregressive example. The snippet below defines a model using these components and simulates a short path of data.

import jax.random as jrandom
from seqjax.model.ar import AR1Target, ARParameters
from seqjax import simulate

# Model parameters and target
parameters = ARParameters(
    ar=0.8,            # autoregressive coefficient
    observation_std=1.0,
    transition_std=0.5,
)
model = AR1Target()

key = jrandom.key(0)

# Simulate a sequence of length 5
latent_path, observation_path = simulate.simulate(
    key, model, condition=None, parameters=parameters, sequence_length=5
)
print(observation_path)

SeqJAX will check at runtime that AR1Target and its components implement the required interface. When extending the library, these checks help catch mistakes such as missing sample or log_prob implementations.

Design

  • A registry converts static config (a dataclass) to objects
  • A run submodule defines an interface between static config and models, building appropriate objects and using them to produce posterior samples
  • The cli submodule is a command line interface for targeting data with inference configured by command line options. It relies on a remmote backend for storing+loading run outputs. The idea is to convert flat strings into config understood by a registry.

About

Sequential models and inference in JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 58.6%
  • Jupyter Notebook 41.4%