Path: blob/master/guides/int8_quantization_in_keras.py
8261 views
"""1Title: 8-bit Integer Quantization in Keras2Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)3Date created: 2025/10/144Last modified: 2025/10/145Description: Complete guide to using INT8 quantization in Keras and KerasHub.6Accelerator: GPU7"""89"""10## What is INT8 quantization?1112Quantization lowers the numerical precision of weights and activations to reduce memory use13and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to14`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs15`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also16improve throughput and latency. Actual gains depend on your backend and device.1718### How it works1920Quantization maps real values to 8-bit integers with a scale:2122* Integer domain: `[-128, 127]` (256 levels).23* For a tensor (often per-output-channel for weights) with values `w`:24* Compute `a_max = max(abs(w))`.25* Set scale `s = (2 * a_max) / 256`.26* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.27* Inference uses `q` and `s` to reconstruct effective weights on the fly28(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.2930### Benefits3132* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,33reducing the computation time does not reduce their overall runtime. INT8 reduces bytes34moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;35this often helps more than increasing raw FLOPs.36* Compute bound layers on supported hardware: On NVIDIA GPUs, INT837[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,38boosting throughput on compute-limited layers.39* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest40drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.4142### What Keras does in INT8 mode4344* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.45* **Weights**: per-output-channel scales to preserve accuracy.46* **Activations**: **dynamic AbsMax** scaling computed at runtime.47* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph48is rewritten so you can run or save immediately.49"""5051"""52## Overview5354This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:55561. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)572. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)583. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)5960## Quantizing a minimal functional model.6162We build a small functional model, capture a baseline output, quantize to INT8 in-place,63and then compare outputs with an MSE metric.64"""6566import os67import numpy as np68import keras69from keras import layers7071# Create a random number generator.72rng = np.random.default_rng()7374# Create a simple functional model.75inputs = keras.Input(shape=(10,))76x = layers.Dense(32, activation="relu")(inputs)77outputs = layers.Dense(1, name="target")(x)78model = keras.Model(inputs, outputs)7980# Compile and train briefly to materialize meaningful weights.81model.compile(optimizer="adam", loss="mse")82x_train = rng.random((256, 10)).astype("float32")83y_train = rng.random((256, 1)).astype("float32")84model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)8586# Sample inputs for evaluation.87x_eval = rng.random((32, 10)).astype("float32")8889# Baseline (FP) outputs.90y_fp32 = model(x_eval)9192# Quantize the model in-place to INT8.93model.quantize("int8")9495# INT8 outputs after quantization.96y_int8 = model(x_eval)9798# Compute a simple MSE between FP and INT8 outputs.99mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))100print("Full-Precision vs INT8 MSE:", float(mse))101102103"""104It is evident that the INT8 quantized model produces outputs close to the original FP32105model, as indicated by the low MSE value.106107## Saving and reloading a quantized model108109You can use the standard Keras saving and loading APIs with quantized models. Quantization110is preserved when saving to `.keras` and loading back.111"""112113# Save the quantized model and reload to verify round-trip.114model.save("int8.keras")115int8_reloaded = keras.saving.load_model("int8.keras")116y_int8_reloaded = int8_reloaded(x_eval)117roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))118print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))119120"""121## Quantizing a KerasHub model122123All KerasHub models support the `.quantize(...)` API for post-training quantization,124and follow the same workflow as above.125126In this example, we will:1271281. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)129preset from KerasHub1302. Generate text using both the full-precision and quantized models, and compare outputs.1313. Save both models to disk and compute storage savings.1324. Reload the INT8 model and verify output consistency with the original quantized model.133"""134135from keras_hub.models import Gemma3CausalLM136137# Load from Gemma3 preset138gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")139140# Generate text for a single prompt141output = gemma3.generate("Keras is a", max_length=50)142print("Full-precision output:", output)143144# Save FP32 Gemma3 model for size comparison.145gemma3.save_to_preset("gemma3_fp32")146147# Quantize in-place to INT8 and generate again148gemma3.quantize("int8")149150output = gemma3.generate("Keras is a", max_length=50)151print("Quantized output:", output)152153# Save INT8 Gemma3 model154gemma3.save_to_preset("gemma3_int8")155156# Reload and compare outputs157gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")158159output = gemma3_int8.generate("Keras is a", max_length=50)160print("Quantized reloaded output:", output)161162163# Compute storage savings164def bytes_to_mib(n):165return n / (1024**2)166167168gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")169gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")170171gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))172print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")173print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")174print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")175176"""177## Practical tips178179* Post-training quantization (PTQ) is a one-time operation; you cannot train a model180after quantizing it to INT8.181* Always materialize weights before quantization (e.g., `build()` or a forward pass).182* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.183* Storage savings are immediate; speedups depend on backend/device kernels.184185## References186187* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)188* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)189"""190191192