v3.13.0
BREAKING changes
Starting with version 3.13.0, Keras now requires Python 3.11 or higher. Please ensure your environment is updated to Python 3.11+ to install the latest version.
Highlights
LiteRT Export
You can now export Keras models directly to the LiteRT format (formerly TensorFlow Lite) for on-device inference. This changes comes with improvements to input signature handling and export utility documentation. The changes ensure that LiteRT export is only available when TensorFlow is installed, update the export API and documentation, and enhance input signature inference for various model types.
Example:
import keras
import numpy as np
# 1. Define a simple model
model = keras.Sequential([
keras.layers.Input(shape=(10,)),
keras.layers.Dense(10, activation="relu"),
keras.layers.Dense(1, activation="sigmoid")
])
# 2. Compile and train (optional, but recommended before export)
model.compile(optimizer="adam", loss="binary_crossentropy")
model.fit(np.random.rand(100, 10), np.random.randint(0, 2, 100), epochs=1)
# 3. Export the model to LiteRT format
model.export("my_model.tflite", format="litert")
print("Model exported successfully to 'my_model.tflite' using LiteRT format.")
GPTQ Quantization
-
Introduced
keras.quantizers.QuantizationConfigAPI that allows for customizable weight and activation quantizers, providing greater flexibility in defining quantization schemes. -
Introduced a new
filtersargument to theModel.quantizemethod, allowing users to specify which layers should be quantized using regex strings, lists of regex strings, or a callable function. This provides fine-grained control over the quantization process. -
Refactored the GPTQ quantization process to remove heuristic-based model structure detection. Instead, the model's quantization structure can now be explicitly provided via
GPTQConfigor by overriding a newModel.get_quantization_layer_structuremethod, enhancing flexibility and robustness for diverse model architectures. -
Core layers such as
Dense,EinsumDense,Embedding, andReversibleEmbeddinghave been updated to accept and utilize the newQuantizationConfigobject, enabling fine-grained control over their quantization behavior. -
Added a new method
get_quantization_layer_structureto the Model class, intended for model authors to define the topology required for structure-aware quantization modes like GPTQ. -
Introduced a new utility function
should_quantize_layerto centralize the logic for determining if a layer should be quantized based on the provided filters. -
Enabled the serialization and deserialization of
QuantizationConfigobjects within Keras layers, allowing quantized models to be saved and loaded correctly. -
Modified the
AbsMaxQuantizerto allow specifying the quantization axis dynamically during the__call__method, rather than strictly defining it at initialization.
Example:
- Default Quantization (Int8)
Applies the default
AbsMaxQuantizerto both weights and activations.
model.quantize("int8")
- Weight-Only Quantization (Int8)
Disable activation quantization by setting the activation quantizer to
None.
from keras.quantizers import Int8QuantizationConfig, AbsMaxQuantizer
config = Int8QuantizationConfig(
weight_quantizer=AbsMaxQuantizer(axis=0),
activation_quantizer=None
)
model.quantize(config=config)
- Custom Quantization Parameters Customize the value range or other parameters for specific quantizers.
config = Int8QuantizationConfig(
# Restrict range for symmetric quantization
weight_quantizer=AbsMaxQuantizer(axis=0, value_range=(-127, 127)),
activation_quantizer=AbsMaxQuantizer(axis=-1, value_range=(-127, 127))
)
model.quantize(config=config)
Adaptive Pooling layers
Added adaptive pooling operations keras.ops.nn.adaptive_average_pool and keras.ops.nn.adaptive_max_pool for 1D, 2D, and 3D inputs. These operations transform inputs of varying spatial dimensions into a fixed target shape defined by output_size by dynamically inferring the required kernel size and stride. Added corresponding layers:
keras.layers.AdaptiveAveragePooling1Dkeras.layers.AdaptiveAveragePooling2Dkeras.layers.AdaptiveAveragePooling3Dkeras.layers.AdaptiveMaxPooling1Dkeras.layers.AdaptiveMaxPooling2Dkeras.layers.AdaptiveMaxPooling3D
New features
- Add
keras.ops.numpy.array_splitop a fundamental building block for tensor parallelism. - Add
keras.ops.numpy.empty_likeop. - Add
keras.ops.numpy.ldexpop. - Add
keras.ops.numpy.vanderop which constructs a Vandermonde matrix from a 1-D input tensor. - Add
keras.distribution.get_device_countutility function for distribution API. keras.layers.JaxLayerandkeras.layers.FlaxLayernow support the TensorFlow backend in addition to the JAX backed. This allows you to embedflax.linen.Moduleinstances or JAX functions in your model. The TensorFlow support is based onjax2tf.
OpenVINO Backend Support:
- Added
numpy.digitizesupport. - Added
numpy.diagsupport. - Added
numpy.isinsupport. - Added
numpy.vdotsupport. - Added
numpy.floor_dividesupport. - Added
numpy.rollsupport. - Added
numpy.multi_hotsupport. - Added
numpy.psnrsupport. - Added
numpy.empty_likesupport.
Bug fixes and Improvements
- NNX Support: Improved compatibility and fixed tests for the NNX library (JAX), ensuring better stability for NNX-based Keras models.
- MultiHeadAttention: Fixed negative index handling in
attention_axesforMultiHeadAttentionlayers. - Softmax: The update on
Softmaxmask handling, aimed at improving numerical robustness, was based on a deep investigation led by Jaswanth Sreeram, who prototyped the solution with contributions from others. - PyDataset Support: The
Normalizationlayer'sadaptmethod now supportsPyDatasetobjects, allowing for proper adaptation when using this data type.
TPU Test setup
Configured the TPU testing infrastructure to enforce unit test coverage across the entire codebase. This ensures that both existing logic and all future contributions are validated for functionality and correctness within the TPU environment.
New Contributors
- @mattjj made their first contribution in https://github.com/keras-team/keras/pull/21776
- @GaetanLepage made their first contribution in https://github.com/keras-team/keras/pull/21790
- @MalyalaKarthik66 made their first contribution in https://github.com/keras-team/keras/pull/21733
- @Mithil27360 made their first contribution in https://github.com/keras-team/keras/pull/21816
- @erahulkulkarni made their first contribution in https://github.com/keras-team/keras/pull/21812
- @yashwantbezawada made their first contribution in https://github.com/keras-team/keras/pull/21827
- @Abhinavexists made their first contribution in https://github.com/keras-team/keras/pull/21819
- @khanhkhanhlele made their first contribution in https://github.com/keras-team/keras/pull/21830
- @kharshith-k made their first contribution in https://github.com/keras-team/keras/pull/21425
- @SamareshSingh made their first contribution in https://github.com/keras-team/keras/pull/21834
- @yangliyl-g made their first contribution in https://github.com/keras-team/keras/pull/21850
- @ssam18 made their first contribution in https://github.com/keras-team/keras/pull/21855
- @kathyfan made their first contribution in https://github.com/keras-team/keras/pull/21805
- @almilosz made their first contribution in https://github.com/keras-team/keras/pull/21826
- @RohitYandigeri made their first contribution in https://github.com/keras-team/keras/pull/21902
Full Changelog: https://github.com/keras-team/keras/compare/v3.12.0...v3.13.0