Path: blob/master/guides/ipynb/define_custom_kernel.ipynb
8071 views
Define a Custom TPU/GPU Kernel
Author: jeffcarp
Date created: 2025/12/18
Last modified: 2025/12/18
Description: Write high-performance custom Keras layers for TPUs and GPUs.
How to Write a Custom TPU or GPU Kernel in Keras
Keras has many pre-made layers to choose from, and the ability to easily create your own if you can't find the exact one you need. However, if you have a need for speed, or otherwise need to customize the exact behavior of your model at the hardware level, you may want to look into writing a custom kernel. A good way to know if you need a custom kernel is to look at the profile of your model and see if there are any idle gaps caused by computation or memory transfer bottlenecks (see the TensorBoard callback for how to get a profile).
This guide will explore how to write a custom kernel and add it to your Keras model. We will utilize Pallas, a library that lets you write kernels in Python that can run on both TPU or GPU, where they're lowered to Mosaic or Triton, respectively. You can learn more in the Pallas docs.
Compatibility note: Pallas is only available when using the JAX backend on certain hardware:
TPU v4 and above
NVIDIA Ampere GPUs (compute capability 8.0) and above
If you're running in Colab, the v5e-1 in the free tier supports running this guide.
First, make sure you're running the latest version of libtpu:
Simple Example
Let's start with the example from the Pallas quickstart: a simple kernel to add two vectors together.
Now jit-compile the Pallas function into a function that can be used by JAX.
Now we can embed the jitted add_vectors function containing the Pallas kernel into a Keras layer, just by calling it there.
That's how to integrate a Pallas kernel into a Keras layer! Now for a more in-depth example.
Writing a Fused Linear Activation Layer
Some common reasons you might want to write a custom kernel is to take advantage of fusion and tiling.
Operator fusion is the process of combining two or more ops into one "fused" op, for example instead of calling keras.ops.matmul then keras.ops.relu sequentially, we could write a custom op that combines both into one more efficient operator. XLA already does operator fusion when possible for certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to write a custom op to specify the fusion exactly.
Tiling is the ability to control how blocks of memory are loaded from the TPU or GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is critical for improving the performance of large matrix multiplications, for example those in the MLP layer at the end of Transformer blocks.
In Pallas, tiling is controlled by the BlockSpec. Learn more in the Pallas BlockSpec guide here.
In this section, we'll take two operations that commonly appear together: a matrix multiplication (like in a Dense layer) and a ReLU activation. We will write a new op that fuses them together for better performance.
Original Unoptimized Implementation
1. Define the Fused Kernel
First we create an inner kernel function that defines the fused computation that combines both matmul (pl.dot) and activation (jnp.maximum).
2. Specify the Tiling (BlockSpec)
Since the input matrices are usually too large to fit into VMEM, Pallas needs ot know how to "slice" them for loading from HBM to VMEM.
We define this using BlockSpec - this tells the hardware: "Take a 128-row chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile of Matrix C."
3. Integrating into a Keras Layer
Now for the final step, call the jit-compiled fused_matmul kernel from a keras.Layer.
4. Benchmarking the Speedup
Why this Works
Memory Bandwidth Efficiency: By fusing the matrix multiplication and activation, we perform the ReLU computation while data is still in the chip's fast VMEM. This drastically reduces expensive read/write roundtrips to HBM.
Automatic Parallelization: Pallas handles the "grid" execution, meaning it automatically parallelizes your defined tiles across the available hardware cores (whether TPU MXUs or GPU Tensor Cores).
Drop-in Inference Speed: This FusedDense kernel can be integrated into any Keras model, giving an example of improving serving/inference performance with minimal code changes.
5. Enabling Training
In order for a Pallas kernel to be trainable, you must also supply a second kernel to define the custom backward pass, since JAX can't AutoGrad through Pallas kernels. Without it, you might see an error like this:
To extend our fused matmul example above:
Followups
In this guide we covered how to define a simple custom Pallas kernel performing vector addition to include in a Keras model. Then we followed up with a more in-depth example of a fused matmul + activation kernel that you might use in a real-world model to improve performance.
Please refer to the Pallas docs for further documentation on writing custom kernels. Additionally to explore more examples of Pallas kernels, including FlashAttention and MoE layers, check out the Tokamax library.