Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/quantization_overview.py
7893 views
1
"""
2
Title: Quantization in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/09
5
Last modified: 2025/10/09
6
Description: Overview of quantization in Keras (int8, float8, int4, GPTQ).
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Modern large models are often **memory- and bandwidth-bound**: most inference time is spent moving tensors between memory and compute units rather than doing math. Quantization reduces the number of bits used to represent the model's weights and (optionally) activations, which:
14
15
* Shrinks model size and VRAM/RAM footprint.
16
* Increases effective memory bandwidth (fewer bytes per value).
17
* Can improve throughput and sometimes latency on supporting hardware with low-precision kernels.
18
19
Keras provides first-class **post-training quantization (PTQ)** workflows which support pretrained models and expose a uniform API at both the model and layer level.
20
21
At a high level, Keras supports:
22
23
* Joint weight + activation PTQ in `int4`, `int8`, and `float8`.
24
* Weight-only PTQ via **GPTQ** (2/3/4/8-bit) to maximize compression with minimal accuracy impact, especially for large language models (LLMs).
25
26
### Terminology
27
28
* *Scale / zero-point:* Quantization maps real values `x` to integers `q` using a scale (and optionally a zero-point). Symmetric schemes use only a scale.
29
* *Per-channel vs per-tensor:* A separate scale per output channel (e.g., per hidden unit) usually preserves accuracy better than a single scale for the whole tensor.
30
* *Calibration:* A short pass over sample data to estimate activation ranges (e.g., max absolute value).
31
"""
32
33
"""
34
## Quantization Modes
35
36
Keras currently focuses on the following numeric formats. Each mode can be applied selectively to layers or to the whole model via the same API.
37
38
* **`int8` (8-bit integer)**: **joint weight + activation** PTQ.
39
40
* **How it works:** Values are linearly mapped to 8-bit integers with per-channel scales. Activations are calibrated using dynamic quantization (see note below).
41
* **Why use it:** Good accuracy for many architectures; broad hardware support.
42
* **What to expect:** ~4x smaller than FP32 parameters (~2x vs FP16) and lower activation bandwidth, with small accuracy loss on many tasks. Throughput gains depend on kernel availability and memory bandwidth.
43
44
* **`float8` (FP8: E4M3 / E5M2 variants)**: Low-precision floating-point useful for training and inference on FP8-capable hardware.
45
46
* **How it works:** Values are quantized to FP8 with a dynamic scale. Fused FP8 kernels on supported devices yield speedups.
47
* **Why use it:** Mixed-precision training/inference with hardware acceleration while keeping floating-point semantics (since underflow/overflow characteristics differ from int).
48
* **What to expect:** Competitive speed and memory reductions where FP8 kernels are available; accuracy varies by model, but is usually acceptable for most tasks.
49
50
* **`int4`**: Ultra-low-bit **weights** for aggressive compression; activations remain in higher precision (int8).
51
52
* **How it works:** Two signed 4-bit "nibbles" are packed per int8 byte. Keras uses symmetric per-output-channel scales to dequantize efficiently inside matmul.
53
* **Why use it:** Significant VRAM/storage savings for LLMs with acceptable accuracy when combined with robust per-channel scaling.
54
* **What to expect:** ~8x smaller than FP32 (~4x vs FP16) for weights; throughput gains depend on kernel availability and memory bandwidth. Competitive accuracy deltas for encoder-only architectures, may show larger regressions on decoder-only models.
55
56
* **`GPTQ` (weight-only 2/3/4/8 bits)**: *Second-order, post-training* method minimizing layer output error.
57
58
* **How it works (brief):** For each weight block (group), GPTQ solves a local least-squares problem using a Hessian approximation built from a small calibration set, then quantizes to low bit-width. The result is a packed weight tensor plus per-group parameters (e.g., scales).
59
* **Why use it:** Strong accuracy retention at very low bit-widths without retraining; ideal for rapid LLM compression.
60
* **What to expect:** Large storage/VRAM savings with small perplexity/accuracy deltas on many decoder-only models when calibrated on task-relevant samples.
61
62
### Implementation notes
63
64
* **Dynamic activation quantization**: In the `int4`, `int8` PTQ path, activation scales are computed on-the-fly at runtime (per tensor and per batch) using an AbsMax estimator. This avoids maintaining a separate, fixed set of activation scales from a calibration pass and adapts to varying input ranges.
65
* **4-bit packing**: For `int4`, Keras packs signed 4-bit values (range = [-8, 7]) and stores per-channel scales such as `kernel_scale`. Dequantization happens on the fly, and matmuls use 8-bit (unpacked) kernels.
66
* **Calibration Strategy**: Activation scaling for `int4` / `int8` / `float8` uses **AbsMax calibration** by default (range set by the maximum absolute value observed). Alternative calibration methods (e.g., percentile) may be added in future releases.
67
* Per-channel scaling is the default for weights where supported, because it materially improves accuracy at negligible overhead.
68
"""
69
70
"""
71
## Quantizing Keras Models
72
73
Quantization is applied explicitly after layers or models are built. The API is designed to be predictable: you call quantize, the graph is rewritten, the weights are replaced, and you can immediately run inference or save the model.
74
75
Typical workflow:
76
77
1. **Build / load your FP model.** Train if needed. Ensure `build()` or a forward pass has materialized weights.
78
2. **(GPTQ only)** For GPTQ, Keras runs a short calibration pass to collect activation statistics. You will need to provide a small, representative dataset for this purpose.
79
3. **Invoke quantization.** Call `model.quantize("<mode>")` or `layer.quantize("<mode>")` with `"int8"`, `"int4"`, `"float8"`, or `"gptq"` (weight-only).
80
4. **Use or save.** Run inference, or `model.save(...)`. Quantization state (packed weights, scales, metadata) is preserved on save/load.
81
82
### Model Quantization
83
"""
84
85
import keras
86
import numpy as np
87
88
# Create a random number generator.
89
rng = np.random.default_rng()
90
91
# Sample training data.
92
x_train = rng.random((100, 10)).astype("float32")
93
y_train = rng.random((100, 1)).astype("float32")
94
95
96
# Build the model.
97
def get_model():
98
"""
99
Helper to build a simple sequential model.
100
"""
101
return keras.Sequential(
102
[
103
keras.Input(shape=(10,)),
104
keras.layers.Dense(32, activation="relu", name="dense_1"),
105
keras.layers.Dense(1, name="output_head"),
106
]
107
)
108
109
110
# Build the model.
111
model = get_model()
112
113
# Compile and fit the model.
114
model.compile(optimizer="adam", loss="mean_squared_error")
115
model.fit(x_train, y_train, epochs=1, verbose=0)
116
117
# Quantize the model.
118
model.quantize("int8")
119
120
"""
121
**What this does:** Quantizes the weights of the supported layers, and re-wires their forward paths to be compatible with the quantized kernels and quantization scales.
122
123
**Note**: Throughput gains depend on backend/hardware kernels; in cases where kernels fall back to dequantized matmul, you still get memory savings but smaller speedups.
124
125
"""
126
127
"""
128
### Selective Quantization
129
130
You can quantize only a subset of the model's layers by passing a `filters` argument to `quantize()`. This argument can be a single regex string, a list of regex strings, or a callable that takes a layer instance and returns a boolean.
131
132
**Using Regex Filters:**
133
"""
134
135
136
# Quantize only layers with "dense" in the name, but skip "output"
137
model = get_model()
138
model.quantize("int8", filters=["dense", "^((?!output).)*$"])
139
140
"""
141
**Using Callable Filters:**
142
"""
143
144
145
def my_filter(layer):
146
# Only quantize Dense layers that aren't the output
147
return isinstance(layer, keras.layers.Dense) and layer.name != "output_head"
148
149
150
model = get_model()
151
model.quantize("int8", filters=my_filter)
152
153
"""
154
This is particularly useful when you want to avoid quantizing sensitive layers (like the first or last layers of a network) to preserve accuracy.
155
156
### Layer-wise Quantization
157
158
The Keras quantization framework allows you to quantize each layer separately, without having to quantize the entire model using the same unified API.
159
"""
160
161
from keras import layers
162
163
input_shape = (10,)
164
layer = layers.Dense(32, activation="relu")
165
layer.build(input_shape)
166
167
layer.quantize("int4") # Or "int8", "float8", etc.
168
169
"""
170
### When to use layer-wise quantization
171
172
* To keep numerically sensitive blocks (e.g., small residual paths, logits) at higher precision while quantizing large projection layers.
173
* To mix modes (e.g., attention projections in int4, feed-forward in int8) and measure trade-offs incrementally.
174
* Always validate on a small eval set after each step; mixing precisions across residual connections can shift distributions.
175
"""
176
177
"""
178
## Layer & model coverage
179
180
Keras supports the following core layers in its quantization framework:
181
182
* `Dense`
183
* `EinsumDense`
184
* `Embedding`
185
* `ReversibleEmbedding` (available in KerasHub)
186
187
Any composite layers that are built from the above (for example, `MultiHeadAttention`, `GroupedQueryAttention`, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.
188
189
Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.
190
191
## Practical guidance
192
193
* For GPTQ, use a calibration set that matches your inference domain (a few hundred to a few thousand tokens is often enough to see strong retention).
194
* Measure both **VRAM** and **throughput/latency**: memory savings are immediate; speedups depend on the availability of fused low-precision kernels on your device.
195
"""
196
197