Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/tutorials/tpu_colab_tutorial.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Running JAX on Cloud TPU VMs from Colab

Authors

  • Gerardo Durán-Martín

  • Mahmoud Soliman

  • Kevin Murphy

Define some global variables

We create a commands.sh file that defines some macros. Edit the values in this file to match your credentials.

This file must be called in every cell below that begins with %%bash

--command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
%%writefile commands.sh gcloud="/root/google-cloud-sdk/bin/gcloud" gtpu="gcloud alpha compute tpus tpu-vm" jax_install="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" # edit lines below #instance_name="murphyk-v3-8" #tpu_zone="us-central1-a" #accelerator_type="v3-8" instance_name="murphyk-tpu" tpu_zone="us-east1-d" accelerator_type="v3-32"

Setup GCP

First we authenticate GCP to our current session

from google.colab import auth auth.authenticate_user()

Next, we install GCloud SDK

%%capture !curl -S https://sdk.cloud.google.com | bash

Now we install the gcloud command line interface This will allow us to work with TPUs at Google cloud. Run the following command

%%bash source /content/commands.sh $gcloud components install alpha
All components are up to date.

Next, we set the project to probml

%%bash source /content/commands.sh $gcloud config set project probml
Updated property [core/project].
  • Verify installation

Finally, we verify that you've successfully installed gcloud alpha by running the following command. Make sure to have version alpha 2021.06.25 or later.

%%bash source /content/commands.sh $gcloud -v
Google Cloud SDK 358.0.0 alpha 2021.09.17 bq 2.0.71 core 2021.09.17 gsutil 4.68

Setup TPUs

Creating an instance

Each GSoC member obtains 8 v3-32 cores (or a Slice) when following the instructions outlined below.

To create our first TPU instance, we run the following command. Note that instance_name should be unique (it was defined at the top of this tutorial)

%%bash source /content/commands.sh $gtpu create $instance_name \ --accelerator-type $accelerator_type \ --version v2-alpha \ --zone $tpu_zone
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.create) INVALID_ARGUMENT: Cloud TPU received a bad request. the accelerator v3-8 was not found in zone us-east1-d [EID: 0x72c898a0fe1c2eef]

You can verify whether your instance has been created by running the following cell

%%bash source /content/commands.sh $gcloud alpha compute tpus list --zone $tpu_zone
NAME ZONE ACCELERATOR_TYPE NETWORK RANGE STATUS API_VERSION murphyk-tpu us-east1-d v3-32 default 10.142.0.0/20 READY V2_ALPHA1 mjsml-tpu us-east1-d v3-32 default 10.142.0.0/20 READY V2_ALPHA1 mjsml-tpu2 us-east1-d v3-128 default 10.142.0.0/20 READY V2_ALPHA1

Deleting an instance

To avoid extra costs, it is important to delete the instance after use (training, testing experimenting, etc.).

To delete an instance, we create and run a cell with the following content

%%bash source /content/commands.sh $gtpu delete --quiet $instance_name --zone=$tpu_zone

Make sure to delete your instance once you finish!!

Setup JAX

When connecting to an instance directly via ssh, it is important to note that running any Jax command will wait for the other hosts to be active. To avoid this, we have to run the desired code simultaneously on all the hosts. Thus To run JAX code on a TPU Pod slice, you must run the code on each host in the TPU Pod slice.

In the next cell, we install Jax on each host of our slice.

%%bash source /content/commands.sh $gtpu ssh $instance_name \ --zone $tpu_zone \ --command "$jax_install" \ --worker all # or machine instance 1..3

JAX examples

Example 1: Hello, TPUs!

In this example, we create a hello_tpu.sh that asserts whether we can connect to all of the hosts. First, we create the .sh file that will be run in each of the workers.

%%writefile hello_tpu.sh #!/bin/bash # file: hello_tpu.sh export gist_url="https://gist.github.com/1e8d226e7a744d22d010ca4980456c3a.git" git clone $gist_url hello_gsoc python3 hello_gsoc/hello_tpu.py
Writing hello_tpu.sh

The content of $gist_url is the following

You do not need to store the following file. Our script hello_tpu.sh will download the file to each of the hosts and run it.

# Taken from https://cloud.google.com/tpu/docs/jax-pods # To be used by the Pyprobml GSoC 2021 team # The following code snippet will be run on all TPU hosts import jax # The total number of TPU cores in the pod device_count = jax.device_count() # The number of TPU cores attached to this host local_device_count = jax.local_device_count() # The psum is performed over all mapped devices across the pod xs = jax.numpy.ones(jax.local_device_count()) r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) # Print from a single host to avoid duplicated output if jax.process_index() == 0: print('global device count:', jax.device_count()) print('local device count:', jax.local_device_count()) print('pmap result:', r)%

Next, we run the code across all workers

%%bash source /content/commands.sh $gtpu ssh $instance_name \ --zone $tpu_zone \ --command "$(<./hello_tpu.sh)" \ --worker all
global device count: 32 local device count: 8 pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... Cloning into 'hello_gsoc'... Cloning into 'hello_gsoc'... Cloning into 'hello_gsoc'... Cloning into 'hello_gsoc'...

Example 2: 🚧K-nearest neighbours🚧

In this example we train the MNIST dataset using the KNN algorithm pmap. Our program clones a Github gist into each of the hosts. We use the multi-device availability of our slice to delegate a part of the training to each of the workers.

First, we create the script that will be run on each of the workers

%%writefile knn_tpu.sh #!/bin/bash # file: knn_tpu.sh export gist_url="https://gist.github.com/716a7bfd4c5c0c0e1949072f7b2e03a6.git" pip3 install -q tensorflow_datasets git clone $gist_url demo python3 demo/knn_tpu.py
Writing knn_tpu.sh

Next, we run the script

%%bash source /content/commands.sh $gtpu ssh $instance_name \ --zone $tpu_zone \ --command "$(<./knn_tpu.sh)" \ --worker all
(8, 10, 20) class_rate=0.9125
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv WARNING: You are using pip version 21.1.2; however, version 21.2.2 is available. You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command. WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv WARNING: You are using pip version 21.1.2; however, version 21.2.2 is available. You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command. WARNING: You are using pip version 21.1.2; however, version 21.2.2 is available. You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command. WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv WARNING: You are using pip version 21.1.2; however, version 21.2.2 is available. You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command. fatal: destination path 'demo' already exists and is not an empty directory. fatal: destination path 'demo' already exists and is not an empty directory. fatal: destination path 'demo' already exists and is not an empty directory. fatal: destination path 'demo' already exists and is not an empty directory. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`. WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.get_single_element()`.

🔪TPUs - The Sharp Bits🔪

Service accounts

Before creating a new TPU instance, make sure that the Admin of the project grants the correct IAM user/group roles for your service account

  • TPU Admin

  • Service Account User

This prevents you from running into the following error

error

Running Jax on a Pod

When creating an instance, we obtain different slices. Running a parallel operation on a single slice will not perform any computation until all of the slices have been run in sync. In Jax, this is done using jax.pmap function

pmaping a function

The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, [...])

Misc