Path: blob/master/guides/ipynb/orbax_checkpoint.ipynb
3281 views
Orbax Checkpointing in Keras
Author: Samaneh Saadat
Date created: 2025/08/20
Last modified: 2025/08/20
Description: A guide on how to save Orbax checkpoints during model training with the JAX backend.
Introduction
Orbax is the default checkpointing library recommended for JAX ecosystem users. It is a high-level checkpointing library which provides functionality for both checkpoint management and composable and extensible serialization. This guide explains how to do Orbax checkpointing when training a model in the JAX backend.
Note that you should use Orbax checkpointing for multi-host training using Keras distribution API as the default Keras checkpointing currently does not support multi-host.
Setup
Let's start by installing Orbax checkpointing library:
We need to set the Keras backend to JAX as this guide is intended for the JAX backend. Then we import Keras and other libraries needed including the Orbax checkpointing library.
Orbax Callback
We need to create two main utilities to manage Orbax checkpointing in Keras:
KerasOrbaxCheckpointManager
: A wrapper aroundorbax.checkpoint.CheckpointManager
for Keras models.KerasOrbaxCheckpointManager
usesModel
'sget_state_tree
andset_state_tree
APIs to save and restore the model variables.OrbaxCheckpointCallback
: A Keras callback that usesKerasOrbaxCheckpointManager
to automatically save and restore model states during training.
Orbax checkpointing in Keras is as simple as copying these utilities to your own codebase and passing OrbaxCheckpointCallback
to the fit
method.
An Orbax checkpointing example
Let's look at how we can use OrbaxCheckpointCallback
to save Orbax checkpoints during the training. To get started, let's define a simple model and a toy training dataset.
Then, we create an Orbax checkpointing callback and pass it to the callbacks
argument in the fit
method.
Now if you look at the Orbax checkpoint directory, you can see all the files saved as part of Orbax checkpointing.