Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/distributed_training_with_tensorflow.py
3273 views
1
"""
2
Title: Multi-GPU distributed training with TensorFlow
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/28
5
Last modified: 2023/06/29
6
Description: Guide to multi-GPU training for Keras models with TensorFlow.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
There are generally two ways to distribute computation across multiple devices:
14
15
**Data parallelism**, where a single model gets replicated on multiple devices or
16
multiple machines. Each of them processes different batches of data, then they merge
17
their results. There exist many variants of this setup, that differ in how the different
18
model replicas merge results, in whether they stay in sync at every batch or whether they
19
are more loosely coupled, etc.
20
21
**Model parallelism**, where different parts of a single model run on different devices,
22
processing a single batch of data together. This works best with models that have a
23
naturally-parallel architecture, such as models that feature multiple branches.
24
25
This guide focuses on data parallelism, in particular **synchronous data parallelism**,
26
where the different replicas of the model stay in sync after each batch they process.
27
Synchronicity keeps the model convergence behavior identical to what you would see for
28
single-device training.
29
30
Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras
31
models on multiple GPUs, with minimal changes to your code,
32
on multiple GPUs (typically 2 to 16) installed on a single machine (single host,
33
multi-device training). This is the most common setup for researchers and small-scale
34
industry workflows.
35
"""
36
37
"""
38
## Setup
39
"""
40
41
import os
42
43
os.environ["KERAS_BACKEND"] = "tensorflow"
44
45
import tensorflow as tf
46
import keras
47
48
"""
49
## Single-host, multi-device synchronous training
50
51
In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each
52
device will run a copy of your model (called a **replica**). For simplicity, in what
53
follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
54
55
**How it works**
56
57
At each step of training:
58
59
- The current batch of data (called **global batch**) is split into 8 different
60
sub-batches (called **local batches**). For instance, if the global batch has 512
61
samples, each of the 8 local batches will have 64 samples.
62
- Each of the 8 replicas independently processes a local batch: they run a forward pass,
63
then a backward pass, outputting the gradient of the weights with respect to the loss of
64
the model on the local batch.
65
- The weight updates originating from local gradients are efficiently merged across the 8
66
replicas. Because this is done at the end of every step, the replicas always stay in
67
sync.
68
69
In practice, the process of synchronously updating the weights of the model replicas is
70
handled at the level of each individual weight variable. This is done through a **mirrored
71
variable** object.
72
73
**How to use it**
74
75
To do single-host, multi-device synchronous training with a Keras model, you would use
76
the [`tf.distribute.MirroredStrategy` API](
77
https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy).
78
Here's how it works:
79
80
- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you
81
want to use (by default the strategy will use all GPUs available).
82
- Use the strategy object to open a scope, and within this scope, create all the Keras
83
objects you need that contain variables. Typically, that means **creating & compiling the
84
model** inside the distribution scope. In some cases, the first call to `fit()` may also
85
create variables, so it's a good idea to put your `fit()` call in the scope as well.
86
- Train the model via `fit()` as usual.
87
88
Importantly, we recommend that you use `tf.data.Dataset` objects to load data
89
in a multi-device or distributed workflow.
90
91
Schematically, it looks like this:
92
93
```python
94
# Create a MirroredStrategy.
95
strategy = tf.distribute.MirroredStrategy()
96
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
97
98
# Open a strategy scope.
99
with strategy.scope():
100
# Everything that creates variables should be under the strategy scope.
101
# In general this is only model construction & `compile()`.
102
model = Model(...)
103
model.compile(...)
104
105
# Train the model on all available devices.
106
model.fit(train_dataset, validation_data=val_dataset, ...)
107
108
# Test the model on all available devices.
109
model.evaluate(test_dataset)
110
```
111
112
Here's a simple end-to-end runnable example:
113
"""
114
115
116
def get_compiled_model():
117
# Make a simple 2-layer densely-connected neural network.
118
inputs = keras.Input(shape=(784,))
119
x = keras.layers.Dense(256, activation="relu")(inputs)
120
x = keras.layers.Dense(256, activation="relu")(x)
121
outputs = keras.layers.Dense(10)(x)
122
model = keras.Model(inputs, outputs)
123
model.compile(
124
optimizer=keras.optimizers.Adam(),
125
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
126
metrics=[keras.metrics.SparseCategoricalAccuracy()],
127
)
128
return model
129
130
131
def get_dataset():
132
batch_size = 32
133
num_val_samples = 10000
134
135
# Return the MNIST dataset in the form of a `tf.data.Dataset`.
136
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
137
138
# Preprocess the data (these are Numpy arrays)
139
x_train = x_train.reshape(-1, 784).astype("float32") / 255
140
x_test = x_test.reshape(-1, 784).astype("float32") / 255
141
y_train = y_train.astype("float32")
142
y_test = y_test.astype("float32")
143
144
# Reserve num_val_samples samples for validation
145
x_val = x_train[-num_val_samples:]
146
y_val = y_train[-num_val_samples:]
147
x_train = x_train[:-num_val_samples]
148
y_train = y_train[:-num_val_samples]
149
return (
150
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
151
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
152
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
153
)
154
155
156
# Create a MirroredStrategy.
157
strategy = tf.distribute.MirroredStrategy()
158
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
159
160
# Open a strategy scope.
161
with strategy.scope():
162
# Everything that creates variables should be under the strategy scope.
163
# In general this is only model construction & `compile()`.
164
model = get_compiled_model()
165
166
# Train the model on all available devices.
167
train_dataset, val_dataset, test_dataset = get_dataset()
168
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
169
170
# Test the model on all available devices.
171
model.evaluate(test_dataset)
172
173
"""
174
## Using callbacks to ensure fault tolerance
175
176
When using distributed training, you should always make sure you have a strategy to
177
recover from failure (fault tolerance). The simplest way to handle this is to pass
178
`ModelCheckpoint` callback to `fit()`, to save your model
179
at regular intervals (e.g. every 100 batches or every epoch). You can then restart
180
training from your saved model.
181
182
Here's a simple example:
183
"""
184
185
# Prepare a directory to store all the checkpoints.
186
checkpoint_dir = "./ckpt"
187
if not os.path.exists(checkpoint_dir):
188
os.makedirs(checkpoint_dir)
189
190
191
def make_or_restore_model():
192
# Either restore the latest model, or create a fresh one
193
# if there is no checkpoint available.
194
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
195
if checkpoints:
196
latest_checkpoint = max(checkpoints, key=os.path.getctime)
197
print("Restoring from", latest_checkpoint)
198
return keras.models.load_model(latest_checkpoint)
199
print("Creating a new model")
200
return get_compiled_model()
201
202
203
def run_training(epochs=1):
204
# Create a MirroredStrategy.
205
strategy = tf.distribute.MirroredStrategy()
206
207
# Open a strategy scope and create/restore the model
208
with strategy.scope():
209
model = make_or_restore_model()
210
211
callbacks = [
212
# This callback saves a SavedModel every epoch
213
# We include the current epoch in the folder name.
214
keras.callbacks.ModelCheckpoint(
215
filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
216
save_freq="epoch",
217
)
218
]
219
model.fit(
220
train_dataset,
221
epochs=epochs,
222
callbacks=callbacks,
223
validation_data=val_dataset,
224
verbose=2,
225
)
226
227
228
# Running the first time creates the model
229
run_training(epochs=1)
230
231
# Calling the same function again will resume from where we left off
232
run_training(epochs=1)
233
234
"""
235
## `tf.data` performance tips
236
237
When doing distributed training, the efficiency with which you load data can often become
238
critical. Here are a few tips to make sure your `tf.data` pipelines
239
run as fast as possible.
240
241
**Note about dataset batching**
242
243
When creating your dataset, make sure it is batched with the global batch size.
244
For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you
245
call use a global batch size of 512.
246
247
**Calling `dataset.cache()`**
248
249
If you call `.cache()` on a dataset, its data will be cached after running through the
250
first iteration over the data. Every subsequent iteration will use the cached data. The
251
cache can be in memory (default) or to a local file you specify.
252
253
This can improve performance when:
254
255
- Your data is not expected to change from iteration to iteration
256
- You are reading data from a remote distributed filesystem
257
- You are reading data from local disk, but your data would fit in memory and your
258
workflow is significantly IO-bound (e.g. reading & decoding image files).
259
260
**Calling `dataset.prefetch(buffer_size)`**
261
262
You should almost always call `.prefetch(buffer_size)` after creating a dataset. It means
263
your data pipeline will run asynchronously from your model,
264
with new samples being preprocessed and stored in a buffer while the current batch
265
samples are used to train the model. The next batch will be prefetched in GPU memory by
266
the time the current batch is over.
267
"""
268
269
"""
270
That's it!
271
"""
272
273