Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/GCP_CC_TPU_Pod_Slice_JAX.ipynb
1192 views
Kernel: Python 3

Open In Colab

# Hints from : # https://medium.com/analytics-vidhya/how-to-access-files-from-google-cloud-storage-in-colab-notebooks-8edaf9e6c020 # https://stackoverflow.com/questions/57772453/login-on-colab-with-gcloud-without-service-account

Authenticate GCP

from google.colab import auth auth.authenticate_user()

Install GCloud SDK into a new directory

!curl https://sdk.cloud.google.com | bash

#Run the following commands in colab's terminal

Install GCloud Alpha components

gcloud1="/root/google-cloud-sdk/bin/gcloud" $gcloud1 components install alpha

Set your GCP Project ID

project_id="YOUR_PROJECT_ID" $gcloud1 config set project $project_id

Create your TPU VM per the insturctions

$gcloud1 alpha compute tpus tpu-vm create *YOUR_TPU_VM_NAME* \ --zone us-east1-d \ --accelerator-type v3-32 \ --version v2-alpha

Install JAX on the pod slice

$gcloud1 alpha compute tpus tpu-vm ssh *YOUR_TPU_VM_NAME* \ --zone us-east1-d \ --worker=all \ --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Run the cell below to write example.py

%%file example.py # The following code snippet will be run on all TPU hosts import jax # The total number of TPU cores in the Pod device_count = jax.device_count() # The number of TPU cores attached to this host local_device_count = jax.local_device_count() # The psum is performed over all mapped devices across the Pod xs = jax.numpy.ones(jax.local_device_count()) r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) # Print from a single host to avoid duplicated output if jax.process_index() == 0: print('global device count:', jax.device_count()) print('local device count:', jax.local_device_count()) print('pmap result:', r)

Copy example.py to all the TPU pod slice hosts

$gcloud1 alpha compute tpus tpu-vm scp /content/example.py *YOUR_TPU_VM_NAME*: \ --worker=all --zone=*YOUR_ZONE*

Run example.py on the TPU-VM pod slice hosts

$gcloud1 alpha compute tpus tpu-vm ssh *YOUR_TPU_VM_NAME* \ --zone *YOUR_ZONE* --worker=all --command "python3 example.py"