Orbax Checkpointing in Keras
Author: Samaneh Saadat
Date created: 2025/08/20
Last modified: 2025/08/20
Description: A guide on how to save Orbax checkpoints during model training with the JAX backend.
Introduction
Orbax is the default checkpointing library recommended for JAX ecosystem users. It is a high-level checkpointing library which provides functionality for both checkpoint management and composable and extensible serialization. This guide explains how to do Orbax checkpointing when training a model in the JAX backend.
Note that you should use Orbax checkpointing for multi-host training using Keras distribution API as the default Keras checkpointing currently does not support multi-host.
Setup
Let's start by installing Orbax checkpointing library:
We need to set the Keras backend to JAX as this guide is intended for the JAX backend. Then we import Keras and other libraries needed including the Orbax checkpointing library.
Orbax Callback
We need to create two main utilities to manage Orbax checkpointing in Keras:
KerasOrbaxCheckpointManager
: A wrapper aroundorbax.checkpoint.CheckpointManager
for Keras models.KerasOrbaxCheckpointManager
usesModel
'sget_state_tree
andset_state_tree
APIs to save and restore the model variables.OrbaxCheckpointCallback
: A Keras callback that usesKerasOrbaxCheckpointManager
to automatically save and restore model states during training.
Orbax checkpointing in Keras is as simple as copying these utilities to your own codebase and passing OrbaxCheckpointCallback
to the fit
method.
An Orbax checkpointing example
Let's look at how we can use OrbaxCheckpointCallback
to save Orbax checkpoints during the training. To get started, let's define a simple model and a toy training dataset.
Then, we create an Orbax checkpointing callback and pass it to the callbacks
argument in the fit
method.
Save Orbax checkpoint on_epoch_end. Save Orbax checkpoint on_epoch_end. Save Orbax checkpoint on_epoch_end.
/tmp/ckpt/2: _CHECKPOINT_METADATA default
/tmp/ckpt/2/default: array_metadatas d manifest.ocdbt _METADATA ocdbt.process_0 _sharding
/tmp/ckpt/2/default/array_metadatas: process_0
/tmp/ckpt/2/default/d: 18ec9a2094133d1aa1a3d7513dae3e8d
/tmp/ckpt/2/default/ocdbt.process_0: d manifest.ocdbt
/tmp/ckpt/2/default/ocdbt.process_0/d: 08372fc5734e445753b38235cb522988 c8af54d085d2d516444bd71f32a3787c 4601db15b67650f7c8818bfc8afeb9f5 cfe1e3ea313d637df6f6d2b2c66ca17a a6ca20e04d8fe161ed95f6f71e8fe113