Skip to content

Releases: jax-ml/jax

JAX v0.9.2

18 Mar 23:40

Choose a tag to compare

JAX 0.9.2 (March 2, 2026)

  • Changes:
    • The semi-private type jax._src.literals.TypedNdArray is now a subclass of
      np.ndarray, rather than a duck type of it.
    • jax.numpy.arange with step specified no longer generates the array
      on host. The benefit is more efficient code, though this can lead to less
      precise outputs for narrow-width floats (e.g. bfloat16). To recover the
      previous behavior in this case, use jnp.array(np.arange(...)).

JAX v0.9.1

02 Mar 11:13

Choose a tag to compare

  • Changes:

    • JAX tracers that are not of Array type (e.g., of Ref type) will no
      longer report themselves to be instances of Array.
    • Using jax.shard_map in Explicit mode will raise an error
      if the PartitionSpec of input does not match the PartitionSpec specified in
      in_specs. In other words, it will act like an assert instead of an
      implicit reshard.
      in_specs is an optional argument so you can omit specifying it
      and shard_map will infer the PartitionSpec from the argument. If you
      want to reshard your inputs, you can use jax.reshard on the arguments and
      then pass those args to shard_map.
  • New features:

    • Added a debug config jax_compilation_cache_check_contents. If set, we miss
      when get() is called on a value that has not been put() by the current
      process, even if the value is actually in the disk cache. When a value is
      put(), we verify that its contents match.

JAX v0.9.0.1

05 Feb 18:51

Choose a tag to compare

JAX v0.9.0.1 is identical to v0.9.0 with the commits from the following four PRs patched in:

JAX v0.8.3

29 Jan 23:10

Choose a tag to compare

JAX v0.8.3 is identical to v0.8.2 with the following two bug fixes patched in:

JAX v0.9.0

20 Jan 23:23

Choose a tag to compare

  • New features:

    • Added jax.thread_guard, a context manager that detects when devices
      are used by multiple threads in multi-controller JAX.
  • Bug fixes:

    • Fixed a workspace size calculation error for pivoted QR (magma_zgeqp3_gpu)
      in MAGMA 2.9.0 when using use_magma=True and pivoting=True.
      (#34145).
  • Deprecations:

    • The flag jax_collectives_common_channel_id was removed.
    • The jax_pmap_no_rank_reduction config state has been removed. The
      no-rank-reduction behavior is now the only supported behavior: a
      jax.pmapped function f sees inputs of the same rank as the input to
      jax.pmap(f). For example, if jax.pmap(f) receives shape (8, 128) on
      8 devices, then f receives shape (1, 128).
    • Setting the jax_pmap_shmap_merge config state is deprecated in JAX v0.9.0
      and will be removed in JAX v0.10.0.
    • jax.numpy.fix is deprecated, anticipating the deprecation of
      numpy.fix in NumPy v2.5.0. jax.numpy.trunc is a drop-in
      replacement.
  • Changes:

    • jax.export now supports explicit sharding. This required a new
      export serialization format version that includes the NamedSharding,
      including the abstract mesh, and the partition spec. As part of this
      change we have added a restriction in the use of exported modules: when
      calling them the abstract mesh must match the one used at export time,
      including the axis names. Previously, only the number of the devices
      mattered.

JAX v0.8.2

18 Dec 18:50

Choose a tag to compare

  • Deprecations

    • jax.lax.pvary has been deprecated.
      Please use jax.lax.pcast(..., to='varying') as the replacement.
    • Complex arguments passed to jax.numpy.arange now result in a
      deprecation warning, because the output is poorly-defined.
    • From jax.core a number of symbols are newly deprecated including:
      call_impl, get_aval, mapped_aval, subjaxprs, set_current_trace,
      take_current_trace, traverse_jaxpr_params, unmapped_aval,
      AbstractToken, and TraceTag.
    • All symbols in jax.interpreters.pxla are deprecated. These are
      primarily JAX internal APIs, and users should not rely on them.
  • Changes:

    • jax's Tracer no longer inherits from jax.Array at runtime. However,
      jax.Array now uses a custom metaclass such isinstance(x, Array) is true
      if an object x represents a traced Array. Only some Tracers represent
      Arrays, so it is not correct for Tracer to inherit from Array.

      For the moment, during Python type checking, we continue to declare Tracer
      as a subclass of Array, however we expect to remove this in a future
      release.

    • jax.experimental.si_vjp has been deleted.
      jax.vjp subsumes it's functionality.

JAX v0.8.1

18 Nov 18:45

Choose a tag to compare

  • New features:

    • jax.jit now supports the decorator factory pattern; i.e instead of
      writing
      @functools.partial(jax.jit, static_argnames=['n'])
      def f(x, n):
        ...
      you may write
      @jax.jit(static_argnames=['n'])
      def f(x, n):
        ...
  • Changes:

    • jax.lax.linalg.eigh now accepts an implementation argument to
      select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
      implementations. The EighImplementation enum is publicly exported from
      jax.lax.linalg.

    • jax.lax.linalg.svd now implements an algorithm that uses the polar
      decomposition on CUDA GPUs. This is also an alias for the existing algorithm
      on TPUs.

  • Bug fixes:

    • Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
      GPU (#33062).
  • Deprecations:

    • jax.sharding.PmapSharding is now deprecated. Please use
      jax.NamedSharding instead.
    • jx.device_put_replicated is now deprecated. Please use jax.device_put
      with the appropriate sharding instead.
    • jax.device_put_sharded is now deprecated. Please use jax.device_put with
      the appropriate sharding instead.
    • Default axis_types of jax.make_mesh will change in JAX v0.9.0 to return
      jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
      DeprecationWarning.
    • jax.cloud_tpu_init and its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.

JAX v0.8.0

15 Oct 23:38

Choose a tag to compare

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in
      terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode
      and we encourage all new code to use jax.shard_map directly. See the
      migration guide for
      more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been
      removed. This means that jax.experimental.shard_map.shard_map no longer
      supports nesting. If you want to nest shard_map calls, please use
      jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly
      to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}#32305),
      and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).
    • JAX no longer accepts Array values where a dtype value is expected. Call
      .dtype on these values first.
    • The deprecated function jax.interpreters.mlir.custom_call was
      removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback
      modules have been removed. All public APIs within these modules were
      deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol jax.custom_derivatives.custom_jvp_call_jaxpr_p
      was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when
      the input is a jax.Array and not fully-addressable and tiled=False. To fix
      this, pass tiled=True to your process_allgather invocation.
    • from jax.experimental.compilation_cache, the deprecated symbols
      is_initialized and initialize_cache were removed.
    • The deprecated function jax.interpreters.xla.canonicalize_dtype
      was removed.
    • jaxlib.hlo_helpers has been removed. Use jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use
      jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to
      jax.numpy.percentile and jax.numpy.quantile has been
      removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality,
      reading from and writing to refs in the loop body, is now directly
      supported by jax.lax.fori_loop. If you need help updating your
      code, please file a bug.
    • jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to jax.numpy.sum and other reductions is now
      required to be boolean. Non-boolean values have resulted in a
      DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in jax.dlpack, jax.errors,
      jax.lib.xla_bridge, jax.lib.xla_client, and
      jax.lib.xla_extension were removed.
    • jax.interpreters.mlir.dense_bool_array was removed. Use MLIR APIs to
      construct attributes instead.
  • Changes

    • jax.numpy.linalg.eig now returns a namedtuple (with attributes
      eigenvalues and eigenvectors) instead of a plain tuple.
    • jax.grad and jax.vjp will now round always primals to
      float32 if float64 mode is not enabled.
    • jax.dlpack.from_dlpack now accepts arrays with non-default layouts,
      for example, transposed.
    • The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
      cusolver. The magma and LAPACK implementations are still available via the
      new implementation argument to jax.lax.linalg.eig
      ({jax-issue}#27265). The use_magma argument is now deprecated in favor
      of implementation.
    • jax.numpy.trim_zeros now follows NumPy 2.2 in supporting
      multi-dimensional inputs.
  • Deprecations

    • jax.experimental.enable_x64 and jax.experimental.disable_x64
      are deprecated in favor of the new non-experimental context manager
      jax.enable_x64.
    • jax.experimental.shard_map.shard_map is deprecated; going forward use
      jax.shard_map.
    • jax.experimental.pjit.pjit is deprecated; going forward use
      jax.jit.

JAX v0.7.2

16 Sep 17:19

Choose a tag to compare

  • Breaking changes:

    • jax.dlpack.from_dlpack no longer accepts a DLPack capsule. This
      behavior was deprecated and is now removed. The function must be called
      with an array implementing __dlpack__ and __dlpack_device__.
  • Changes

    • The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
      for NumPy 2.0 support, the minimum supported SciPy version is now 1.13.

    • JAX now represents constants in its internal jaxpr representation as a
      LiteralArray, which is a private JAX type that duck types as a
      numpy.ndarray. This type may be exposed to users via custom_jvp rules,
      for example, and may break code that uses isinstance(x, np.ndarray). If
      this breaks your code, you may convert these arrays to classic NumPy arrays
      using np.asarray(x).

  • Bug fixes

    • arr.view(dtype=None) now returns the array unchanged, matching NumPy's
      semantics. Previously it returned the array with a float dtype.
    • jax.random.randint now produces a less-biased distribution for 8-bit and
      16-bit integer types ({jax-issue}#27742). To restore the previous biased
      behavior, you may temporarily set the jax_safer_randint configuration to
      False, but note this is a temporary config that will be removed in a
      future release.
  • Deprecations:

    • The parameters enable_xla and native_serialization for jax2tf.convert
      are deprecated and will be removed in a future version of JAX. These were
      used for jax2tf with non-native serialization, which has been now removed.
    • Setting the config state jax_pmap_no_rank_reduction to False is
      deprecated. By default, jax_pmap_no_rank_reduction will be set to True
      and jax.pmap shards will not have their rank reduced, keeping the same
      rank as their enclosing array.

JAX v0.7.1

20 Aug 16:04

Choose a tag to compare

  • New features

    • JAX now ships Python 3.14 and 3.14t wheels.
    • JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
      offered free-threading builds on Linux.
  • Changes

    • Exposed jax.set_mesh which acts as a global setter and a context manager.
      Removed jax.sharding.use_mesh in favor of jax.set_mesh.
    • JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
      supported.
    • jax.lax.dot now implements the general dot product via the optional
      dimension_numbers argument.
  • Deprecations:

    • jax.lax.zeros_like_array is deprecated. Please use
      jax.numpy.zeros_like instead.
    • Attempting to import jax.experimental.host_callback now results in
      a DeprecationWarning, and will result in an ImportError starting in JAX
      v0.8.0. Its APIs have raised NotImplementedError since JAX version 0.4.35.
    • In jax.lax.dot, passing the precision and preferred_element_type
      arguments by position is deprecated. Pass them by explicit keyword instead.
    • Several dozen internal APIs have been deprecated from jax.interpreters.ad,
      jax.interpreters.batching, and jax.interpreters.partial_eval; they
      are used rarely if ever outside JAX itself, and most are deprecated without any
      public replacement.