Path: blob/master/guides/custom_train_step_in_torch.py
3273 views
"""1Title: Customizing what happens in `fit()` with PyTorch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/274Last modified: 2024/08/015Description: Overriding the training step of the Model class with PyTorch.6Accelerator: GPU7"""89"""10## Introduction1112When you're doing supervised learning, you can use `fit()` and everything works13smoothly.1415When you need to take control of every little detail, you can write your own training16loop entirely from scratch.1718But what if you need a custom training algorithm, but you still want to benefit from19the convenient features of `fit()`, such as callbacks, built-in distribution support,20or step fusing?2122A core principle of Keras is **progressive disclosure of complexity**. You should23always be able to get into lower-level workflows in a gradual way. You shouldn't fall24off a cliff if the high-level functionality doesn't exactly match your use case. You25should be able to gain more control over the small details while retaining a26commensurate amount of high-level convenience.2728When you need to customize what `fit()` does, you should **override the training step29function of the `Model` class**. This is the function that is called by `fit()` for30every batch of data. You will then be able to call `fit()` as usual -- and it will be31running your own learning algorithm.3233Note that this pattern does not prevent you from building models with the Functional34API. You can do this whether you're building `Sequential` models, Functional API35models, or subclassed models.3637Let's see how that works.38"""3940"""41## Setup42"""4344import os4546# This guide can only be run with the torch backend.47os.environ["KERAS_BACKEND"] = "torch"4849import torch50import keras51from keras import layers52import numpy as np5354"""55## A first simple example5657Let's start from a simple example:5859- We create a new class that subclasses `keras.Model`.60- We just override the method `train_step(self, data)`.61- We return a dictionary mapping metric names (including the loss) to their current62value.6364The input argument `data` is what gets passed to fit as training data:6566- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple67`(x, y)`68- If you pass a `torch.utils.data.DataLoader` or a `tf.data.Dataset`,69by calling `fit(dataset, ...)`, then `data` will be what gets yielded70by `dataset` at each batch.7172In the body of the `train_step()` method, we implement a regular training update,73similar to what you are already familiar with. Importantly, **we compute the loss via74`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to75`compile()`.7677Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,78to update the state of the metrics that were passed in `compile()`,79and we query results from `self.metrics` at the end to retrieve their current value.80"""818283class CustomModel(keras.Model):84def train_step(self, data):85# Unpack the data. Its structure depends on your model and86# on what you pass to `fit()`.87x, y = data8889# Call torch.nn.Module.zero_grad() to clear the leftover gradients90# for the weights from the previous train step.91self.zero_grad()9293# Compute loss94y_pred = self(x, training=True) # Forward pass95loss = self.compute_loss(y=y, y_pred=y_pred)9697# Call torch.Tensor.backward() on the loss to compute gradients98# for the weights.99loss.backward()100101trainable_weights = [v for v in self.trainable_weights]102gradients = [v.value.grad for v in trainable_weights]103104# Update weights105with torch.no_grad():106self.optimizer.apply(gradients, trainable_weights)107108# Update metrics (includes the metric that tracks the loss)109for metric in self.metrics:110if metric.name == "loss":111metric.update_state(loss)112else:113metric.update_state(y, y_pred)114115# Return a dict mapping metric names to current value116# Note that it will include the loss (tracked in self.metrics).117return {m.name: m.result() for m in self.metrics}118119120"""121Let's try this out:122"""123124# Construct and compile an instance of CustomModel125inputs = keras.Input(shape=(32,))126outputs = keras.layers.Dense(1)(inputs)127model = CustomModel(inputs, outputs)128model.compile(optimizer="adam", loss="mse", metrics=["mae"])129130# Just use `fit` as usual131x = np.random.random((1000, 32))132y = np.random.random((1000, 1))133model.fit(x, y, epochs=3)134135"""136## Going lower-level137138Naturally, you could just skip passing a loss function in `compile()`, and instead do139everything *manually* in `train_step`. Likewise for metrics.140141Here's a lower-level example, that only uses `compile()` to configure the optimizer:142143- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).144- We implement a custom `train_step()` that updates the state of these metrics145(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,146to be displayed by the progress bar and to be pass to any callback.147- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise148calling `result()` would return an average since the start of training, whereas we usually work149with per-epoch averages. Thankfully, the framework can do that for us: just list any metric150you want to reset in the `metrics` property of the model. The model will call `reset_states()`151on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to152`evaluate()`.153"""154155156class CustomModel(keras.Model):157def __init__(self, *args, **kwargs):158super().__init__(*args, **kwargs)159self.loss_tracker = keras.metrics.Mean(name="loss")160self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")161self.loss_fn = keras.losses.MeanSquaredError()162163def train_step(self, data):164x, y = data165166# Call torch.nn.Module.zero_grad() to clear the leftover gradients167# for the weights from the previous train step.168self.zero_grad()169170# Compute loss171y_pred = self(x, training=True) # Forward pass172loss = self.loss_fn(y, y_pred)173174# Call torch.Tensor.backward() on the loss to compute gradients175# for the weights.176loss.backward()177178trainable_weights = [v for v in self.trainable_weights]179gradients = [v.value.grad for v in trainable_weights]180181# Update weights182with torch.no_grad():183self.optimizer.apply(gradients, trainable_weights)184185# Compute our own metrics186self.loss_tracker.update_state(loss)187self.mae_metric.update_state(y, y_pred)188return {189"loss": self.loss_tracker.result(),190"mae": self.mae_metric.result(),191}192193@property194def metrics(self):195# We list our `Metric` objects here so that `reset_states()` can be196# called automatically at the start of each epoch197# or at the start of `evaluate()`.198return [self.loss_tracker, self.mae_metric]199200201# Construct an instance of CustomModel202inputs = keras.Input(shape=(32,))203outputs = keras.layers.Dense(1)(inputs)204model = CustomModel(inputs, outputs)205206# We don't pass a loss or metrics here.207model.compile(optimizer="adam")208209# Just use `fit` as usual -- you can use callbacks, etc.210x = np.random.random((1000, 32))211y = np.random.random((1000, 1))212model.fit(x, y, epochs=5)213214215"""216## Supporting `sample_weight` & `class_weight`217218You may have noticed that our first basic example didn't make any mention of sample219weighting. If you want to support the `fit()` arguments `sample_weight` and220`class_weight`, you'd simply do the following:221222- Unpack `sample_weight` from the `data` argument223- Pass it to `compute_loss` & `update_state` (of course, you could also just apply224it manually if you don't rely on `compile()` for losses & metrics)225- That's it.226"""227228229class CustomModel(keras.Model):230def train_step(self, data):231# Unpack the data. Its structure depends on your model and232# on what you pass to `fit()`.233if len(data) == 3:234x, y, sample_weight = data235else:236sample_weight = None237x, y = data238239# Call torch.nn.Module.zero_grad() to clear the leftover gradients240# for the weights from the previous train step.241self.zero_grad()242243# Compute loss244y_pred = self(x, training=True) # Forward pass245loss = self.compute_loss(246y=y,247y_pred=y_pred,248sample_weight=sample_weight,249)250251# Call torch.Tensor.backward() on the loss to compute gradients252# for the weights.253loss.backward()254255trainable_weights = [v for v in self.trainable_weights]256gradients = [v.value.grad for v in trainable_weights]257258# Update weights259with torch.no_grad():260self.optimizer.apply(gradients, trainable_weights)261262# Update metrics (includes the metric that tracks the loss)263for metric in self.metrics:264if metric.name == "loss":265metric.update_state(loss)266else:267metric.update_state(y, y_pred, sample_weight=sample_weight)268269# Return a dict mapping metric names to current value270# Note that it will include the loss (tracked in self.metrics).271return {m.name: m.result() for m in self.metrics}272273274# Construct and compile an instance of CustomModel275inputs = keras.Input(shape=(32,))276outputs = keras.layers.Dense(1)(inputs)277model = CustomModel(inputs, outputs)278model.compile(optimizer="adam", loss="mse", metrics=["mae"])279280# You can now use sample_weight argument281x = np.random.random((1000, 32))282y = np.random.random((1000, 1))283sw = np.random.random((1000, 1))284model.fit(x, y, sample_weight=sw, epochs=3)285286"""287## Providing your own evaluation step288289What if you want to do the same for calls to `model.evaluate()`? Then you would290override `test_step` in exactly the same way. Here's what it looks like:291"""292293294class CustomModel(keras.Model):295def test_step(self, data):296# Unpack the data297x, y = data298# Compute predictions299y_pred = self(x, training=False)300# Updates the metrics tracking the loss301loss = self.compute_loss(y=y, y_pred=y_pred)302# Update the metrics.303for metric in self.metrics:304if metric.name == "loss":305metric.update_state(loss)306else:307metric.update_state(y, y_pred)308# Return a dict mapping metric names to current value.309# Note that it will include the loss (tracked in self.metrics).310return {m.name: m.result() for m in self.metrics}311312313# Construct an instance of CustomModel314inputs = keras.Input(shape=(32,))315outputs = keras.layers.Dense(1)(inputs)316model = CustomModel(inputs, outputs)317model.compile(loss="mse", metrics=["mae"])318319# Evaluate with our custom test_step320x = np.random.random((1000, 32))321y = np.random.random((1000, 1))322model.evaluate(x, y)323324"""325## Wrapping up: an end-to-end GAN example326327Let's walk through an end-to-end example that leverages everything you just learned.328329Let's consider:330331- A generator network meant to generate 28x28x1 images.332- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and333"real").334- One optimizer for each.335- A loss function to train the discriminator.336"""337338# Create the discriminator339discriminator = keras.Sequential(340[341keras.Input(shape=(28, 28, 1)),342layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),343layers.LeakyReLU(negative_slope=0.2),344layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),345layers.LeakyReLU(negative_slope=0.2),346layers.GlobalMaxPooling2D(),347layers.Dense(1),348],349name="discriminator",350)351352# Create the generator353latent_dim = 128354generator = keras.Sequential(355[356keras.Input(shape=(latent_dim,)),357# We want to generate 128 coefficients to reshape into a 7x7x128 map358layers.Dense(7 * 7 * 128),359layers.LeakyReLU(negative_slope=0.2),360layers.Reshape((7, 7, 128)),361layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),362layers.LeakyReLU(negative_slope=0.2),363layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),364layers.LeakyReLU(negative_slope=0.2),365layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),366],367name="generator",368)369370"""371Here's a feature-complete GAN class, overriding `compile()` to use its own signature,372and implementing the entire GAN algorithm in 17 lines in `train_step`:373"""374375376class GAN(keras.Model):377def __init__(self, discriminator, generator, latent_dim):378super().__init__()379self.discriminator = discriminator380self.generator = generator381self.latent_dim = latent_dim382self.d_loss_tracker = keras.metrics.Mean(name="d_loss")383self.g_loss_tracker = keras.metrics.Mean(name="g_loss")384self.seed_generator = keras.random.SeedGenerator(1337)385self.built = True386387@property388def metrics(self):389return [self.d_loss_tracker, self.g_loss_tracker]390391def compile(self, d_optimizer, g_optimizer, loss_fn):392super().compile()393self.d_optimizer = d_optimizer394self.g_optimizer = g_optimizer395self.loss_fn = loss_fn396397def train_step(self, real_images):398device = "cuda" if torch.cuda.is_available() else "cpu"399if isinstance(real_images, tuple) or isinstance(real_images, list):400real_images = real_images[0]401# Sample random points in the latent space402batch_size = real_images.shape[0]403random_latent_vectors = keras.random.normal(404shape=(batch_size, self.latent_dim), seed=self.seed_generator405)406407# Decode them to fake images408generated_images = self.generator(random_latent_vectors)409410# Combine them with real images411real_images = torch.tensor(real_images, device=device)412combined_images = torch.concat([generated_images, real_images], axis=0)413414# Assemble labels discriminating real from fake images415labels = torch.concat(416[417torch.ones((batch_size, 1), device=device),418torch.zeros((batch_size, 1), device=device),419],420axis=0,421)422# Add random noise to the labels - important trick!423labels += 0.05 * keras.random.uniform(labels.shape, seed=self.seed_generator)424425# Train the discriminator426self.zero_grad()427predictions = self.discriminator(combined_images)428d_loss = self.loss_fn(labels, predictions)429d_loss.backward()430grads = [v.value.grad for v in self.discriminator.trainable_weights]431with torch.no_grad():432self.d_optimizer.apply(grads, self.discriminator.trainable_weights)433434# Sample random points in the latent space435random_latent_vectors = keras.random.normal(436shape=(batch_size, self.latent_dim), seed=self.seed_generator437)438439# Assemble labels that say "all real images"440misleading_labels = torch.zeros((batch_size, 1), device=device)441442# Train the generator (note that we should *not* update the weights443# of the discriminator)!444self.zero_grad()445predictions = self.discriminator(self.generator(random_latent_vectors))446g_loss = self.loss_fn(misleading_labels, predictions)447grads = g_loss.backward()448grads = [v.value.grad for v in self.generator.trainable_weights]449with torch.no_grad():450self.g_optimizer.apply(grads, self.generator.trainable_weights)451452# Update metrics and return their value.453self.d_loss_tracker.update_state(d_loss)454self.g_loss_tracker.update_state(g_loss)455return {456"d_loss": self.d_loss_tracker.result(),457"g_loss": self.g_loss_tracker.result(),458}459460461"""462Let's test-drive it:463"""464465# Prepare the dataset. We use both the training & test MNIST digits.466batch_size = 64467(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()468all_digits = np.concatenate([x_train, x_test])469all_digits = all_digits.astype("float32") / 255.0470all_digits = np.reshape(all_digits, (-1, 28, 28, 1))471472# Create a TensorDataset473dataset = torch.utils.data.TensorDataset(474torch.from_numpy(all_digits), torch.from_numpy(all_digits)475)476# Create a DataLoader477dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)478479gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)480gan.compile(481d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),482g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),483loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),484)485486gan.fit(dataloader, epochs=1)487488"""489The ideas behind deep learning are simple, so why should their implementation be painful?490"""491492493