Path: blob/master/site/en-snapshot/tutorials/distribute/custom_training.ipynb
38653 views
Copyright 2019 The TensorFlow Authors.
Custom training with tf.distribute.Strategy
This tutorial demonstrates how to use tf.distribute.Strategy—a TensorFlow API that provides an abstraction for distributing your training across multiple processing units (GPUs, multiple machines, or TPUs)—with custom training loops. In this example, you will train a simple convolutional neural network on the Fashion MNIST dataset containing 70,000 images of size 28 x 28.
Custom training loops provide flexibility and a greater control on training. They also make it easier to debug the model and the training loop.
Download the Fashion MNIST dataset
Create a strategy to distribute the variables and the graph
How does tf.distribute.MirroredStrategy strategy work?
All the variables and the model graph are replicated across the replicas.
Input is evenly distributed across the replicas.
Each replica calculates the loss and gradients for the input it received.
The gradients are synced across all the replicas by summing them.
After the sync, the same update is made to the copies of the variables on each replica.
Note: You can put all the code below inside a single scope. This example divides it into several code cells for illustration purposes.
Setup input pipeline
Create the datasets and distribute them:
Create the model
Create a model using tf.keras.Sequential. You can also use the Model Subclassing API or the functional API to do this.
Define the loss function
Recall that the loss function consists of one or two parts:
The prediction loss measures how far off the model's predictions are from the training labels for a batch of training examples. It is computed for each labeled example and then reduced across the batch by computing the average value.
Optionally, regularization loss terms can be added to the prediction loss, to steer the model away from overfitting the training data. A common choice is L2 regularization, which adds a small fixed multiple of the sum of squares of all model weights, independent of the number of examples. The model above uses L2 regularization to demonstrate its handling in the training loop below.
For training on a single machine with a single GPU/CPU, this works as follows:
The prediction loss is computed for each example in the batch, summed across the batch, and then divided by the batch size.
The regularization loss is added to the prediction loss.
The gradient of the total loss is computed w.r.t. each model weight, and the optimizer updates each model weight from the corresponding gradient.
With tf.distribute.Strategy, the input batch is split between replicas. For example, let's say you have 4 GPUs, each with one replica of the model. One batch of 256 input examples is distributed evenly across the 4 replicas, so each replica gets a batch of size 64: We have 256 = 4*64, or generally GLOBAL_BATCH_SIZE = num_replicas_in_sync * BATCH_SIZE_PER_REPLICA.
Each replica computes the loss from the training examples it gets and computes the gradients of the loss w.r.t. each model weight. The optimizer takes care that these gradients are summed up across replicas before using them to update the copies of the model weights on each replica.
So, how should the loss be calculated when using a tf.distribute.Strategy?
Each replica computes the prediction loss for all examples distributed to it, sums up the results and divides them by
num_replicas_in_sync * BATCH_SIZE_PER_REPLICA, or equivently,GLOBAL_BATCH_SIZE.Each replica compues the regularization loss(es) and divides them by
num_replicas_in_sync.
Compared to non-distributed training, all per-replica loss terms are scaled down by a factor of 1/num_replicas_in_sync. On the other hand, all loss terms -- or rather, their gradients -- are summed across that number of replicas before the optimizer applies them. In effect, the optimizer on each replica uses the same gradients as if a non-distributed computation with GLOBAL_BATCH_SIZE had happened. This is consistent with the distributed and undistributed behavior of Keras Model.fit. See the Distributed training with Keras tutorial on how a larger gloabl batch size enables to scale up the learning rate.
How to do this in TensorFlow?
Loss reduction and scaling is done automatically in Keras
Model.compileandModel.fitIf you're writing a custom training loop, as in this tutorial, you should sum the per-example losses and divide the sum by the global batch size using
tf.nn.compute_average_loss, which takes the per-example losses and optional sample weights as arguments and returns the scaled loss.If using
tf.keras.lossesclasses (as in the example below), the loss reduction needs to be explicitly specified to be one ofNONEorSUM. The defaultAUTOandSUM_OVER_BATCH_SIZEare disallowed outsideModel.fit.AUTOis disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case.SUM_OVER_BATCH_SIZEis disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So, instead, you need to do the reduction yourself explicitly.
If you're writing a custom training loop for a model with a non-empty list of
Model.losses(e.g., weight regularizers), you should sum them up and divide the sum by the number of replicas. You can do this by using thetf.nn.scale_regularization_lossfunction. The model code itself remains unaware of the number of replicas.
However, models can define input-dependent regularization losses with Keras APIs such as Layer.add_loss(...) and Layer(activity_regularizer=...). For Layer.add_loss(...), it falls on the modeling code to perform the division of the summed per-example terms by the per-replica(!) batch size, e.g., by using tf.math.reduce_mean().
Special cases
Advanced users should also consider the following special cases.
Input batches shorter than
GLOBAL_BATCH_SIZEcreate unpleasant corner cases in several places. In practice, it often works best to avoid them by allowing batches to span epoch boundaries usingDataset.repeat().batch()and defining approximate epochs by step counts, not dataset ends. Alternatively,Dataset.batch(drop_remainder=True)maintains the notion of epoch but drops the last few examples.
For illustration, this example goes the harder route and allows short batches, so that each training epoch contains each trainig example exactly once.
Which denominator should be used by tf.nn.compute_average_loss()?
Both options are equivalent if short batches are avoided, as suggested above.
Multi-dimensional
labelsrequire you to average theper_example_lossacross the number of predictions in each example. Consider a classification task for all pixels of an input image, withpredictionsof shape(batch_size, H, W, n_classes)andlabelsof shape(batch_size, H, W). You will need to updateper_example_losslike:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
Caution: Verify the shape of your loss. Loss functions in tf.losses/tf.keras.losses typically return the average over the last dimension of the input. The loss classes wrap these functions. Passing reduction=Reduction.NONE when creating an instance of a loss class means "no additional reduction". For categorical losses with an example input shape of [batch, W, H, n_classes] the n_classes dimension is reduced. For pointwise losses like losses.mean_squared_error or losses.binary_crossentropy include a dummy axis so that [batch, W, H, 1] is reduced to [batch, W, H]. Without the dummy axis [batch, W, H] will be incorrectly reduced to [batch, W].
Define the metrics to track loss and accuracy
These metrics track the test loss and training and test accuracy. You can use .result() to get the accumulated statistics at any time.
Training loop
Things to note in the example above
Iterate over the
train_dist_datasetandtest_dist_datasetusing afor x in ...construct.The scaled loss is the return value of the
distributed_train_step. This value is aggregated across replicas using thetf.distribute.Strategy.reducecall and then across batches by summing the return value of thetf.distribute.Strategy.reducecalls.tf.keras.Metricsshould be updated insidetrain_stepandtest_stepthat gets executed bytf.distribute.Strategy.run.tf.distribute.Strategy.runreturns results from each local replica in the strategy, and there are multiple ways to consume this result. You can dotf.distribute.Strategy.reduceto get an aggregated value. You can also dotf.distribute.Strategy.experimental_local_resultsto get the list of values contained in the result, one per local replica.
Restore the latest checkpoint and test
A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.
Alternate ways of iterating over a dataset
Using iterators
If you want to iterate over a given number of steps and not through the entire dataset, you can create an iterator using the iter call and explicitly call next on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function. Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.
Iterating inside a tf.function
You can also iterate over the entire input train_dist_dataset inside a tf.function using the for x in ... construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a @tf.function decorator and iterating over train_dist_dataset inside the function.
Tracking training loss across replicas
Note: As a general rule, you should use tf.keras.Metrics to track per-sample values and avoid values that have been aggregated within a replica.
Because of the loss scaling computation that is carried out, it's not recommended to use tf.keras.metrics.Mean to track the training loss across different replicas.
For example, if you run a training job with the following characteristics:
Two replicas
Two samples are processed on each replica
Resulting loss values: [2, 3] and [4, 5] on each replica
Global batch size = 4
With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: (2 + 3) / 4 = 1.25 and (4 + 5) / 4 = 2.25.
If you use tf.keras.metrics.Mean to track loss across the two replicas, the result is different. In this example, you end up with a total of 3.50 and count of 2, which results in total/count = 1.75 when result() is called on the metric. Loss calculated with tf.keras.Metrics is scaled by an additional factor that is equal to the number of replicas in sync.
Guide and examples
Here are some examples for using distribution strategy with custom training loops:
DenseNet example using
MirroredStrategy.BERT example trained using
MirroredStrategyandTPUStrategy. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.NCF example trained using
MirroredStrategythat can be enabled using thekeras_use_ctlflag.NMT example trained using
MirroredStrategy.
You can find more examples listed under Examples and tutorials in the Distribution strategy guide.
Next steps
Try out the new
tf.distribute.StrategyAPI on your models.Visit the Better performance with
tf.functionand TensorFlow Profiler guides to learn more about tools to optimize the performance of your TensorFlow models.Check out the Distributed training in TensorFlow guide, which provides an overview of the available distribution strategies.
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook