Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/orbax_checkpoint.py
3273 views
1
"""
2
Title: Orbax Checkpointing in Keras
3
Author: [Samaneh Saadat](https://github.com/SamanehSaadat/)
4
Date created: 2025/08/20
5
Last modified: 2025/08/20
6
Description: A guide on how to save Orbax checkpoints during model training with the JAX backend.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Orbax is the default checkpointing library recommended for JAX ecosystem
14
users. It is a high-level checkpointing library which provides functionality
15
for both checkpoint management and composable and extensible serialization.
16
This guide explains how to do Orbax checkpointing when training a model in
17
the JAX backend.
18
19
Note that you should use Orbax checkpointing for multi-host training using
20
Keras distribution API as the default Keras checkpointing currently does not
21
support multi-host.
22
"""
23
24
"""
25
## Setup
26
27
Let's start by installing Orbax checkpointing library:
28
"""
29
30
"""shell
31
pip install -q -U orbax-checkpoint
32
"""
33
34
"""
35
We need to set the Keras backend to JAX as this guide is intended for the
36
JAX backend. Then we import Keras and other libraries needed including the
37
Orbax checkpointing library.
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "jax"
43
44
import keras
45
import numpy as np
46
import orbax.checkpoint as ocp
47
48
"""
49
## Orbax Callback
50
51
We need to create two main utilities to manage Orbax checkpointing in Keras:
52
53
1. `KerasOrbaxCheckpointManager`: A wrapper around
54
`orbax.checkpoint.CheckpointManager` for Keras models.
55
`KerasOrbaxCheckpointManager` uses `Model`'s `get_state_tree` and
56
`set_state_tree` APIs to save and restore the model variables.
57
2. `OrbaxCheckpointCallback`: A Keras callback that uses
58
`KerasOrbaxCheckpointManager` to automatically save and restore model states
59
during training.
60
61
Orbax checkpointing in Keras is as simple as copying these utilities to your
62
own codebase and passing `OrbaxCheckpointCallback` to the `fit` method.
63
"""
64
65
66
class KerasOrbaxCheckpointManager(ocp.CheckpointManager):
67
"""A wrapper over Orbax CheckpointManager for Keras with the JAX
68
backend."""
69
70
def __init__(
71
self,
72
model,
73
checkpoint_dir,
74
max_to_keep=5,
75
steps_per_epoch=1,
76
**kwargs,
77
):
78
"""Initialize the Keras Orbax Checkpoint Manager.
79
80
Args:
81
model: The Keras model to checkpoint.
82
checkpoint_dir: Directory path where checkpoints will be saved.
83
max_to_keep: Maximum number of checkpoints to keep in the directory.
84
Default is 5.
85
steps_per_epoch: Number of steps per epoch. Default is 1.
86
**kwargs: Additional keyword arguments to pass to Orbax's
87
CheckpointManagerOptions.
88
"""
89
options = ocp.CheckpointManagerOptions(
90
max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
91
)
92
self._model = model
93
self._steps_per_epoch = steps_per_epoch
94
self._checkpoint_dir = checkpoint_dir
95
super().__init__(checkpoint_dir, options=options)
96
97
def _get_state(self):
98
"""Gets the model state and metrics.
99
100
This method retrieves the complete state tree from the model and separates
101
the metrics variables from the rest of the state.
102
103
Returns:
104
A tuple containing:
105
- state: A dictionary containing the model's state (weights, optimizer state, etc.)
106
- metrics: The model's metrics variables, if any
107
"""
108
state = self._model.get_state_tree().copy()
109
metrics = state.pop("metrics_variables", None)
110
return state, metrics
111
112
def save_state(self, epoch):
113
"""Saves the model to the checkpoint directory.
114
115
Args:
116
epoch: The epoch number at which the state is saved.
117
"""
118
state, metrics_value = self._get_state()
119
self.save(
120
epoch * self._steps_per_epoch,
121
args=ocp.args.StandardSave(item=state),
122
metrics=metrics_value,
123
)
124
125
def restore_state(self, step=None):
126
"""Restores the model from the checkpoint directory.
127
128
Args:
129
step: The step number to restore the state from. Default=None
130
restores the latest step.
131
"""
132
step = step or self.latest_step()
133
if step is None:
134
return
135
# Restore the model state only, not metrics.
136
state, _ = self._get_state()
137
restored_state = self.restore(step, args=ocp.args.StandardRestore(item=state))
138
self._model.set_state_tree(restored_state)
139
140
141
class OrbaxCheckpointCallback(keras.callbacks.Callback):
142
"""A callback for checkpointing and restoring state using Orbax."""
143
144
def __init__(
145
self,
146
model,
147
checkpoint_dir,
148
max_to_keep=5,
149
steps_per_epoch=1,
150
**kwargs,
151
):
152
"""Initialize the Orbax checkpoint callback.
153
154
Args:
155
model: The Keras model to checkpoint.
156
checkpoint_dir: Directory path where checkpoints will be saved.
157
max_to_keep: Maximum number of checkpoints to keep in the directory.
158
Default is 5.
159
steps_per_epoch: Number of steps per epoch. Default is 1.
160
**kwargs: Additional keyword arguments to pass to Orbax's
161
CheckpointManagerOptions.
162
"""
163
if keras.config.backend() != "jax":
164
raise ValueError(
165
f"`OrbaxCheckpointCallback` is only supported on a "
166
f"`jax` backend. Provided backend is {keras.config.backend()}."
167
)
168
self._checkpoint_manager = KerasOrbaxCheckpointManager(
169
model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
170
)
171
172
def on_train_begin(self, logs=None):
173
if not self.model.built or not self.model.optimizer.built:
174
raise ValueError(
175
"To use `OrbaxCheckpointCallback`, your model and "
176
"optimizer must be built before you call `fit()`."
177
)
178
latest_epoch = self._checkpoint_manager.latest_step()
179
if latest_epoch is not None:
180
print("Load Orbax checkpoint on_train_begin.")
181
self._checkpoint_manager.restore_state(step=latest_epoch)
182
183
def on_epoch_end(self, epoch, logs=None):
184
print("Save Orbax checkpoint on_epoch_end.")
185
self._checkpoint_manager.save_state(epoch)
186
187
188
"""
189
## An Orbax checkpointing example
190
191
Let's look at how we can use `OrbaxCheckpointCallback` to save Orbax
192
checkpoints during the training. To get started, let's define a simple model
193
and a toy training dataset.
194
"""
195
196
197
def get_model():
198
# Create a simple model.
199
inputs = keras.Input(shape=(32,))
200
outputs = keras.layers.Dense(1, name="dense")(inputs)
201
model = keras.Model(inputs, outputs)
202
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
203
return model
204
205
206
model = get_model()
207
208
x_train = np.random.random((128, 32))
209
y_train = np.random.random((128, 1))
210
211
"""
212
Then, we create an Orbax checkpointing callback and pass it to the
213
`callbacks` argument in the `fit` method.
214
"""
215
216
orbax_callback = OrbaxCheckpointCallback(
217
model,
218
checkpoint_dir="/tmp/ckpt",
219
max_to_keep=1,
220
steps_per_epoch=1,
221
)
222
history = model.fit(
223
x_train,
224
y_train,
225
batch_size=32,
226
epochs=3,
227
verbose=0,
228
validation_split=0.2,
229
callbacks=[orbax_callback],
230
)
231
232
"""
233
Now if you look at the Orbax checkpoint directory, you can see all the files
234
saved as part of Orbax checkpointing.
235
"""
236
237
"""shell
238
ls -R /tmp/ckpt
239
"""
240
241