Path: blob/master/guides/distributed_training_with_torch.py
3273 views
"""1Title: Multi-GPU distributed training with PyTorch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/294Last modified: 2023/06/295Description: Guide to multi-GPU training for Keras models with PyTorch.6Accelerator: GPU7"""89"""10## Introduction1112There are generally two ways to distribute computation across multiple devices:1314**Data parallelism**, where a single model gets replicated on multiple devices or15multiple machines. Each of them processes different batches of data, then they merge16their results. There exist many variants of this setup, that differ in how the different17model replicas merge results, in whether they stay in sync at every batch or whether they18are more loosely coupled, etc.1920**Model parallelism**, where different parts of a single model run on different devices,21processing a single batch of data together. This works best with models that have a22naturally-parallel architecture, such as models that feature multiple branches.2324This guide focuses on data parallelism, in particular **synchronous data parallelism**,25where the different replicas of the model stay in sync after each batch they process.26Synchronicity keeps the model convergence behavior identical to what you would see for27single-device training.2829Specifically, this guide teaches you how to use PyTorch's `DistributedDataParallel`30module wrapper to train Keras, with minimal changes to your code,31on multiple GPUs (typically 2 to 16) installed on a single machine (single host,32multi-device training). This is the most common setup for researchers and small-scale33industry workflows.34"""3536"""37## Setup3839Let's start by defining the function that creates the model that we will train,40and the function that creates the dataset we will train on (MNIST in this case).41"""4243import os4445os.environ["KERAS_BACKEND"] = "torch"4647import torch48import numpy as np49import keras505152def get_model():53# Make a simple convnet with batch normalization and dropout.54inputs = keras.Input(shape=(28, 28, 1))55x = keras.layers.Rescaling(1.0 / 255.0)(inputs)56x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(57x58)59x = keras.layers.BatchNormalization(scale=False, center=True)(x)60x = keras.layers.ReLU()(x)61x = keras.layers.Conv2D(62filters=24,63kernel_size=6,64use_bias=False,65strides=2,66)(x)67x = keras.layers.BatchNormalization(scale=False, center=True)(x)68x = keras.layers.ReLU()(x)69x = keras.layers.Conv2D(70filters=32,71kernel_size=6,72padding="same",73strides=2,74name="large_k",75)(x)76x = keras.layers.BatchNormalization(scale=False, center=True)(x)77x = keras.layers.ReLU()(x)78x = keras.layers.GlobalAveragePooling2D()(x)79x = keras.layers.Dense(256, activation="relu")(x)80x = keras.layers.Dropout(0.5)(x)81outputs = keras.layers.Dense(10)(x)82model = keras.Model(inputs, outputs)83return model848586def get_dataset():87# Load the data and split it between train and test sets88(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8990# Scale images to the [0, 1] range91x_train = x_train.astype("float32")92x_test = x_test.astype("float32")93# Make sure images have shape (28, 28, 1)94x_train = np.expand_dims(x_train, -1)95x_test = np.expand_dims(x_test, -1)96print("x_train shape:", x_train.shape)9798# Create a TensorDataset99dataset = torch.utils.data.TensorDataset(100torch.from_numpy(x_train), torch.from_numpy(y_train)101)102return dataset103104105"""106Next, let's define a simple PyTorch training loop that targets107a GPU (note the calls to `.cuda()`).108"""109110111def train_model(model, dataloader, num_epochs, optimizer, loss_fn):112for epoch in range(num_epochs):113running_loss = 0.0114running_loss_count = 0115for batch_idx, (inputs, targets) in enumerate(dataloader):116inputs = inputs.cuda(non_blocking=True)117targets = targets.cuda(non_blocking=True)118119# Forward pass120outputs = model(inputs)121loss = loss_fn(outputs, targets)122123# Backward and optimize124optimizer.zero_grad()125loss.backward()126optimizer.step()127128running_loss += loss.item()129running_loss_count += 1130131# Print loss statistics132print(133f"Epoch {epoch + 1}/{num_epochs}, "134f"Loss: {running_loss / running_loss_count}"135)136137138"""139## Single-host, multi-device synchronous training140141In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each142device will run a copy of your model (called a **replica**). For simplicity, in what143follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.144145**How it works**146147At each step of training:148149- The current batch of data (called **global batch**) is split into 8 different150sub-batches (called **local batches**). For instance, if the global batch has 512151samples, each of the 8 local batches will have 64 samples.152- Each of the 8 replicas independently processes a local batch: they run a forward pass,153then a backward pass, outputting the gradient of the weights with respect to the loss of154the model on the local batch.155- The weight updates originating from local gradients are efficiently merged across the 8156replicas. Because this is done at the end of every step, the replicas always stay in157sync.158159In practice, the process of synchronously updating the weights of the model replicas is160handled at the level of each individual weight variable. This is done through a **mirrored161variable** object.162163**How to use it**164165To do single-host, multi-device synchronous training with a Keras model, you would use166the `torch.nn.parallel.DistributedDataParallel` module wrapper.167Here's how it works:168169- We use `torch.multiprocessing.start_processes` to start multiple Python processes, one170per device. Each process will run the `per_device_launch_fn` function.171- The `per_device_launch_fn` function does the following:172- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`173to configure the device to be used for that process.174- It uses `torch.utils.data.distributed.DistributedSampler`175and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.176- It also uses `torch.nn.parallel.DistributedDataParallel` to turn our model into177a distributed PyTorch module.178- It then calls the `train_model` function.179- The `train_model` function will then run in each process, with the model using180a separate device in each process.181182Here's the flow, where each step is split into its own utility function:183"""184185# Config186num_gpu = torch.cuda.device_count()187num_epochs = 2188batch_size = 64189print(f"Running on {num_gpu} GPUs")190191192def setup_device(current_gpu_index, num_gpus):193# Device setup194os.environ["MASTER_ADDR"] = "localhost"195os.environ["MASTER_PORT"] = "56492"196device = torch.device("cuda:{}".format(current_gpu_index))197torch.distributed.init_process_group(198backend="nccl",199init_method="env://",200world_size=num_gpus,201rank=current_gpu_index,202)203torch.cuda.set_device(device)204205206def cleanup():207torch.distributed.destroy_process_group()208209210def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):211sampler = torch.utils.data.distributed.DistributedSampler(212dataset,213num_replicas=num_gpus,214rank=current_gpu_index,215shuffle=False,216)217dataloader = torch.utils.data.DataLoader(218dataset,219sampler=sampler,220batch_size=batch_size,221shuffle=False,222)223return dataloader224225226def per_device_launch_fn(current_gpu_index, num_gpu):227# Setup the process groups228setup_device(current_gpu_index, num_gpu)229230dataset = get_dataset()231model = get_model()232233# prepare the dataloader234dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)235236# Instantiate the torch optimizer237optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)238239# Instantiate the torch loss function240loss_fn = torch.nn.CrossEntropyLoss()241242# Put model on device243model = model.to(current_gpu_index)244ddp_model = torch.nn.parallel.DistributedDataParallel(245model, device_ids=[current_gpu_index], output_device=current_gpu_index246)247248train_model(ddp_model, dataloader, num_epochs, optimizer, loss_fn)249250cleanup()251252253"""254Time to start multiple processes:255"""256257if __name__ == "__main__":258# We use the "fork" method rather than "spawn" to support notebooks259torch.multiprocessing.start_processes(260per_device_launch_fn,261args=(num_gpu,),262nprocs=num_gpu,263join=True,264start_method="fork",265)266267"""268That's it!269"""270271272