Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/custom_train_step_in_tensorflow.py
3273 views
1
"""
2
Title: Customizing what happens in `fit()` with TensorFlow
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/15
5
Last modified: 2023/06/27
6
Description: Overriding the training step of the Model class with TensorFlow.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
When you're doing supervised learning, you can use `fit()` and everything works
14
smoothly.
15
16
When you need to take control of every little detail, you can write your own training
17
loop entirely from scratch.
18
19
But what if you need a custom training algorithm, but you still want to benefit from
20
the convenient features of `fit()`, such as callbacks, built-in distribution support,
21
or step fusing?
22
23
A core principle of Keras is **progressive disclosure of complexity**. You should
24
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
25
off a cliff if the high-level functionality doesn't exactly match your use case. You
26
should be able to gain more control over the small details while retaining a
27
commensurate amount of high-level convenience.
28
29
When you need to customize what `fit()` does, you should **override the training step
30
function of the `Model` class**. This is the function that is called by `fit()` for
31
every batch of data. You will then be able to call `fit()` as usual -- and it will be
32
running your own learning algorithm.
33
34
Note that this pattern does not prevent you from building models with the Functional
35
API. You can do this whether you're building `Sequential` models, Functional API
36
models, or subclassed models.
37
38
Let's see how that works.
39
"""
40
41
"""
42
## Setup
43
"""
44
45
import os
46
47
# This guide can only be run with the TF backend.
48
os.environ["KERAS_BACKEND"] = "tensorflow"
49
50
import tensorflow as tf
51
import keras
52
from keras import layers
53
import numpy as np
54
55
"""
56
## A first simple example
57
58
Let's start from a simple example:
59
60
- We create a new class that subclasses `keras.Model`.
61
- We just override the method `train_step(self, data)`.
62
- We return a dictionary mapping metric names (including the loss) to their current
63
value.
64
65
The input argument `data` is what gets passed to fit as training data:
66
67
- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple
68
`(x, y)`
69
- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be
70
what gets yielded by `dataset` at each batch.
71
72
In the body of the `train_step()` method, we implement a regular training update,
73
similar to what you are already familiar with. Importantly, **we compute the loss via
74
`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to
75
`compile()`.
76
77
Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,
78
to update the state of the metrics that were passed in `compile()`,
79
and we query results from `self.metrics` at the end to retrieve their current value.
80
"""
81
82
83
class CustomModel(keras.Model):
84
def train_step(self, data):
85
# Unpack the data. Its structure depends on your model and
86
# on what you pass to `fit()`.
87
x, y = data
88
89
with tf.GradientTape() as tape:
90
y_pred = self(x, training=True) # Forward pass
91
# Compute the loss value
92
# (the loss function is configured in `compile()`)
93
loss = self.compute_loss(y=y, y_pred=y_pred)
94
95
# Compute gradients
96
trainable_vars = self.trainable_variables
97
gradients = tape.gradient(loss, trainable_vars)
98
99
# Update weights
100
self.optimizer.apply(gradients, trainable_vars)
101
102
# Update metrics (includes the metric that tracks the loss)
103
for metric in self.metrics:
104
if metric.name == "loss":
105
metric.update_state(loss)
106
else:
107
metric.update_state(y, y_pred)
108
109
# Return a dict mapping metric names to current value
110
return {m.name: m.result() for m in self.metrics}
111
112
113
"""
114
Let's try this out:
115
"""
116
117
# Construct and compile an instance of CustomModel
118
inputs = keras.Input(shape=(32,))
119
outputs = keras.layers.Dense(1)(inputs)
120
model = CustomModel(inputs, outputs)
121
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
122
123
# Just use `fit` as usual
124
x = np.random.random((1000, 32))
125
y = np.random.random((1000, 1))
126
model.fit(x, y, epochs=3)
127
128
"""
129
## Going lower-level
130
131
Naturally, you could just skip passing a loss function in `compile()`, and instead do
132
everything *manually* in `train_step`. Likewise for metrics.
133
134
Here's a lower-level example, that only uses `compile()` to configure the optimizer:
135
136
- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).
137
- We implement a custom `train_step()` that updates the state of these metrics
138
(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,
139
to be displayed by the progress bar and to be pass to any callback.
140
- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise
141
calling `result()` would return an average since the start of training, whereas we usually work
142
with per-epoch averages. Thankfully, the framework can do that for us: just list any metric
143
you want to reset in the `metrics` property of the model. The model will call `reset_states()`
144
on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to
145
`evaluate()`.
146
"""
147
148
149
class CustomModel(keras.Model):
150
def __init__(self, *args, **kwargs):
151
super().__init__(*args, **kwargs)
152
self.loss_tracker = keras.metrics.Mean(name="loss")
153
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
154
self.loss_fn = keras.losses.MeanSquaredError()
155
156
def train_step(self, data):
157
x, y = data
158
159
with tf.GradientTape() as tape:
160
y_pred = self(x, training=True) # Forward pass
161
# Compute our own loss
162
loss = self.loss_fn(y, y_pred)
163
164
# Compute gradients
165
trainable_vars = self.trainable_variables
166
gradients = tape.gradient(loss, trainable_vars)
167
168
# Update weights
169
self.optimizer.apply(gradients, trainable_vars)
170
171
# Compute our own metrics
172
self.loss_tracker.update_state(loss)
173
self.mae_metric.update_state(y, y_pred)
174
return {
175
"loss": self.loss_tracker.result(),
176
"mae": self.mae_metric.result(),
177
}
178
179
@property
180
def metrics(self):
181
# We list our `Metric` objects here so that `reset_states()` can be
182
# called automatically at the start of each epoch
183
# or at the start of `evaluate()`.
184
return [self.loss_tracker, self.mae_metric]
185
186
187
# Construct an instance of CustomModel
188
inputs = keras.Input(shape=(32,))
189
outputs = keras.layers.Dense(1)(inputs)
190
model = CustomModel(inputs, outputs)
191
192
# We don't pass a loss or metrics here.
193
model.compile(optimizer="adam")
194
195
# Just use `fit` as usual -- you can use callbacks, etc.
196
x = np.random.random((1000, 32))
197
y = np.random.random((1000, 1))
198
model.fit(x, y, epochs=5)
199
200
201
"""
202
## Supporting `sample_weight` & `class_weight`
203
204
You may have noticed that our first basic example didn't make any mention of sample
205
weighting. If you want to support the `fit()` arguments `sample_weight` and
206
`class_weight`, you'd simply do the following:
207
208
- Unpack `sample_weight` from the `data` argument
209
- Pass it to `compute_loss` & `update_state` (of course, you could also just apply
210
it manually if you don't rely on `compile()` for losses & metrics)
211
- That's it.
212
"""
213
214
215
class CustomModel(keras.Model):
216
def train_step(self, data):
217
# Unpack the data. Its structure depends on your model and
218
# on what you pass to `fit()`.
219
if len(data) == 3:
220
x, y, sample_weight = data
221
else:
222
sample_weight = None
223
x, y = data
224
225
with tf.GradientTape() as tape:
226
y_pred = self(x, training=True) # Forward pass
227
# Compute the loss value.
228
# The loss function is configured in `compile()`.
229
loss = self.compute_loss(
230
y=y,
231
y_pred=y_pred,
232
sample_weight=sample_weight,
233
)
234
235
# Compute gradients
236
trainable_vars = self.trainable_variables
237
gradients = tape.gradient(loss, trainable_vars)
238
239
# Update weights
240
self.optimizer.apply(gradients, trainable_vars)
241
242
# Update the metrics.
243
# Metrics are configured in `compile()`.
244
for metric in self.metrics:
245
if metric.name == "loss":
246
metric.update_state(loss)
247
else:
248
metric.update_state(y, y_pred, sample_weight=sample_weight)
249
250
# Return a dict mapping metric names to current value.
251
# Note that it will include the loss (tracked in self.metrics).
252
return {m.name: m.result() for m in self.metrics}
253
254
255
# Construct and compile an instance of CustomModel
256
inputs = keras.Input(shape=(32,))
257
outputs = keras.layers.Dense(1)(inputs)
258
model = CustomModel(inputs, outputs)
259
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
260
261
# You can now use sample_weight argument
262
x = np.random.random((1000, 32))
263
y = np.random.random((1000, 1))
264
sw = np.random.random((1000, 1))
265
model.fit(x, y, sample_weight=sw, epochs=3)
266
267
"""
268
## Providing your own evaluation step
269
270
What if you want to do the same for calls to `model.evaluate()`? Then you would
271
override `test_step` in exactly the same way. Here's what it looks like:
272
"""
273
274
275
class CustomModel(keras.Model):
276
def test_step(self, data):
277
# Unpack the data
278
x, y = data
279
# Compute predictions
280
y_pred = self(x, training=False)
281
# Updates the metrics tracking the loss
282
loss = self.compute_loss(y=y, y_pred=y_pred)
283
# Update the metrics.
284
for metric in self.metrics:
285
if metric.name == "loss":
286
metric.update_state(loss)
287
else:
288
metric.update_state(y, y_pred)
289
# Return a dict mapping metric names to current value.
290
# Note that it will include the loss (tracked in self.metrics).
291
return {m.name: m.result() for m in self.metrics}
292
293
294
# Construct an instance of CustomModel
295
inputs = keras.Input(shape=(32,))
296
outputs = keras.layers.Dense(1)(inputs)
297
model = CustomModel(inputs, outputs)
298
model.compile(loss="mse", metrics=["mae"])
299
300
# Evaluate with our custom test_step
301
x = np.random.random((1000, 32))
302
y = np.random.random((1000, 1))
303
model.evaluate(x, y)
304
305
"""
306
## Wrapping up: an end-to-end GAN example
307
308
Let's walk through an end-to-end example that leverages everything you just learned.
309
310
Let's consider:
311
312
- A generator network meant to generate 28x28x1 images.
313
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and
314
"real").
315
- One optimizer for each.
316
- A loss function to train the discriminator.
317
"""
318
319
# Create the discriminator
320
discriminator = keras.Sequential(
321
[
322
keras.Input(shape=(28, 28, 1)),
323
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
324
layers.LeakyReLU(negative_slope=0.2),
325
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
326
layers.LeakyReLU(negative_slope=0.2),
327
layers.GlobalMaxPooling2D(),
328
layers.Dense(1),
329
],
330
name="discriminator",
331
)
332
333
# Create the generator
334
latent_dim = 128
335
generator = keras.Sequential(
336
[
337
keras.Input(shape=(latent_dim,)),
338
# We want to generate 128 coefficients to reshape into a 7x7x128 map
339
layers.Dense(7 * 7 * 128),
340
layers.LeakyReLU(negative_slope=0.2),
341
layers.Reshape((7, 7, 128)),
342
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
343
layers.LeakyReLU(negative_slope=0.2),
344
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
345
layers.LeakyReLU(negative_slope=0.2),
346
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
347
],
348
name="generator",
349
)
350
351
"""
352
Here's a feature-complete GAN class, overriding `compile()` to use its own signature,
353
and implementing the entire GAN algorithm in 17 lines in `train_step`:
354
"""
355
356
357
class GAN(keras.Model):
358
def __init__(self, discriminator, generator, latent_dim):
359
super().__init__()
360
self.discriminator = discriminator
361
self.generator = generator
362
self.latent_dim = latent_dim
363
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
364
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
365
self.seed_generator = keras.random.SeedGenerator(1337)
366
367
@property
368
def metrics(self):
369
return [self.d_loss_tracker, self.g_loss_tracker]
370
371
def compile(self, d_optimizer, g_optimizer, loss_fn):
372
super().compile()
373
self.d_optimizer = d_optimizer
374
self.g_optimizer = g_optimizer
375
self.loss_fn = loss_fn
376
377
def train_step(self, real_images):
378
if isinstance(real_images, tuple):
379
real_images = real_images[0]
380
# Sample random points in the latent space
381
batch_size = tf.shape(real_images)[0]
382
random_latent_vectors = keras.random.normal(
383
shape=(batch_size, self.latent_dim), seed=self.seed_generator
384
)
385
386
# Decode them to fake images
387
generated_images = self.generator(random_latent_vectors)
388
389
# Combine them with real images
390
combined_images = tf.concat([generated_images, real_images], axis=0)
391
392
# Assemble labels discriminating real from fake images
393
labels = tf.concat(
394
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
395
)
396
# Add random noise to the labels - important trick!
397
labels += 0.05 * keras.random.uniform(
398
tf.shape(labels), seed=self.seed_generator
399
)
400
401
# Train the discriminator
402
with tf.GradientTape() as tape:
403
predictions = self.discriminator(combined_images)
404
d_loss = self.loss_fn(labels, predictions)
405
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
406
self.d_optimizer.apply(grads, self.discriminator.trainable_weights)
407
408
# Sample random points in the latent space
409
random_latent_vectors = keras.random.normal(
410
shape=(batch_size, self.latent_dim), seed=self.seed_generator
411
)
412
413
# Assemble labels that say "all real images"
414
misleading_labels = tf.zeros((batch_size, 1))
415
416
# Train the generator (note that we should *not* update the weights
417
# of the discriminator)!
418
with tf.GradientTape() as tape:
419
predictions = self.discriminator(self.generator(random_latent_vectors))
420
g_loss = self.loss_fn(misleading_labels, predictions)
421
grads = tape.gradient(g_loss, self.generator.trainable_weights)
422
self.g_optimizer.apply(grads, self.generator.trainable_weights)
423
424
# Update metrics and return their value.
425
self.d_loss_tracker.update_state(d_loss)
426
self.g_loss_tracker.update_state(g_loss)
427
return {
428
"d_loss": self.d_loss_tracker.result(),
429
"g_loss": self.g_loss_tracker.result(),
430
}
431
432
433
"""
434
Let's test-drive it:
435
"""
436
437
# Prepare the dataset. We use both the training & test MNIST digits.
438
batch_size = 64
439
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
440
all_digits = np.concatenate([x_train, x_test])
441
all_digits = all_digits.astype("float32") / 255.0
442
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
443
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
444
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
445
446
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
447
gan.compile(
448
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
449
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
450
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
451
)
452
453
# To limit the execution time, we only train on 100 batches. You can train on
454
# the entire dataset. You will need about 20 epochs to get nice results.
455
gan.fit(dataset.take(100), epochs=1)
456
457
"""
458
The ideas behind deep learning are simple, so why should their implementation be painful?
459
"""
460
461