Path: blob/master/guides/writing_a_custom_training_loop_in_tensorflow.py
3273 views
"""1Title: Writing a training loop from scratch in TensorFlow2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/03/014Last modified: 2023/06/255Description: Writing low-level training & evaluation loops in TensorFlow.6Accelerator: None7"""89"""10## Setup11"""1213import time14import os1516# This guide can only be run with the TensorFlow backend.17os.environ["KERAS_BACKEND"] = "tensorflow"1819import tensorflow as tf20import keras21import numpy as np2223"""24## Introduction2526Keras provides default training and evaluation loops, `fit()` and `evaluate()`.27Their usage is covered in the guide28[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).2930If you want to customize the learning algorithm of your model while still leveraging31the convenience of `fit()`32(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and33implement your own `train_step()` method, which34is called repeatedly during `fit()`.3536Now, if you want very low-level control over training & evaluation, you should write37your own training & evaluation loops from scratch. This is what this guide is about.38"""3940"""41## A first end-to-end example4243Let's consider a simple MNIST model:44"""454647def get_model():48inputs = keras.Input(shape=(784,), name="digits")49x1 = keras.layers.Dense(64, activation="relu")(inputs)50x2 = keras.layers.Dense(64, activation="relu")(x1)51outputs = keras.layers.Dense(10, name="predictions")(x2)52model = keras.Model(inputs=inputs, outputs=outputs)53return model545556model = get_model()5758"""59Let's train it using mini-batch gradient with a custom training loop.6061First, we're going to need an optimizer, a loss function, and a dataset:62"""6364# Instantiate an optimizer.65optimizer = keras.optimizers.Adam(learning_rate=1e-3)66# Instantiate a loss function.67loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)6869# Prepare the training dataset.70batch_size = 3271(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()72x_train = np.reshape(x_train, (-1, 784))73x_test = np.reshape(x_test, (-1, 784))7475# Reserve 10,000 samples for validation.76x_val = x_train[-10000:]77y_val = y_train[-10000:]78x_train = x_train[:-10000]79y_train = y_train[:-10000]8081# Prepare the training dataset.82train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))83train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)8485# Prepare the validation dataset.86val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))87val_dataset = val_dataset.batch(batch_size)8889"""90Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of91the trainable weights of the layer with respect to a loss value. Using an optimizer92instance, you can use these gradients to update these variables (which you can93retrieve using `model.trainable_weights`).9495Here's our training loop, step by step:9697- We open a `for` loop that iterates over epochs98- For each epoch, we open a `for` loop that iterates over the dataset, in batches99- For each batch, we open a `GradientTape()` scope100- Inside this scope, we call the model (forward pass) and compute the loss101- Outside the scope, we retrieve the gradients of the weights102of the model with regard to the loss103- Finally, we use the optimizer to update the weights of the model based on the104gradients105"""106107epochs = 3108for epoch in range(epochs):109print(f"\nStart of epoch {epoch}")110111# Iterate over the batches of the dataset.112for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):113# Open a GradientTape to record the operations run114# during the forward pass, which enables auto-differentiation.115with tf.GradientTape() as tape:116# Run the forward pass of the layer.117# The operations that the layer applies118# to its inputs are going to be recorded119# on the GradientTape.120logits = model(x_batch_train, training=True) # Logits for this minibatch121122# Compute the loss value for this minibatch.123loss_value = loss_fn(y_batch_train, logits)124125# Use the gradient tape to automatically retrieve126# the gradients of the trainable variables with respect to the loss.127grads = tape.gradient(loss_value, model.trainable_weights)128129# Run one step of gradient descent by updating130# the value of the variables to minimize the loss.131optimizer.apply(grads, model.trainable_weights)132133# Log every 100 batches.134if step % 100 == 0:135print(136f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"137)138print(f"Seen so far: {(step + 1) * batch_size} samples")139140"""141## Low-level handling of metrics142143Let's add metrics monitoring to this basic loop.144145You can readily reuse the built-in metrics (or custom ones you wrote) in such training146loops written from scratch. Here's the flow:147148- Instantiate the metric at the start of the loop149- Call `metric.update_state()` after each batch150- Call `metric.result()` when you need to display the current value of the metric151- Call `metric.reset_state()` when you need to clear the state of the metric152(typically at the end of an epoch)153154Let's use this knowledge to compute `SparseCategoricalAccuracy` on training and155validation data at the end of each epoch:156"""157158# Get a fresh model159model = get_model()160161# Instantiate an optimizer to train the model.162optimizer = keras.optimizers.Adam(learning_rate=1e-3)163# Instantiate a loss function.164loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)165166# Prepare the metrics.167train_acc_metric = keras.metrics.SparseCategoricalAccuracy()168val_acc_metric = keras.metrics.SparseCategoricalAccuracy()169170"""171Here's our training & evaluation loop:172"""173174epochs = 2175for epoch in range(epochs):176print(f"\nStart of epoch {epoch}")177start_time = time.time()178179# Iterate over the batches of the dataset.180for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):181with tf.GradientTape() as tape:182logits = model(x_batch_train, training=True)183loss_value = loss_fn(y_batch_train, logits)184grads = tape.gradient(loss_value, model.trainable_weights)185optimizer.apply(grads, model.trainable_weights)186187# Update training metric.188train_acc_metric.update_state(y_batch_train, logits)189190# Log every 100 batches.191if step % 100 == 0:192print(193f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"194)195print(f"Seen so far: {(step + 1) * batch_size} samples")196197# Display metrics at the end of each epoch.198train_acc = train_acc_metric.result()199print(f"Training acc over epoch: {float(train_acc):.4f}")200201# Reset training metrics at the end of each epoch202train_acc_metric.reset_state()203204# Run a validation loop at the end of each epoch.205for x_batch_val, y_batch_val in val_dataset:206val_logits = model(x_batch_val, training=False)207# Update val metrics208val_acc_metric.update_state(y_batch_val, val_logits)209val_acc = val_acc_metric.result()210val_acc_metric.reset_state()211print(f"Validation acc: {float(val_acc):.4f}")212print(f"Time taken: {time.time() - start_time:.2f}s")213214"""215## Speeding-up your training step with `tf.function`216217The default runtime in TensorFlow is eager execution.218As such, our training loop above executes eagerly.219220This is great for debugging, but graph compilation has a definite performance221advantage. Describing your computation as a static graph enables the framework222to apply global performance optimizations. This is impossible when223the framework is constrained to greedily execute one operation after another,224with no knowledge of what comes next.225226You can compile into a static graph any function that takes tensors as input.227Just add a `@tf.function` decorator on it, like this:228"""229230231@tf.function232def train_step(x, y):233with tf.GradientTape() as tape:234logits = model(x, training=True)235loss_value = loss_fn(y, logits)236grads = tape.gradient(loss_value, model.trainable_weights)237optimizer.apply(grads, model.trainable_weights)238train_acc_metric.update_state(y, logits)239return loss_value240241242"""243Let's do the same with the evaluation step:244"""245246247@tf.function248def test_step(x, y):249val_logits = model(x, training=False)250val_acc_metric.update_state(y, val_logits)251252253"""254Now, let's re-run our training loop with this compiled training step:255"""256257epochs = 2258for epoch in range(epochs):259print(f"\nStart of epoch {epoch}")260start_time = time.time()261262# Iterate over the batches of the dataset.263for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):264loss_value = train_step(x_batch_train, y_batch_train)265266# Log every 100 batches.267if step % 100 == 0:268print(269f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"270)271print(f"Seen so far: {(step + 1) * batch_size} samples")272273# Display metrics at the end of each epoch.274train_acc = train_acc_metric.result()275print(f"Training acc over epoch: {float(train_acc):.4f}")276277# Reset training metrics at the end of each epoch278train_acc_metric.reset_state()279280# Run a validation loop at the end of each epoch.281for x_batch_val, y_batch_val in val_dataset:282test_step(x_batch_val, y_batch_val)283284val_acc = val_acc_metric.result()285val_acc_metric.reset_state()286print(f"Validation acc: {float(val_acc):.4f}")287print(f"Time taken: {time.time() - start_time:.2f}s")288289"""290Much faster, isn't it?291"""292293"""294## Low-level handling of losses tracked by the model295296Layers & models recursively track any losses created during the forward pass297by layers that call `self.add_loss(value)`. The resulting list of scalar loss298values are available via the property `model.losses`299at the end of the forward pass.300301If you want to be using these loss components, you should sum them302and add them to the main loss in your training step.303304Consider this layer, that creates an activity regularization loss:305306"""307308309class ActivityRegularizationLayer(keras.layers.Layer):310def call(self, inputs):311self.add_loss(1e-2 * tf.reduce_sum(inputs))312return inputs313314315"""316Let's build a really simple model that uses it:317"""318319inputs = keras.Input(shape=(784,), name="digits")320x = keras.layers.Dense(64, activation="relu")(inputs)321# Insert activity regularization as a layer322x = ActivityRegularizationLayer()(x)323x = keras.layers.Dense(64, activation="relu")(x)324outputs = keras.layers.Dense(10, name="predictions")(x)325326model = keras.Model(inputs=inputs, outputs=outputs)327328"""329Here's what our training step should look like now:330"""331332333@tf.function334def train_step(x, y):335with tf.GradientTape() as tape:336logits = model(x, training=True)337loss_value = loss_fn(y, logits)338# Add any extra losses created during the forward pass.339loss_value += sum(model.losses)340grads = tape.gradient(loss_value, model.trainable_weights)341optimizer.apply(grads, model.trainable_weights)342train_acc_metric.update_state(y, logits)343return loss_value344345346"""347## Summary348349Now you know everything there is to know about using built-in training loops and350writing your own from scratch.351352To conclude, here's a simple end-to-end example that ties together everything353you've learned in this guide: a DCGAN trained on MNIST digits.354"""355356"""357## End-to-end example: a GAN training loop from scratch358359You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new360images that look almost real, by learning the latent distribution of a training361dataset of images (the "latent space" of the images).362363A GAN is made of two parts: a "generator" model that maps points in the latent364space to points in image space, a "discriminator" model, a classifier365that can tell the difference between real images (from the training dataset)366and fake images (the output of the generator network).367368A GAN training loop looks like this:3693701) Train the discriminator.371- Sample a batch of random points in the latent space.372- Turn the points into fake images via the "generator" model.373- Get a batch of real images and combine them with the generated images.374- Train the "discriminator" model to classify generated vs. real images.3753762) Train the generator.377- Sample random points in the latent space.378- Turn the points into fake images via the "generator" network.379- Get a batch of real images and combine them with the generated images.380- Train the "generator" model to "fool" the discriminator and classify the fake images381as real.382383For a much more detailed overview of how GANs works, see384[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).385386Let's implement this training loop. First, create the discriminator meant to classify387fake vs real digits:388"""389390discriminator = keras.Sequential(391[392keras.Input(shape=(28, 28, 1)),393keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),394keras.layers.LeakyReLU(negative_slope=0.2),395keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),396keras.layers.LeakyReLU(negative_slope=0.2),397keras.layers.GlobalMaxPooling2D(),398keras.layers.Dense(1),399],400name="discriminator",401)402discriminator.summary()403404"""405Then let's create a generator network,406that turns latent vectors into outputs of shape `(28, 28, 1)` (representing407MNIST digits):408"""409410latent_dim = 128411412generator = keras.Sequential(413[414keras.Input(shape=(latent_dim,)),415# We want to generate 128 coefficients to reshape into a 7x7x128 map416keras.layers.Dense(7 * 7 * 128),417keras.layers.LeakyReLU(negative_slope=0.2),418keras.layers.Reshape((7, 7, 128)),419keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),420keras.layers.LeakyReLU(negative_slope=0.2),421keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),422keras.layers.LeakyReLU(negative_slope=0.2),423keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),424],425name="generator",426)427428"""429Here's the key bit: the training loop. As you can see it is quite straightforward. The430training step function only takes 17 lines.431"""432433# Instantiate one optimizer for the discriminator and another for the generator.434d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)435g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)436437# Instantiate a loss function.438loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)439440441@tf.function442def train_step(real_images):443# Sample random points in the latent space444random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))445# Decode them to fake images446generated_images = generator(random_latent_vectors)447# Combine them with real images448combined_images = tf.concat([generated_images, real_images], axis=0)449450# Assemble labels discriminating real from fake images451labels = tf.concat(452[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0453)454# Add random noise to the labels - important trick!455labels += 0.05 * tf.random.uniform(labels.shape)456457# Train the discriminator458with tf.GradientTape() as tape:459predictions = discriminator(combined_images)460d_loss = loss_fn(labels, predictions)461grads = tape.gradient(d_loss, discriminator.trainable_weights)462d_optimizer.apply(grads, discriminator.trainable_weights)463464# Sample random points in the latent space465random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))466# Assemble labels that say "all real images"467misleading_labels = tf.zeros((batch_size, 1))468469# Train the generator (note that we should *not* update the weights470# of the discriminator)!471with tf.GradientTape() as tape:472predictions = discriminator(generator(random_latent_vectors))473g_loss = loss_fn(misleading_labels, predictions)474grads = tape.gradient(g_loss, generator.trainable_weights)475g_optimizer.apply(grads, generator.trainable_weights)476return d_loss, g_loss, generated_images477478479"""480Let's train our GAN, by repeatedly calling `train_step` on batches of images.481482Since our discriminator and generator are convnets, you're going to want to483run this code on a GPU.484"""485486# Prepare the dataset. We use both the training & test MNIST digits.487batch_size = 64488(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()489all_digits = np.concatenate([x_train, x_test])490all_digits = all_digits.astype("float32") / 255.0491all_digits = np.reshape(all_digits, (-1, 28, 28, 1))492dataset = tf.data.Dataset.from_tensor_slices(all_digits)493dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)494495epochs = 1 # In practice you need at least 20 epochs to generate nice digits.496save_dir = "./"497498for epoch in range(epochs):499print(f"\nStart epoch {epoch}")500501for step, real_images in enumerate(dataset):502# Train the discriminator & generator on one batch of real images.503d_loss, g_loss, generated_images = train_step(real_images)504505# Logging.506if step % 100 == 0:507# Print metrics508print(f"discriminator loss at step {step}: {d_loss:.2f}")509print(f"adversarial loss at step {step}: {g_loss:.2f}")510511# Save one generated image512img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)513img.save(os.path.join(save_dir, f"generated_img_{step}.png"))514515# To limit execution time we stop after 10 steps.516# Remove the lines below to actually train the model!517if step > 10:518break519520"""521That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the522Colab GPU.523"""524525526