Path: blob/master/deprecated/notebooks/mix_PPCA_celeba.ipynb
1192 views
Kernel: Python 3
Get the CelebA dataset
In [1]:
Get helpers
In [2]:
Get the Kaggle api token and upload it to colab. Follow the instructions here.
In [3]:
Out[3]:
Requirement already satisfied: kaggle in /usr/local/lib/python3.7/dist-packages (1.5.12)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from kaggle) (2.8.1)
Requirement already satisfied: urllib3 in /usr/local/lib/python3.7/dist-packages (from kaggle) (1.24.3)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from kaggle) (4.41.1)
Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.7/dist-packages (from kaggle) (1.15.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.7/dist-packages (from kaggle) (2021.5.30)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from kaggle) (2.23.0)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.7/dist-packages (from kaggle) (5.0.2)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.7/dist-packages (from python-slugify->kaggle) (1.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->kaggle) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->kaggle) (3.0.4)
In [4]:
Out[4]:
Saving kaggle.json to kaggle.json
In [5]:
In [6]:
In [7]:
Getting the checkpoint of the model from buckets.
In [8]:
In [9]:
In [10]:
In [11]:
Out[11]:
Copying gs://probml_data/mix_PPCA/model_c_300_l_10_init_rnd_samples.pth...
/ [1 files][168.8 MiB/168.8 MiB]
Operation completed over 1 objects/168.8 MiB.
Main
In [12]:
Out[12]:
Collecting pytorch-lightning
Downloading pytorch_lightning-1.4.1-py3-none-any.whl (915 kB)
|████████████████████████████████| 915 kB 7.6 MB/s
Collecting tensorboard!=2.5.0,>=2.2.0
Downloading tensorboard-2.4.1-py3-none-any.whl (10.6 MB)
|████████████████████████████████| 10.6 MB 65.0 MB/s
Collecting PyYAML>=5.1
Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
|████████████████████████████████| 636 kB 46.7 MB/s
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.7.4.3)
Collecting pyDeprecate==0.3.1
Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting future>=0.17.1
Downloading future-0.18.2.tar.gz (829 kB)
|████████████████████████████████| 829 kB 49.0 MB/s
Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.9.0+cu102)
Collecting torchmetrics>=0.4.0
Downloading torchmetrics-0.4.1-py3-none-any.whl (234 kB)
|████████████████████████████████| 234 kB 62.4 MB/s
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB)
|████████████████████████████████| 118 kB 76.6 MB/s
Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.0)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.19.5)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.41.1)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0)
Collecting aiohttp
Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
|████████████████████████████████| 1.3 MB 52.2 MB/s
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (2.4.7)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.0.1)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.34.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.32.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.8.0)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.36.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.4.4)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.17.3)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (57.2.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.3.4)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.15.0)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.12.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.7.2)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.2.2)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.3.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.6.1)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.5.30)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.24.3)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.1.1)
Collecting yarl<2.0,>=1.0
Downloading yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294 kB)
|████████████████████████████████| 294 kB 55.8 MB/s
Collecting multidict<7.0,>=4.5
Downloading multidict-5.1.0-cp37-cp37m-manylinux2014_x86_64.whl (142 kB)
|████████████████████████████████| 142 kB 75.8 MB/s
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.2.0)
Collecting async-timeout<4.0,>=3.0
Downloading async_timeout-3.0.1-py3-none-any.whl (8.2 kB)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->markdown>=2.6.8->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.5.0)
Building wheels for collected packages: future
Building wheel for future (setup.py) ... done
Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=77259c2b1469c74bda6592e2f1a7ecc3201f9071999b6f97c6f3017f01fdeeac
Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
Successfully built future
Installing collected packages: multidict, yarl, async-timeout, fsspec, aiohttp, torchmetrics, tensorboard, PyYAML, pyDeprecate, future, pytorch-lightning
Attempting uninstall: tensorboard
Found existing installation: tensorboard 2.5.0
Uninstalling tensorboard-2.5.0:
Successfully uninstalled tensorboard-2.5.0
Attempting uninstall: PyYAML
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
Attempting uninstall: future
Found existing installation: future 0.16.0
Uninstalling future-0.16.0:
Successfully uninstalled future-0.16.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.5.0 requires tensorboard~=2.5, but you have tensorboard 2.4.1 which is incompatible.
Successfully installed PyYAML-5.4.1 aiohttp-3.7.4.post0 async-timeout-3.0.1 fsspec-2021.7.0 future-0.18.2 multidict-5.1.0 pyDeprecate-0.3.1 pytorch-lightning-1.4.1 tensorboard-2.4.1 torchmetrics-0.4.1 yarl-1.6.3
In [13]:
Out[13]:
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.10.0+cu102)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.19.5)
Requirement already satisfied: pillow>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.1.2)
Requirement already satisfied: torch==1.9.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.9.0+cu102)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.9.0->torchvision) (3.7.4.3)
In [14]:
Inference
Preparing dataset
In [15]:
Out[15]:
Preparing dataset and parameters for celeba ...
Downloading dataset. Please while while the download and extraction processes complete
0%| | 5.00M/1.33G [00:00<00:40, 35.4MB/s]
Downloading celeba-dataset.zip to ./data
100%|██████████| 1.33G/1.33G [00:09<00:00, 149MB/s]
100%|██████████| 2.02M/2.02M [00:00<00:00, 42.0MB/s]
Downloading list_attr_celeba.csv.zip to ./data
100%|██████████| 1.54M/1.54M [00:00<00:00, 132MB/s]
Downloading list_bbox_celeba.csv.zip to ./data
100%|██████████| 466k/466k [00:00<00:00, 126MB/s]
Downloading list_eval_partition.csv.zip to ./data
100%|██████████| 2.07M/2.07M [00:00<00:00, 171MB/s]
Downloading list_landmarks_align_celeba.csv.zip to ./data
Done!
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
return torch.floor_divide(self, other)
In [16]:
Loading pre-trained MFA model
In [17]:
Out[17]:
Loading pre-trained MFA model...
<All keys matched successfully>
Samples
In [18]:
Out[18]:
Enter gird size: 3 4
In [19]:
Out[19]:
WARNING:root:Lossy conversion from float64 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Visualizing the trained model...
In [20]:
Out[20]:
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Generating random samples...
Showing outliers
In [21]:
Out[21]:
Enter gird size: 1 4
In [22]:
Out[22]:
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
0%| | 0/156 [00:00<?, ?it/s]
Finding dataset outliers...
100%|██████████| 156/156 [00:30<00:00, 5.14it/s]
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Reconstructing original masked images
In [23]:
Out[23]:
Enter the type of mask from following options: (a)centre (b)bottom (c)right (d)left (e)top: centre
Enter gird size: 3 5
In [24]:
Out[24]:
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Reconstructing images from the trained model...
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.