Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/writing_a_custom_training_loop_in_tensorflow.py
3273 views
1
"""
2
Title: Writing a training loop from scratch in TensorFlow
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/03/01
5
Last modified: 2023/06/25
6
Description: Writing low-level training & evaluation loops in TensorFlow.
7
Accelerator: None
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import time
15
import os
16
17
# This guide can only be run with the TensorFlow backend.
18
os.environ["KERAS_BACKEND"] = "tensorflow"
19
20
import tensorflow as tf
21
import keras
22
import numpy as np
23
24
"""
25
## Introduction
26
27
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
28
Their usage is covered in the guide
29
[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).
30
31
If you want to customize the learning algorithm of your model while still leveraging
32
the convenience of `fit()`
33
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
34
implement your own `train_step()` method, which
35
is called repeatedly during `fit()`.
36
37
Now, if you want very low-level control over training & evaluation, you should write
38
your own training & evaluation loops from scratch. This is what this guide is about.
39
"""
40
41
"""
42
## A first end-to-end example
43
44
Let's consider a simple MNIST model:
45
"""
46
47
48
def get_model():
49
inputs = keras.Input(shape=(784,), name="digits")
50
x1 = keras.layers.Dense(64, activation="relu")(inputs)
51
x2 = keras.layers.Dense(64, activation="relu")(x1)
52
outputs = keras.layers.Dense(10, name="predictions")(x2)
53
model = keras.Model(inputs=inputs, outputs=outputs)
54
return model
55
56
57
model = get_model()
58
59
"""
60
Let's train it using mini-batch gradient with a custom training loop.
61
62
First, we're going to need an optimizer, a loss function, and a dataset:
63
"""
64
65
# Instantiate an optimizer.
66
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
67
# Instantiate a loss function.
68
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
69
70
# Prepare the training dataset.
71
batch_size = 32
72
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
73
x_train = np.reshape(x_train, (-1, 784))
74
x_test = np.reshape(x_test, (-1, 784))
75
76
# Reserve 10,000 samples for validation.
77
x_val = x_train[-10000:]
78
y_val = y_train[-10000:]
79
x_train = x_train[:-10000]
80
y_train = y_train[:-10000]
81
82
# Prepare the training dataset.
83
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
84
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
85
86
# Prepare the validation dataset.
87
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
88
val_dataset = val_dataset.batch(batch_size)
89
90
"""
91
Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of
92
the trainable weights of the layer with respect to a loss value. Using an optimizer
93
instance, you can use these gradients to update these variables (which you can
94
retrieve using `model.trainable_weights`).
95
96
Here's our training loop, step by step:
97
98
- We open a `for` loop that iterates over epochs
99
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
100
- For each batch, we open a `GradientTape()` scope
101
- Inside this scope, we call the model (forward pass) and compute the loss
102
- Outside the scope, we retrieve the gradients of the weights
103
of the model with regard to the loss
104
- Finally, we use the optimizer to update the weights of the model based on the
105
gradients
106
"""
107
108
epochs = 3
109
for epoch in range(epochs):
110
print(f"\nStart of epoch {epoch}")
111
112
# Iterate over the batches of the dataset.
113
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
114
# Open a GradientTape to record the operations run
115
# during the forward pass, which enables auto-differentiation.
116
with tf.GradientTape() as tape:
117
# Run the forward pass of the layer.
118
# The operations that the layer applies
119
# to its inputs are going to be recorded
120
# on the GradientTape.
121
logits = model(x_batch_train, training=True) # Logits for this minibatch
122
123
# Compute the loss value for this minibatch.
124
loss_value = loss_fn(y_batch_train, logits)
125
126
# Use the gradient tape to automatically retrieve
127
# the gradients of the trainable variables with respect to the loss.
128
grads = tape.gradient(loss_value, model.trainable_weights)
129
130
# Run one step of gradient descent by updating
131
# the value of the variables to minimize the loss.
132
optimizer.apply(grads, model.trainable_weights)
133
134
# Log every 100 batches.
135
if step % 100 == 0:
136
print(
137
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
138
)
139
print(f"Seen so far: {(step + 1) * batch_size} samples")
140
141
"""
142
## Low-level handling of metrics
143
144
Let's add metrics monitoring to this basic loop.
145
146
You can readily reuse the built-in metrics (or custom ones you wrote) in such training
147
loops written from scratch. Here's the flow:
148
149
- Instantiate the metric at the start of the loop
150
- Call `metric.update_state()` after each batch
151
- Call `metric.result()` when you need to display the current value of the metric
152
- Call `metric.reset_state()` when you need to clear the state of the metric
153
(typically at the end of an epoch)
154
155
Let's use this knowledge to compute `SparseCategoricalAccuracy` on training and
156
validation data at the end of each epoch:
157
"""
158
159
# Get a fresh model
160
model = get_model()
161
162
# Instantiate an optimizer to train the model.
163
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
164
# Instantiate a loss function.
165
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
166
167
# Prepare the metrics.
168
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
169
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
170
171
"""
172
Here's our training & evaluation loop:
173
"""
174
175
epochs = 2
176
for epoch in range(epochs):
177
print(f"\nStart of epoch {epoch}")
178
start_time = time.time()
179
180
# Iterate over the batches of the dataset.
181
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
182
with tf.GradientTape() as tape:
183
logits = model(x_batch_train, training=True)
184
loss_value = loss_fn(y_batch_train, logits)
185
grads = tape.gradient(loss_value, model.trainable_weights)
186
optimizer.apply(grads, model.trainable_weights)
187
188
# Update training metric.
189
train_acc_metric.update_state(y_batch_train, logits)
190
191
# Log every 100 batches.
192
if step % 100 == 0:
193
print(
194
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
195
)
196
print(f"Seen so far: {(step + 1) * batch_size} samples")
197
198
# Display metrics at the end of each epoch.
199
train_acc = train_acc_metric.result()
200
print(f"Training acc over epoch: {float(train_acc):.4f}")
201
202
# Reset training metrics at the end of each epoch
203
train_acc_metric.reset_state()
204
205
# Run a validation loop at the end of each epoch.
206
for x_batch_val, y_batch_val in val_dataset:
207
val_logits = model(x_batch_val, training=False)
208
# Update val metrics
209
val_acc_metric.update_state(y_batch_val, val_logits)
210
val_acc = val_acc_metric.result()
211
val_acc_metric.reset_state()
212
print(f"Validation acc: {float(val_acc):.4f}")
213
print(f"Time taken: {time.time() - start_time:.2f}s")
214
215
"""
216
## Speeding-up your training step with `tf.function`
217
218
The default runtime in TensorFlow is eager execution.
219
As such, our training loop above executes eagerly.
220
221
This is great for debugging, but graph compilation has a definite performance
222
advantage. Describing your computation as a static graph enables the framework
223
to apply global performance optimizations. This is impossible when
224
the framework is constrained to greedily execute one operation after another,
225
with no knowledge of what comes next.
226
227
You can compile into a static graph any function that takes tensors as input.
228
Just add a `@tf.function` decorator on it, like this:
229
"""
230
231
232
@tf.function
233
def train_step(x, y):
234
with tf.GradientTape() as tape:
235
logits = model(x, training=True)
236
loss_value = loss_fn(y, logits)
237
grads = tape.gradient(loss_value, model.trainable_weights)
238
optimizer.apply(grads, model.trainable_weights)
239
train_acc_metric.update_state(y, logits)
240
return loss_value
241
242
243
"""
244
Let's do the same with the evaluation step:
245
"""
246
247
248
@tf.function
249
def test_step(x, y):
250
val_logits = model(x, training=False)
251
val_acc_metric.update_state(y, val_logits)
252
253
254
"""
255
Now, let's re-run our training loop with this compiled training step:
256
"""
257
258
epochs = 2
259
for epoch in range(epochs):
260
print(f"\nStart of epoch {epoch}")
261
start_time = time.time()
262
263
# Iterate over the batches of the dataset.
264
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
265
loss_value = train_step(x_batch_train, y_batch_train)
266
267
# Log every 100 batches.
268
if step % 100 == 0:
269
print(
270
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
271
)
272
print(f"Seen so far: {(step + 1) * batch_size} samples")
273
274
# Display metrics at the end of each epoch.
275
train_acc = train_acc_metric.result()
276
print(f"Training acc over epoch: {float(train_acc):.4f}")
277
278
# Reset training metrics at the end of each epoch
279
train_acc_metric.reset_state()
280
281
# Run a validation loop at the end of each epoch.
282
for x_batch_val, y_batch_val in val_dataset:
283
test_step(x_batch_val, y_batch_val)
284
285
val_acc = val_acc_metric.result()
286
val_acc_metric.reset_state()
287
print(f"Validation acc: {float(val_acc):.4f}")
288
print(f"Time taken: {time.time() - start_time:.2f}s")
289
290
"""
291
Much faster, isn't it?
292
"""
293
294
"""
295
## Low-level handling of losses tracked by the model
296
297
Layers & models recursively track any losses created during the forward pass
298
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
299
values are available via the property `model.losses`
300
at the end of the forward pass.
301
302
If you want to be using these loss components, you should sum them
303
and add them to the main loss in your training step.
304
305
Consider this layer, that creates an activity regularization loss:
306
307
"""
308
309
310
class ActivityRegularizationLayer(keras.layers.Layer):
311
def call(self, inputs):
312
self.add_loss(1e-2 * tf.reduce_sum(inputs))
313
return inputs
314
315
316
"""
317
Let's build a really simple model that uses it:
318
"""
319
320
inputs = keras.Input(shape=(784,), name="digits")
321
x = keras.layers.Dense(64, activation="relu")(inputs)
322
# Insert activity regularization as a layer
323
x = ActivityRegularizationLayer()(x)
324
x = keras.layers.Dense(64, activation="relu")(x)
325
outputs = keras.layers.Dense(10, name="predictions")(x)
326
327
model = keras.Model(inputs=inputs, outputs=outputs)
328
329
"""
330
Here's what our training step should look like now:
331
"""
332
333
334
@tf.function
335
def train_step(x, y):
336
with tf.GradientTape() as tape:
337
logits = model(x, training=True)
338
loss_value = loss_fn(y, logits)
339
# Add any extra losses created during the forward pass.
340
loss_value += sum(model.losses)
341
grads = tape.gradient(loss_value, model.trainable_weights)
342
optimizer.apply(grads, model.trainable_weights)
343
train_acc_metric.update_state(y, logits)
344
return loss_value
345
346
347
"""
348
## Summary
349
350
Now you know everything there is to know about using built-in training loops and
351
writing your own from scratch.
352
353
To conclude, here's a simple end-to-end example that ties together everything
354
you've learned in this guide: a DCGAN trained on MNIST digits.
355
"""
356
357
"""
358
## End-to-end example: a GAN training loop from scratch
359
360
You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new
361
images that look almost real, by learning the latent distribution of a training
362
dataset of images (the "latent space" of the images).
363
364
A GAN is made of two parts: a "generator" model that maps points in the latent
365
space to points in image space, a "discriminator" model, a classifier
366
that can tell the difference between real images (from the training dataset)
367
and fake images (the output of the generator network).
368
369
A GAN training loop looks like this:
370
371
1) Train the discriminator.
372
- Sample a batch of random points in the latent space.
373
- Turn the points into fake images via the "generator" model.
374
- Get a batch of real images and combine them with the generated images.
375
- Train the "discriminator" model to classify generated vs. real images.
376
377
2) Train the generator.
378
- Sample random points in the latent space.
379
- Turn the points into fake images via the "generator" network.
380
- Get a batch of real images and combine them with the generated images.
381
- Train the "generator" model to "fool" the discriminator and classify the fake images
382
as real.
383
384
For a much more detailed overview of how GANs works, see
385
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).
386
387
Let's implement this training loop. First, create the discriminator meant to classify
388
fake vs real digits:
389
"""
390
391
discriminator = keras.Sequential(
392
[
393
keras.Input(shape=(28, 28, 1)),
394
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
395
keras.layers.LeakyReLU(negative_slope=0.2),
396
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
397
keras.layers.LeakyReLU(negative_slope=0.2),
398
keras.layers.GlobalMaxPooling2D(),
399
keras.layers.Dense(1),
400
],
401
name="discriminator",
402
)
403
discriminator.summary()
404
405
"""
406
Then let's create a generator network,
407
that turns latent vectors into outputs of shape `(28, 28, 1)` (representing
408
MNIST digits):
409
"""
410
411
latent_dim = 128
412
413
generator = keras.Sequential(
414
[
415
keras.Input(shape=(latent_dim,)),
416
# We want to generate 128 coefficients to reshape into a 7x7x128 map
417
keras.layers.Dense(7 * 7 * 128),
418
keras.layers.LeakyReLU(negative_slope=0.2),
419
keras.layers.Reshape((7, 7, 128)),
420
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
421
keras.layers.LeakyReLU(negative_slope=0.2),
422
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
423
keras.layers.LeakyReLU(negative_slope=0.2),
424
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
425
],
426
name="generator",
427
)
428
429
"""
430
Here's the key bit: the training loop. As you can see it is quite straightforward. The
431
training step function only takes 17 lines.
432
"""
433
434
# Instantiate one optimizer for the discriminator and another for the generator.
435
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
436
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
437
438
# Instantiate a loss function.
439
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
440
441
442
@tf.function
443
def train_step(real_images):
444
# Sample random points in the latent space
445
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
446
# Decode them to fake images
447
generated_images = generator(random_latent_vectors)
448
# Combine them with real images
449
combined_images = tf.concat([generated_images, real_images], axis=0)
450
451
# Assemble labels discriminating real from fake images
452
labels = tf.concat(
453
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
454
)
455
# Add random noise to the labels - important trick!
456
labels += 0.05 * tf.random.uniform(labels.shape)
457
458
# Train the discriminator
459
with tf.GradientTape() as tape:
460
predictions = discriminator(combined_images)
461
d_loss = loss_fn(labels, predictions)
462
grads = tape.gradient(d_loss, discriminator.trainable_weights)
463
d_optimizer.apply(grads, discriminator.trainable_weights)
464
465
# Sample random points in the latent space
466
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
467
# Assemble labels that say "all real images"
468
misleading_labels = tf.zeros((batch_size, 1))
469
470
# Train the generator (note that we should *not* update the weights
471
# of the discriminator)!
472
with tf.GradientTape() as tape:
473
predictions = discriminator(generator(random_latent_vectors))
474
g_loss = loss_fn(misleading_labels, predictions)
475
grads = tape.gradient(g_loss, generator.trainable_weights)
476
g_optimizer.apply(grads, generator.trainable_weights)
477
return d_loss, g_loss, generated_images
478
479
480
"""
481
Let's train our GAN, by repeatedly calling `train_step` on batches of images.
482
483
Since our discriminator and generator are convnets, you're going to want to
484
run this code on a GPU.
485
"""
486
487
# Prepare the dataset. We use both the training & test MNIST digits.
488
batch_size = 64
489
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
490
all_digits = np.concatenate([x_train, x_test])
491
all_digits = all_digits.astype("float32") / 255.0
492
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
493
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
494
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
495
496
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
497
save_dir = "./"
498
499
for epoch in range(epochs):
500
print(f"\nStart epoch {epoch}")
501
502
for step, real_images in enumerate(dataset):
503
# Train the discriminator & generator on one batch of real images.
504
d_loss, g_loss, generated_images = train_step(real_images)
505
506
# Logging.
507
if step % 100 == 0:
508
# Print metrics
509
print(f"discriminator loss at step {step}: {d_loss:.2f}")
510
print(f"adversarial loss at step {step}: {g_loss:.2f}")
511
512
# Save one generated image
513
img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
514
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
515
516
# To limit execution time we stop after 10 steps.
517
# Remove the lines below to actually train the model!
518
if step > 10:
519
break
520
521
"""
522
That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the
523
Colab GPU.
524
"""
525
526