Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/distributed_training_with_jax.py
3273 views
1
"""
2
Title: Multi-GPU distributed training with JAX
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/07/11
5
Last modified: 2023/07/11
6
Description: Guide to multi-GPU/TPU training for Keras models with JAX.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
There are generally two ways to distribute computation across multiple devices:
14
15
**Data parallelism**, where a single model gets replicated on multiple devices or
16
multiple machines. Each of them processes different batches of data, then they merge
17
their results. There exist many variants of this setup, that differ in how the different
18
model replicas merge results, in whether they stay in sync at every batch or whether they
19
are more loosely coupled, etc.
20
21
**Model parallelism**, where different parts of a single model run on different devices,
22
processing a single batch of data together. This works best with models that have a
23
naturally-parallel architecture, such as models that feature multiple branches.
24
25
This guide focuses on data parallelism, in particular **synchronous data parallelism**,
26
where the different replicas of the model stay in sync after each batch they process.
27
Synchronicity keeps the model convergence behavior identical to what you would see for
28
single-device training.
29
30
Specifically, this guide teaches you how to use `jax.sharding` APIs to train Keras
31
models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16)
32
installed on a single machine (single host, multi-device training). This is the
33
most common setup for researchers and small-scale industry workflows.
34
"""
35
36
"""
37
## Setup
38
39
Let's start by defining the function that creates the model that we will train,
40
and the function that creates the dataset we will train on (MNIST in this case).
41
"""
42
43
import os
44
45
os.environ["KERAS_BACKEND"] = "jax"
46
47
import jax
48
import numpy as np
49
import tensorflow as tf
50
import keras
51
52
from jax.experimental import mesh_utils
53
from jax.sharding import Mesh
54
from jax.sharding import NamedSharding
55
from jax.sharding import PartitionSpec as P
56
57
58
def get_model():
59
# Make a simple convnet with batch normalization and dropout.
60
inputs = keras.Input(shape=(28, 28, 1))
61
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
62
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
63
x
64
)
65
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
66
x = keras.layers.ReLU()(x)
67
x = keras.layers.Conv2D(
68
filters=24,
69
kernel_size=6,
70
use_bias=False,
71
strides=2,
72
)(x)
73
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
74
x = keras.layers.ReLU()(x)
75
x = keras.layers.Conv2D(
76
filters=32,
77
kernel_size=6,
78
padding="same",
79
strides=2,
80
name="large_k",
81
)(x)
82
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
83
x = keras.layers.ReLU()(x)
84
x = keras.layers.GlobalAveragePooling2D()(x)
85
x = keras.layers.Dense(256, activation="relu")(x)
86
x = keras.layers.Dropout(0.5)(x)
87
outputs = keras.layers.Dense(10)(x)
88
model = keras.Model(inputs, outputs)
89
return model
90
91
92
def get_datasets():
93
# Load the data and split it between train and test sets
94
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
95
96
# Scale images to the [0, 1] range
97
x_train = x_train.astype("float32")
98
x_test = x_test.astype("float32")
99
# Make sure images have shape (28, 28, 1)
100
x_train = np.expand_dims(x_train, -1)
101
x_test = np.expand_dims(x_test, -1)
102
print("x_train shape:", x_train.shape)
103
print(x_train.shape[0], "train samples")
104
print(x_test.shape[0], "test samples")
105
106
# Create TF Datasets
107
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
108
eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
109
return train_data, eval_data
110
111
112
"""
113
## Single-host, multi-device synchronous training
114
115
In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16).
116
Each device will run a copy of your model (called a **replica**). For simplicity, in
117
what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
118
119
**How it works**
120
121
At each step of training:
122
123
- The current batch of data (called **global batch**) is split into 8 different
124
sub-batches (called **local batches**). For instance, if the global batch has 512
125
samples, each of the 8 local batches will have 64 samples.
126
- Each of the 8 replicas independently processes a local batch: they run a forward pass,
127
then a backward pass, outputting the gradient of the weights with respect to the loss of
128
the model on the local batch.
129
- The weight updates originating from local gradients are efficiently merged across the 8
130
replicas. Because this is done at the end of every step, the replicas always stay in
131
sync.
132
133
In practice, the process of synchronously updating the weights of the model replicas is
134
handled at the level of each individual weight variable. This is done through a using
135
a `jax.sharding.NamedSharding` that is configured to replicate the variables.
136
137
**How to use it**
138
139
To do single-host, multi-device synchronous training with a Keras model, you
140
would use the `jax.sharding` features. Here's how it works:
141
142
- We first create a device mesh using `mesh_utils.create_device_mesh`.
143
- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and
144
`jax.sharding.PartitionSpec` to define how to partition JAX arrays.
145
- We specify that we want to replicate the model and optimizer variables
146
across all devices by using a spec with no axis.
147
- We specify that we want to shard the data across devices by using a spec
148
that splits along the batch dimension.
149
- We use `jax.device_put` to replicate the model and optimizer variables across
150
devices. This happens once at the beginning.
151
- In the training loop, for each batch that we process, we use `jax.device_put`
152
to split the batch across devices before invoking the train step.
153
154
Here's the flow, where each step is split into its own utility function:
155
"""
156
157
# Config
158
num_epochs = 2
159
batch_size = 64
160
161
train_data, eval_data = get_datasets()
162
train_data = train_data.batch(batch_size, drop_remainder=True)
163
164
model = get_model()
165
optimizer = keras.optimizers.Adam(1e-3)
166
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
167
168
# Initialize all state with .build()
169
(one_batch, one_batch_labels) = next(iter(train_data))
170
model.build(one_batch)
171
optimizer.build(model.trainable_variables)
172
173
174
# This is the loss function that will be differentiated.
175
# Keras provides a pure functional forward pass: model.stateless_call
176
def compute_loss(trainable_variables, non_trainable_variables, x, y):
177
y_pred, updated_non_trainable_variables = model.stateless_call(
178
trainable_variables, non_trainable_variables, x, training=True
179
)
180
loss_value = loss(y, y_pred)
181
return loss_value, updated_non_trainable_variables
182
183
184
# Function to compute gradients
185
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)
186
187
188
# Training step, Keras provides a pure functional optimizer.stateless_apply
189
@jax.jit
190
def train_step(train_state, x, y):
191
trainable_variables, non_trainable_variables, optimizer_variables = train_state
192
(loss_value, non_trainable_variables), grads = compute_gradients(
193
trainable_variables, non_trainable_variables, x, y
194
)
195
196
trainable_variables, optimizer_variables = optimizer.stateless_apply(
197
optimizer_variables, grads, trainable_variables
198
)
199
200
return loss_value, (
201
trainable_variables,
202
non_trainable_variables,
203
optimizer_variables,
204
)
205
206
207
# Replicate the model and optimizer variable on all devices
208
def get_replicated_train_state(devices):
209
# All variables will be replicated on all devices
210
var_mesh = Mesh(devices, axis_names=("_"))
211
# In NamedSharding, axes not mentioned are replicated (all axes here)
212
var_replication = NamedSharding(var_mesh, P())
213
214
# Apply the distribution settings to the model variables
215
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
216
non_trainable_variables = jax.device_put(
217
model.non_trainable_variables, var_replication
218
)
219
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
220
221
# Combine all state in a tuple
222
return (trainable_variables, non_trainable_variables, optimizer_variables)
223
224
225
num_devices = len(jax.local_devices())
226
print(f"Running on {num_devices} devices: {jax.local_devices()}")
227
devices = mesh_utils.create_device_mesh((num_devices,))
228
229
# Data will be split along the batch axis
230
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
231
data_sharding = NamedSharding(
232
data_mesh,
233
P(
234
"batch",
235
),
236
) # naming axes of the sharded partition
237
238
# Display data sharding
239
x, y = next(iter(train_data))
240
sharded_x = jax.device_put(x.numpy(), data_sharding)
241
print("Data sharding")
242
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))
243
244
train_state = get_replicated_train_state(devices)
245
246
# Custom training loop
247
for epoch in range(num_epochs):
248
data_iter = iter(train_data)
249
for data in data_iter:
250
x, y = data
251
sharded_x = jax.device_put(x.numpy(), data_sharding)
252
loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
253
print("Epoch", epoch, "loss:", loss_value)
254
255
# Post-processing model state update to write them back into the model
256
trainable_variables, non_trainable_variables, optimizer_variables = train_state
257
for variable, value in zip(model.trainable_variables, trainable_variables):
258
variable.assign(value)
259
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
260
variable.assign(value)
261
262
"""
263
That's it!
264
"""
265
266