New
JAX v0.7.1
-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager. Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain supported.
jax.lax.dotnow implements the general dot product via the optionaldimension_numbersargument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please usejax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in aDeprecationWarning, and will result in anImportErrorstarting in JAX v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_typearguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,jax.interpreters.batching, andjax.interpreters.partial_eval; they are used rarely if ever outside JAX itself, and most are deprecated without any public replacement.