Path: blob/master/guides/_customizing_what_happens_in_fit.py
3273 views
"""1Title: Customizing what happens in `fit()`2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/04/154Last modified: 2023/06/145Description: Complete guide to overriding the training step of the Model class.6Accelerator: GPU7"""89"""10## Introduction1112When you're doing supervised learning, you can use `fit()` and everything works13smoothly.1415When you need to write your own training loop from scratch, you can use the16`GradientTape` and take control of every little detail.1718But what if you need a custom training algorithm, but you still want to benefit from19the convenient features of `fit()`, such as callbacks, built-in distribution support,20or step fusing?2122A core principle of Keras is **progressive disclosure of complexity**. You should23always be able to get into lower-level workflows in a gradual way. You shouldn't fall24off a cliff if the high-level functionality doesn't exactly match your use case. You25should be able to gain more control over the small details while retaining a26commensurate amount of high-level convenience.2728When you need to customize what `fit()` does, you should **override the training step29function of the `Model` class**. This is the function that is called by `fit()` for30every batch of data. You will then be able to call `fit()` as usual -- and it will be31running your own learning algorithm.3233Note that this pattern does not prevent you from building models with the Functional34API. You can do this whether you're building `Sequential` models, Functional API35models, or subclassed models.3637Let's see how that works.38"""3940"""41## Setup4243Requires TensorFlow 2.8 or later.44"""4546import tensorflow as tf47import keras4849"""50## A first simple example5152Let's start from a simple example:5354- We create a new class that subclasses `keras.Model`.55- We just override the method `train_step(self, data)`.56- We return a dictionary mapping metric names (including the loss) to their current57value.5859The input argument `data` is what gets passed to fit as training data:6061- If you pass Numpy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple62`(x, y)`63- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be64what gets yielded by `dataset` at each batch.6566In the body of the `train_step` method, we implement a regular training update,67similar to what you are already familiar with. Importantly, **we compute the loss via68`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to69`compile()`.7071Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,72to update the state of the metrics that were passed in `compile()`,73and we query results from `self.metrics` at the end to retrieve their current value.74"""757677class CustomModel(keras.Model):78def train_step(self, data):79# Unpack the data. Its structure depends on your model and80# on what you pass to `fit()`.81x, y = data8283with tf.GradientTape() as tape:84y_pred = self(x, training=True) # Forward pass85# Compute the loss value86# (the loss function is configured in `compile()`)87loss = self.compute_loss(y=y, y_pred=y_pred)8889# Compute gradients90trainable_vars = self.trainable_variables91gradients = tape.gradient(loss, trainable_vars)92# Update weights93self.optimizer.apply_gradients(zip(gradients, trainable_vars))94# Update metrics (includes the metric that tracks the loss)95for metric in self.metrics:96if metric.name == "loss":97metric.update_state(loss)98else:99metric.update_state(y, y_pred)100# Return a dict mapping metric names to current value101return {m.name: m.result() for m in self.metrics}102103104"""105Let's try this out:106"""107108import numpy as np109110# Construct and compile an instance of CustomModel111inputs = keras.Input(shape=(32,))112outputs = keras.layers.Dense(1)(inputs)113model = CustomModel(inputs, outputs)114model.compile(optimizer="adam", loss="mse", metrics=["mae"])115116# Just use `fit` as usual117x = np.random.random((1000, 32))118y = np.random.random((1000, 1))119model.fit(x, y, epochs=3)120121"""122## Going lower-level123124Naturally, you could just skip passing a loss function in `compile()`, and instead do125everything *manually* in `train_step`. Likewise for metrics.126127Here's a lower-level example, that only uses `compile()` to configure the optimizer:128129- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).130- We implement a custom `train_step()` that updates the state of these metrics131(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,132to be displayed by the progress bar and to be pass to any callback.133- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise134calling `result()` would return an average since the start of training, whereas we usually work135with per-epoch averages. Thankfully, the framework can do that for us: just list any metric136you want to reset in the `metrics` property of the model. The model will call `reset_states()`137on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to138`evaluate()`.139"""140141142class CustomModel(keras.Model):143def __init__(self, *args, **kwargs):144super().__init__(*args, **kwargs)145self.loss_tracker = keras.metrics.Mean(name="loss")146self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")147148def train_step(self, data):149x, y = data150151with tf.GradientTape() as tape:152y_pred = self(x, training=True) # Forward pass153# Compute our own loss154loss = keras.losses.mean_squared_error(y, y_pred)155156# Compute gradients157trainable_vars = self.trainable_variables158gradients = tape.gradient(loss, trainable_vars)159160# Update weights161self.optimizer.apply_gradients(zip(gradients, trainable_vars))162163# Compute our own metrics164self.loss_tracker.update_state(loss)165self.mae_metric.update_state(y, y_pred)166return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}167168@property169def metrics(self):170# We list our `Metric` objects here so that `reset_states()` can be171# called automatically at the start of each epoch172# or at the start of `evaluate()`.173# If you don't implement this property, you have to call174# `reset_states()` yourself at the time of your choosing.175return [self.loss_tracker, self.mae_metric]176177178# Construct an instance of CustomModel179inputs = keras.Input(shape=(32,))180outputs = keras.layers.Dense(1)(inputs)181model = CustomModel(inputs, outputs)182183# We don't pass a loss or metrics here.184model.compile(optimizer="adam")185186# Just use `fit` as usual -- you can use callbacks, etc.187x = np.random.random((1000, 32))188y = np.random.random((1000, 1))189model.fit(x, y, epochs=5)190191192"""193## Supporting `sample_weight` & `class_weight`194195You may have noticed that our first basic example didn't make any mention of sample196weighting. If you want to support the `fit()` arguments `sample_weight` and197`class_weight`, you'd simply do the following:198199- Unpack `sample_weight` from the `data` argument200- Pass it to `compute_loss` & `update_state` (of course, you could also just apply201it manually if you don't rely on `compile()` for losses & metrics)202- That's it.203"""204205206class CustomModel(keras.Model):207def train_step(self, data):208# Unpack the data. Its structure depends on your model and209# on what you pass to `fit()`.210if len(data) == 3:211x, y, sample_weight = data212else:213sample_weight = None214x, y = data215216with tf.GradientTape() as tape:217y_pred = self(x, training=True) # Forward pass218# Compute the loss value.219# The loss function is configured in `compile()`.220loss = self.compute_loss(221y=y,222y_pred=y_pred,223sample_weight=sample_weight,224)225226# Compute gradients227trainable_vars = self.trainable_variables228gradients = tape.gradient(loss, trainable_vars)229230# Update weights231self.optimizer.apply_gradients(zip(gradients, trainable_vars))232233# Update the metrics.234# Metrics are configured in `compile()`.235for metric in self.metrics:236if metric.name == "loss":237metric.update_state(loss)238else:239metric.update_state(y, y_pred, sample_weight=sample_weight)240241# Return a dict mapping metric names to current value.242# Note that it will include the loss (tracked in self.metrics).243return {m.name: m.result() for m in self.metrics}244245246# Construct and compile an instance of CustomModel247inputs = keras.Input(shape=(32,))248outputs = keras.layers.Dense(1)(inputs)249model = CustomModel(inputs, outputs)250model.compile(optimizer="adam", loss="mse", metrics=["mae"])251252# You can now use sample_weight argument253x = np.random.random((1000, 32))254y = np.random.random((1000, 1))255sw = np.random.random((1000, 1))256model.fit(x, y, sample_weight=sw, epochs=3)257258"""259## Providing your own evaluation step260261What if you want to do the same for calls to `model.evaluate()`? Then you would262override `test_step` in exactly the same way. Here's what it looks like:263"""264265266class CustomModel(keras.Model):267def test_step(self, data):268# Unpack the data269x, y = data270# Compute predictions271y_pred = self(x, training=False)272# Updates the metrics tracking the loss273self.compute_loss(y=y, y_pred=y_pred)274# Update the metrics.275for metric in self.metrics:276if metric.name != "loss":277metric.update_state(y, y_pred)278# Return a dict mapping metric names to current value.279# Note that it will include the loss (tracked in self.metrics).280return {m.name: m.result() for m in self.metrics}281282283# Construct an instance of CustomModel284inputs = keras.Input(shape=(32,))285outputs = keras.layers.Dense(1)(inputs)286model = CustomModel(inputs, outputs)287model.compile(loss="mse", metrics=["mae"])288289# Evaluate with our custom test_step290x = np.random.random((1000, 32))291y = np.random.random((1000, 1))292model.evaluate(x, y)293294"""295## Wrapping up: an end-to-end GAN example296297Let's walk through an end-to-end example that leverages everything you just learned.298299Let's consider:300301- A generator network meant to generate 28x28x1 images.302- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and303"real").304- One optimizer for each.305- A loss function to train the discriminator.306307308"""309310from tensorflow.keras import layers311312# Create the discriminator313discriminator = keras.Sequential(314[315keras.Input(shape=(28, 28, 1)),316layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),317layers.LeakyReLU(alpha=0.2),318layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),319layers.LeakyReLU(alpha=0.2),320layers.GlobalMaxPooling2D(),321layers.Dense(1),322],323name="discriminator",324)325326# Create the generator327latent_dim = 128328generator = keras.Sequential(329[330keras.Input(shape=(latent_dim,)),331# We want to generate 128 coefficients to reshape into a 7x7x128 map332layers.Dense(7 * 7 * 128),333layers.LeakyReLU(alpha=0.2),334layers.Reshape((7, 7, 128)),335layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),336layers.LeakyReLU(alpha=0.2),337layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),338layers.LeakyReLU(alpha=0.2),339layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),340],341name="generator",342)343344"""345Here's a feature-complete GAN class, overriding `compile()` to use its own signature,346and implementing the entire GAN algorithm in 17 lines in `train_step`:347"""348349350class GAN(keras.Model):351def __init__(self, discriminator, generator, latent_dim):352super().__init__()353self.discriminator = discriminator354self.generator = generator355self.latent_dim = latent_dim356self.d_loss_tracker = keras.metrics.Mean(name="d_loss")357self.g_loss_tracker = keras.metrics.Mean(name="g_loss")358359def compile(self, d_optimizer, g_optimizer, loss_fn):360super().compile()361self.d_optimizer = d_optimizer362self.g_optimizer = g_optimizer363self.loss_fn = loss_fn364365def train_step(self, real_images):366if isinstance(real_images, tuple):367real_images = real_images[0]368# Sample random points in the latent space369batch_size = tf.shape(real_images)[0]370random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))371372# Decode them to fake images373generated_images = self.generator(random_latent_vectors)374375# Combine them with real images376combined_images = tf.concat([generated_images, real_images], axis=0)377378# Assemble labels discriminating real from fake images379labels = tf.concat(380[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0381)382# Add random noise to the labels - important trick!383labels += 0.05 * tf.random.uniform(tf.shape(labels))384385# Train the discriminator386with tf.GradientTape() as tape:387predictions = self.discriminator(combined_images)388d_loss = self.loss_fn(labels, predictions)389grads = tape.gradient(d_loss, self.discriminator.trainable_weights)390self.d_optimizer.apply_gradients(391zip(grads, self.discriminator.trainable_weights)392)393394# Sample random points in the latent space395random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))396397# Assemble labels that say "all real images"398misleading_labels = tf.zeros((batch_size, 1))399400# Train the generator (note that we should *not* update the weights401# of the discriminator)!402with tf.GradientTape() as tape:403predictions = self.discriminator(self.generator(random_latent_vectors))404g_loss = self.loss_fn(misleading_labels, predictions)405grads = tape.gradient(g_loss, self.generator.trainable_weights)406self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))407408# Update metrics and return their value.409self.d_loss_tracker.update_state(d_loss)410self.g_loss_tracker.update_state(g_loss)411return {412"d_loss": self.d_loss_tracker.result(),413"g_loss": self.g_loss_tracker.result(),414}415416417"""418Let's test-drive it:419"""420421# Prepare the dataset. We use both the training & test MNIST digits.422batch_size = 64423(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()424all_digits = np.concatenate([x_train, x_test])425all_digits = all_digits.astype("float32") / 255.0426all_digits = np.reshape(all_digits, (-1, 28, 28, 1))427dataset = tf.data.Dataset.from_tensor_slices(all_digits)428dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)429430gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)431gan.compile(432d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),433g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),434loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),435)436437# To limit the execution time, we only train on 100 batches. You can train on438# the entire dataset. You will need about 20 epochs to get nice results.439gan.fit(dataset.take(100), epochs=1)440441"""442The ideas behind deep learning are simple, so why should their implementation be painful?443"""444445446