Call Python from inside jax.jit-compiled code via XLA FFI.
JAX already comes with a python callback feature, but it breaks jit caching because it saves callback pointers directly into the hlo. This project uses named callbacks, similar to how the xla ffi itself does, so that's not a problem. It also supports executing callbacks on jax:cuda platform, resulting in zero D2H copy overhead (with big caveats -- this may risk deadlock unless you are careful due to jax's internal locking).
make_callback— wrap a Python callable so it can be invoked from compiled JAX code. Pytree inputs/outputs are flattened automatically and validated against anabstract_eval. Supportsjax:cpu,jax:cuda, andnumpyplatforms.
from jax_callback.callback import make_callback
@make_callback(
name="my_callback",
abstract_eval=lambda *args, **kwargs: tuple(args),
platform="jax:cpu",
)
def my_callback(*arrays, scale=1):
return tuple(arr * scale for arr in arrays)See main.py for full usage.
uv sync
Built with scikit-build-core + nanobind; requires JAX 0.8.2 and CUDA 12.