Path: blob/master/notebooks/book1/13/multi_gpu_training_torch.ipynb
1192 views
Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/multi_gpu_training_jax.ipynb
#Train a CNN on multiple GPUs using data parallelism.
Based on sec 12.5 of http://d2l.ai/chapter_computational-performance/multiple-gpus.html.
Note: in colab, we only have access to 1 GPU, so the code below just simulates the effects of multiple GPUs, so it will not run faster. You may not see a speedup eveen on a machine which really does have multiple GPUs, because the model and data are too small. But the example should still illustrate the key ideas.
Model
We use a slightly modified version of the LeNet CNN.
Copying parameters across devices
All-reduce will copy data (eg gradients) from all devices to device 0, add them, and then broadcast the result back to each device.
Distribute data across GPUs
Split data and labels.