Skip to content

NovelAI/jax-callback

Repository files navigation

jax-callback

Call Python from inside jax.jit-compiled code via XLA FFI.

What's the point of this?

JAX already comes with a python callback feature, but it breaks jit caching because it saves callback pointers directly into the hlo. This project uses named callbacks, similar to how the xla ffi itself does, so that's not a problem. It also supports executing callbacks on jax:cuda platform, resulting in zero D2H copy overhead (with big caveats -- this may risk deadlock unless you are careful due to jax's internal locking).

What it provides

  • make_callback — wrap a Python callable so it can be invoked from compiled JAX code. Pytree inputs/outputs are flattened automatically and validated against an abstract_eval. Supports jax:cpu, jax:cuda, and numpy platforms.

Example

from jax_callback.callback import make_callback

@make_callback(
    name="my_callback",
    abstract_eval=lambda *args, **kwargs: tuple(args),
    platform="jax:cpu",
)
def my_callback(*arrays, scale=1):
    return tuple(arr * scale for arr in arrays)

See main.py for full usage.

Build

uv sync

Built with scikit-build-core + nanobind; requires JAX 0.8.2 and CUDA 12.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors