Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/_customizing_what_happens_in_fit.py
3273 views
1
"""
2
Title: Customizing what happens in `fit()`
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/15
5
Last modified: 2023/06/14
6
Description: Complete guide to overriding the training step of the Model class.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
When you're doing supervised learning, you can use `fit()` and everything works
14
smoothly.
15
16
When you need to write your own training loop from scratch, you can use the
17
`GradientTape` and take control of every little detail.
18
19
But what if you need a custom training algorithm, but you still want to benefit from
20
the convenient features of `fit()`, such as callbacks, built-in distribution support,
21
or step fusing?
22
23
A core principle of Keras is **progressive disclosure of complexity**. You should
24
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
25
off a cliff if the high-level functionality doesn't exactly match your use case. You
26
should be able to gain more control over the small details while retaining a
27
commensurate amount of high-level convenience.
28
29
When you need to customize what `fit()` does, you should **override the training step
30
function of the `Model` class**. This is the function that is called by `fit()` for
31
every batch of data. You will then be able to call `fit()` as usual -- and it will be
32
running your own learning algorithm.
33
34
Note that this pattern does not prevent you from building models with the Functional
35
API. You can do this whether you're building `Sequential` models, Functional API
36
models, or subclassed models.
37
38
Let's see how that works.
39
"""
40
41
"""
42
## Setup
43
44
Requires TensorFlow 2.8 or later.
45
"""
46
47
import tensorflow as tf
48
import keras
49
50
"""
51
## A first simple example
52
53
Let's start from a simple example:
54
55
- We create a new class that subclasses `keras.Model`.
56
- We just override the method `train_step(self, data)`.
57
- We return a dictionary mapping metric names (including the loss) to their current
58
value.
59
60
The input argument `data` is what gets passed to fit as training data:
61
62
- If you pass Numpy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple
63
`(x, y)`
64
- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be
65
what gets yielded by `dataset` at each batch.
66
67
In the body of the `train_step` method, we implement a regular training update,
68
similar to what you are already familiar with. Importantly, **we compute the loss via
69
`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to
70
`compile()`.
71
72
Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,
73
to update the state of the metrics that were passed in `compile()`,
74
and we query results from `self.metrics` at the end to retrieve their current value.
75
"""
76
77
78
class CustomModel(keras.Model):
79
def train_step(self, data):
80
# Unpack the data. Its structure depends on your model and
81
# on what you pass to `fit()`.
82
x, y = data
83
84
with tf.GradientTape() as tape:
85
y_pred = self(x, training=True) # Forward pass
86
# Compute the loss value
87
# (the loss function is configured in `compile()`)
88
loss = self.compute_loss(y=y, y_pred=y_pred)
89
90
# Compute gradients
91
trainable_vars = self.trainable_variables
92
gradients = tape.gradient(loss, trainable_vars)
93
# Update weights
94
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
95
# Update metrics (includes the metric that tracks the loss)
96
for metric in self.metrics:
97
if metric.name == "loss":
98
metric.update_state(loss)
99
else:
100
metric.update_state(y, y_pred)
101
# Return a dict mapping metric names to current value
102
return {m.name: m.result() for m in self.metrics}
103
104
105
"""
106
Let's try this out:
107
"""
108
109
import numpy as np
110
111
# Construct and compile an instance of CustomModel
112
inputs = keras.Input(shape=(32,))
113
outputs = keras.layers.Dense(1)(inputs)
114
model = CustomModel(inputs, outputs)
115
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
116
117
# Just use `fit` as usual
118
x = np.random.random((1000, 32))
119
y = np.random.random((1000, 1))
120
model.fit(x, y, epochs=3)
121
122
"""
123
## Going lower-level
124
125
Naturally, you could just skip passing a loss function in `compile()`, and instead do
126
everything *manually* in `train_step`. Likewise for metrics.
127
128
Here's a lower-level example, that only uses `compile()` to configure the optimizer:
129
130
- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).
131
- We implement a custom `train_step()` that updates the state of these metrics
132
(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,
133
to be displayed by the progress bar and to be pass to any callback.
134
- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise
135
calling `result()` would return an average since the start of training, whereas we usually work
136
with per-epoch averages. Thankfully, the framework can do that for us: just list any metric
137
you want to reset in the `metrics` property of the model. The model will call `reset_states()`
138
on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to
139
`evaluate()`.
140
"""
141
142
143
class CustomModel(keras.Model):
144
def __init__(self, *args, **kwargs):
145
super().__init__(*args, **kwargs)
146
self.loss_tracker = keras.metrics.Mean(name="loss")
147
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
148
149
def train_step(self, data):
150
x, y = data
151
152
with tf.GradientTape() as tape:
153
y_pred = self(x, training=True) # Forward pass
154
# Compute our own loss
155
loss = keras.losses.mean_squared_error(y, y_pred)
156
157
# Compute gradients
158
trainable_vars = self.trainable_variables
159
gradients = tape.gradient(loss, trainable_vars)
160
161
# Update weights
162
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
163
164
# Compute our own metrics
165
self.loss_tracker.update_state(loss)
166
self.mae_metric.update_state(y, y_pred)
167
return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}
168
169
@property
170
def metrics(self):
171
# We list our `Metric` objects here so that `reset_states()` can be
172
# called automatically at the start of each epoch
173
# or at the start of `evaluate()`.
174
# If you don't implement this property, you have to call
175
# `reset_states()` yourself at the time of your choosing.
176
return [self.loss_tracker, self.mae_metric]
177
178
179
# Construct an instance of CustomModel
180
inputs = keras.Input(shape=(32,))
181
outputs = keras.layers.Dense(1)(inputs)
182
model = CustomModel(inputs, outputs)
183
184
# We don't pass a loss or metrics here.
185
model.compile(optimizer="adam")
186
187
# Just use `fit` as usual -- you can use callbacks, etc.
188
x = np.random.random((1000, 32))
189
y = np.random.random((1000, 1))
190
model.fit(x, y, epochs=5)
191
192
193
"""
194
## Supporting `sample_weight` & `class_weight`
195
196
You may have noticed that our first basic example didn't make any mention of sample
197
weighting. If you want to support the `fit()` arguments `sample_weight` and
198
`class_weight`, you'd simply do the following:
199
200
- Unpack `sample_weight` from the `data` argument
201
- Pass it to `compute_loss` & `update_state` (of course, you could also just apply
202
it manually if you don't rely on `compile()` for losses & metrics)
203
- That's it.
204
"""
205
206
207
class CustomModel(keras.Model):
208
def train_step(self, data):
209
# Unpack the data. Its structure depends on your model and
210
# on what you pass to `fit()`.
211
if len(data) == 3:
212
x, y, sample_weight = data
213
else:
214
sample_weight = None
215
x, y = data
216
217
with tf.GradientTape() as tape:
218
y_pred = self(x, training=True) # Forward pass
219
# Compute the loss value.
220
# The loss function is configured in `compile()`.
221
loss = self.compute_loss(
222
y=y,
223
y_pred=y_pred,
224
sample_weight=sample_weight,
225
)
226
227
# Compute gradients
228
trainable_vars = self.trainable_variables
229
gradients = tape.gradient(loss, trainable_vars)
230
231
# Update weights
232
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
233
234
# Update the metrics.
235
# Metrics are configured in `compile()`.
236
for metric in self.metrics:
237
if metric.name == "loss":
238
metric.update_state(loss)
239
else:
240
metric.update_state(y, y_pred, sample_weight=sample_weight)
241
242
# Return a dict mapping metric names to current value.
243
# Note that it will include the loss (tracked in self.metrics).
244
return {m.name: m.result() for m in self.metrics}
245
246
247
# Construct and compile an instance of CustomModel
248
inputs = keras.Input(shape=(32,))
249
outputs = keras.layers.Dense(1)(inputs)
250
model = CustomModel(inputs, outputs)
251
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
252
253
# You can now use sample_weight argument
254
x = np.random.random((1000, 32))
255
y = np.random.random((1000, 1))
256
sw = np.random.random((1000, 1))
257
model.fit(x, y, sample_weight=sw, epochs=3)
258
259
"""
260
## Providing your own evaluation step
261
262
What if you want to do the same for calls to `model.evaluate()`? Then you would
263
override `test_step` in exactly the same way. Here's what it looks like:
264
"""
265
266
267
class CustomModel(keras.Model):
268
def test_step(self, data):
269
# Unpack the data
270
x, y = data
271
# Compute predictions
272
y_pred = self(x, training=False)
273
# Updates the metrics tracking the loss
274
self.compute_loss(y=y, y_pred=y_pred)
275
# Update the metrics.
276
for metric in self.metrics:
277
if metric.name != "loss":
278
metric.update_state(y, y_pred)
279
# Return a dict mapping metric names to current value.
280
# Note that it will include the loss (tracked in self.metrics).
281
return {m.name: m.result() for m in self.metrics}
282
283
284
# Construct an instance of CustomModel
285
inputs = keras.Input(shape=(32,))
286
outputs = keras.layers.Dense(1)(inputs)
287
model = CustomModel(inputs, outputs)
288
model.compile(loss="mse", metrics=["mae"])
289
290
# Evaluate with our custom test_step
291
x = np.random.random((1000, 32))
292
y = np.random.random((1000, 1))
293
model.evaluate(x, y)
294
295
"""
296
## Wrapping up: an end-to-end GAN example
297
298
Let's walk through an end-to-end example that leverages everything you just learned.
299
300
Let's consider:
301
302
- A generator network meant to generate 28x28x1 images.
303
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and
304
"real").
305
- One optimizer for each.
306
- A loss function to train the discriminator.
307
308
309
"""
310
311
from tensorflow.keras import layers
312
313
# Create the discriminator
314
discriminator = keras.Sequential(
315
[
316
keras.Input(shape=(28, 28, 1)),
317
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
318
layers.LeakyReLU(alpha=0.2),
319
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
320
layers.LeakyReLU(alpha=0.2),
321
layers.GlobalMaxPooling2D(),
322
layers.Dense(1),
323
],
324
name="discriminator",
325
)
326
327
# Create the generator
328
latent_dim = 128
329
generator = keras.Sequential(
330
[
331
keras.Input(shape=(latent_dim,)),
332
# We want to generate 128 coefficients to reshape into a 7x7x128 map
333
layers.Dense(7 * 7 * 128),
334
layers.LeakyReLU(alpha=0.2),
335
layers.Reshape((7, 7, 128)),
336
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
337
layers.LeakyReLU(alpha=0.2),
338
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
339
layers.LeakyReLU(alpha=0.2),
340
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
341
],
342
name="generator",
343
)
344
345
"""
346
Here's a feature-complete GAN class, overriding `compile()` to use its own signature,
347
and implementing the entire GAN algorithm in 17 lines in `train_step`:
348
"""
349
350
351
class GAN(keras.Model):
352
def __init__(self, discriminator, generator, latent_dim):
353
super().__init__()
354
self.discriminator = discriminator
355
self.generator = generator
356
self.latent_dim = latent_dim
357
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
358
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
359
360
def compile(self, d_optimizer, g_optimizer, loss_fn):
361
super().compile()
362
self.d_optimizer = d_optimizer
363
self.g_optimizer = g_optimizer
364
self.loss_fn = loss_fn
365
366
def train_step(self, real_images):
367
if isinstance(real_images, tuple):
368
real_images = real_images[0]
369
# Sample random points in the latent space
370
batch_size = tf.shape(real_images)[0]
371
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
372
373
# Decode them to fake images
374
generated_images = self.generator(random_latent_vectors)
375
376
# Combine them with real images
377
combined_images = tf.concat([generated_images, real_images], axis=0)
378
379
# Assemble labels discriminating real from fake images
380
labels = tf.concat(
381
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
382
)
383
# Add random noise to the labels - important trick!
384
labels += 0.05 * tf.random.uniform(tf.shape(labels))
385
386
# Train the discriminator
387
with tf.GradientTape() as tape:
388
predictions = self.discriminator(combined_images)
389
d_loss = self.loss_fn(labels, predictions)
390
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
391
self.d_optimizer.apply_gradients(
392
zip(grads, self.discriminator.trainable_weights)
393
)
394
395
# Sample random points in the latent space
396
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
397
398
# Assemble labels that say "all real images"
399
misleading_labels = tf.zeros((batch_size, 1))
400
401
# Train the generator (note that we should *not* update the weights
402
# of the discriminator)!
403
with tf.GradientTape() as tape:
404
predictions = self.discriminator(self.generator(random_latent_vectors))
405
g_loss = self.loss_fn(misleading_labels, predictions)
406
grads = tape.gradient(g_loss, self.generator.trainable_weights)
407
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
408
409
# Update metrics and return their value.
410
self.d_loss_tracker.update_state(d_loss)
411
self.g_loss_tracker.update_state(g_loss)
412
return {
413
"d_loss": self.d_loss_tracker.result(),
414
"g_loss": self.g_loss_tracker.result(),
415
}
416
417
418
"""
419
Let's test-drive it:
420
"""
421
422
# Prepare the dataset. We use both the training & test MNIST digits.
423
batch_size = 64
424
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
425
all_digits = np.concatenate([x_train, x_test])
426
all_digits = all_digits.astype("float32") / 255.0
427
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
428
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
429
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
430
431
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
432
gan.compile(
433
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
434
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
435
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
436
)
437
438
# To limit the execution time, we only train on 100 batches. You can train on
439
# the entire dataset. You will need about 20 epochs to get nice results.
440
gan.fit(dataset.take(100), epochs=1)
441
442
"""
443
The ideas behind deep learning are simple, so why should their implementation be painful?
444
"""
445
446