Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/define_custom_kernel.py
7753 views
1
"""
2
Title: Define a Custom TPU/GPU Kernel
3
Author: [jeffcarp](https://www.jeffcarp.com/)
4
Date created: 2025/12/18
5
Last modified: 2025/12/18
6
Description: Write high-performance custom Keras layers for TPUs and GPUs.
7
Accelerator: TPU
8
"""
9
10
"""
11
# How to Write a Custom TPU or GPU Kernel in Keras
12
13
Keras has [many pre-made layers to choose from](/api/layers/), and the
14
ability to easily [create your
15
own](/guides/making_new_layers_and_models_via_subclassing/) if you can't
16
find the exact one you need. However, if you have a need for speed, or otherwise
17
need to customize the exact behavior of your model at the hardware level, you
18
may want to look into writing a custom kernel. A good way to know if you need a
19
custom kernel is to look at the profile of your model and see if there are any
20
idle gaps caused by computation or memory transfer bottlenecks (see the
21
[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile).
22
23
This guide will explore how to write a custom kernel and add it to your
24
Keras model. We will utilize **Pallas**, a library that lets you write
25
kernels in Python that can run on both TPU or GPU, where they're lowered
26
to Mosaic or Triton, respectively. You can learn more in the [Pallas
27
docs](https://docs.jax.dev/en/latest/pallas/index.html).
28
29
**Compatibility note:** Pallas is only available when using the JAX backend on
30
certain hardware:
31
32
- TPU v4 and above
33
- NVIDIA Ampere GPUs (compute capability 8.0) and above
34
35
If you're running in Colab, the v5e-1 in the free tier supports running this
36
guide.
37
38
First, make sure you're running the latest version of `libtpu`:
39
"""
40
41
"""shell
42
pip install --upgrade -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
43
"""
44
45
from functools import partial
46
import os
47
import time
48
49
os.environ["KERAS_BACKEND"] = "jax"
50
51
import jax
52
from jax.experimental import pallas as pl
53
import jax.numpy as jnp
54
import keras
55
56
"""
57
# Simple Example
58
59
Let's start with the example from the [Pallas
60
quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple
61
kernel to add two vectors together.
62
"""
63
64
65
def add_vectors_kernel(x_ref, y_ref, o_ref):
66
"""Pallas kernel for adding two vectors together."""
67
x, y = x_ref[...], y_ref[...]
68
o_ref[...] = x + y
69
70
71
"""
72
Now jit-compile the Pallas function into a function that can be used by JAX.
73
"""
74
75
76
@jax.jit
77
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
78
return pl.pallas_call(
79
add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
80
)(x, y)
81
82
83
add_vectors(jnp.arange(8), jnp.arange(8))
84
85
"""
86
Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a
87
Keras layer, just by calling it there.
88
"""
89
90
91
class PallasAddLayer(keras.Layer):
92
def call(self, x, y):
93
# Reuse the JIT-compiled Pallas function
94
return add_vectors(x, y)
95
96
97
layer = PallasAddLayer()
98
99
x_data = jnp.arange(8, dtype=jnp.int32)
100
y_data = jnp.arange(8, dtype=jnp.int32)
101
102
layer(x_data, y_data)
103
104
"""
105
That's how to integrate a Pallas kernel into a Keras layer! Now for a more
106
in-depth example.
107
"""
108
109
"""
110
# Writing a Fused Linear Activation Layer
111
112
Some common reasons you might want to write a custom kernel is to take advantage of
113
**fusion** and **tiling**.
114
115
**Operator fusion** is the process of combining two or more ops into one "fused" op, for
116
example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we
117
could write a custom op that combines both into one more efficient operator.
118
XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for
119
certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to
120
write a custom op to specify the fusion exactly.
121
122
**Tiling** is the ability to control how blocks of memory are loaded from the TPU or
123
GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip
124
memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation
125
units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is
126
critical for improving the performance of large matrix multiplications, for
127
example those in the MLP layer at the end of Transformer blocks.
128
129
In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the
130
[Pallas BlockSpec guide
131
here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs).
132
133
In this section, we'll take two operations that commonly appear together: a
134
matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will
135
write a new op that fuses them together for better performance.
136
137
## Original Unoptimized Implementation
138
"""
139
140
141
class StandardDenseReLU(keras.layers.Layer):
142
"""Standard Matmul and ReLU implementation using keras.ops."""
143
144
def __init__(self, units, **kwargs):
145
super().__init__(**kwargs)
146
self.units = units
147
148
def build(self, input_shape):
149
self.w = self.add_weight(
150
shape=(input_shape[-1], self.units),
151
initializer="glorot_uniform",
152
trainable=True,
153
)
154
155
def call(self, inputs):
156
# The standard implementation performs two separate operations.
157
# Each one involves expensive data transfer with the main device memory (HBM).
158
# 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM)
159
y = keras.ops.matmul(inputs, self.w)
160
# 2. ReLU: intermediate (HBM) -> compute -> output (HBM)
161
return keras.ops.relu(y)
162
163
164
"""
165
## 1. Define the Fused Kernel
166
167
First we create an inner kernel function that defines the fused computation that
168
combines both matmul (`pl.dot`) and activation (`jnp.maximum`).
169
"""
170
171
import jax.numpy as jnp
172
from jax.experimental import pallas as pl
173
174
175
def matmul_relu_kernel(a_ref, b_ref, c_ref):
176
"""Pallas kernel for fused matmul + ReLU."""
177
# Perform the matrix multiplication on the local tile
178
# pl.dot leverages the hardware's Matrix Unit (MXU)
179
acc = pl.dot(a_ref[...], b_ref[...])
180
181
# Fusion happens here: apply activation while data is in VMEM
182
result = jnp.maximum(acc, 0)
183
184
# Write the final result to the output reference
185
c_ref[...] = result
186
187
188
"""
189
## 2. Specify the Tiling (BlockSpec)
190
191
Since the input matrices are usually too large to fit into VMEM, Pallas needs ot
192
know how to "slice" them for loading from HBM to VMEM.
193
194
We define this using `BlockSpec` - this tells the hardware: "Take a 128-row
195
chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile
196
of Matrix C."
197
"""
198
199
200
@jax.jit
201
def fused_matmul(a, b):
202
m, k = a.shape
203
_, n = b.shape
204
205
# Define tile sizes
206
tile_m, tile_n = 128, 128
207
assert (
208
m % tile_m == 0 and n % tile_n == 0
209
), "Inputs must be multiples of 128 for this demo"
210
211
return pl.pallas_call(
212
matmul_relu_kernel,
213
# Map output indices to input blocks
214
out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
215
in_specs=[
216
# For each output tile, we take a slice of A of shape (tile_m, k)
217
pl.BlockSpec(
218
index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)
219
), # Matrix A
220
# For each output tile, we take a slice of B of shape (k, tile_n)
221
pl.BlockSpec(
222
index_map=lambda i, j: (0, j), block_shape=(k, tile_n)
223
), # Matrix B
224
],
225
out_specs=pl.BlockSpec(
226
index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
227
), # Matrix C
228
grid=(m // tile_m, n // tile_n),
229
)(a, b)
230
231
232
fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256)))
233
234
"""
235
## 3. Integrating into a Keras Layer
236
237
Now for the final step, call the jit-compiled `fused_matmul` kernel from a
238
`keras.Layer`.
239
"""
240
241
242
class FusedDense(keras.layers.Layer):
243
"""Custom Keras layer that applies the fused Dense and ReLU op."""
244
245
def __init__(self, units, **kwargs):
246
super().__init__(**kwargs)
247
self.units = units
248
249
def build(self, input_shape):
250
self.w = self.add_weight(
251
shape=(input_shape[-1], self.units), initializer="glorot_uniform"
252
)
253
254
def call(self, inputs):
255
# Dispatch to our Pallas kernel
256
return fused_matmul(inputs, self.w.value)
257
258
259
FusedDense(256)(jnp.ones((256, 256)))
260
261
"""
262
## 4. Benchmarking the Speedup
263
"""
264
265
# 1. Setup Data
266
N = 8192 # Large enough to be memory bound
267
input_data = jnp.ones((N, N), dtype="float32")
268
269
# Initialize layers
270
standard_layer = StandardDenseReLU(units=N)
271
pallas_layer = FusedDense(units=N)
272
273
# Build layers by calling them once
274
standard_layer(input_data)
275
pallas_layer(input_data)
276
277
278
def benchmark(layer, x, name, iterations=100):
279
# Warm up to ensure JIT compilation is finished
280
for _ in range(10):
281
layer(x).block_until_ready()
282
283
start_time = time.perf_counter()
284
for _ in range(iterations):
285
layer(x).block_until_ready()
286
end_time = time.perf_counter()
287
288
avg_time = (end_time - start_time) / iterations * 1000 # convert to ms
289
print(f"{name} Average Latency: {avg_time:.3f} ms")
290
291
292
# 2. Run Comparison
293
print(f"Benchmarking Matrix Size: {N}x{N}\n" + "-" * 30)
294
benchmark(standard_layer, input_data, "Standard Keras (Matmul + ReLU)")
295
benchmark(pallas_layer, input_data, "Pallas Fused (Matmul + ReLU)")
296
297
298
"""
299
### Why this Works
300
301
**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and
302
activation, we perform the ReLU computation while data is still in the chip's
303
fast VMEM. This drastically reduces expensive read/write roundtrips to HBM.
304
305
**Automatic Parallelization:** Pallas handles the "grid" execution, meaning
306
it automatically parallelizes your defined tiles across the available hardware
307
cores (whether TPU MXUs or GPU Tensor Cores).
308
309
**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any
310
Keras model, giving an example of improving serving/inference performance with
311
minimal code changes.
312
"""
313
314
"""
315
## 5. Enabling Training
316
317
In order for a Pallas kernel to be trainable, you must also supply
318
a second kernel to define the custom backward pass, since JAX can't
319
[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html)
320
through Pallas kernels. Without it, you might see an error like this:
321
322
```
323
model = keras.Sequential([FusedDense(256)])
324
model.compile(optimizer="adam", loss="mse")
325
model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)))
326
>>> Linearization failed to produce known values for all output primals. This is
327
typically caused by attempting to differentiate a function uses an operation
328
that does not support reverse-mode autodiff.
329
```
330
331
To extend our fused matmul example above:
332
"""
333
334
335
# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`.
336
@jax.custom_vjp
337
def fused_matmul_trainable(x, w):
338
return fused_matmul(x, w)
339
340
341
# 2. Define the Forward Pass
342
# It must return the output AND "residuals" (data needed for the backward pass)
343
def fused_matmul_fwd(x, w):
344
y = fused_matmul_trainable(x, w)
345
# We save inputs x, w and output y for the backward calculation
346
return y, (x, w, y)
347
348
349
# 3. Define the Backward Pass
350
# JAX gives us the residuals and the incoming gradient (g)
351
def fused_matmul_bwd(residuals, g):
352
x, w, y = residuals
353
354
# Calculate the gradient of ReLU: 1 if y > 0, else 0
355
# g is the gradient flowing back from the next layer
356
grad_relu = g * (y > 0)
357
358
# Standard backprop math for matmul:
359
# grad_x = grad_relu @ w.T
360
grad_x = jnp.dot(grad_relu, w.T)
361
362
# grad_w = x.T @ grad_relu
363
grad_w = jnp.dot(x.T, grad_relu)
364
365
return grad_x, grad_w
366
367
368
# 4. Register the forward and backward functions
369
fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd)
370
371
372
class FusedDenseTrainable(FusedDense):
373
"""Updated layer that contains Pallas forward and backward pass."""
374
375
def call(self, inputs):
376
# Dispatch to our trainable Pallas kernel
377
return fused_matmul_trainable(inputs, self.w.value)
378
379
380
# Demonstrate trainability on dummy data
381
model = keras.Sequential([FusedDenseTrainable(256)])
382
model.compile(optimizer="adam", loss="mse")
383
model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128)
384
385
"""
386
# Followups
387
388
In this guide we covered how to define a simple custom Pallas kernel performing vector
389
addition to include in a Keras model. Then we followed up with a more in-depth
390
example of a fused matmul + activation kernel that you might use in a real-world
391
model to improve performance.
392
393
Please refer to the [Pallas
394
docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further
395
documentation on writing custom kernels. Additionally to explore more examples
396
of Pallas kernels, including FlashAttention and MoE layers, check out the
397
[Tokamax](https://github.com/openxla/tokamax) library.
398
"""
399
400