Path: blob/master/notebooks/book1/13/multi_gpu_training_jax.ipynb
1192 views
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/multi_gpu_training_torch.ipynb
#Train a CNN on multiple GPUs using data parallelism.
Based on sec 12.5 of http://d2l.ai/chapter_computational-performance/multiple-gpus.html.
Since writing JAX requires a completely different mindset from that of PyTorch, translating the notebook work-by-word would inevitably lead to JAX code with a PyTorch "accent". To avoid that, I created an idiomatic JAX/Flax implementation of multi-device training from scratch. It borrows some code from the official Parallel Evaluation in JAX notebook (which trains a linear regression model), and follows roughly the same narration as the original D2L notebook.
Preparation
Not everyone can enjoy the luxury of a TPU runtime on CoLab, and the GPU runtime only provides one GPU instance, so we are going to stick with a humble CPU runtime to make the notebook more accessible. When JAX is running on CPU, we can emulate an arbitrary number of devices with an XLA flag. Of course, this is for illustration purposes only, and won't make your code run any faster.
Now we have 8 devices
Model
We use the CNN from the Flax README as a demonstration. Training with multiple device does not require any change to the model, so you can replace it with your favourite network, as long as it outputs a 10-width logit for each input instance.
Data loading
We load the Fashion MNIST dataset with TFDS.
Note that TFDS uses the NHWC format, which mean "channel" dimension is the last axis. This is different from the NCHW format used by other data loaders, e.g. torchvision.datasets.FashionMNIST
.
We split the training set into mini-batches of 256 instances, so that each of the 8 devices will handle 32 instances within each batch. Note that the "total batch size" must be divisible by the number of devices. However, it is okay if the number of training instances cannot be evenly divided by the total batch size (256). While one can handle the leftover with a bit more work, we skip the incomplete batches for simplicity.
We are going to use jax.pmap
, which requires us to manually distribute data across the devices. More specifically, we need to transform both the images and their labels to the shape [batch_per_device, num_devices, batch_size, ...]
. That way, the leading dimension of each "local batch" will be equal to num_devices
.
To clarify, we have 8 devices, each of which will handle 234 local batches of 32 instances.
Parallel training
Parallel training with JAX is as simple as decorating the update
function with jax.pmap
and synchronising gradients across devices with jax.lax.pmean
. See the comments for more details.
We are summing the gradient within a batch, and averaging it across devices. You need to keep this fact in mind when choosing the learning rate.
Learning curve
The traing loop mostly stays the same. Note that we need to replicate the parameters across all devices when creating the train state, and remove that extra leading dimension when the training is done.
We ignore the first item in losses
, which represents the initial loss of the randomly-initalized parameter.
Evaluation
We can evaluate the model on the test set as usual. The performance is not bad given that we only trained for 5 epoches.