Writing a training loop from scratch
Author: fchollet
Date created: 2019/03/01
Last modified: 2023/07/10
Description: Complete guide to writing low-level training & evaluation loops.
View in Colab •
GitHub source
Setup
import tensorflow as tf
import keras
from keras import layers
import numpy as np
Introduction
Keras provides default training and evaluation loops, fit()
and evaluate()
. Their usage is covered in the guide Training & evaluation with the built-in methods.
If you want to customize the learning algorithm of your model while still leveraging the convenience of fit()
(for instance, to train a GAN using fit()
), you can subclass the Model
class and implement your own train_step()
method, which is called repeatedly during fit()
. This is covered in the guide Customizing what happens in fit()
.
Now, if you want very low-level control over training & evaluation, you should write your own training & evaluation loops from scratch. This is what this guide is about.
Using the GradientTape
: a first end-to-end example
Calling a model inside a GradientTape
scope enables you to retrieve the gradients of the trainable weights of the layer with respect to a loss value. Using an optimizer instance, you can use these gradients to update these variables (which you can retrieve using model.trainable_weights
).
Let's consider a simple MNIST model:
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
Let's train it using mini-batch gradient with a custom training loop.
First, we're going to need an optimizer, a loss function, and a dataset:
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
Here's our training loop:
We open a for
loop that iterates over epochs
For each epoch, we open a for
loop that iterates over the dataset, in batches
For each batch, we open a GradientTape()
scope
Inside this scope, we call the model (forward pass) and compute the loss
Outside the scope, we retrieve the gradients of the weights of the model with regard to the loss
Finally, we use the optimizer to update the weights of the model based on the gradients
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %s samples" % ((step + 1) * batch_size))
```
Start of epoch 0
Training loss (for one batch) at step 0: 120.0656
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.4296
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.0072
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.8556
Seen so far: 38464 samples
```
```
Start of epoch 1
Training loss (for one batch) at step 0: 0.6670
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.3697
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.3445
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.4279
Seen so far: 38464 samples
</div>
---
## Low-level handling of metrics
Let's add metrics monitoring to this basic loop.
You can readily reuse the built-in metrics (or custom ones you wrote) in such training
loops written from scratch. Here's the flow:
- Instantiate the metric at the start of the loop
- Call `metric.update_state()` after each batch
- Call `metric.result()` when you need to display the current value of the metric
- Call `metric.reset_states()` when you need to clear the state of the metric
(typically at the end of an epoch)
Let's use this knowledge to compute `SparseCategoricalAccuracy` on validation data at
the end of each epoch:
```python
# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
Here's our training & evaluation loop:
import time
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y_batch_train, logits)
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * batch_size))
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val, training=False)
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
```
Start of epoch 0
Training loss (for one batch) at step 0: 154.5849
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.2994
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.0750
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.1264
Seen so far: 38464 samples
Training acc over epoch: 0.7203
Validation acc: 0.8233
Time taken: 7.95s
```
```
Start of epoch 1
Training loss (for one batch) at step 0: 1.0552
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.8037
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.2875
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5536
Seen so far: 38464 samples
Training acc over epoch: 0.8370
Validation acc: 0.8622
Time taken: 7.97s
</div>
---
The default runtime in TensorFlow 2 is
[eager execution](https://www.tensorflow.org/guide/eager).
As such, our training loop above executes eagerly.
This is great for debugging, but graph compilation has a definite performance
advantage. Describing your computation as a static graph enables the framework
to apply global performance optimizations. This is impossible when
the framework is constrained to greedily execute one operation after another,
with no knowledge of what comes next.
You can compile into a static graph any function that takes tensors as input.
Just add a `@tf.function` decorator on it, like this:
```python
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
Let's do the same with the evaluation step:
@tf.function
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
Now, let's re-run our training loop with this compiled training step:
import time
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * batch_size))
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
```
Start of epoch 0
Training loss (for one batch) at step 0: 0.4807
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4289
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.6062
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5791
Seen so far: 38464 samples
Training acc over epoch: 0.8666
Validation acc: 0.8798
Time taken: 1.45s
```
```
Start of epoch 1
Training loss (for one batch) at step 0: 0.5122
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4184
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.2736
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5048
Seen so far: 38464 samples
Training acc over epoch: 0.8823
Validation acc: 0.8872
Time taken: 1.11s
</div>
Much faster, isn't it?
---
## Low-level handling of losses tracked by the model
Layers & models recursively track any losses created during the forward pass
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
values are available via the property `model.losses`
at the end of the forward pass.
If you want to be using these loss components, you should sum them
and add them to the main loss in your training step.
Consider this layer, that creates an activity regularization loss:
```python
@keras.saving.register_keras_serializable()
class ActivityRegularizationLayer(layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * tf.reduce_sum(inputs))
return inputs
Let's build a really simple model that uses it:
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
Here's what our training step should look like now:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
Summary
Now you know everything there is to know about using built-in training loops and writing your own from scratch.
To conclude, here's a simple end-to-end example that ties together everything you've learned in this guide: a DCGAN trained on MNIST digits.
End-to-end example: a GAN training loop from scratch
You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new images that look almost real, by learning the latent distribution of a training dataset of images (the "latent space" of the images).
A GAN is made of two parts: a "generator" model that maps points in the latent space to points in image space, a "discriminator" model, a classifier that can tell the difference between real images (from the training dataset) and fake images (the output of the generator network).
A GAN training loop looks like this:
Train the discriminator.
Sample a batch of random points in the latent space.
Turn the points into fake images via the "generator" model.
Get a batch of real images and combine them with the generated images.
Train the "discriminator" model to classify generated vs. real images.
Train the generator.
Sample random points in the latent space.
Turn the points into fake images via the "generator" network.
Get a batch of real images and combine them with the generated images.
Train the "generator" model to "fool" the discriminator and classify the fake images as real.
For a much more detailed overview of how GANs works, see Deep Learning with Python.
Let's implement this training loop. First, create the discriminator meant to classify fake vs real digits:
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
discriminator.summary()
```
Model: "discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 14, 14, 64) 640
leaky_re_lu (LeakyReLU) (None, 14, 14, 64) 0
conv2d_1 (Conv2D) (None, 7, 7, 128) 73856
leaky_re_lu_1 (LeakyReLU) (None, 7, 7, 128) 0
global_max_pooling2d (Glob (None, 128) 0
alMaxPooling2D)
dense_4 (Dense) (None, 1) 129
================================================================= Total params: 74625 (291.50 KB) Trainable params: 74625 (291.50 KB) Non-trainable params: 0 (0.00 Byte)
</div>
Then let's create a generator network,
that turns latent vectors into outputs of shape `(28, 28, 1)` (representing
MNIST digits):
```python
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
Here's the key bit: the training loop. As you can see it is quite straightforward. The training step function only takes 17 lines.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
@tf.function
def train_step(real_images):
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
generated_images = generator(random_latent_vectors)
combined_images = tf.concat([generated_images, real_images], axis=0)
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
)
labels += 0.05 * tf.random.uniform(labels.shape)
with tf.GradientTape() as tape:
predictions = discriminator(combined_images)
d_loss = loss_fn(labels, predictions)
grads = tape.gradient(d_loss, discriminator.trainable_weights)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
misleading_labels = tf.zeros((batch_size, 1))
with tf.GradientTape() as tape:
predictions = discriminator(generator(random_latent_vectors))
g_loss = loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, generator.trainable_weights)
g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
return d_loss, g_loss, generated_images
Let's train our GAN, by repeatedly calling train_step
on batches of images.
Since our discriminator and generator are convnets, you're going to want to run this code on a GPU.
import os
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
epochs = 1
save_dir = "./"
for epoch in range(epochs):
print("\nStart epoch", epoch)
for step, real_images in enumerate(dataset):
d_loss, g_loss, generated_images = train_step(real_images)
if step % 200 == 0:
print("discriminator loss at step %d: %.2f" % (step, d_loss))
print("adversarial loss at step %d: %.2f" % (step, g_loss))
img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png"))
if step > 10:
break
```
Start epoch 0
discriminator loss at step 0: 0.68
adversarial loss at step 0: 0.69
</div>
That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the
Colab GPU.