New
JAX v0.7.0
-
New features:
- Added
jax.Pwhich is an alias forjax.sharding.PartitionSpec. - Added
jax.tree.reduce_associative.
- Added
-
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
- JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jitnow requiresfunto be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
- Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave been renamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayoutand.layout
jax.experimental.shardmodule has been deleted and all the APIs have been moved to thejax.shardingendpoint. So usejax.sharding.reshard,jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of their experimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated in JAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods were also removed theDeviceobjects.- The
jax.extend.core.primitives.pjit_pprimitive has been renamed tojit_p, and itsnameattribute has changed from"pjit"to"jit". This affects the string representations of jaxprs. The same primitive is no longer exported from thejax.experimental.pjitmodule. - The (undocumented) function
jax.extend.backend.add_clear_backends_callbackhas been removed. Users should usejax.extend.backend.register_backend_cacheinstead.
-
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the newjax.dlpack.is_supported_dtypefunction. jax.scipy.special.sph_harmhas been deprecated following a similar deprecation in SciPy; usejax.scipy.special.sph_harm_yinstead.- From {mod}
jax.interpreters.xla, the previously deprecated symbolsabstractifyandpytype_aval_mappingshave been removed. jax.interpreters.xla.canonicalize_dtypeis deprecated. For canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype. For checking whether an object is a valid jax input, preferjax.core.valid_jaxtype.- From {mod}
jax.core, the previously deprecated symbolsAxisName,ConcretizationTypeError,axis_frame,call_p,closed_call_p,get_type,trace_state_clean,typematch, andtypecheckhave been removed. - From {mod}
jax.lib.xla_client, the previously deprecated symbolsDeviceAssignment,get_topology_for_devices, andmlir_api_versionhave been removed. jax.extend.ffiwas removed after being deprecated in v0.5.0. Use {mod}jax.ffiinstead.jax.lib.xla_bridge.get_compile_optionsis deprecated, and replaced byjax.extend.backend.get_compile_options.
- {obj}