Path: blob/master/guides/ipynb/distributed_training_with_jax.ipynb
3283 views
Multi-GPU distributed training with JAX
Author: fchollet
Date created: 2023/07/11
Last modified: 2023/07/11
Description: Guide to multi-GPU/TPU training for Keras models with JAX.
Introduction
There are generally two ways to distribute computation across multiple devices:
Data parallelism, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results. There exist many variants of this setup, that differ in how the different model replicas merge results, in whether they stay in sync at every batch or whether they are more loosely coupled, etc.
Model parallelism, where different parts of a single model run on different devices, processing a single batch of data together. This works best with models that have a naturally-parallel architecture, such as models that feature multiple branches.
This guide focuses on data parallelism, in particular synchronous data parallelism, where the different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what you would see for single-device training.
Specifically, this guide teaches you how to use jax.sharding
APIs to train Keras models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16) installed on a single machine (single host, multi-device training). This is the most common setup for researchers and small-scale industry workflows.
Setup
Let's start by defining the function that creates the model that we will train, and the function that creates the dataset we will train on (MNIST in this case).
Single-host, multi-device synchronous training
In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16). Each device will run a copy of your model (called a replica). For simplicity, in what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.
How it works
At each step of training:
The current batch of data (called global batch) is split into 8 different sub-batches (called local batches). For instance, if the global batch has 512 samples, each of the 8 local batches will have 64 samples.
Each of the 8 replicas independently processes a local batch: they run a forward pass, then a backward pass, outputting the gradient of the weights with respect to the loss of the model on the local batch.
The weight updates originating from local gradients are efficiently merged across the 8 replicas. Because this is done at the end of every step, the replicas always stay in sync.
In practice, the process of synchronously updating the weights of the model replicas is handled at the level of each individual weight variable. This is done through a using a jax.sharding.NamedSharding
that is configured to replicate the variables.
How to use it
To do single-host, multi-device synchronous training with a Keras model, you would use the jax.sharding
features. Here's how it works:
We first create a device mesh using
mesh_utils.create_device_mesh
.We use
jax.sharding.Mesh
,jax.sharding.NamedSharding
andjax.sharding.PartitionSpec
to define how to partition JAX arrays.We specify that we want to replicate the model and optimizer variables across all devices by using a spec with no axis.
We specify that we want to shard the data across devices by using a spec that splits along the batch dimension.
We use
jax.device_put
to replicate the model and optimizer variables across devices. This happens once at the beginning.In the training loop, for each batch that we process, we use
jax.device_put
to split the batch across devices before invoking the train step.
Here's the flow, where each step is split into its own utility function:
That's it!