Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/custom_train_step_in_jax.py
3273 views
1
"""
2
Title: Customizing what happens in `fit()` with JAX
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/27
5
Last modified: 2023/06/27
6
Description: Overriding the training step of the Model class with JAX.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
When you're doing supervised learning, you can use `fit()` and everything works
14
smoothly.
15
16
When you need to take control of every little detail, you can write your own training
17
loop entirely from scratch.
18
19
But what if you need a custom training algorithm, but you still want to benefit from
20
the convenient features of `fit()`, such as callbacks, built-in distribution support,
21
or step fusing?
22
23
A core principle of Keras is **progressive disclosure of complexity**. You should
24
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
25
off a cliff if the high-level functionality doesn't exactly match your use case. You
26
should be able to gain more control over the small details while retaining a
27
commensurate amount of high-level convenience.
28
29
When you need to customize what `fit()` does, you should **override the training step
30
function of the `Model` class**. This is the function that is called by `fit()` for
31
every batch of data. You will then be able to call `fit()` as usual -- and it will be
32
running your own learning algorithm.
33
34
Note that this pattern does not prevent you from building models with the Functional
35
API. You can do this whether you're building `Sequential` models, Functional API
36
models, or subclassed models.
37
38
Let's see how that works.
39
"""
40
41
"""
42
## Setup
43
"""
44
45
import os
46
47
# This guide can only be run with the JAX backend.
48
os.environ["KERAS_BACKEND"] = "jax"
49
50
import jax
51
import keras
52
import numpy as np
53
54
"""
55
## A first simple example
56
57
Let's start from a simple example:
58
59
- We create a new class that subclasses `keras.Model`.
60
- We implement a fully-stateless `compute_loss_and_updates()` method
61
to compute the loss as well as the updated values for the non-trainable
62
variables of the model. Internally, it calls `stateless_call()` and
63
the built-in `compute_loss()`.
64
- We implement a fully-stateless `train_step()` method to compute current
65
metric values (including the loss) as well as updated values for the
66
trainable variables, the optimizer variables, and the metric variables.
67
68
Note that you can also take into account the `sample_weight` argument by:
69
70
- Unpacking the data as `x, y, sample_weight = data`
71
- Passing `sample_weight` to `compute_loss()`
72
- Passing `sample_weight` alongside `y` and `y_pred`
73
to metrics in `stateless_update_state()`
74
"""
75
76
77
class CustomModel(keras.Model):
78
def compute_loss_and_updates(
79
self,
80
trainable_variables,
81
non_trainable_variables,
82
x,
83
y,
84
training=False,
85
):
86
y_pred, non_trainable_variables = self.stateless_call(
87
trainable_variables,
88
non_trainable_variables,
89
x,
90
training=training,
91
)
92
loss = self.compute_loss(x, y, y_pred)
93
return loss, (y_pred, non_trainable_variables)
94
95
def train_step(self, state, data):
96
(
97
trainable_variables,
98
non_trainable_variables,
99
optimizer_variables,
100
metrics_variables,
101
) = state
102
x, y = data
103
104
# Get the gradient function.
105
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
106
107
# Compute the gradients.
108
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
109
trainable_variables,
110
non_trainable_variables,
111
x,
112
y,
113
training=True,
114
)
115
116
# Update trainable variables and optimizer variables.
117
(
118
trainable_variables,
119
optimizer_variables,
120
) = self.optimizer.stateless_apply(
121
optimizer_variables, grads, trainable_variables
122
)
123
124
# Update metrics.
125
new_metrics_vars = []
126
logs = {}
127
for metric in self.metrics:
128
this_metric_vars = metrics_variables[
129
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
130
]
131
if metric.name == "loss":
132
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
133
else:
134
this_metric_vars = metric.stateless_update_state(
135
this_metric_vars, y, y_pred
136
)
137
logs[metric.name] = metric.stateless_result(this_metric_vars)
138
new_metrics_vars += this_metric_vars
139
140
# Return metric logs and updated state variables.
141
state = (
142
trainable_variables,
143
non_trainable_variables,
144
optimizer_variables,
145
new_metrics_vars,
146
)
147
return logs, state
148
149
150
"""
151
Let's try this out:
152
"""
153
154
# Construct and compile an instance of CustomModel
155
inputs = keras.Input(shape=(32,))
156
outputs = keras.layers.Dense(1)(inputs)
157
model = CustomModel(inputs, outputs)
158
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
159
160
# Just use `fit` as usual
161
x = np.random.random((1000, 32))
162
y = np.random.random((1000, 1))
163
model.fit(x, y, epochs=3)
164
165
166
"""
167
## Going lower-level
168
169
Naturally, you could just skip passing a loss function in `compile()`, and instead do
170
everything *manually* in `train_step`. Likewise for metrics.
171
172
Here's a lower-level example, that only uses `compile()` to configure the optimizer:
173
"""
174
175
176
class CustomModel(keras.Model):
177
def __init__(self, *args, **kwargs):
178
super().__init__(*args, **kwargs)
179
self.loss_tracker = keras.metrics.Mean(name="loss")
180
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
181
self.loss_fn = keras.losses.MeanSquaredError()
182
183
def compute_loss_and_updates(
184
self,
185
trainable_variables,
186
non_trainable_variables,
187
x,
188
y,
189
training=False,
190
):
191
y_pred, non_trainable_variables = self.stateless_call(
192
trainable_variables,
193
non_trainable_variables,
194
x,
195
training=training,
196
)
197
loss = self.loss_fn(y, y_pred)
198
return loss, (y_pred, non_trainable_variables)
199
200
def train_step(self, state, data):
201
(
202
trainable_variables,
203
non_trainable_variables,
204
optimizer_variables,
205
metrics_variables,
206
) = state
207
x, y = data
208
209
# Get the gradient function.
210
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
211
212
# Compute the gradients.
213
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
214
trainable_variables,
215
non_trainable_variables,
216
x,
217
y,
218
training=True,
219
)
220
221
# Update trainable variables and optimizer variables.
222
(
223
trainable_variables,
224
optimizer_variables,
225
) = self.optimizer.stateless_apply(
226
optimizer_variables, grads, trainable_variables
227
)
228
229
# Update metrics.
230
loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
231
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
232
233
loss_tracker_vars = self.loss_tracker.stateless_update_state(
234
loss_tracker_vars, loss
235
)
236
mae_metric_vars = self.mae_metric.stateless_update_state(
237
mae_metric_vars, y, y_pred
238
)
239
240
logs = {}
241
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
242
loss_tracker_vars
243
)
244
logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)
245
246
new_metrics_vars = loss_tracker_vars + mae_metric_vars
247
248
# Return metric logs and updated state variables.
249
state = (
250
trainable_variables,
251
non_trainable_variables,
252
optimizer_variables,
253
new_metrics_vars,
254
)
255
return logs, state
256
257
@property
258
def metrics(self):
259
# We list our `Metric` objects here so that `reset_states()` can be
260
# called automatically at the start of each epoch
261
# or at the start of `evaluate()`.
262
return [self.loss_tracker, self.mae_metric]
263
264
265
# Construct an instance of CustomModel
266
inputs = keras.Input(shape=(32,))
267
outputs = keras.layers.Dense(1)(inputs)
268
model = CustomModel(inputs, outputs)
269
270
# We don't pass a loss or metrics here.
271
model.compile(optimizer="adam")
272
273
# Just use `fit` as usual -- you can use callbacks, etc.
274
x = np.random.random((1000, 32))
275
y = np.random.random((1000, 1))
276
model.fit(x, y, epochs=5)
277
278
279
"""
280
## Providing your own evaluation step
281
282
What if you want to do the same for calls to `model.evaluate()`? Then you would
283
override `test_step` in exactly the same way. Here's what it looks like:
284
"""
285
286
287
class CustomModel(keras.Model):
288
def test_step(self, state, data):
289
# Unpack the data.
290
x, y = data
291
(
292
trainable_variables,
293
non_trainable_variables,
294
metrics_variables,
295
) = state
296
297
# Compute predictions and loss.
298
y_pred, non_trainable_variables = self.stateless_call(
299
trainable_variables,
300
non_trainable_variables,
301
x,
302
training=False,
303
)
304
loss = self.compute_loss(x, y, y_pred)
305
306
# Update metrics.
307
new_metrics_vars = []
308
for metric in self.metrics:
309
this_metric_vars = metrics_variables[
310
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
311
]
312
if metric.name == "loss":
313
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
314
else:
315
this_metric_vars = metric.stateless_update_state(
316
this_metric_vars, y, y_pred
317
)
318
logs = metric.stateless_result(this_metric_vars)
319
new_metrics_vars += this_metric_vars
320
321
# Return metric logs and updated state variables.
322
state = (
323
trainable_variables,
324
non_trainable_variables,
325
new_metrics_vars,
326
)
327
return logs, state
328
329
330
# Construct an instance of CustomModel
331
inputs = keras.Input(shape=(32,))
332
outputs = keras.layers.Dense(1)(inputs)
333
model = CustomModel(inputs, outputs)
334
model.compile(loss="mse", metrics=["mae"])
335
336
# Evaluate with our custom test_step
337
x = np.random.random((1000, 32))
338
y = np.random.random((1000, 1))
339
model.evaluate(x, y)
340
341
342
"""
343
That's it!
344
"""
345
346