Releases: jax-ml/jax
JAX v0.9.2
JAX 0.9.2 (March 2, 2026)
- Changes:
- The semi-private type
jax._src.literals.TypedNdArrayis now a subclass of
np.ndarray, rather than a duck type of it. jax.numpy.arangewithstepspecified no longer generates the array
on host. The benefit is more efficient code, though this can lead to less
precise outputs for narrow-width floats (e.g. bfloat16). To recover the
previous behavior in this case, usejnp.array(np.arange(...)).
- The semi-private type
JAX v0.9.1
-
Changes:
- JAX tracers that are not of
Arraytype (e.g., ofReftype) will no
longer report themselves to be instances ofArray. - Using
jax.shard_mapin Explicit mode will raise an error
if the PartitionSpec of input does not match the PartitionSpec specified in
in_specs. In other words, it will act like an assert instead of an
implicit reshard.
in_specsis an optional argument so you can omit specifying it
andshard_mapwill infer thePartitionSpecfrom the argument. If you
want to reshard your inputs, you can usejax.reshardon the arguments and
then pass those args to shard_map.
- JAX tracers that are not of
-
New features:
- Added a debug config
jax_compilation_cache_check_contents. If set, we miss
whenget()is called on a value that has not beenput()by the current
process, even if the value is actually in the disk cache. When a value is
put(), we verify that its contents match.
- Added a debug config
JAX v0.9.0.1
JAX v0.9.0.1 is identical to v0.9.0 with the commits from the following four PRs patched in:
JAX v0.8.3
JAX v0.8.3 is identical to v0.8.2 with the following two bug fixes patched in:
JAX v0.9.0
-
New features:
- Added
jax.thread_guard, a context manager that detects when devices
are used by multiple threads in multi-controller JAX.
- Added
-
Bug fixes:
- Fixed a workspace size calculation error for pivoted QR (
magma_zgeqp3_gpu)
in MAGMA 2.9.0 when usinguse_magma=Trueandpivoting=True.
(#34145).
- Fixed a workspace size calculation error for pivoted QR (
-
Deprecations:
- The flag
jax_collectives_common_channel_idwas removed. - The
jax_pmap_no_rank_reductionconfig state has been removed. The
no-rank-reduction behavior is now the only supported behavior: a
jax.pmapped functionfsees inputs of the same rank as the input to
jax.pmap(f). For example, ifjax.pmap(f)receives shape(8, 128)on
8 devices, thenfreceives shape(1, 128). - Setting the
jax_pmap_shmap_mergeconfig state is deprecated in JAX v0.9.0
and will be removed in JAX v0.10.0. jax.numpy.fixis deprecated, anticipating the deprecation of
numpy.fixin NumPy v2.5.0.jax.numpy.truncis a drop-in
replacement.
- The flag
-
Changes:
jax.exportnow supports explicit sharding. This required a new
export serialization format version that includes the NamedSharding,
including the abstract mesh, and the partition spec. As part of this
change we have added a restriction in the use of exported modules: when
calling them the abstract mesh must match the one used at export time,
including the axis names. Previously, only the number of the devices
mattered.
JAX v0.8.2
-
Deprecations
jax.lax.pvaryhas been deprecated.
Please usejax.lax.pcast(..., to='varying')as the replacement.- Complex arguments passed to
jax.numpy.arangenow result in a
deprecation warning, because the output is poorly-defined. - From
jax.corea number of symbols are newly deprecated including:
call_impl,get_aval,mapped_aval,subjaxprs,set_current_trace,
take_current_trace,traverse_jaxpr_params,unmapped_aval,
AbstractToken, andTraceTag. - All symbols in
jax.interpreters.pxlaare deprecated. These are
primarily JAX internal APIs, and users should not rely on them.
-
Changes:
-
jax's
Tracerno longer inherits fromjax.Arrayat runtime. However,
jax.Arraynow uses a custom metaclass suchisinstance(x, Array)is true
if an objectxrepresents a tracedArray. Only someTracers represent
Arrays, so it is not correct forTracerto inherit fromArray.For the moment, during Python type checking, we continue to declare
Tracer
as a subclass ofArray, however we expect to remove this in a future
release. -
jax.experimental.si_vjphas been deleted.
jax.vjpsubsumes it's functionality.
-
JAX v0.8.1
-
New features:
jax.jitnow supports the decorator factory pattern; i.e instead of
writingyou may write@functools.partial(jax.jit, static_argnames=['n']) def f(x, n): ...
@jax.jit(static_argnames=['n']) def f(x, n): ...
-
Changes:
-
jax.lax.linalg.eighnow accepts animplementationargument to
select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
implementations. TheEighImplementationenum is publicly exported from
jax.lax.linalg. -
jax.lax.linalg.svdnow implements analgorithmthat uses the polar
decomposition on CUDA GPUs. This is also an alias for the existing algorithm
on TPUs.
-
-
Bug fixes:
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
GPU (#33062).
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
-
Deprecations:
jax.sharding.PmapShardingis now deprecated. Please use
jax.NamedShardinginstead.jx.device_put_replicatedis now deprecated. Please usejax.device_put
with the appropriate sharding instead.jax.device_put_shardedis now deprecated. Please usejax.device_putwith
the appropriate sharding instead.- Default
axis_typesofjax.make_meshwill change in JAX v0.9.0 to return
jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
DeprecationWarning. jax.cloud_tpu_initand its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.
JAX v0.8.0
-
Breaking changes:
- JAX is changing the default
jax.pmapimplementation to one implemented in
terms ofjax.jitandjax.shard_map.jax.pmapis in maintenance mode
and we encourage all new code to usejax.shard_mapdirectly. See the
migration guide for
more information. - The
auto=parameter ofjax.experimental.shard_map.shard_maphas been
removed. This means thatjax.experimental.shard_map.shard_mapno longer
supports nesting. If you want to nest shard_map calls, please use
jax.shard_map. - JAX no longer allows passing objects that support
__jax_array__directly
to, e.g.jit-ed functions. Calljax.numpy.asarrayon them first. jax.numpy.covis now returns NaN for empty arrays ({jax-issue}#32305),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).- JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call
.dtypeon these values first. - The deprecated function
jax.interpreters.mlir.custom_callwas
removed. - The
jax.util,jax.extend.ffi, andjax.experimental.host_callback
modules have been removed. All public APIs within these modules were
deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_p
was removed. jax.experimental.multihost_utils.process_allgatherraises an error when
the input is a jax.Array and not fully-addressable andtiled=False. To fix
this, passtiled=Trueto yourprocess_allgatherinvocation.- from
jax.experimental.compilation_cache, the deprecated symbols
is_initializedandinitialize_cachewere removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtype
was removed. jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.- The option
jax_cpu_enable_gloo_collectiveshas been removed. Use
jax_cpu_collectives_implementationinstead. - The previously-deprecated
interpolationargument to
jax.numpy.percentileandjax.numpy.quantilehas been
removed; usemethodinstead. - The JAX-internal
for_loopprimitive was removed. Its functionality,
reading from and writing to refs in the loop body, is now directly
supported byjax.lax.fori_loop. If you need help updating your
code, please file a bug. jax.numpy.trimzerosnow errors for non-1D input.- The
whereargument tojax.numpy.sumand other reductions is now
required to be boolean. Non-boolean values have resulted in a
DeprecationWarningsince JAX v0.5.0. - The deprecated functions in
jax.dlpack,jax.errors,
jax.lib.xla_bridge,jax.lib.xla_client, and
jax.lib.xla_extensionwere removed. jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs to
construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eignow returns a namedtuple (with attributes
eigenvaluesandeigenvectors) instead of a plain tuple.jax.gradandjax.vjpwill now round always primals to
float32iffloat64mode is not enabled.jax.dlpack.from_dlpacknow accepts arrays with non-default layouts,
for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
newimplementationargument tojax.lax.linalg.eig
({jax-issue}#27265). Theuse_magmaargument is now deprecated in favor
ofimplementation. jax.numpy.trim_zerosnow follows NumPy 2.2 in supporting
multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64andjax.experimental.disable_x64
are deprecated in favor of the new non-experimental context manager
jax.enable_x64.jax.experimental.shard_map.shard_mapis deprecated; going forward use
jax.shard_map.jax.experimental.pjit.pjitis deprecated; going forward use
jax.jit.
JAX v0.7.2
-
Breaking changes:
jax.dlpack.from_dlpackno longer accepts a DLPack capsule. This
behavior was deprecated and is now removed. The function must be called
with an array implementing__dlpack__and__dlpack_device__.
-
Changes
-
The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
for NumPy 2.0 support, the minimum supported SciPy version is now 1.13. -
JAX now represents constants in its internal jaxpr representation as a
LiteralArray, which is a private JAX type that duck types as a
numpy.ndarray. This type may be exposed to users viacustom_jvprules,
for example, and may break code that usesisinstance(x, np.ndarray). If
this breaks your code, you may convert these arrays to classic NumPy arrays
usingnp.asarray(x).
-
-
Bug fixes
arr.view(dtype=None)now returns the array unchanged, matching NumPy's
semantics. Previously it returned the array with a float dtype.jax.random.randintnow produces a less-biased distribution for 8-bit and
16-bit integer types ({jax-issue}#27742). To restore the previous biased
behavior, you may temporarily set thejax_safer_randintconfiguration to
False, but note this is a temporary config that will be removed in a
future release.
-
Deprecations:
- The parameters
enable_xlaandnative_serializationforjax2tf.convert
are deprecated and will be removed in a future version of JAX. These were
used for jax2tf with non-native serialization, which has been now removed. - Setting the config state
jax_pmap_no_rank_reductiontoFalseis
deprecated. By default,jax_pmap_no_rank_reductionwill be set toTrue
andjax.pmapshards will not have their rank reduced, keeping the same
rank as their enclosing array.
- The parameters
JAX v0.7.1
-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager.
Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
supported. jax.lax.dotnow implements the general dot product via the optional
dimension_numbersargument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please use
jax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in
aDeprecationWarning, and will result in anImportErrorstarting in JAX
v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,
jax.interpreters.batching, andjax.interpreters.partial_eval; they
are used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.