Path: blob/master/guides/distributed_training_with_tensorflow.py
3273 views
"""1Title: Multi-GPU distributed training with TensorFlow2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/04/284Last modified: 2023/06/295Description: Guide to multi-GPU training for Keras models with TensorFlow.6Accelerator: GPU7"""89"""10## Introduction1112There are generally two ways to distribute computation across multiple devices:1314**Data parallelism**, where a single model gets replicated on multiple devices or15multiple machines. Each of them processes different batches of data, then they merge16their results. There exist many variants of this setup, that differ in how the different17model replicas merge results, in whether they stay in sync at every batch or whether they18are more loosely coupled, etc.1920**Model parallelism**, where different parts of a single model run on different devices,21processing a single batch of data together. This works best with models that have a22naturally-parallel architecture, such as models that feature multiple branches.2324This guide focuses on data parallelism, in particular **synchronous data parallelism**,25where the different replicas of the model stay in sync after each batch they process.26Synchronicity keeps the model convergence behavior identical to what you would see for27single-device training.2829Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras30models on multiple GPUs, with minimal changes to your code,31on multiple GPUs (typically 2 to 16) installed on a single machine (single host,32multi-device training). This is the most common setup for researchers and small-scale33industry workflows.34"""3536"""37## Setup38"""3940import os4142os.environ["KERAS_BACKEND"] = "tensorflow"4344import tensorflow as tf45import keras4647"""48## Single-host, multi-device synchronous training4950In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each51device will run a copy of your model (called a **replica**). For simplicity, in what52follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.5354**How it works**5556At each step of training:5758- The current batch of data (called **global batch**) is split into 8 different59sub-batches (called **local batches**). For instance, if the global batch has 51260samples, each of the 8 local batches will have 64 samples.61- Each of the 8 replicas independently processes a local batch: they run a forward pass,62then a backward pass, outputting the gradient of the weights with respect to the loss of63the model on the local batch.64- The weight updates originating from local gradients are efficiently merged across the 865replicas. Because this is done at the end of every step, the replicas always stay in66sync.6768In practice, the process of synchronously updating the weights of the model replicas is69handled at the level of each individual weight variable. This is done through a **mirrored70variable** object.7172**How to use it**7374To do single-host, multi-device synchronous training with a Keras model, you would use75the [`tf.distribute.MirroredStrategy` API](76https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy).77Here's how it works:7879- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you80want to use (by default the strategy will use all GPUs available).81- Use the strategy object to open a scope, and within this scope, create all the Keras82objects you need that contain variables. Typically, that means **creating & compiling the83model** inside the distribution scope. In some cases, the first call to `fit()` may also84create variables, so it's a good idea to put your `fit()` call in the scope as well.85- Train the model via `fit()` as usual.8687Importantly, we recommend that you use `tf.data.Dataset` objects to load data88in a multi-device or distributed workflow.8990Schematically, it looks like this:9192```python93# Create a MirroredStrategy.94strategy = tf.distribute.MirroredStrategy()95print('Number of devices: {}'.format(strategy.num_replicas_in_sync))9697# Open a strategy scope.98with strategy.scope():99# Everything that creates variables should be under the strategy scope.100# In general this is only model construction & `compile()`.101model = Model(...)102model.compile(...)103104# Train the model on all available devices.105model.fit(train_dataset, validation_data=val_dataset, ...)106107# Test the model on all available devices.108model.evaluate(test_dataset)109```110111Here's a simple end-to-end runnable example:112"""113114115def get_compiled_model():116# Make a simple 2-layer densely-connected neural network.117inputs = keras.Input(shape=(784,))118x = keras.layers.Dense(256, activation="relu")(inputs)119x = keras.layers.Dense(256, activation="relu")(x)120outputs = keras.layers.Dense(10)(x)121model = keras.Model(inputs, outputs)122model.compile(123optimizer=keras.optimizers.Adam(),124loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),125metrics=[keras.metrics.SparseCategoricalAccuracy()],126)127return model128129130def get_dataset():131batch_size = 32132num_val_samples = 10000133134# Return the MNIST dataset in the form of a `tf.data.Dataset`.135(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()136137# Preprocess the data (these are Numpy arrays)138x_train = x_train.reshape(-1, 784).astype("float32") / 255139x_test = x_test.reshape(-1, 784).astype("float32") / 255140y_train = y_train.astype("float32")141y_test = y_test.astype("float32")142143# Reserve num_val_samples samples for validation144x_val = x_train[-num_val_samples:]145y_val = y_train[-num_val_samples:]146x_train = x_train[:-num_val_samples]147y_train = y_train[:-num_val_samples]148return (149tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),150tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),151tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),152)153154155# Create a MirroredStrategy.156strategy = tf.distribute.MirroredStrategy()157print("Number of devices: {}".format(strategy.num_replicas_in_sync))158159# Open a strategy scope.160with strategy.scope():161# Everything that creates variables should be under the strategy scope.162# In general this is only model construction & `compile()`.163model = get_compiled_model()164165# Train the model on all available devices.166train_dataset, val_dataset, test_dataset = get_dataset()167model.fit(train_dataset, epochs=2, validation_data=val_dataset)168169# Test the model on all available devices.170model.evaluate(test_dataset)171172"""173## Using callbacks to ensure fault tolerance174175When using distributed training, you should always make sure you have a strategy to176recover from failure (fault tolerance). The simplest way to handle this is to pass177`ModelCheckpoint` callback to `fit()`, to save your model178at regular intervals (e.g. every 100 batches or every epoch). You can then restart179training from your saved model.180181Here's a simple example:182"""183184# Prepare a directory to store all the checkpoints.185checkpoint_dir = "./ckpt"186if not os.path.exists(checkpoint_dir):187os.makedirs(checkpoint_dir)188189190def make_or_restore_model():191# Either restore the latest model, or create a fresh one192# if there is no checkpoint available.193checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]194if checkpoints:195latest_checkpoint = max(checkpoints, key=os.path.getctime)196print("Restoring from", latest_checkpoint)197return keras.models.load_model(latest_checkpoint)198print("Creating a new model")199return get_compiled_model()200201202def run_training(epochs=1):203# Create a MirroredStrategy.204strategy = tf.distribute.MirroredStrategy()205206# Open a strategy scope and create/restore the model207with strategy.scope():208model = make_or_restore_model()209210callbacks = [211# This callback saves a SavedModel every epoch212# We include the current epoch in the folder name.213keras.callbacks.ModelCheckpoint(214filepath=checkpoint_dir + "/ckpt-{epoch}.keras",215save_freq="epoch",216)217]218model.fit(219train_dataset,220epochs=epochs,221callbacks=callbacks,222validation_data=val_dataset,223verbose=2,224)225226227# Running the first time creates the model228run_training(epochs=1)229230# Calling the same function again will resume from where we left off231run_training(epochs=1)232233"""234## `tf.data` performance tips235236When doing distributed training, the efficiency with which you load data can often become237critical. Here are a few tips to make sure your `tf.data` pipelines238run as fast as possible.239240**Note about dataset batching**241242When creating your dataset, make sure it is batched with the global batch size.243For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you244call use a global batch size of 512.245246**Calling `dataset.cache()`**247248If you call `.cache()` on a dataset, its data will be cached after running through the249first iteration over the data. Every subsequent iteration will use the cached data. The250cache can be in memory (default) or to a local file you specify.251252This can improve performance when:253254- Your data is not expected to change from iteration to iteration255- You are reading data from a remote distributed filesystem256- You are reading data from local disk, but your data would fit in memory and your257workflow is significantly IO-bound (e.g. reading & decoding image files).258259**Calling `dataset.prefetch(buffer_size)`**260261You should almost always call `.prefetch(buffer_size)` after creating a dataset. It means262your data pipeline will run asynchronously from your model,263with new samples being preprocessed and stored in a buffer while the current batch264samples are used to train the model. The next batch will be prefetched in GPU memory by265the time the current batch is over.266"""267268"""269That's it!270"""271272273