Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/distribution.py
3273 views
1
"""
2
Title: Distributed training with Keras 3
3
Author: [Qianli Zhu](https://github.com/qlzh727)
4
Date created: 2023/11/07
5
Last modified: 2023/11/07
6
Description: Complete guide to the distribution API for multi-backend Keras.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The Keras distribution API is a new interface designed to facilitate
14
distributed deep learning across a variety of backends like JAX, TensorFlow and
15
PyTorch. This powerful API introduces a suite of tools enabling data and model
16
parallelism, allowing for efficient scaling of deep learning models on multiple
17
accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API
18
provides a streamlined approach to initializing distributed environments,
19
defining device meshes, and orchestrating the layout of tensors across
20
computational resources. Through classes like `DataParallel` and
21
`ModelParallel`, it abstracts the complexity involved in parallel computation,
22
making it easier for developers to accelerate their machine learning
23
workflows.
24
25
"""
26
27
"""
28
## How it works
29
30
The Keras distribution API provides a global programming model that allows
31
developers to compose applications that operate on tensors in a global context
32
(as if working with a single device) while
33
automatically managing distribution across many devices. The API leverages the
34
underlying framework (e.g. JAX) to distribute the program and tensors according to the
35
sharding directives through a procedure called single program, multiple data
36
(SPMD) expansion.
37
38
By decoupling the application from sharding directives, the API enables running
39
the same application on a single device, multiple devices, or even multiple
40
clients, while preserving its global semantics.
41
"""
42
43
"""
44
## Setup
45
"""
46
47
import os
48
49
# The distribution API is only implemented for the JAX backend for now.
50
os.environ["KERAS_BACKEND"] = "jax"
51
52
import keras
53
from keras import layers
54
import jax
55
import numpy as np
56
from tensorflow import data as tf_data # For dataset input.
57
58
"""
59
## `DeviceMesh` and `TensorLayout`
60
61
The `keras.distribution.DeviceMesh` class in Keras distribution API represents a cluster of
62
computational devices configured for distributed computation. It aligns with
63
similar concepts in [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) and
64
[`tf.dtensor.Mesh`](https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh),
65
where it's used to map the physical devices to a logical mesh structure.
66
67
The `TensorLayout` class then specifies how tensors are distributed across the
68
`DeviceMesh`, detailing the sharding of tensors along specified axes that
69
correspond to the names of the axes in the `DeviceMesh`.
70
71
You can find more detailed concept explainers in the
72
[TensorFlow DTensor guide](https://www.tensorflow.org/guide/dtensor_overview#dtensors_model_of_distributed_tensors).
73
"""
74
75
# Retrieve the local available gpu devices.
76
devices = jax.devices("gpu") # Assume it has 8 local GPUs.
77
78
# Define a 2x4 device mesh with data and model parallel axes
79
mesh = keras.distribution.DeviceMesh(
80
shape=(2, 4), axis_names=["data", "model"], devices=devices
81
)
82
83
# A 2D layout, which describes how a tensor is distributed across the
84
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
85
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
86
# devices on the mesh.
87
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)
88
89
# A 4D layout which could be used for data parallel of a image input.
90
replicated_layout_4d = keras.distribution.TensorLayout(
91
axes=("data", None, None, None), device_mesh=mesh
92
)
93
94
"""
95
## Distribution
96
97
The `Distribution` class in Keras serves as a foundational abstract class designed
98
for developing custom distribution strategies. It encapsulates the core logic
99
needed to distribute a model's variables, input data, and intermediate
100
computations across a device mesh. As an end user, you won't have to interact
101
directly with this class, but its subclasses like `DataParallel` or
102
`ModelParallel`.
103
"""
104
105
"""
106
## DataParallel
107
108
The `DataParallel` class in the Keras distribution API is designed for the
109
data parallelism strategy in distributed training, where the model weights are
110
replicated across all devices in the `DeviceMesh`, and each device processes a
111
portion of the input data.
112
113
Here is a sample usage of this class.
114
"""
115
116
# Create DataParallel with list of devices.
117
# As a shortcut, the devices can be skipped,
118
# and Keras will detect all local available devices.
119
# E.g. data_parallel = DataParallel()
120
data_parallel = keras.distribution.DataParallel(devices=devices)
121
122
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
123
mesh_1d = keras.distribution.DeviceMesh(
124
shape=(8,), axis_names=["data"], devices=devices
125
)
126
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
127
128
inputs = np.random.normal(size=(128, 28, 28, 1))
129
labels = np.random.normal(size=(128, 10))
130
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
131
132
# Set the global distribution.
133
keras.distribution.set_distribution(data_parallel)
134
135
# Note that all the model weights from here on are replicated to
136
# all the devices of the `DeviceMesh`. This includes the RNG
137
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
138
# `model.evaluate` will be split evenly on the batch dimension, and sent to
139
# all the devices. You don't have to do any manual aggregration of losses,
140
# since all the computation happens in a global context.
141
inputs = layers.Input(shape=(28, 28, 1))
142
y = layers.Flatten()(inputs)
143
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
144
y = layers.Dropout(0.4)(y)
145
y = layers.Dense(units=10, activation="softmax")(y)
146
model = keras.Model(inputs=inputs, outputs=y)
147
148
model.compile(loss="mse")
149
model.fit(dataset, epochs=3)
150
model.evaluate(dataset)
151
152
"""
153
## `ModelParallel` and `LayoutMap`
154
155
`ModelParallel` will be mostly useful when model weights are too large to fit
156
on a single accelerator. This setting allows you to spit your model weights or
157
activation tensors across all the devices on the `DeviceMesh`, and enable the
158
horizontal scaling for the large models.
159
160
Unlike the `DataParallel` model where all weights are fully replicated,
161
the weights layout under `ModelParallel` usually need some customization for
162
best performances. We introduce `LayoutMap` to let you specify the
163
`TensorLayout` for any weights and intermediate tensors from global perspective.
164
165
`LayoutMap` is a dict-like object that maps a string to `TensorLayout`
166
instances. It behaves differently from a normal Python dict in that the string
167
key is treated as a regex when retrieving the value. The class allows you to
168
define the naming schema of `TensorLayout` and then retrieve the corresponding
169
`TensorLayout` instance. Typically, the key used to query
170
is the `variable.path` attribute, which is the identifier of the variable.
171
As a shortcut, a tuple or list of axis
172
names is also allowed when inserting a value, and it will be converted to
173
`TensorLayout`.
174
175
The `LayoutMap` can also optionally contain a `DeviceMesh` to populate the
176
`TensorLayout.device_mesh` if it is not set. When retrieving a layout with a
177
key, and if there isn't an exact match, all existing keys in the layout map will
178
be treated as regex and matched against the input key again. If there are
179
multiple matches, a `ValueError` is raised. If no matches are found, `None` is
180
returned.
181
"""
182
183
mesh_2d = keras.distribution.DeviceMesh(
184
shape=(2, 4), axis_names=["data", "model"], devices=devices
185
)
186
layout_map = keras.distribution.LayoutMap(mesh_2d)
187
# The rule below means that for any weights that match with d1/kernel, it
188
# will be sharded with model dimensions (4 devices), same for the d1/bias.
189
# All other weights will be fully replicated.
190
layout_map["d1/kernel"] = (None, "model")
191
layout_map["d1/bias"] = ("model",)
192
193
# You can also set the layout for the layer output like
194
layout_map["d2/output"] = ("data", None)
195
196
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")
197
198
keras.distribution.set_distribution(model_parallel)
199
200
inputs = layers.Input(shape=(28, 28, 1))
201
y = layers.Flatten()(inputs)
202
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
203
y = layers.Dropout(0.4)(y)
204
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
205
model = keras.Model(inputs=inputs, outputs=y)
206
207
# The data will be sharded across the "data" dimension of the method, which
208
# has 2 devices.
209
model.compile(loss="mse")
210
model.fit(dataset, epochs=3)
211
model.evaluate(dataset)
212
213
"""
214
It is also easy to change the mesh structure to tune the computation between
215
more data parallel or model parallel. You can do this by adjusting the shape of
216
the mesh. And no changes are needed for any other code.
217
"""
218
219
full_data_parallel_mesh = keras.distribution.DeviceMesh(
220
shape=(8, 1), axis_names=["data", "model"], devices=devices
221
)
222
more_data_parallel_mesh = keras.distribution.DeviceMesh(
223
shape=(4, 2), axis_names=["data", "model"], devices=devices
224
)
225
more_model_parallel_mesh = keras.distribution.DeviceMesh(
226
shape=(2, 4), axis_names=["data", "model"], devices=devices
227
)
228
full_model_parallel_mesh = keras.distribution.DeviceMesh(
229
shape=(1, 8), axis_names=["data", "model"], devices=devices
230
)
231
232
"""
233
### Further reading
234
235
1. [JAX Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
236
2. [JAX sharding module](https://jax.readthedocs.io/en/latest/jax.sharding.html)
237
3. [TensorFlow Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial)
238
4. [TensorFlow DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview)
239
5. [Using DTensors with tf.keras](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial)
240
"""
241
242