Distributed training with Keras 3
Author: Qianli Zhu
Date created: 2023/11/07
Last modified: 2023/11/07
Description: Complete guide to the distribution API for multi-backend Keras.
View in Colab •
GitHub source
Introduction
The Keras distribution API is a new interface designed to facilitate distributed deep learning across a variety of backends like JAX, TensorFlow and PyTorch. This powerful API introduces a suite of tools enabling data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API provides a streamlined approach to initializing distributed environments, defining device meshes, and orchestrating the layout of tensors across computational resources. Through classes like DataParallel
and ModelParallel
, it abstracts the complexity involved in parallel computation, making it easier for developers to accelerate their machine learning workflows.
How it works
The Keras distribution API provides a global programming model that allows developers to compose applications that operate on tensors in a global context (as if working with a single device) while automatically managing distribution across many devices. The API leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion.
By decoupling the application from sharding directives, the API enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics.
Setup
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data
DeviceMesh
and TensorLayout
The keras.distribution.DeviceMesh
class in Keras distribution API represents a cluster of computational devices configured for distributed computation. It aligns with similar concepts in jax.sharding.Mesh
and tf.dtensor.Mesh
, where it's used to map the physical devices to a logical mesh structure.
The TensorLayout
class then specifies how tensors are distributed across the DeviceMesh
, detailing the sharding of tensors along specified axes that correspond to the names of the axes in the DeviceMesh
.
You can find more detailed concept explainers in the TensorFlow DTensor guide.
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)
replicated_layout_4d = keras.distribution.TensorLayout(
axes=("data", None, None, None), device_mesh=mesh
)
Distribution
The Distribution
class in Keras serves as a foundational abstract class designed for developing custom distribution strategies. It encapsulates the core logic needed to distribute a model's variables, input data, and intermediate computations across a device mesh. As an end user, you won't have to interact directly with this class, but its subclasses like DataParallel
or ModelParallel
.
DataParallel
The DataParallel
class in the Keras distribution API is designed for the data parallelism strategy in distributed training, where the model weights are replicated across all devices in the DeviceMesh
, and each device processes a portion of the input data.
Here is a sample usage of this class.
data_parallel = keras.distribution.DataParallel(devices=devices)
mesh_1d = keras.distribution.DeviceMesh(
shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
keras.distribution.set_distribution(data_parallel)
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
```
Epoch 1/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349
0.842325747013092
</div>
---
`ModelParallel` will be mostly useful when model weights are too large to fit
on a single accelerator. This setting allows you to spit your model weights or
activation tensors across all the devices on the `DeviceMesh`, and enable the
horizontal scaling for the large models.
Unlike the `DataParallel` model where all weights are fully replicated,
the weights layout under `ModelParallel` usually need some customization for
best performances. We introduce `LayoutMap` to let you specify the
`TensorLayout` for any weights and intermediate tensors from global perspective.
`LayoutMap` is a dict-like object that maps a string to `TensorLayout`
instances. It behaves differently from a normal Python dict in that the string
key is treated as a regex when retrieving the value. The class allows you to
define the naming schema of `TensorLayout` and then retrieve the corresponding
`TensorLayout` instance. Typically, the key used to query
is the `variable.path` attribute, which is the identifier of the variable.
As a shortcut, a tuple or list of axis
names is also allowed when inserting a value, and it will be converted to
`TensorLayout`.
The `LayoutMap` can also optionally contain a `DeviceMesh` to populate the
`TensorLayout.device_mesh` if it is not set. When retrieving a layout with a
key, and if there isn't an exact match, all existing keys in the layout map will
be treated as regex and matched against the input key again. If there are
multiple matches, a `ValueError` is raised. If no matches are found, `None` is
returned.
```python
mesh_2d = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)
layout_map["d2/output"] = ("data", None)
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")
keras.distribution.set_distribution(model_parallel)
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation. warnings.warn("Some donated buffers were not usable:"
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266 Epoch 2/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181 Epoch 3/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381
0.8502610325813293
</div>
It is also easy to change the mesh structure to tune the computation between
more data parallel or model parallel. You can do this by adjusting the shape of
the mesh. And no changes are needed for any other code.
full_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=["data", "model"], devices=devices
)
Further reading
JAX Distributed arrays and automatic parallelization
JAX sharding module
TensorFlow Distributed training with DTensors
TensorFlow DTensor concepts
Using DTensors with tf.keras