Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/clip_make_dataset_tpu_jax.ipynb
1192 views
Kernel: Python 3

Open In Colab

Required Installations and Environment

import os assert os.environ["COLAB_TPU_ADDR"], "Make sure to select TPU from Edit > Notebook settings > Hardware accelerator"
import os if "google.colab" in str(get_ipython()) and "COLAB_TPU_ADDR" in os.environ: import jax import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() print("Connected to TPU.") else: print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
Connected to TPU.
import jax print("jax version {}".format(jax.__version__)) print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform)) print(jax.lib.xla_bridge.device_count()) print(jax.local_device_count()) import jax.numpy as jnp devices = jax.local_devices() print(f"jax devices:") devices
jax version 0.2.13 jax backend tpu 8 8 jax devices:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Cloning Clip_jax

and loading the jax version of clip_model.

%cd /content/
/content
!git clone https://github.com/kingoflolz/CLIP_JAX.git
Cloning into 'CLIP_JAX'... remote: Enumerating objects: 94, done. remote: Counting objects: 100% (3/3), done. remote: Compressing objects: 100% (2/2), done. remote: Total 94 (delta 0), reused 2 (delta 0), pack-reused 91 Unpacking objects: 100% (94/94), done.
cd / content / CLIP_JAX
/content/CLIP_JAX
pip install ftfy regex tqdm dm-haiku
Collecting ftfy Downloading https://files.pythonhosted.org/packages/af/da/d215a091986e5f01b80f5145cff6f22e2dc57c6b048aab2e882a07018473/ftfy-6.0.3.tar.gz (64kB) |████████████████████████████████| 71kB 4.1MB/s Requirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1) Collecting dm-haiku Downloading https://files.pythonhosted.org/packages/72/08/639371b979cb2c0bf2f67c832a7c3a358ca3d717a8c38563a2e2435c41c9/dm_haiku-0.0.4-py3-none-any.whl (284kB) |████████████████████████████████| 286kB 7.6MB/s Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy) (0.2.5) Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.12.0) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.19.5) Requirement already satisfied: typing-extensions; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.7.4.3) Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0) Building wheels for collected packages: ftfy Building wheel for ftfy (setup.py) ... done Created wheel for ftfy: filename=ftfy-6.0.3-cp37-none-any.whl size=41935 sha256=29488162a550e5759efea0c435630be251692c5988dddeba2bbfa847e8b6bdf5 Stored in directory: /root/.cache/pip/wheels/99/2c/e6/109c8a28fef7a443f67ba58df21fe1d0067ac3322e75e6b0b7 Successfully built ftfy Installing collected packages: ftfy, dm-haiku Successfully installed dm-haiku-0.0.4 ftfy-6.0.3
import numpy as np from PIL import Image import time import clip_jax image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load("ViT-B/32", "cpu", jit=True)
100%|███████████████████████████████████████| 354M/354M [00:10<00:00, 32.7MiB/s]
jax devices:
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py:281: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. "Argument interpolation should be of type InterpolationMode instead of int. "
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

pmapping the encoding function and replicating the params.

jax_params_repl = jax.device_put_replicated(jax_params, devices) image_fn_pmapped = jax.pmap(image_fn)

Dataset

Download the dataset used here so that it just loads the downloaded dataset when used later.

Change ds_name to the dataset required.

ds_name = "imagenette/160px-v2"
data_dir = "/root/tensorflow_datasets"
# @title Choose whether if you want to make a copy of the dataset in the drive # @markdown Drive can be mounted to download the tfds into the drive for future uses, # @markdown downloaded ds can be found in `your_drive_path/MyDrive/$ds_name` to_load_into_drive = False # @param ["False", "True"] {type:"raw"} if to_load_into_drive: from google.colab import drive drive.mount("/content/drive") !mkdir /content/drive/MyDrive/$ds_name # your_drive_path data_dir = f"/content/drive/MyDrive/{ds_name}"

Loading tfds

import tensorflow as tf import tensorflow_datasets as tfds try: tfds.load(ds_name, data_dir=data_dir) except: tfds.load(ds_name, data_dir=data_dir)
Downloading and preparing dataset imagenette/160px-v2/0.1.0 (download: 94.36 MiB, generated: 102.10 MiB, total: 196.46 MiB) to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
Downloading and preparing dataset imagenette/160px-v2/0.1.0 (download: 94.36 MiB, generated: 102.10 MiB, total: 196.46 MiB) to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0.incomplete0PLOWR/imagenette-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=9469.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0.incomplete0PLOWR/imagenette-validation.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3925.0), HTML(value='')))
Dataset imagenette downloaded and prepared to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0. Subsequent calls will reuse this data.

Model

len(devices)
8

Datamodule which makes the numpy dataloaders for the dataset that return batches such that their leading dimension is len(devices)

class Tpu_data_loader: def __init__(self, loader, split, batch_per_core, no_of_cores): self.loader = loader self.split = split self.batch_size = batch_per_core * no_of_cores class NumpyDataModule: def __init__(self, ds_name: str, data_dir: str): self.ds_name = ds_name self.data_dir = data_dir self.image_size = 224 self.mean = [0.48145466, 0.4578275, 0.40821073] self.std = [0.48145466, 0.4578275, 0.40821073] self.ds = None def preprocess(self, sample): image = sample["image"] """ `uint8` -> `float32`.""" image = tf.cast(image, tf.float32) image = tf.image.resize_with_crop_or_pad(image, self.image_size, self.image_size) image = (image - self.mean) / (self.std) image = tf.transpose(image, perm=[2, 0, 1]) return image def make_dataset(self, split, batch_per_core, no_of_cores): ds = self.ds[split] ds = ds.map(self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_per_core).batch(no_of_cores) return Tpu_data_loader( tfds.as_numpy(ds.prefetch(tf.data.experimental.AUTOTUNE)), split, batch_per_core, no_of_cores ) def prepare_data(self): self.ds, ds_info = tfds.load( self.ds_name, with_info=True, data_dir=self.data_dir, ) return ds_info
dm = NumpyDataModule(ds_name=ds_name, data_dir=data_dir) ds_info = dm.prepare_data()

batch_per_core should be such that (n_examples//batch_per_core) % no_of_cores == 0

train_loader = dm.make_dataset("train", batch_per_core=62, no_of_cores=len(devices)) test_loader = dm.make_dataset("validation", batch_per_core=61, no_of_cores=len(devices))
print(ds_info.splits[train_loader.split].num_examples) print(ds_info.splits[test_loader.split].num_examples)
9469 3925
import tqdm def clip_extract(tpu_loader): clip_features = [] steps = (ds_info.splits[tpu_loader.split].num_examples // tpu_loader.batch_size) + 1 for i, batch in zip(tqdm.trange(steps), tpu_loader.loader): # the last batch is not parallised. if i == steps - 1: clip_encoded_batch = image_fn(jax_params, np.squeeze(batch, axis=0)) else: clip_encoded_batch = image_fn_pmapped(jax_params_repl, batch) clip_encoded_batch = jax.device_get(clip_encoded_batch) clip_features.append(clip_encoded_batch) clip_flattened_features = [fea.reshape(-1, 512) for fea in clip_features] coco_clip = np.concatenate(clip_flattened_features) return coco_clip
clip_train = clip_extract(train_loader)
0%| | 0/20 [00:00<?, ?it/s] 5%|▌ | 1/20 [00:07<02:28, 7.82s/it] 10%|█ | 2/20 [00:10<01:53, 6.33s/it] 15%|█▌ | 3/20 [00:13<01:29, 5.29s/it] 20%|██ | 4/20 [00:16<01:12, 4.54s/it] 25%|██▌ | 5/20 [00:19<01:00, 4.02s/it] 30%|███ | 6/20 [00:21<00:51, 3.66s/it] 35%|███▌ | 7/20 [00:24<00:44, 3.41s/it] 40%|████ | 8/20 [00:27<00:38, 3.23s/it] 45%|████▌ | 9/20 [00:30<00:34, 3.10s/it] 50%|█████ | 10/20 [00:33<00:30, 3.02s/it] 55%|█████▌ | 11/20 [00:36<00:26, 2.95s/it] 60%|██████ | 12/20 [00:38<00:23, 2.92s/it] 65%|██████▌ | 13/20 [00:41<00:20, 2.89s/it] 70%|███████ | 14/20 [00:44<00:17, 2.87s/it] 75%|███████▌ | 15/20 [00:47<00:14, 2.86s/it] 80%|████████ | 16/20 [00:50<00:11, 2.85s/it] 85%|████████▌ | 17/20 [00:52<00:08, 2.84s/it] 90%|█████████ | 18/20 [00:55<00:05, 2.83s/it] 95%|█████████▌| 19/20 [00:58<00:02, 2.81s/it] 100%|██████████| 20/20 [01:05<00:00, 3.29s/it]
clip_eval = clip_extract(test_loader)
0%| | 0/9 [00:00<?, ?it/s] 11%|█ | 1/9 [00:07<01:02, 7.85s/it] 22%|██▏ | 2/9 [00:10<00:44, 6.32s/it] 33%|███▎ | 3/9 [00:13<00:31, 5.25s/it] 44%|████▍ | 4/9 [00:16<00:22, 4.51s/it] 56%|█████▌ | 5/9 [00:18<00:15, 3.98s/it] 67%|██████▋ | 6/9 [00:21<00:10, 3.62s/it] 78%|███████▊ | 7/9 [00:24<00:06, 3.36s/it] 89%|████████▉ | 8/9 [00:27<00:03, 3.17s/it] 100%|██████████| 9/9 [00:35<00:00, 3.91s/it]
def make_tfds_and_save(numpy_data, name): tf_ds = tf.data.Dataset.from_tensor_slices(numpy_data) tf.data.experimental.save(tf_ds, f"/content/{name}") return tf_ds
clip_train_ds = make_tfds_and_save(clip_train, "clip_train_ds") clip_test_ds = make_tfds_and_save(clip_eval, "clip_test_ds")