Path: blob/master/notebooks/misc/clip_make_dataset_tpu_jax.ipynb
1192 views
Kernel: Python 3
Required Installations and Environment
In [ ]:
In [ ]:
Connected to TPU.
In [ ]:
jax version 0.2.13
jax backend tpu
8
8
jax devices:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Cloning Clip_jax
and loading the jax version of clip_model.
In [ ]:
/content
In [ ]:
Cloning into 'CLIP_JAX'...
remote: Enumerating objects: 94, done.
remote: Counting objects: 100% (3/3), done.
remote: Compressing objects: 100% (2/2), done.
remote: Total 94 (delta 0), reused 2 (delta 0), pack-reused 91
Unpacking objects: 100% (94/94), done.
In [ ]:
/content/CLIP_JAX
In [ ]:
Collecting ftfy
Downloading https://files.pythonhosted.org/packages/af/da/d215a091986e5f01b80f5145cff6f22e2dc57c6b048aab2e882a07018473/ftfy-6.0.3.tar.gz (64kB)
|████████████████████████████████| 71kB 4.1MB/s
Requirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)
Collecting dm-haiku
Downloading https://files.pythonhosted.org/packages/72/08/639371b979cb2c0bf2f67c832a7c3a358ca3d717a8c38563a2e2435c41c9/dm_haiku-0.0.4-py3-none-any.whl (284kB)
|████████████████████████████████| 286kB 7.6MB/s
Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy) (0.2.5)
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.12.0)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.19.5)
Requirement already satisfied: typing-extensions; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.7.4.3)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0)
Building wheels for collected packages: ftfy
Building wheel for ftfy (setup.py) ... done
Created wheel for ftfy: filename=ftfy-6.0.3-cp37-none-any.whl size=41935 sha256=29488162a550e5759efea0c435630be251692c5988dddeba2bbfa847e8b6bdf5
Stored in directory: /root/.cache/pip/wheels/99/2c/e6/109c8a28fef7a443f67ba58df21fe1d0067ac3322e75e6b0b7
Successfully built ftfy
Installing collected packages: ftfy, dm-haiku
Successfully installed dm-haiku-0.0.4 ftfy-6.0.3
In [ ]:
100%|███████████████████████████████████████| 354M/354M [00:10<00:00, 32.7MiB/s]
jax devices:
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py:281: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
"Argument interpolation should be of type InterpolationMode instead of int. "
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
pmapping the encoding function and replicating the params.
In [ ]:
Dataset
Download the dataset used here so that it just loads the downloaded dataset when used later.
Change ds_name
to the dataset required.
In [ ]:
In [ ]:
In [ ]:
Loading tfds
In [ ]:
Downloading and preparing dataset imagenette/160px-v2/0.1.0 (download: 94.36 MiB, generated: 102.10 MiB, total: 196.46 MiB) to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
Downloading and preparing dataset imagenette/160px-v2/0.1.0 (download: 94.36 MiB, generated: 102.10 MiB, total: 196.46 MiB) to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0.incomplete0PLOWR/imagenette-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=9469.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0.incomplete0PLOWR/imagenette-validation.tfrecord
HBox(children=(FloatProgress(value=0.0, max=3925.0), HTML(value='')))
Dataset imagenette downloaded and prepared to /root/tensorflow_datasets/imagenette/160px-v2/0.1.0. Subsequent calls will reuse this data.
Model
In [ ]:
8
Datamodule which makes the numpy dataloaders for the dataset that return batches such that their leading dimension is len(devices)
In [ ]:
In [ ]:
batch_per_core
should be such that (n_examples//batch_per_core) % no_of_cores == 0
In [ ]:
In [ ]:
9469
3925
In [ ]:
In [ ]:
0%| | 0/20 [00:00<?, ?it/s]
5%|▌ | 1/20 [00:07<02:28, 7.82s/it]
10%|█ | 2/20 [00:10<01:53, 6.33s/it]
15%|█▌ | 3/20 [00:13<01:29, 5.29s/it]
20%|██ | 4/20 [00:16<01:12, 4.54s/it]
25%|██▌ | 5/20 [00:19<01:00, 4.02s/it]
30%|███ | 6/20 [00:21<00:51, 3.66s/it]
35%|███▌ | 7/20 [00:24<00:44, 3.41s/it]
40%|████ | 8/20 [00:27<00:38, 3.23s/it]
45%|████▌ | 9/20 [00:30<00:34, 3.10s/it]
50%|█████ | 10/20 [00:33<00:30, 3.02s/it]
55%|█████▌ | 11/20 [00:36<00:26, 2.95s/it]
60%|██████ | 12/20 [00:38<00:23, 2.92s/it]
65%|██████▌ | 13/20 [00:41<00:20, 2.89s/it]
70%|███████ | 14/20 [00:44<00:17, 2.87s/it]
75%|███████▌ | 15/20 [00:47<00:14, 2.86s/it]
80%|████████ | 16/20 [00:50<00:11, 2.85s/it]
85%|████████▌ | 17/20 [00:52<00:08, 2.84s/it]
90%|█████████ | 18/20 [00:55<00:05, 2.83s/it]
95%|█████████▌| 19/20 [00:58<00:02, 2.81s/it]
100%|██████████| 20/20 [01:05<00:00, 3.29s/it]
In [ ]:
0%| | 0/9 [00:00<?, ?it/s]
11%|█ | 1/9 [00:07<01:02, 7.85s/it]
22%|██▏ | 2/9 [00:10<00:44, 6.32s/it]
33%|███▎ | 3/9 [00:13<00:31, 5.25s/it]
44%|████▍ | 4/9 [00:16<00:22, 4.51s/it]
56%|█████▌ | 5/9 [00:18<00:15, 3.98s/it]
67%|██████▋ | 6/9 [00:21<00:10, 3.62s/it]
78%|███████▊ | 7/9 [00:24<00:06, 3.36s/it]
89%|████████▉ | 8/9 [00:27<00:03, 3.17s/it]
100%|██████████| 9/9 [00:35<00:00, 3.91s/it]
In [ ]:
In [ ]: