Path: blob/master/site/en-snapshot/tutorials/distribute/keras.ipynb
39032 views
Copyright 2019 The TensorFlow Authors.
Distributed training with Keras
Overview
The tf.distribute.Strategy API provides an abstraction for distributing your training across multiple processing units. It allows you to carry out distributed training using existing models and training code with minimal changes.
This tutorial demonstrates how to use the tf.distribute.MirroredStrategy to perform in-graph replication with synchronous training on many GPUs on one machine. The strategy essentially copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors, and applies the combined value to all copies of the model.
You will use the tf.keras APIs to build the model and Model.fit for training it. (To learn about distributed training with a custom training loop and the MirroredStrategy, check out this tutorial.)
MirroredStrategy trains your model on multiple GPUs on a single machine. For synchronous training on many GPUs on multiple workers, use the tf.distribute.MultiWorkerMirroredStrategy with the Keras Model.fit or a custom training loop. For other options, refer to the Distributed training guide.
To learn about various other strategies, there is the Distributed training with TensorFlow guide.
Setup
Download the dataset
Load the MNIST dataset from TensorFlow Datasets. This returns a dataset in the tf.data format.
Setting the with_info argument to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.
Define the distribution strategy
Create a MirroredStrategy object. This will handle distribution and provide a context manager (MirroredStrategy.scope) to build your model inside.
Set up the input pipeline
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory and tune the learning rate accordingly.
Define a function that normalizes the image pixel values from the [0, 255] range to the [0, 1] range (feature scaling):
Apply this scale function to the training and test data, and then use the tf.data.Dataset APIs to shuffle the training data (Dataset.shuffle), and batch it (Dataset.batch). Notice that you are also keeping an in-memory cache of the training data to improve performance (Dataset.cache).
Create the model and instantiate the optimizer
Within the context of Strategy.scope, create and compile the model using the Keras API:
For this toy example with the MNIST dataset, you will be using the Adam optimizer's default learning rate of 0.001.
For larger datasets, the key benefit of distributed training is to learn more in each training step, because each step processes more training data in parallel, which allows for a larger learning rate (within the limits of the model and dataset).
Define the callbacks
Define the following Keras Callbacks:
tf.keras.callbacks.TensorBoard: writes a log for TensorBoard, which allows you to visualize the graphs.tf.keras.callbacks.ModelCheckpoint: saves the model at a certain frequency, such as after every epoch.tf.keras.callbacks.BackupAndRestore: provides the fault tolerance functionality by backing up the model and current epoch number. Learn more in the Fault tolerance section of the Multi-worker training with Keras tutorial.tf.keras.callbacks.LearningRateScheduler: schedules the learning rate to change after, for example, every epoch/batch.
For illustrative purposes, add a custom callback called PrintLR to display the learning rate in the notebook.
Note: Use the BackupAndRestore callback instead of ModelCheckpoint as the main mechanism to restore the training state upon a restart from a job failure. Since BackupAndRestore only supports eager mode, in graph mode consider using ModelCheckpoint.
Train and evaluate
Now, train the model in the usual way by calling Keras Model.fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.
Check for saved checkpoints:
To check how well the model performs, load the latest checkpoint and call Model.evaluate on the test data:
To visualize the output, launch TensorBoard and view the logs:
Save the model
Save the model to a .keras zip archive using Model.save. After your model is saved, you can load it with or without the Strategy.scope.
Now, load the model without Strategy.scope:
Load the model with Strategy.scope:
Additional resources
More examples that use different distribution strategies with the Keras Model.fit API:
The Solve GLUE tasks using BERT on TPU tutorial uses
tf.distribute.MirroredStrategyfor training on GPUs andtf.distribute.TPUStrategyon TPUs.The Save and load a model using a distribution strategy tutorial demonstates how to use the SavedModel APIs with
tf.distribute.Strategy.The official TensorFlow models can be configured to run multiple distribution strategies.
To learn more about TensorFlow distribution strategies:
The Custom training with tf.distribute.Strategy tutorial shows how to use the
tf.distribute.MirroredStrategyfor single-worker training with a custom training loop.The Multi-worker training with Keras tutorial shows how to use the
MultiWorkerMirroredStrategywithModel.fit.The Custom training loop with Keras and MultiWorkerMirroredStrategy tutorial shows how to use the
MultiWorkerMirroredStrategywith Keras and a custom training loop.The Distributed training in TensorFlow guide provides an overview of the available distribution strategies.
The Better performance with tf.function guide provides information about other strategies and tools, such as the TensorFlow Profiler you can use to optimize the performance of your TensorFlow models.
Note: tf.distribute.Strategy is actively under development and TensorFlow will be adding more examples and tutorials in the near future. Please give it a try. Your feedback is welcome—feel free to submit it via issues on GitHub.
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook