New
JAX v0.9.0
-
New features:
- Added
jax.thread_guard, a context manager that detects when devices are used by multiple threads in multi-controller JAX.
- Added
-
Bug fixes:
- Fixed a workspace size calculation error for pivoted QR (
magma_zgeqp3_gpu) in MAGMA 2.9.0 when usinguse_magma=Trueandpivoting=True. (#34145).
- Fixed a workspace size calculation error for pivoted QR (
-
Deprecations:
- The flag
jax_collectives_common_channel_idwas removed. - The
jax_pmap_no_rank_reductionconfig state has been removed. The no-rank-reduction behavior is now the only supported behavior: ajax.pmapped functionfsees inputs of the same rank as the input tojax.pmap(f). For example, ifjax.pmap(f)receives shape(8, 128)on 8 devices, thenfreceives shape(1, 128). - Setting the
jax_pmap_shmap_mergeconfig state is deprecated in JAX v0.9.0 and will be removed in JAX v0.10.0. jax.numpy.fixis deprecated, anticipating the deprecation ofnumpy.fixin NumPy v2.5.0.jax.numpy.truncis a drop-in replacement.
- The flag
-
Changes:
jax.exportnow supports explicit sharding. This required a new export serialization format version that includes the NamedSharding, including the abstract mesh, and the partition spec. As part of this change we have added a restriction in the use of exported modules: when calling them the abstract mesh must match the one used at export time, including the axis names. Previously, only the number of the devices mattered.