Path: blob/master/notebooks/tutorials/tpu_colab_tutorial.ipynb
1192 views
Running JAX on Cloud TPU VMs from Colab
Authors
Gerardo Durán-Martín
Mahmoud Soliman
Kevin Murphy
Define some global variables
We create a commands.sh
file that defines some macros. Edit the values in this file to match your credentials.
This file must be called in every cell below that begins with %%bash
Setup GCP
First we authenticate GCP to our current session
Next, we install GCloud SDK
Now we install the gcloud command line interface This will allow us to work with TPUs at Google cloud. Run the following command
Next, we set the project to probml
Verify installation
Finally, we verify that you've successfully installed gcloud alpha
by running the following command. Make sure to have version alpha 2021.06.25
or later.
Setup TPUs
Creating an instance
Each GSoC member obtains 8 v3-32 cores (or a Slice) when following the instructions outlined below.
To create our first TPU instance, we run the following command. Note that instance_name
should be unique (it was defined at the top of this tutorial)
You can verify whether your instance has been created by running the following cell
Deleting an instance
To avoid extra costs, it is important to delete the instance after use (training, testing experimenting, etc.).
To delete an instance, we create and run a cell with the following content
Make sure to delete your instance once you finish!!
Setup JAX
When connecting to an instance directly via ssh, it is important to note that running any Jax command will wait for the other hosts to be active. To avoid this, we have to run the desired code simultaneously on all the hosts. Thus To run JAX code on a TPU Pod slice, you must run the code on each host in the TPU Pod slice.
In the next cell, we install Jax on each host of our slice.
JAX examples
Example 1: Hello, TPUs!
In this example, we create a hello_tpu.sh
that asserts whether we can connect to all of the hosts. First, we create the .sh
file that will be run in each of the workers.
The content of $gist_url
is the following
You do not need to store the following file. Our script hello_tpu.sh
will download the file to each of the hosts and run it.
Next, we run the code across all workers
Example 2: 🚧K-nearest neighbours🚧
In this example we train the MNIST dataset using the KNN algorithm pmap
. Our program clones a Github gist into each of the hosts. We use the multi-device availability of our slice to delegate a part of the training to each of the workers.
First, we create the script that will be run on each of the workers
Next, we run the script
🔪TPUs - The Sharp Bits🔪
Service accounts
Before creating a new TPU instance, make sure that the Admin of the project grants the correct IAM user/group roles for your service account
TPU Admin
Service Account User
This prevents you from running into the following error
Running Jax on a Pod
When creating an instance, we obtain different slices. Running a parallel operation on a single slice will not perform any computation until all of the slices have been run in sync. In Jax, this is done using jax.pmap
function
pmap
ing a function
The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, [...])