Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/_distributed_training.py
3273 views
1
"""
2
Title: Multi-GPU and distributed training
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/28
5
Last modified: 2020/04/29
6
Description: Guide to multi-GPU & distributed training for Keras models.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
There are generally two ways to distribute computation across multiple devices:
14
15
**Data parallelism**, where a single model gets replicated on multiple devices or
16
multiple machines. Each of them processes different batches of data, then they merge
17
their results. There exist many variants of this setup, that differ in how the different
18
model replicas merge results, in whether they stay in sync at every batch or whether they
19
are more loosely coupled, etc.
20
21
**Model parallelism**, where different parts of a single model run on different devices,
22
processing a single batch of data together. This works best with models that have a
23
naturally-parallel architecture, such as models that feature multiple branches.
24
25
This guide focuses on data parallelism, in particular **synchronous data parallelism**,
26
where the different replicas of the model stay in sync after each batch they process.
27
Synchronicity keeps the model convergence behavior identical to what you would see for
28
single-device training.
29
30
Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras
31
models on multiple GPUs, with minimal changes to your code, in the following two setups:
32
33
- On multiple GPUs (typically 2 to 8) installed on a single machine (single host,
34
multi-device training). This is the most common setup for researchers and small-scale
35
industry workflows.
36
- On a cluster of many machines, each hosting one or multiple GPUs (multi-worker
37
distributed training). This is a good setup for large-scale industry workflows, e.g.
38
training high-resolution image classification models on tens of millions of images using
39
20-100 GPUs.
40
41
42
"""
43
44
"""
45
## Setup
46
"""
47
48
import tensorflow as tf
49
import keras
50
51
"""
52
## Single-host, multi-device synchronous training
53
54
In this setup, you have one machine with several GPUs on it (typically 2 to 8). Each
55
device will run a copy of your model (called a **replica**). For simplicity, in what
56
follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
57
58
**How it works**
59
60
At each step of training:
61
62
- The current batch of data (called **global batch**) is split into 8 different
63
sub-batches (called **local batches**). For instance, if the global batch has 512
64
samples, each of the 8 local batches will have 64 samples.
65
- Each of the 8 replicas independently processes a local batch: they run a forward pass,
66
then a backward pass, outputting the gradient of the weights with respect to the loss of
67
the model on the local batch.
68
- The weight updates originating from local gradients are efficiently merged across the 8
69
replicas. Because this is done at the end of every step, the replicas always stay in
70
sync.
71
72
In practice, the process of synchronously updating the weights of the model replicas is
73
handled at the level of each individual weight variable. This is done through a **mirrored
74
variable** object.
75
76
**How to use it**
77
78
To do single-host, multi-device synchronous training with a Keras model, you would use
79
the [`tf.distribute.MirroredStrategy` API](
80
https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy).
81
Here's how it works:
82
83
- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you
84
want to use (by default the strategy will use all GPUs available).
85
- Use the strategy object to open a scope, and within this scope, create all the Keras
86
objects you need that contain variables. Typically, that means **creating & compiling the
87
model** inside the distribution scope.
88
- Train the model via `fit()` as usual.
89
90
Importantly, we recommend that you use `tf.data.Dataset` objects to load data
91
in a multi-device or distributed workflow.
92
93
Schematically, it looks like this:
94
95
```python
96
# Create a MirroredStrategy.
97
strategy = tf.distribute.MirroredStrategy()
98
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
99
100
# Open a strategy scope.
101
with strategy.scope():
102
# Everything that creates variables should be under the strategy scope.
103
# In general this is only model construction & `compile()`.
104
model = Model(...)
105
model.compile(...)
106
107
# Train the model on all available devices.
108
model.fit(train_dataset, validation_data=val_dataset, ...)
109
110
# Test the model on all available devices.
111
model.evaluate(test_dataset)
112
```
113
114
Here's a simple end-to-end runnable example:
115
116
117
"""
118
119
120
def get_compiled_model():
121
# Make a simple 2-layer densely-connected neural network.
122
inputs = keras.Input(shape=(784,))
123
x = keras.layers.Dense(256, activation="relu")(inputs)
124
x = keras.layers.Dense(256, activation="relu")(x)
125
outputs = keras.layers.Dense(10)(x)
126
model = keras.Model(inputs, outputs)
127
model.compile(
128
optimizer=keras.optimizers.Adam(),
129
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
130
metrics=[keras.metrics.SparseCategoricalAccuracy()],
131
)
132
return model
133
134
135
def get_dataset():
136
batch_size = 32
137
num_val_samples = 10000
138
139
# Return the MNIST dataset in the form of a `tf.data.Dataset`.
140
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
141
142
# Preprocess the data (these are Numpy arrays)
143
x_train = x_train.reshape(-1, 784).astype("float32") / 255
144
x_test = x_test.reshape(-1, 784).astype("float32") / 255
145
y_train = y_train.astype("float32")
146
y_test = y_test.astype("float32")
147
148
# Reserve num_val_samples samples for validation
149
x_val = x_train[-num_val_samples:]
150
y_val = y_train[-num_val_samples:]
151
x_train = x_train[:-num_val_samples]
152
y_train = y_train[:-num_val_samples]
153
return (
154
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
155
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
156
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
157
)
158
159
160
# Create a MirroredStrategy.
161
strategy = tf.distribute.MirroredStrategy()
162
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
163
164
# Open a strategy scope.
165
with strategy.scope():
166
# Everything that creates variables should be under the strategy scope.
167
# In general this is only model construction & `compile()`.
168
model = get_compiled_model()
169
170
# Train the model on all available devices.
171
train_dataset, val_dataset, test_dataset = get_dataset()
172
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
173
174
# Test the model on all available devices.
175
model.evaluate(test_dataset)
176
177
"""
178
## Using callbacks to ensure fault tolerance
179
180
When using distributed training, you should always make sure you have a strategy to
181
recover from failure (fault tolerance). The simplest way to handle this is to pass
182
`ModelCheckpoint` callback to `fit()`, to save your model
183
at regular intervals (e.g. every 100 batches or every epoch). You can then restart
184
training from your saved model.
185
186
Here's a simple example:
187
"""
188
189
import os
190
from tensorflow import keras
191
192
# Prepare a directory to store all the checkpoints.
193
checkpoint_dir = "./ckpt"
194
if not os.path.exists(checkpoint_dir):
195
os.makedirs(checkpoint_dir)
196
197
198
def make_or_restore_model():
199
# Either restore the latest model, or create a fresh one
200
# if there is no checkpoint available.
201
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
202
if checkpoints:
203
latest_checkpoint = max(checkpoints, key=os.path.getctime)
204
print("Restoring from", latest_checkpoint)
205
return keras.models.load_model(latest_checkpoint)
206
print("Creating a new model")
207
return get_compiled_model()
208
209
210
def run_training(epochs=1):
211
# Create a MirroredStrategy.
212
strategy = tf.distribute.MirroredStrategy()
213
214
# Open a strategy scope and create/restore the model
215
with strategy.scope():
216
model = make_or_restore_model()
217
218
callbacks = [
219
# This callback saves a SavedModel every epoch
220
# We include the current epoch in the folder name.
221
keras.callbacks.ModelCheckpoint(
222
filepath=checkpoint_dir + "/ckpt-{epoch}", save_freq="epoch"
223
)
224
]
225
model.fit(
226
train_dataset,
227
epochs=epochs,
228
callbacks=callbacks,
229
validation_data=val_dataset,
230
verbose=2,
231
)
232
233
234
# Running the first time creates the model
235
run_training(epochs=1)
236
237
# Calling the same function again will resume from where we left off
238
run_training(epochs=1)
239
240
"""
241
## `tf.data` performance tips
242
243
When doing distributed training, the efficiency with which you load data can often become
244
critical. Here are a few tips to make sure your `tf.data` pipelines
245
run as fast as possible.
246
247
**Note about dataset batching**
248
249
When creating your dataset, make sure it is batched with the global batch size.
250
For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you
251
call use a global batch size of 512.
252
253
**Calling `dataset.cache()`**
254
255
If you call `.cache()` on a dataset, its data will be cached after running through the
256
first iteration over the data. Every subsequent iteration will use the cached data. The
257
cache can be in memory (default) or to a local file you specify.
258
259
This can improve performance when:
260
261
- Your data is not expected to change from iteration to iteration
262
- You are reading data from a remote distributed filesystem
263
- You are reading data from local disk, but your data would fit in memory and your
264
workflow is significantly IO-bound (e.g. reading & decoding image files).
265
266
**Calling `dataset.prefetch(buffer_size)`**
267
268
You should almost always call `.prefetch(buffer_size)` after creating a dataset. It means
269
your data pipeline will run asynchronously from your model,
270
with new samples being preprocessed and stored in a buffer while the current batch
271
samples are used to train the model. The next batch will be prefetched in GPU memory by
272
the time the current batch is over.
273
"""
274
275
"""
276
## Multi-worker distributed synchronous training
277
278
**How it works**
279
280
In this setup, you have multiple machines (called **workers**), each with one or several
281
GPUs on them. Much like what happens for single-host training,
282
each available GPU will run one model replica, and the value of the variables of each
283
replica is kept in sync after each batch.
284
285
Importantly, the current implementation assumes that all workers have the same number of
286
GPUs (homogeneous cluster).
287
288
**How to use it**
289
290
1. Set up a cluster (we provide pointers below).
291
2. Set up an appropriate `TF_CONFIG` environment variable on each worker. This tells the
292
worker what its role is and how to communicate with its peers.
293
3. On each worker, run your model construction & compilation code within the scope of a
294
[`MultiWorkerMirroredStrategy` object](
295
https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy),
296
similarly to we did for single-host training.
297
4. Run evaluation code on a designated evaluator machine.
298
299
**Setting up a cluster**
300
301
First, set up a cluster (collective of machines). Each machine individually should be
302
setup so as to be able to run your model (typically, each machine will run the same
303
Docker image) and to able to access your data source (e.g. GCS).
304
305
Cluster management is beyond the scope of this guide.
306
[Here is a document](
307
https://cloud.google.com/ai-platform/training/docs/distributed-training-containers)
308
to help you get started.
309
You can also take a look at [Kubeflow](https://www.kubeflow.org/).
310
311
**Setting up the `TF_CONFIG` environment variable**
312
313
While the code running on each worker is almost the same as the code used in the
314
single-host workflow (except with a different `tf.distribute` strategy object), one
315
significant difference between the single-host workflow and the multi-worker workflow is
316
that you need to set a `TF_CONFIG` environment variable on each machine running in your
317
cluster.
318
319
The `TF_CONFIG` environment variable is a JSON string that specifies:
320
321
- The cluster configuration, while the list of addresses & ports of the machines that
322
make up the cluster
323
- The worker's "task", which is the role that this specific machine has to play within
324
the cluster.
325
326
One example of TF_CONFIG is:
327
328
```
329
os.environ['TF_CONFIG'] = json.dumps({
330
'cluster': {
331
'worker': ["localhost:12345", "localhost:23456"]
332
},
333
'task': {'type': 'worker', 'index': 0}
334
})
335
```
336
337
In the multi-worker synchronous training setup, valid roles (task types) for the machines
338
are "worker" and "evaluator".
339
340
For example, if you have 8 machines with 4 GPUs each, you could have 7 workers and one
341
evaluator.
342
343
- The workers train the model, each one processing sub-batches of a global batch.
344
- One of the workers (worker 0) will serve as "chief", a particular kind of worker that
345
is responsible for saving logs and checkpoints for later reuse (typically to a Cloud
346
storage location).
347
- The evaluator runs a continuous loop that loads the latest checkpoint saved by the
348
chief worker, runs evaluation on it (asynchronously from the other workers) and writes
349
evaluation logs (e.g. TensorBoard logs).
350
351
352
**Running code on each worker**
353
354
You would run training code on each worker (including the chief) and evaluation code on
355
the evaluator.
356
357
The training code is basically the same as what you would use in the single-host setup,
358
except using `MultiWorkerMirroredStrategy` instead of `MirroredStrategy`.
359
360
Each worker would run the same code (minus the difference explained in the note below),
361
including the same callbacks.
362
363
**Note:** Callbacks that save model checkpoints or logs should save to a different
364
directory for each worker. It is standard practice that all workers should save to local
365
disk (which is typically temporary), **except worker 0**, which would save TensorBoard
366
logs checkpoints to a Cloud storage location for later access & reuse.
367
368
The evaluator would simply use `MirroredStrategy` (since it runs on a single machine and
369
does not need to communicate with other machines) and call `model.evaluate()`. It would be
370
loading the latest checkpoint saved by the chief worker to a Cloud storage location, and
371
would save evaluation logs to the same location as the chief logs.
372
373
374
"""
375
376
"""
377
### Example: code running in a multi-worker setup
378
379
On the chief (worker 0):
380
381
```python
382
# Set TF_CONFIG
383
os.environ['TF_CONFIG'] = json.dumps({
384
'cluster': {
385
'worker': ["localhost:12345", "localhost:23456"]
386
},
387
'task': {'type': 'worker', 'index': 0}
388
})
389
390
391
# Open a strategy scope and create/restore the model.
392
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
393
with strategy.scope():
394
model = make_or_restore_model()
395
396
callbacks = [
397
# This callback saves a SavedModel every 100 batches
398
keras.callbacks.ModelCheckpoint(filepath='path/to/cloud/location/ckpt',
399
save_freq=100),
400
keras.callbacks.TensorBoard('path/to/cloud/location/tb/')
401
]
402
model.fit(train_dataset,
403
callbacks=callbacks,
404
...)
405
```
406
407
On other workers:
408
409
```python
410
# Set TF_CONFIG
411
worker_index = 1 # For instance
412
os.environ['TF_CONFIG'] = json.dumps({
413
'cluster': {
414
'worker': ["localhost:12345", "localhost:23456"]
415
},
416
'task': {'type': 'worker', 'index': worker_index}
417
})
418
419
420
# Open a strategy scope and create/restore the model.
421
# You can restore from the checkpoint saved by the chief.
422
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
423
with strategy.scope():
424
model = make_or_restore_model()
425
426
callbacks = [
427
keras.callbacks.ModelCheckpoint(filepath='local/path/ckpt', save_freq=100),
428
keras.callbacks.TensorBoard('local/path/tb/')
429
]
430
model.fit(train_dataset,
431
callbacks=callbacks,
432
...)
433
```
434
435
On the evaluator:
436
437
```python
438
strategy = tf.distribute.MirroredStrategy()
439
with strategy.scope():
440
model = make_or_restore_model() # Restore from the checkpoint saved by the chief.
441
442
results = model.evaluate(val_dataset)
443
# Then, log the results on a shared location, write TensorBoard logs, etc
444
```
445
446
447
"""
448
449
"""
450
### Further reading
451
452
453
1. [TensorFlow distributed training guide](
454
https://www.tensorflow.org/guide/distributed_training)
455
2. [Tutorial on multi-worker training with Keras](
456
https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
457
3. [MirroredStrategy docs](
458
https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
459
4. [MultiWorkerMirroredStrategy docs](
460
https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
461
5. [Distributed training in tf.keras with Weights & Biases](
462
https://towardsdatascience.com/distributed-training-in-tf-keras-with-w-b-ccf021f9322e)
463
"""
464
465