Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/writing_a_custom_training_loop_in_jax.py
3273 views
1
"""
2
Title: Writing a training loop from scratch in JAX
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/25
5
Last modified: 2023/06/25
6
Description: Writing low-level training & evaluation loops in JAX.
7
Accelerator: None
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import os
15
16
# This guide can only be run with the jax backend.
17
os.environ["KERAS_BACKEND"] = "jax"
18
19
import jax
20
21
# We import TF so we can use tf.data.
22
import tensorflow as tf
23
import keras
24
import numpy as np
25
26
"""
27
## Introduction
28
29
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
30
Their usage is covered in the guide
31
[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).
32
33
If you want to customize the learning algorithm of your model while still leveraging
34
the convenience of `fit()`
35
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
36
implement your own `train_step()` method, which
37
is called repeatedly during `fit()`.
38
39
Now, if you want very low-level control over training & evaluation, you should write
40
your own training & evaluation loops from scratch. This is what this guide is about.
41
"""
42
43
"""
44
## A first end-to-end example
45
46
To write a custom training loop, we need the following ingredients:
47
48
- A model to train, of course.
49
- An optimizer. You could either use an optimizer from `keras.optimizers`, or
50
one from the `optax` package.
51
- A loss function.
52
- A dataset. The standard in the JAX ecosystem is to load data via `tf.data`,
53
so that's what we'll use.
54
55
Let's line them up.
56
57
First, let's get the model and the MNIST dataset:
58
"""
59
60
61
def get_model():
62
inputs = keras.Input(shape=(784,), name="digits")
63
x1 = keras.layers.Dense(64, activation="relu")(inputs)
64
x2 = keras.layers.Dense(64, activation="relu")(x1)
65
outputs = keras.layers.Dense(10, name="predictions")(x2)
66
model = keras.Model(inputs=inputs, outputs=outputs)
67
return model
68
69
70
model = get_model()
71
72
# Prepare the training dataset.
73
batch_size = 32
74
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
75
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
76
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
77
y_train = keras.utils.to_categorical(y_train)
78
y_test = keras.utils.to_categorical(y_test)
79
80
# Reserve 10,000 samples for validation.
81
x_val = x_train[-10000:]
82
y_val = y_train[-10000:]
83
x_train = x_train[:-10000]
84
y_train = y_train[:-10000]
85
86
# Prepare the training dataset.
87
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
88
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
89
90
# Prepare the validation dataset.
91
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
92
val_dataset = val_dataset.batch(batch_size)
93
94
"""
95
Next, here's the loss function and the optimizer.
96
We'll use a Keras optimizer in this case.
97
"""
98
99
# Instantiate a loss function.
100
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
101
102
# Instantiate an optimizer.
103
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
104
105
"""
106
### Getting gradients in JAX
107
108
Let's train our model using mini-batch gradient with a custom training loop.
109
110
In JAX, gradients are computed via *metaprogramming*: you call the `jax.grad` (or
111
`jax.value_and_grad` on a function in order to create a gradient-computing function
112
for that first function.
113
114
So the first thing we need is a function that returns the loss value.
115
That's the function we'll use to generate the gradient function. Something like this:
116
117
```python
118
def compute_loss(x, y):
119
...
120
return loss
121
```
122
123
Once you have such a function, you can compute gradients via metaprogramming as such:
124
125
```python
126
grad_fn = jax.grad(compute_loss)
127
grads = grad_fn(x, y)
128
```
129
130
Typically, you don't just want to get the gradient values, you also want to get
131
the loss value. You can do this by using `jax.value_and_grad` instead of `jax.grad`:
132
133
```python
134
grad_fn = jax.value_and_grad(compute_loss)
135
loss, grads = grad_fn(x, y)
136
```
137
138
### JAX computation is purely stateless
139
140
In JAX, everything must be a stateless function -- so our loss computation function
141
must be stateless as well. That means that all Keras variables (e.g. weight tensors)
142
must be passed as function inputs, and any variable that has been updated during the
143
forward pass must be returned as function output. The function have no side effect.
144
145
During the forward pass, the non-trainable variables of a Keras model might get
146
updated. These variables could be, for instance, RNG seed state variables or
147
BatchNormalization statistics. We're going to need to return those. So we need
148
something like this:
149
150
```python
151
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
152
...
153
return loss, non_trainable_variables
154
```
155
156
Once you have such a function, you can get the gradient function by
157
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
158
computation function returns more outputs than just the loss. Note that the loss
159
should always be the first output.
160
161
```python
162
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
163
(loss, non_trainable_variables), grads = grad_fn(
164
trainable_variables, non_trainable_variables, x, y
165
)
166
```
167
168
Now that we have established the basics,
169
let's implement this `compute_loss_and_updates` function.
170
Keras models have a `stateless_call` method which will come in handy here.
171
It works just like `model.__call__`, but it requires you to explicitly
172
pass the value of all the variables in the model, and it returns not just
173
the `__call__` outputs but also the (potentially updated) non-trainable
174
variables.
175
"""
176
177
178
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
179
y_pred, non_trainable_variables = model.stateless_call(
180
trainable_variables, non_trainable_variables, x, training=True
181
)
182
loss = loss_fn(y, y_pred)
183
return loss, non_trainable_variables
184
185
186
"""
187
Let's get the gradient function:
188
"""
189
190
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
191
192
"""
193
### The training step function
194
195
Next, let's implement the end-to-end training step, the function
196
that will both run the forward pass, compute the loss, compute the gradients,
197
but also use the optimizer to update the trainable variables. This function
198
also needs to be stateless, so it will get as input a `state` tuple that
199
includes every state element we're going to use:
200
201
- `trainable_variables` and `non_trainable_variables`: the model's variables.
202
- `optimizer_variables`: the optimizer's state variables,
203
such as momentum accumulators.
204
205
To update the trainable variables, we use the optimizer's stateless method
206
`stateless_apply`. It's equivalent to `optimizer.apply()`, but it requires
207
always passing `trainable_variables` and `optimizer_variables`. It returns
208
both the updated trainable variables and the updated optimizer_variables.
209
"""
210
211
212
def train_step(state, data):
213
trainable_variables, non_trainable_variables, optimizer_variables = state
214
x, y = data
215
(loss, non_trainable_variables), grads = grad_fn(
216
trainable_variables, non_trainable_variables, x, y
217
)
218
trainable_variables, optimizer_variables = optimizer.stateless_apply(
219
optimizer_variables, grads, trainable_variables
220
)
221
# Return updated state
222
return loss, (
223
trainable_variables,
224
non_trainable_variables,
225
optimizer_variables,
226
)
227
228
229
"""
230
### Make it fast with `jax.jit`
231
232
By default, JAX operations run eagerly,
233
just like in TensorFlow eager mode and PyTorch eager mode.
234
And just like TensorFlow eager mode and PyTorch eager mode, it's pretty slow
235
-- eager mode is better used as a debugging environment, not as a way to do
236
any actual work. So let's make our `train_step` fast by compiling it.
237
238
When you have a stateless JAX function, you can compile it to XLA via the
239
`@jax.jit` decorator. It will get traced during its first execution, and in
240
subsequent executions you will be executing the traced graph (this is just
241
like `@tf.function(jit_compile=True)`. Let's try it:
242
"""
243
244
245
@jax.jit
246
def train_step(state, data):
247
trainable_variables, non_trainable_variables, optimizer_variables = state
248
x, y = data
249
(loss, non_trainable_variables), grads = grad_fn(
250
trainable_variables, non_trainable_variables, x, y
251
)
252
trainable_variables, optimizer_variables = optimizer.stateless_apply(
253
optimizer_variables, grads, trainable_variables
254
)
255
# Return updated state
256
return loss, (
257
trainable_variables,
258
non_trainable_variables,
259
optimizer_variables,
260
)
261
262
263
"""
264
We're now ready to train our model. The training loop itself
265
is trivial: we just repeatedly call `loss, state = train_step(state, data)`.
266
267
Note:
268
269
- We convert the TF tensors yielded by the `tf.data.Dataset` to NumPy
270
before passing them to our JAX function.
271
- All variables must be built beforehand:
272
the model must be built and the optimizer must be built. Since we're using a
273
Functional API model, it's already built, but if it were a subclassed model
274
you'd need to call it on a batch of data to build it.
275
"""
276
277
# Build optimizer variables.
278
optimizer.build(model.trainable_variables)
279
280
trainable_variables = model.trainable_variables
281
non_trainable_variables = model.non_trainable_variables
282
optimizer_variables = optimizer.variables
283
state = trainable_variables, non_trainable_variables, optimizer_variables
284
285
# Training loop
286
for step, data in enumerate(train_dataset):
287
data = (data[0].numpy(), data[1].numpy())
288
loss, state = train_step(state, data)
289
# Log every 100 batches.
290
if step % 100 == 0:
291
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
292
print(f"Seen so far: {(step + 1) * batch_size} samples")
293
294
"""
295
A key thing to notice here is that the loop is entirely stateless -- the variables
296
attached to the model (`model.weights`) are never getting updated during the loop.
297
Their new values are only stored in the `state` tuple. That means that at some point,
298
before saving the model, you should be attaching the new variable values back to the model.
299
300
Just call `variable.assign(new_value)` on each model variable you want to update:
301
"""
302
303
trainable_variables, non_trainable_variables, optimizer_variables = state
304
for variable, value in zip(model.trainable_variables, trainable_variables):
305
variable.assign(value)
306
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
307
variable.assign(value)
308
309
"""
310
## Low-level handling of metrics
311
312
Let's add metrics monitoring to this basic training loop.
313
314
You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training
315
loops written from scratch. Here's the flow:
316
317
- Instantiate the metric at the start of the loop
318
- Include `metric_variables` in the `train_step` arguments
319
and `compute_loss_and_updates` arguments.
320
- Call `metric.stateless_update_state()` in the `compute_loss_and_updates` function.
321
It's equivalent to `update_state()` -- only stateless.
322
- When you need to display the current value of the metric, outside the `train_step`
323
(in the eager scope), attach the new metric variable values to the metric object
324
and vall `metric.result()`.
325
- Call `metric.reset_state()` when you need to clear the state of the metric
326
(typically at the end of an epoch)
327
328
Let's use this knowledge to compute `CategoricalAccuracy` on training and
329
validation data at the end of training:
330
"""
331
332
# Get a fresh model
333
model = get_model()
334
335
# Instantiate an optimizer to train the model.
336
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
337
# Instantiate a loss function.
338
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
339
340
# Prepare the metrics.
341
train_acc_metric = keras.metrics.CategoricalAccuracy()
342
val_acc_metric = keras.metrics.CategoricalAccuracy()
343
344
345
def compute_loss_and_updates(
346
trainable_variables, non_trainable_variables, metric_variables, x, y
347
):
348
y_pred, non_trainable_variables = model.stateless_call(
349
trainable_variables, non_trainable_variables, x
350
)
351
loss = loss_fn(y, y_pred)
352
metric_variables = train_acc_metric.stateless_update_state(
353
metric_variables, y, y_pred
354
)
355
return loss, (non_trainable_variables, metric_variables)
356
357
358
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
359
360
361
@jax.jit
362
def train_step(state, data):
363
(
364
trainable_variables,
365
non_trainable_variables,
366
optimizer_variables,
367
metric_variables,
368
) = state
369
x, y = data
370
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
371
trainable_variables, non_trainable_variables, metric_variables, x, y
372
)
373
trainable_variables, optimizer_variables = optimizer.stateless_apply(
374
optimizer_variables, grads, trainable_variables
375
)
376
# Return updated state
377
return loss, (
378
trainable_variables,
379
non_trainable_variables,
380
optimizer_variables,
381
metric_variables,
382
)
383
384
385
"""
386
We'll also prepare an evaluation step function:
387
"""
388
389
390
@jax.jit
391
def eval_step(state, data):
392
trainable_variables, non_trainable_variables, metric_variables = state
393
x, y = data
394
y_pred, non_trainable_variables = model.stateless_call(
395
trainable_variables, non_trainable_variables, x
396
)
397
loss = loss_fn(y, y_pred)
398
metric_variables = val_acc_metric.stateless_update_state(
399
metric_variables, y, y_pred
400
)
401
return loss, (
402
trainable_variables,
403
non_trainable_variables,
404
metric_variables,
405
)
406
407
408
"""
409
Here are our loops:
410
"""
411
412
# Build optimizer variables.
413
optimizer.build(model.trainable_variables)
414
415
trainable_variables = model.trainable_variables
416
non_trainable_variables = model.non_trainable_variables
417
optimizer_variables = optimizer.variables
418
metric_variables = train_acc_metric.variables
419
state = (
420
trainable_variables,
421
non_trainable_variables,
422
optimizer_variables,
423
metric_variables,
424
)
425
426
# Training loop
427
for step, data in enumerate(train_dataset):
428
data = (data[0].numpy(), data[1].numpy())
429
loss, state = train_step(state, data)
430
# Log every 100 batches.
431
if step % 100 == 0:
432
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
433
_, _, _, metric_variables = state
434
for variable, value in zip(train_acc_metric.variables, metric_variables):
435
variable.assign(value)
436
print(f"Training accuracy: {train_acc_metric.result()}")
437
print(f"Seen so far: {(step + 1) * batch_size} samples")
438
439
metric_variables = val_acc_metric.variables
440
(
441
trainable_variables,
442
non_trainable_variables,
443
optimizer_variables,
444
metric_variables,
445
) = state
446
state = trainable_variables, non_trainable_variables, metric_variables
447
448
# Eval loop
449
for step, data in enumerate(val_dataset):
450
data = (data[0].numpy(), data[1].numpy())
451
loss, state = eval_step(state, data)
452
# Log every 100 batches.
453
if step % 100 == 0:
454
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
455
_, _, metric_variables = state
456
for variable, value in zip(val_acc_metric.variables, metric_variables):
457
variable.assign(value)
458
print(f"Validation accuracy: {val_acc_metric.result()}")
459
print(f"Seen so far: {(step + 1) * batch_size} samples")
460
461
"""
462
## Low-level handling of losses tracked by the model
463
464
Layers & models recursively track any losses created during the forward pass
465
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
466
values are available via the property `model.losses`
467
at the end of the forward pass.
468
469
If you want to be using these loss components, you should sum them
470
and add them to the main loss in your training step.
471
472
Consider this layer, that creates an activity regularization loss:
473
"""
474
475
476
class ActivityRegularizationLayer(keras.layers.Layer):
477
def call(self, inputs):
478
self.add_loss(1e-2 * jax.numpy.sum(inputs))
479
return inputs
480
481
482
"""
483
Let's build a really simple model that uses it:
484
"""
485
486
inputs = keras.Input(shape=(784,), name="digits")
487
x = keras.layers.Dense(64, activation="relu")(inputs)
488
# Insert activity regularization as a layer
489
x = ActivityRegularizationLayer()(x)
490
x = keras.layers.Dense(64, activation="relu")(x)
491
outputs = keras.layers.Dense(10, name="predictions")(x)
492
493
model = keras.Model(inputs=inputs, outputs=outputs)
494
495
"""
496
Here's what our `compute_loss_and_updates` function should look like now:
497
498
- Pass `return_losses=True` to `model.stateless_call()`.
499
- Sum the resulting `losses` and add them to the main loss.
500
"""
501
502
503
def compute_loss_and_updates(
504
trainable_variables, non_trainable_variables, metric_variables, x, y
505
):
506
y_pred, non_trainable_variables, losses = model.stateless_call(
507
trainable_variables, non_trainable_variables, x, return_losses=True
508
)
509
loss = loss_fn(y, y_pred)
510
if losses:
511
loss += jax.numpy.sum(losses)
512
metric_variables = train_acc_metric.stateless_update_state(
513
metric_variables, y, y_pred
514
)
515
return loss, non_trainable_variables, metric_variables
516
517
518
"""
519
That's it!
520
"""
521
522