Path: blob/master/guides/custom_train_step_in_tensorflow.py
3273 views
"""1Title: Customizing what happens in `fit()` with TensorFlow2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/04/154Last modified: 2023/06/275Description: Overriding the training step of the Model class with TensorFlow.6Accelerator: GPU7"""89"""10## Introduction1112When you're doing supervised learning, you can use `fit()` and everything works13smoothly.1415When you need to take control of every little detail, you can write your own training16loop entirely from scratch.1718But what if you need a custom training algorithm, but you still want to benefit from19the convenient features of `fit()`, such as callbacks, built-in distribution support,20or step fusing?2122A core principle of Keras is **progressive disclosure of complexity**. You should23always be able to get into lower-level workflows in a gradual way. You shouldn't fall24off a cliff if the high-level functionality doesn't exactly match your use case. You25should be able to gain more control over the small details while retaining a26commensurate amount of high-level convenience.2728When you need to customize what `fit()` does, you should **override the training step29function of the `Model` class**. This is the function that is called by `fit()` for30every batch of data. You will then be able to call `fit()` as usual -- and it will be31running your own learning algorithm.3233Note that this pattern does not prevent you from building models with the Functional34API. You can do this whether you're building `Sequential` models, Functional API35models, or subclassed models.3637Let's see how that works.38"""3940"""41## Setup42"""4344import os4546# This guide can only be run with the TF backend.47os.environ["KERAS_BACKEND"] = "tensorflow"4849import tensorflow as tf50import keras51from keras import layers52import numpy as np5354"""55## A first simple example5657Let's start from a simple example:5859- We create a new class that subclasses `keras.Model`.60- We just override the method `train_step(self, data)`.61- We return a dictionary mapping metric names (including the loss) to their current62value.6364The input argument `data` is what gets passed to fit as training data:6566- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple67`(x, y)`68- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be69what gets yielded by `dataset` at each batch.7071In the body of the `train_step()` method, we implement a regular training update,72similar to what you are already familiar with. Importantly, **we compute the loss via73`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to74`compile()`.7576Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,77to update the state of the metrics that were passed in `compile()`,78and we query results from `self.metrics` at the end to retrieve their current value.79"""808182class CustomModel(keras.Model):83def train_step(self, data):84# Unpack the data. Its structure depends on your model and85# on what you pass to `fit()`.86x, y = data8788with tf.GradientTape() as tape:89y_pred = self(x, training=True) # Forward pass90# Compute the loss value91# (the loss function is configured in `compile()`)92loss = self.compute_loss(y=y, y_pred=y_pred)9394# Compute gradients95trainable_vars = self.trainable_variables96gradients = tape.gradient(loss, trainable_vars)9798# Update weights99self.optimizer.apply(gradients, trainable_vars)100101# Update metrics (includes the metric that tracks the loss)102for metric in self.metrics:103if metric.name == "loss":104metric.update_state(loss)105else:106metric.update_state(y, y_pred)107108# Return a dict mapping metric names to current value109return {m.name: m.result() for m in self.metrics}110111112"""113Let's try this out:114"""115116# Construct and compile an instance of CustomModel117inputs = keras.Input(shape=(32,))118outputs = keras.layers.Dense(1)(inputs)119model = CustomModel(inputs, outputs)120model.compile(optimizer="adam", loss="mse", metrics=["mae"])121122# Just use `fit` as usual123x = np.random.random((1000, 32))124y = np.random.random((1000, 1))125model.fit(x, y, epochs=3)126127"""128## Going lower-level129130Naturally, you could just skip passing a loss function in `compile()`, and instead do131everything *manually* in `train_step`. Likewise for metrics.132133Here's a lower-level example, that only uses `compile()` to configure the optimizer:134135- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).136- We implement a custom `train_step()` that updates the state of these metrics137(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,138to be displayed by the progress bar and to be pass to any callback.139- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise140calling `result()` would return an average since the start of training, whereas we usually work141with per-epoch averages. Thankfully, the framework can do that for us: just list any metric142you want to reset in the `metrics` property of the model. The model will call `reset_states()`143on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to144`evaluate()`.145"""146147148class CustomModel(keras.Model):149def __init__(self, *args, **kwargs):150super().__init__(*args, **kwargs)151self.loss_tracker = keras.metrics.Mean(name="loss")152self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")153self.loss_fn = keras.losses.MeanSquaredError()154155def train_step(self, data):156x, y = data157158with tf.GradientTape() as tape:159y_pred = self(x, training=True) # Forward pass160# Compute our own loss161loss = self.loss_fn(y, y_pred)162163# Compute gradients164trainable_vars = self.trainable_variables165gradients = tape.gradient(loss, trainable_vars)166167# Update weights168self.optimizer.apply(gradients, trainable_vars)169170# Compute our own metrics171self.loss_tracker.update_state(loss)172self.mae_metric.update_state(y, y_pred)173return {174"loss": self.loss_tracker.result(),175"mae": self.mae_metric.result(),176}177178@property179def metrics(self):180# We list our `Metric` objects here so that `reset_states()` can be181# called automatically at the start of each epoch182# or at the start of `evaluate()`.183return [self.loss_tracker, self.mae_metric]184185186# Construct an instance of CustomModel187inputs = keras.Input(shape=(32,))188outputs = keras.layers.Dense(1)(inputs)189model = CustomModel(inputs, outputs)190191# We don't pass a loss or metrics here.192model.compile(optimizer="adam")193194# Just use `fit` as usual -- you can use callbacks, etc.195x = np.random.random((1000, 32))196y = np.random.random((1000, 1))197model.fit(x, y, epochs=5)198199200"""201## Supporting `sample_weight` & `class_weight`202203You may have noticed that our first basic example didn't make any mention of sample204weighting. If you want to support the `fit()` arguments `sample_weight` and205`class_weight`, you'd simply do the following:206207- Unpack `sample_weight` from the `data` argument208- Pass it to `compute_loss` & `update_state` (of course, you could also just apply209it manually if you don't rely on `compile()` for losses & metrics)210- That's it.211"""212213214class CustomModel(keras.Model):215def train_step(self, data):216# Unpack the data. Its structure depends on your model and217# on what you pass to `fit()`.218if len(data) == 3:219x, y, sample_weight = data220else:221sample_weight = None222x, y = data223224with tf.GradientTape() as tape:225y_pred = self(x, training=True) # Forward pass226# Compute the loss value.227# The loss function is configured in `compile()`.228loss = self.compute_loss(229y=y,230y_pred=y_pred,231sample_weight=sample_weight,232)233234# Compute gradients235trainable_vars = self.trainable_variables236gradients = tape.gradient(loss, trainable_vars)237238# Update weights239self.optimizer.apply(gradients, trainable_vars)240241# Update the metrics.242# Metrics are configured in `compile()`.243for metric in self.metrics:244if metric.name == "loss":245metric.update_state(loss)246else:247metric.update_state(y, y_pred, sample_weight=sample_weight)248249# Return a dict mapping metric names to current value.250# Note that it will include the loss (tracked in self.metrics).251return {m.name: m.result() for m in self.metrics}252253254# Construct and compile an instance of CustomModel255inputs = keras.Input(shape=(32,))256outputs = keras.layers.Dense(1)(inputs)257model = CustomModel(inputs, outputs)258model.compile(optimizer="adam", loss="mse", metrics=["mae"])259260# You can now use sample_weight argument261x = np.random.random((1000, 32))262y = np.random.random((1000, 1))263sw = np.random.random((1000, 1))264model.fit(x, y, sample_weight=sw, epochs=3)265266"""267## Providing your own evaluation step268269What if you want to do the same for calls to `model.evaluate()`? Then you would270override `test_step` in exactly the same way. Here's what it looks like:271"""272273274class CustomModel(keras.Model):275def test_step(self, data):276# Unpack the data277x, y = data278# Compute predictions279y_pred = self(x, training=False)280# Updates the metrics tracking the loss281loss = self.compute_loss(y=y, y_pred=y_pred)282# Update the metrics.283for metric in self.metrics:284if metric.name == "loss":285metric.update_state(loss)286else:287metric.update_state(y, y_pred)288# Return a dict mapping metric names to current value.289# Note that it will include the loss (tracked in self.metrics).290return {m.name: m.result() for m in self.metrics}291292293# Construct an instance of CustomModel294inputs = keras.Input(shape=(32,))295outputs = keras.layers.Dense(1)(inputs)296model = CustomModel(inputs, outputs)297model.compile(loss="mse", metrics=["mae"])298299# Evaluate with our custom test_step300x = np.random.random((1000, 32))301y = np.random.random((1000, 1))302model.evaluate(x, y)303304"""305## Wrapping up: an end-to-end GAN example306307Let's walk through an end-to-end example that leverages everything you just learned.308309Let's consider:310311- A generator network meant to generate 28x28x1 images.312- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and313"real").314- One optimizer for each.315- A loss function to train the discriminator.316"""317318# Create the discriminator319discriminator = keras.Sequential(320[321keras.Input(shape=(28, 28, 1)),322layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),323layers.LeakyReLU(negative_slope=0.2),324layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),325layers.LeakyReLU(negative_slope=0.2),326layers.GlobalMaxPooling2D(),327layers.Dense(1),328],329name="discriminator",330)331332# Create the generator333latent_dim = 128334generator = keras.Sequential(335[336keras.Input(shape=(latent_dim,)),337# We want to generate 128 coefficients to reshape into a 7x7x128 map338layers.Dense(7 * 7 * 128),339layers.LeakyReLU(negative_slope=0.2),340layers.Reshape((7, 7, 128)),341layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),342layers.LeakyReLU(negative_slope=0.2),343layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),344layers.LeakyReLU(negative_slope=0.2),345layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),346],347name="generator",348)349350"""351Here's a feature-complete GAN class, overriding `compile()` to use its own signature,352and implementing the entire GAN algorithm in 17 lines in `train_step`:353"""354355356class GAN(keras.Model):357def __init__(self, discriminator, generator, latent_dim):358super().__init__()359self.discriminator = discriminator360self.generator = generator361self.latent_dim = latent_dim362self.d_loss_tracker = keras.metrics.Mean(name="d_loss")363self.g_loss_tracker = keras.metrics.Mean(name="g_loss")364self.seed_generator = keras.random.SeedGenerator(1337)365366@property367def metrics(self):368return [self.d_loss_tracker, self.g_loss_tracker]369370def compile(self, d_optimizer, g_optimizer, loss_fn):371super().compile()372self.d_optimizer = d_optimizer373self.g_optimizer = g_optimizer374self.loss_fn = loss_fn375376def train_step(self, real_images):377if isinstance(real_images, tuple):378real_images = real_images[0]379# Sample random points in the latent space380batch_size = tf.shape(real_images)[0]381random_latent_vectors = keras.random.normal(382shape=(batch_size, self.latent_dim), seed=self.seed_generator383)384385# Decode them to fake images386generated_images = self.generator(random_latent_vectors)387388# Combine them with real images389combined_images = tf.concat([generated_images, real_images], axis=0)390391# Assemble labels discriminating real from fake images392labels = tf.concat(393[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0394)395# Add random noise to the labels - important trick!396labels += 0.05 * keras.random.uniform(397tf.shape(labels), seed=self.seed_generator398)399400# Train the discriminator401with tf.GradientTape() as tape:402predictions = self.discriminator(combined_images)403d_loss = self.loss_fn(labels, predictions)404grads = tape.gradient(d_loss, self.discriminator.trainable_weights)405self.d_optimizer.apply(grads, self.discriminator.trainable_weights)406407# Sample random points in the latent space408random_latent_vectors = keras.random.normal(409shape=(batch_size, self.latent_dim), seed=self.seed_generator410)411412# Assemble labels that say "all real images"413misleading_labels = tf.zeros((batch_size, 1))414415# Train the generator (note that we should *not* update the weights416# of the discriminator)!417with tf.GradientTape() as tape:418predictions = self.discriminator(self.generator(random_latent_vectors))419g_loss = self.loss_fn(misleading_labels, predictions)420grads = tape.gradient(g_loss, self.generator.trainable_weights)421self.g_optimizer.apply(grads, self.generator.trainable_weights)422423# Update metrics and return their value.424self.d_loss_tracker.update_state(d_loss)425self.g_loss_tracker.update_state(g_loss)426return {427"d_loss": self.d_loss_tracker.result(),428"g_loss": self.g_loss_tracker.result(),429}430431432"""433Let's test-drive it:434"""435436# Prepare the dataset. We use both the training & test MNIST digits.437batch_size = 64438(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()439all_digits = np.concatenate([x_train, x_test])440all_digits = all_digits.astype("float32") / 255.0441all_digits = np.reshape(all_digits, (-1, 28, 28, 1))442dataset = tf.data.Dataset.from_tensor_slices(all_digits)443dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)444445gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)446gan.compile(447d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),448g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),449loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),450)451452# To limit the execution time, we only train on 100 batches. You can train on453# the entire dataset. You will need about 20 epochs to get nice results.454gan.fit(dataset.take(100), epochs=1)455456"""457The ideas behind deep learning are simple, so why should their implementation be painful?458"""459460461