Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/distributed_training_with_torch.py
3273 views
1
"""
2
Title: Multi-GPU distributed training with PyTorch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/29
5
Last modified: 2023/06/29
6
Description: Guide to multi-GPU training for Keras models with PyTorch.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
There are generally two ways to distribute computation across multiple devices:
14
15
**Data parallelism**, where a single model gets replicated on multiple devices or
16
multiple machines. Each of them processes different batches of data, then they merge
17
their results. There exist many variants of this setup, that differ in how the different
18
model replicas merge results, in whether they stay in sync at every batch or whether they
19
are more loosely coupled, etc.
20
21
**Model parallelism**, where different parts of a single model run on different devices,
22
processing a single batch of data together. This works best with models that have a
23
naturally-parallel architecture, such as models that feature multiple branches.
24
25
This guide focuses on data parallelism, in particular **synchronous data parallelism**,
26
where the different replicas of the model stay in sync after each batch they process.
27
Synchronicity keeps the model convergence behavior identical to what you would see for
28
single-device training.
29
30
Specifically, this guide teaches you how to use PyTorch's `DistributedDataParallel`
31
module wrapper to train Keras, with minimal changes to your code,
32
on multiple GPUs (typically 2 to 16) installed on a single machine (single host,
33
multi-device training). This is the most common setup for researchers and small-scale
34
industry workflows.
35
"""
36
37
"""
38
## Setup
39
40
Let's start by defining the function that creates the model that we will train,
41
and the function that creates the dataset we will train on (MNIST in this case).
42
"""
43
44
import os
45
46
os.environ["KERAS_BACKEND"] = "torch"
47
48
import torch
49
import numpy as np
50
import keras
51
52
53
def get_model():
54
# Make a simple convnet with batch normalization and dropout.
55
inputs = keras.Input(shape=(28, 28, 1))
56
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
57
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
58
x
59
)
60
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
61
x = keras.layers.ReLU()(x)
62
x = keras.layers.Conv2D(
63
filters=24,
64
kernel_size=6,
65
use_bias=False,
66
strides=2,
67
)(x)
68
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
69
x = keras.layers.ReLU()(x)
70
x = keras.layers.Conv2D(
71
filters=32,
72
kernel_size=6,
73
padding="same",
74
strides=2,
75
name="large_k",
76
)(x)
77
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
78
x = keras.layers.ReLU()(x)
79
x = keras.layers.GlobalAveragePooling2D()(x)
80
x = keras.layers.Dense(256, activation="relu")(x)
81
x = keras.layers.Dropout(0.5)(x)
82
outputs = keras.layers.Dense(10)(x)
83
model = keras.Model(inputs, outputs)
84
return model
85
86
87
def get_dataset():
88
# Load the data and split it between train and test sets
89
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
90
91
# Scale images to the [0, 1] range
92
x_train = x_train.astype("float32")
93
x_test = x_test.astype("float32")
94
# Make sure images have shape (28, 28, 1)
95
x_train = np.expand_dims(x_train, -1)
96
x_test = np.expand_dims(x_test, -1)
97
print("x_train shape:", x_train.shape)
98
99
# Create a TensorDataset
100
dataset = torch.utils.data.TensorDataset(
101
torch.from_numpy(x_train), torch.from_numpy(y_train)
102
)
103
return dataset
104
105
106
"""
107
Next, let's define a simple PyTorch training loop that targets
108
a GPU (note the calls to `.cuda()`).
109
"""
110
111
112
def train_model(model, dataloader, num_epochs, optimizer, loss_fn):
113
for epoch in range(num_epochs):
114
running_loss = 0.0
115
running_loss_count = 0
116
for batch_idx, (inputs, targets) in enumerate(dataloader):
117
inputs = inputs.cuda(non_blocking=True)
118
targets = targets.cuda(non_blocking=True)
119
120
# Forward pass
121
outputs = model(inputs)
122
loss = loss_fn(outputs, targets)
123
124
# Backward and optimize
125
optimizer.zero_grad()
126
loss.backward()
127
optimizer.step()
128
129
running_loss += loss.item()
130
running_loss_count += 1
131
132
# Print loss statistics
133
print(
134
f"Epoch {epoch + 1}/{num_epochs}, "
135
f"Loss: {running_loss / running_loss_count}"
136
)
137
138
139
"""
140
## Single-host, multi-device synchronous training
141
142
In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each
143
device will run a copy of your model (called a **replica**). For simplicity, in what
144
follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
145
146
**How it works**
147
148
At each step of training:
149
150
- The current batch of data (called **global batch**) is split into 8 different
151
sub-batches (called **local batches**). For instance, if the global batch has 512
152
samples, each of the 8 local batches will have 64 samples.
153
- Each of the 8 replicas independently processes a local batch: they run a forward pass,
154
then a backward pass, outputting the gradient of the weights with respect to the loss of
155
the model on the local batch.
156
- The weight updates originating from local gradients are efficiently merged across the 8
157
replicas. Because this is done at the end of every step, the replicas always stay in
158
sync.
159
160
In practice, the process of synchronously updating the weights of the model replicas is
161
handled at the level of each individual weight variable. This is done through a **mirrored
162
variable** object.
163
164
**How to use it**
165
166
To do single-host, multi-device synchronous training with a Keras model, you would use
167
the `torch.nn.parallel.DistributedDataParallel` module wrapper.
168
Here's how it works:
169
170
- We use `torch.multiprocessing.start_processes` to start multiple Python processes, one
171
per device. Each process will run the `per_device_launch_fn` function.
172
- The `per_device_launch_fn` function does the following:
173
- It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`
174
to configure the device to be used for that process.
175
- It uses `torch.utils.data.distributed.DistributedSampler`
176
and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.
177
- It also uses `torch.nn.parallel.DistributedDataParallel` to turn our model into
178
a distributed PyTorch module.
179
- It then calls the `train_model` function.
180
- The `train_model` function will then run in each process, with the model using
181
a separate device in each process.
182
183
Here's the flow, where each step is split into its own utility function:
184
"""
185
186
# Config
187
num_gpu = torch.cuda.device_count()
188
num_epochs = 2
189
batch_size = 64
190
print(f"Running on {num_gpu} GPUs")
191
192
193
def setup_device(current_gpu_index, num_gpus):
194
# Device setup
195
os.environ["MASTER_ADDR"] = "localhost"
196
os.environ["MASTER_PORT"] = "56492"
197
device = torch.device("cuda:{}".format(current_gpu_index))
198
torch.distributed.init_process_group(
199
backend="nccl",
200
init_method="env://",
201
world_size=num_gpus,
202
rank=current_gpu_index,
203
)
204
torch.cuda.set_device(device)
205
206
207
def cleanup():
208
torch.distributed.destroy_process_group()
209
210
211
def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
212
sampler = torch.utils.data.distributed.DistributedSampler(
213
dataset,
214
num_replicas=num_gpus,
215
rank=current_gpu_index,
216
shuffle=False,
217
)
218
dataloader = torch.utils.data.DataLoader(
219
dataset,
220
sampler=sampler,
221
batch_size=batch_size,
222
shuffle=False,
223
)
224
return dataloader
225
226
227
def per_device_launch_fn(current_gpu_index, num_gpu):
228
# Setup the process groups
229
setup_device(current_gpu_index, num_gpu)
230
231
dataset = get_dataset()
232
model = get_model()
233
234
# prepare the dataloader
235
dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)
236
237
# Instantiate the torch optimizer
238
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
239
240
# Instantiate the torch loss function
241
loss_fn = torch.nn.CrossEntropyLoss()
242
243
# Put model on device
244
model = model.to(current_gpu_index)
245
ddp_model = torch.nn.parallel.DistributedDataParallel(
246
model, device_ids=[current_gpu_index], output_device=current_gpu_index
247
)
248
249
train_model(ddp_model, dataloader, num_epochs, optimizer, loss_fn)
250
251
cleanup()
252
253
254
"""
255
Time to start multiple processes:
256
"""
257
258
if __name__ == "__main__":
259
# We use the "fork" method rather than "spawn" to support notebooks
260
torch.multiprocessing.start_processes(
261
per_device_launch_fn,
262
args=(num_gpu,),
263
nprocs=num_gpu,
264
join=True,
265
start_method="fork",
266
)
267
268
"""
269
That's it!
270
"""
271
272