Path: blob/master/guides/writing_a_custom_training_loop_in_torch.py
3273 views
"""1Title: Writing a training loop from scratch in PyTorch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/254Last modified: 2023/06/255Description: Writing low-level training & evaluation loops in PyTorch.6Accelerator: None7"""89"""10## Setup11"""1213import os1415# This guide can only be run with the torch backend.16os.environ["KERAS_BACKEND"] = "torch"1718import torch19import keras20import numpy as np2122"""23## Introduction2425Keras provides default training and evaluation loops, `fit()` and `evaluate()`.26Their usage is covered in the guide27[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).2829If you want to customize the learning algorithm of your model while still leveraging30the convenience of `fit()`31(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and32implement your own `train_step()` method, which33is called repeatedly during `fit()`.3435Now, if you want very low-level control over training & evaluation, you should write36your own training & evaluation loops from scratch. This is what this guide is about.37"""3839"""40## A first end-to-end example4142To write a custom training loop, we need the following ingredients:4344- A model to train, of course.45- An optimizer. You could either use a `keras.optimizers` optimizer,46or a native PyTorch optimizer from `torch.optim`.47- A loss function. You could either use a `keras.losses` loss,48or a native PyTorch loss from `torch.nn`.49- A dataset. You could use any format: a `tf.data.Dataset`,50a PyTorch `DataLoader`, a Python generator, etc.5152Let's line them up. We'll use torch-native objects in each case --53except, of course, for the Keras model.5455First, let's get the model and the MNIST dataset:56"""575859# Let's consider a simple MNIST model60def get_model():61inputs = keras.Input(shape=(784,), name="digits")62x1 = keras.layers.Dense(64, activation="relu")(inputs)63x2 = keras.layers.Dense(64, activation="relu")(x1)64outputs = keras.layers.Dense(10, name="predictions")(x2)65model = keras.Model(inputs=inputs, outputs=outputs)66return model676869# Create load up the MNIST dataset and put it in a torch DataLoader70# Prepare the training dataset.71batch_size = 3272(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()73x_train = np.reshape(x_train, (-1, 784)).astype("float32")74x_test = np.reshape(x_test, (-1, 784)).astype("float32")75y_train = keras.utils.to_categorical(y_train)76y_test = keras.utils.to_categorical(y_test)7778# Reserve 10,000 samples for validation.79x_val = x_train[-10000:]80y_val = y_train[-10000:]81x_train = x_train[:-10000]82y_train = y_train[:-10000]8384# Create torch Datasets85train_dataset = torch.utils.data.TensorDataset(86torch.from_numpy(x_train), torch.from_numpy(y_train)87)88val_dataset = torch.utils.data.TensorDataset(89torch.from_numpy(x_val), torch.from_numpy(y_val)90)9192# Create DataLoaders for the Datasets93train_dataloader = torch.utils.data.DataLoader(94train_dataset, batch_size=batch_size, shuffle=True95)96val_dataloader = torch.utils.data.DataLoader(97val_dataset, batch_size=batch_size, shuffle=False98)99100"""101Next, here's our PyTorch optimizer and our PyTorch loss function:102"""103104# Instantiate a torch optimizer105model = get_model()106optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)107108# Instantiate a torch loss function109loss_fn = torch.nn.CrossEntropyLoss()110111"""112Let's train our model using mini-batch gradient with a custom training loop.113114Calling `loss.backward()` on a loss tensor triggers backpropagation.115Once that's done, your optimizer is magically aware of the gradients for each variable116and can update its variables, which is done via `optimizer.step()`.117Tensors, variables, optimizers are all interconnected to one another via hidden global state.118Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't119get the right gradients for your variables.120121Here's our training loop, step by step:122123- We open a `for` loop that iterates over epochs124- For each epoch, we open a `for` loop that iterates over the dataset, in batches125- For each batch, we call the model on the input data to retrieve the predictions,126then we use them to compute a loss value127- We call `loss.backward()` to128- Outside the scope, we retrieve the gradients of the weights129of the model with regard to the loss130- Finally, we use the optimizer to update the weights of the model based on the131gradients132"""133134epochs = 3135for epoch in range(epochs):136for step, (inputs, targets) in enumerate(train_dataloader):137# Forward pass138logits = model(inputs)139loss = loss_fn(logits, targets)140141# Backward pass142model.zero_grad()143loss.backward()144145# Optimizer variable updates146optimizer.step()147148# Log every 100 batches.149if step % 100 == 0:150print(151f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"152)153print(f"Seen so far: {(step + 1) * batch_size} samples")154155"""156As an alternative, let's look at what the loop looks like when using a Keras optimizer157and a Keras loss function.158159Important differences:160161- You retrieve the gradients for the variables via `v.value.grad`,162called on each trainable variable.163- You update your variables via `optimizer.apply()`, which must be164called in a `torch.no_grad()` scope.165166**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs167as well as Python `unittest` APIs use the argument order convention168`fn(y_true, y_pred)` (reference values first, predicted values second),169PyTorch actually uses `fn(y_pred, y_true)` for its losses.170So make sure to invert the order of `logits` and `targets`.171"""172173model = get_model()174optimizer = keras.optimizers.Adam(learning_rate=1e-3)175loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)176177for epoch in range(epochs):178print(f"\nStart of epoch {epoch}")179for step, (inputs, targets) in enumerate(train_dataloader):180# Forward pass181logits = model(inputs)182loss = loss_fn(targets, logits)183184# Backward pass185model.zero_grad()186trainable_weights = [v for v in model.trainable_weights]187188# Call torch.Tensor.backward() on the loss to compute gradients189# for the weights.190loss.backward()191gradients = [v.value.grad for v in trainable_weights]192193# Update weights194with torch.no_grad():195optimizer.apply(gradients, trainable_weights)196197# Log every 100 batches.198if step % 100 == 0:199print(200f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"201)202print(f"Seen so far: {(step + 1) * batch_size} samples")203204"""205## Low-level handling of metrics206207Let's add metrics monitoring to this basic training loop.208209You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training210loops written from scratch. Here's the flow:211212- Instantiate the metric at the start of the loop213- Call `metric.update_state()` after each batch214- Call `metric.result()` when you need to display the current value of the metric215- Call `metric.reset_state()` when you need to clear the state of the metric216(typically at the end of an epoch)217218Let's use this knowledge to compute `CategoricalAccuracy` on training and219validation data at the end of each epoch:220"""221222# Get a fresh model223model = get_model()224225# Instantiate an optimizer to train the model.226optimizer = keras.optimizers.Adam(learning_rate=1e-3)227# Instantiate a loss function.228loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)229230# Prepare the metrics.231train_acc_metric = keras.metrics.CategoricalAccuracy()232val_acc_metric = keras.metrics.CategoricalAccuracy()233234"""235Here's our training & evaluation loop:236"""237238for epoch in range(epochs):239print(f"\nStart of epoch {epoch}")240for step, (inputs, targets) in enumerate(train_dataloader):241# Forward pass242logits = model(inputs)243loss = loss_fn(targets, logits)244245# Backward pass246model.zero_grad()247trainable_weights = [v for v in model.trainable_weights]248249# Call torch.Tensor.backward() on the loss to compute gradients250# for the weights.251loss.backward()252gradients = [v.value.grad for v in trainable_weights]253254# Update weights255with torch.no_grad():256optimizer.apply(gradients, trainable_weights)257258# Update training metric.259train_acc_metric.update_state(targets, logits)260261# Log every 100 batches.262if step % 100 == 0:263print(264f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"265)266print(f"Seen so far: {(step + 1) * batch_size} samples")267268# Display metrics at the end of each epoch.269train_acc = train_acc_metric.result()270print(f"Training acc over epoch: {float(train_acc):.4f}")271272# Reset training metrics at the end of each epoch273train_acc_metric.reset_state()274275# Run a validation loop at the end of each epoch.276for x_batch_val, y_batch_val in val_dataloader:277val_logits = model(x_batch_val, training=False)278# Update val metrics279val_acc_metric.update_state(y_batch_val, val_logits)280val_acc = val_acc_metric.result()281val_acc_metric.reset_state()282print(f"Validation acc: {float(val_acc):.4f}")283284285"""286## Low-level handling of losses tracked by the model287288Layers & models recursively track any losses created during the forward pass289by layers that call `self.add_loss(value)`. The resulting list of scalar loss290values are available via the property `model.losses`291at the end of the forward pass.292293If you want to be using these loss components, you should sum them294and add them to the main loss in your training step.295296Consider this layer, that creates an activity regularization loss:297"""298299300class ActivityRegularizationLayer(keras.layers.Layer):301def call(self, inputs):302self.add_loss(1e-2 * torch.sum(inputs))303return inputs304305306"""307Let's build a really simple model that uses it:308"""309310inputs = keras.Input(shape=(784,), name="digits")311x = keras.layers.Dense(64, activation="relu")(inputs)312# Insert activity regularization as a layer313x = ActivityRegularizationLayer()(x)314x = keras.layers.Dense(64, activation="relu")(x)315outputs = keras.layers.Dense(10, name="predictions")(x)316317model = keras.Model(inputs=inputs, outputs=outputs)318319"""320Here's what our training loop should look like now:321"""322323# Get a fresh model324model = get_model()325326# Instantiate an optimizer to train the model.327optimizer = keras.optimizers.Adam(learning_rate=1e-3)328# Instantiate a loss function.329loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)330331# Prepare the metrics.332train_acc_metric = keras.metrics.CategoricalAccuracy()333val_acc_metric = keras.metrics.CategoricalAccuracy()334335for epoch in range(epochs):336print(f"\nStart of epoch {epoch}")337for step, (inputs, targets) in enumerate(train_dataloader):338# Forward pass339logits = model(inputs)340loss = loss_fn(targets, logits)341if model.losses:342loss = loss + torch.sum(*model.losses)343344# Backward pass345model.zero_grad()346trainable_weights = [v for v in model.trainable_weights]347348# Call torch.Tensor.backward() on the loss to compute gradients349# for the weights.350loss.backward()351gradients = [v.value.grad for v in trainable_weights]352353# Update weights354with torch.no_grad():355optimizer.apply(gradients, trainable_weights)356357# Update training metric.358train_acc_metric.update_state(targets, logits)359360# Log every 100 batches.361if step % 100 == 0:362print(363f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"364)365print(f"Seen so far: {(step + 1) * batch_size} samples")366367# Display metrics at the end of each epoch.368train_acc = train_acc_metric.result()369print(f"Training acc over epoch: {float(train_acc):.4f}")370371# Reset training metrics at the end of each epoch372train_acc_metric.reset_state()373374# Run a validation loop at the end of each epoch.375for x_batch_val, y_batch_val in val_dataloader:376val_logits = model(x_batch_val, training=False)377# Update val metrics378val_acc_metric.update_state(y_batch_val, val_logits)379val_acc = val_acc_metric.result()380val_acc_metric.reset_state()381print(f"Validation acc: {float(val_acc):.4f}")382383"""384That's it!385"""386387388