New
JAX v0.8.0
-
Breaking changes:
- JAX is changing the default
jax.pmapimplementation to one implemented in terms ofjax.jitandjax.shard_map.jax.pmapis in maintenance mode and we encourage all new code to usejax.shard_mapdirectly. See the migration guide for more information. - The
auto=parameter ofjax.experimental.shard_map.shard_maphas been removed. This means thatjax.experimental.shard_map.shard_mapno longer supports nesting. If you want to nest shard_map calls, please usejax.shard_map. - JAX no longer allows passing objects that support
__jax_array__directly to, e.g.jit-ed functions. Calljax.numpy.asarrayon them first. jax.numpy.covis now returns NaN for empty arrays ({jax-issue}#32305), and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).- JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call.dtypeon these values first. - The deprecated function
jax.interpreters.mlir.custom_callwas removed. - The
jax.util,jax.extend.ffi, andjax.experimental.host_callbackmodules have been removed. All public APIs within these modules were deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_pwas removed. jax.experimental.multihost_utils.process_allgatherraises an error when the input is a jax.Array and not fully-addressable andtiled=False. To fix this, passtiled=Trueto yourprocess_allgatherinvocation.- from
jax.experimental.compilation_cache, the deprecated symbolsis_initializedandinitialize_cachewere removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtypewas removed. jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.- The option
jax_cpu_enable_gloo_collectiveshas been removed. Usejax_cpu_collectives_implementationinstead. - The previously-deprecated
interpolationargument tojax.numpy.percentileandjax.numpy.quantilehas been removed; usemethodinstead. - The JAX-internal
for_loopprimitive was removed. Its functionality, reading from and writing to refs in the loop body, is now directly supported byjax.lax.fori_loop. If you need help updating your code, please file a bug. jax.numpy.trimzerosnow errors for non-1D input.- The
whereargument tojax.numpy.sumand other reductions is now required to be boolean. Non-boolean values have resulted in aDeprecationWarningsince JAX v0.5.0. - The deprecated functions in
jax.dlpack,jax.errors,jax.lib.xla_bridge,jax.lib.xla_client, andjax.lib.xla_extensionwere removed. jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs to construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eignow returns a namedtuple (with attributeseigenvaluesandeigenvectors) instead of a plain tuple.jax.gradandjax.vjpwill now round always primals tofloat32iffloat64mode is not enabled.jax.dlpack.from_dlpacknow accepts arrays with non-default layouts, for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
new
implementationargument tojax.lax.linalg.eig({jax-issue}#27265). Theuse_magmaargument is now deprecated in favor ofimplementation. jax.numpy.trim_zerosnow follows NumPy 2.2 in supporting multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64andjax.experimental.disable_x64are deprecated in favor of the new non-experimental context managerjax.enable_x64.jax.experimental.shard_map.shard_mapis deprecated; going forward usejax.shard_map.jax.experimental.pjit.pjitis deprecated; going forward usejax.jit.