JAX v0.8.2
-
Deprecations
jax.lax.pvaryhas been deprecated. Please usejax.lax.pcast(..., to='varying')as the replacement.- Complex arguments passed to
jax.numpy.arangenow result in a deprecation warning, because the output is poorly-defined. - From
jax.corea number of symbols are newly deprecated including:call_impl,get_aval,mapped_aval,subjaxprs,set_current_trace,take_current_trace,traverse_jaxpr_params,unmapped_aval,AbstractToken, andTraceTag. - All symbols in
jax.interpreters.pxlaare deprecated. These are primarily JAX internal APIs, and users should not rely on them.
-
Changes:
-
jax's
Tracerno longer inherits fromjax.Arrayat runtime. However,jax.Arraynow uses a custom metaclass suchisinstance(x, Array)is true if an objectxrepresents a tracedArray. Only someTracers representArrays, so it is not correct forTracerto inherit fromArray.For the moment, during Python type checking, we continue to declare
Traceras a subclass ofArray, however we expect to remove this in a future release. -
jax.experimental.si_vjphas been deleted.jax.vjpsubsumes it's functionality.
-