Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/int8_quantization_in_keras.py
8261 views
1
"""
2
Title: 8-bit Integer Quantization in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/14
5
Last modified: 2025/10/14
6
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## What is INT8 quantization?
12
13
Quantization lowers the numerical precision of weights and activations to reduce memory use
14
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
15
`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs
16
`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
17
improve throughput and latency. Actual gains depend on your backend and device.
18
19
### How it works
20
21
Quantization maps real values to 8-bit integers with a scale:
22
23
* Integer domain: `[-128, 127]` (256 levels).
24
* For a tensor (often per-output-channel for weights) with values `w`:
25
* Compute `a_max = max(abs(w))`.
26
* Set scale `s = (2 * a_max) / 256`.
27
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.
28
* Inference uses `q` and `s` to reconstruct effective weights on the fly
29
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
30
31
### Benefits
32
33
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
34
reducing the computation time does not reduce their overall runtime. INT8 reduces bytes
35
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
36
this often helps more than increasing raw FLOPs.
37
* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8
38
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
39
boosting throughput on compute-limited layers.
40
* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest
41
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
42
43
### What Keras does in INT8 mode
44
45
* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.
46
* **Weights**: per-output-channel scales to preserve accuracy.
47
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
48
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
49
is rewritten so you can run or save immediately.
50
"""
51
52
"""
53
## Overview
54
55
This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:
56
57
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
58
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
59
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
60
61
## Quantizing a minimal functional model.
62
63
We build a small functional model, capture a baseline output, quantize to INT8 in-place,
64
and then compare outputs with an MSE metric.
65
"""
66
67
import os
68
import numpy as np
69
import keras
70
from keras import layers
71
72
# Create a random number generator.
73
rng = np.random.default_rng()
74
75
# Create a simple functional model.
76
inputs = keras.Input(shape=(10,))
77
x = layers.Dense(32, activation="relu")(inputs)
78
outputs = layers.Dense(1, name="target")(x)
79
model = keras.Model(inputs, outputs)
80
81
# Compile and train briefly to materialize meaningful weights.
82
model.compile(optimizer="adam", loss="mse")
83
x_train = rng.random((256, 10)).astype("float32")
84
y_train = rng.random((256, 1)).astype("float32")
85
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
86
87
# Sample inputs for evaluation.
88
x_eval = rng.random((32, 10)).astype("float32")
89
90
# Baseline (FP) outputs.
91
y_fp32 = model(x_eval)
92
93
# Quantize the model in-place to INT8.
94
model.quantize("int8")
95
96
# INT8 outputs after quantization.
97
y_int8 = model(x_eval)
98
99
# Compute a simple MSE between FP and INT8 outputs.
100
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
101
print("Full-Precision vs INT8 MSE:", float(mse))
102
103
104
"""
105
It is evident that the INT8 quantized model produces outputs close to the original FP32
106
model, as indicated by the low MSE value.
107
108
## Saving and reloading a quantized model
109
110
You can use the standard Keras saving and loading APIs with quantized models. Quantization
111
is preserved when saving to `.keras` and loading back.
112
"""
113
114
# Save the quantized model and reload to verify round-trip.
115
model.save("int8.keras")
116
int8_reloaded = keras.saving.load_model("int8.keras")
117
y_int8_reloaded = int8_reloaded(x_eval)
118
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
119
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
120
121
"""
122
## Quantizing a KerasHub model
123
124
All KerasHub models support the `.quantize(...)` API for post-training quantization,
125
and follow the same workflow as above.
126
127
In this example, we will:
128
129
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
130
preset from KerasHub
131
2. Generate text using both the full-precision and quantized models, and compare outputs.
132
3. Save both models to disk and compute storage savings.
133
4. Reload the INT8 model and verify output consistency with the original quantized model.
134
"""
135
136
from keras_hub.models import Gemma3CausalLM
137
138
# Load from Gemma3 preset
139
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
140
141
# Generate text for a single prompt
142
output = gemma3.generate("Keras is a", max_length=50)
143
print("Full-precision output:", output)
144
145
# Save FP32 Gemma3 model for size comparison.
146
gemma3.save_to_preset("gemma3_fp32")
147
148
# Quantize in-place to INT8 and generate again
149
gemma3.quantize("int8")
150
151
output = gemma3.generate("Keras is a", max_length=50)
152
print("Quantized output:", output)
153
154
# Save INT8 Gemma3 model
155
gemma3.save_to_preset("gemma3_int8")
156
157
# Reload and compare outputs
158
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
159
160
output = gemma3_int8.generate("Keras is a", max_length=50)
161
print("Quantized reloaded output:", output)
162
163
164
# Compute storage savings
165
def bytes_to_mib(n):
166
return n / (1024**2)
167
168
169
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
170
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
171
172
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
173
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
174
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
175
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
176
177
"""
178
## Practical tips
179
180
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
181
after quantizing it to INT8.
182
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
183
* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.
184
* Storage savings are immediate; speedups depend on backend/device kernels.
185
186
## References
187
188
* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)
189
* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)
190
"""
191
192