Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/mix_PPCA_celeba.ipynb
1192 views
Kernel: Python 3

Get the CelebA dataset

!wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py

Get helpers

!wget -q https://raw.githubusercontent.com/probml/pyprobml/master/scripts/mfa_celeba_helpers.py

Get the Kaggle api token and upload it to colab. Follow the instructions here.

!pip install kaggle
Requirement already satisfied: kaggle in /usr/local/lib/python3.7/dist-packages (1.5.12) Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from kaggle) (2.8.1) Requirement already satisfied: urllib3 in /usr/local/lib/python3.7/dist-packages (from kaggle) (1.24.3) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from kaggle) (4.41.1) Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.7/dist-packages (from kaggle) (1.15.0) Requirement already satisfied: certifi in /usr/local/lib/python3.7/dist-packages (from kaggle) (2021.5.30) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from kaggle) (2.23.0) Requirement already satisfied: python-slugify in /usr/local/lib/python3.7/dist-packages (from kaggle) (5.0.2) Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.7/dist-packages (from python-slugify->kaggle) (1.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->kaggle) (2.10) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->kaggle) (3.0.4)
from google.colab import files uploaded = files.upload()
Saving kaggle.json to kaggle.json
!mkdir /root/.kaggle
!cp kaggle.json /root/.kaggle/kaggle.json
!chmod 600 /root/.kaggle/kaggle.json

Getting the checkpoint of the model from buckets.

from google.colab import auth auth.authenticate_user()
bucket_name = "probml_data"
!mkdir /content/models
!gsutil cp -r gs://{bucket_name}/mix_PPCA /content/models/
Copying gs://probml_data/mix_PPCA/model_c_300_l_10_init_rnd_samples.pth... / [1 files][168.8 MiB/168.8 MiB] Operation completed over 1 objects/168.8 MiB.

Main

!pip install pytorch-lightning
Collecting pytorch-lightning Downloading pytorch_lightning-1.4.1-py3-none-any.whl (915 kB) |████████████████████████████████| 915 kB 7.6 MB/s Collecting tensorboard!=2.5.0,>=2.2.0 Downloading tensorboard-2.4.1-py3-none-any.whl (10.6 MB) |████████████████████████████████| 10.6 MB 65.0 MB/s Collecting PyYAML>=5.1 Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB) |████████████████████████████████| 636 kB 46.7 MB/s Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.7.4.3) Collecting pyDeprecate==0.3.1 Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB) Collecting future>=0.17.1 Downloading future-0.18.2.tar.gz (829 kB) |████████████████████████████████| 829 kB 49.0 MB/s Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.9.0+cu102) Collecting torchmetrics>=0.4.0 Downloading torchmetrics-0.4.1-py3-none-any.whl (234 kB) |████████████████████████████████| 234 kB 62.4 MB/s Collecting fsspec[http]!=2021.06.0,>=2021.05.0 Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB) |████████████████████████████████| 118 kB 76.6 MB/s Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.0) Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.19.5) Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.41.1) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0) Collecting aiohttp Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB) |████████████████████████████████| 1.3 MB 52.2 MB/s Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (2.4.7) Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.0.1) Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.34.1) Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.32.1) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.8.0) Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.36.2) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.4.4) Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.17.3) Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (57.2.0) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.3.4) Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.15.0) Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.12.0) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.7.2) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.2.2) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (1.3.0) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (4.6.1) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (0.4.8) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.5.30) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.24.3) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.1.1) Collecting yarl<2.0,>=1.0 Downloading yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294 kB) |████████████████████████████████| 294 kB 55.8 MB/s Collecting multidict<7.0,>=4.5 Downloading multidict-5.1.0-cp37-cp37m-manylinux2014_x86_64.whl (142 kB) |████████████████████████████████| 142 kB 75.8 MB/s Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.2.0) Collecting async-timeout<4.0,>=3.0 Downloading async_timeout-3.0.1-py3-none-any.whl (8.2 kB) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->markdown>=2.6.8->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning) (3.5.0) Building wheels for collected packages: future Building wheel for future (setup.py) ... done Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=77259c2b1469c74bda6592e2f1a7ecc3201f9071999b6f97c6f3017f01fdeeac Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0 Successfully built future Installing collected packages: multidict, yarl, async-timeout, fsspec, aiohttp, torchmetrics, tensorboard, PyYAML, pyDeprecate, future, pytorch-lightning Attempting uninstall: tensorboard Found existing installation: tensorboard 2.5.0 Uninstalling tensorboard-2.5.0: Successfully uninstalled tensorboard-2.5.0 Attempting uninstall: PyYAML Found existing installation: PyYAML 3.13 Uninstalling PyYAML-3.13: Successfully uninstalled PyYAML-3.13 Attempting uninstall: future Found existing installation: future 0.16.0 Uninstalling future-0.16.0: Successfully uninstalled future-0.16.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tensorflow 2.5.0 requires tensorboard~=2.5, but you have tensorboard 2.4.1 which is incompatible. Successfully installed PyYAML-5.4.1 aiohttp-3.7.4.post0 async-timeout-3.0.1 fsspec-2021.7.0 future-0.18.2 multidict-5.1.0 pyDeprecate-0.3.1 pytorch-lightning-1.4.1 tensorboard-2.4.1 torchmetrics-0.4.1 yarl-1.6.3
!pip install torchvision
Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.10.0+cu102) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.19.5) Requirement already satisfied: pillow>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.1.2) Requirement already satisfied: torch==1.9.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.9.0+cu102) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.9.0->torchvision) (3.7.4.3)
import sys, os import torch from torchvision.datasets import CelebA, MNIST import torchvision.transforms as transforms from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch.utils.data import DataLoader, random_split, SequentialSampler, RandomSampler import numpy as np from matplotlib import pyplot as plt from imageio import imwrite from packaging import version from tqdm import tqdm from data import CelebADataset, CelebADataModule from mfa_celeba_helpers import * from IPython.display import Image def main(argv): assert version.parse(torch.__version__) >= version.parse("1.2.0") dataset = argv[1] if len(argv) == 2 else "celeba" print("Preparing dataset and parameters for", dataset, "...") if dataset == "celeba": image_shape = [64, 64, 3] # The input image shape n_components = 300 # Number of components in the mixture model n_factors = 10 # Number of factors - the latent dimension (same for all components) batch_size = 1000 # The EM batch size num_iterations = 30 # Number of EM iterations (=epochs) feature_sampling = 0.2 # For faster responsibilities calculation, randomly sample the coordinates (or False) mfa_sgd_epochs = 0 # Perform additional training with diagonal (per-pixel) covariance, using SGD init_method = "rnd_samples" # Initialize each component from few random samples using PPCA trans = transforms.Compose( [ CropTransform((25, 50, 25 + 128, 50 + 128)), transforms.Resize(image_shape[0]), transforms.ToTensor(), ReshapeTransform([-1]), ] ) train_set = CelebADataset(root="./data", split="train", transform=trans, download=True) test_set = CelebADataset(root="./data", split="test", transform=trans, download=True) elif dataset == "mnist": image_shape = [28, 28] # The input image shape n_components = 50 # Number of components in the mixture model n_factors = 6 # Number of factors - the latent dimension (same for all components) batch_size = 1000 # The EM batch size num_iterations = 1 # Number of EM iterations (=epochs) feature_sampling = False # For faster responsibilities calculation, randomly sample the coordinates (or False) mfa_sgd_epochs = 0 # Perform additional training with diagonal (per-pixel) covariance, using SGD init_method = "kmeans" # Initialize by using k-means clustering trans = transforms.Compose([transforms.ToTensor(), ReshapeTransform([-1])]) train_set = MNIST(root="./data", train=True, transform=trans, download=True) test_set = MNIST(root="./data", train=False, transform=trans, download=True) else: assert False, "Unknown dataset: " + dataset

Inference

Preparing dataset

""" Examples for inference using the trained MFA model - likelihood evaluation and (conditional) reconstruction """ if __name__ == "__main__": dataset = "celeba" find_outliers = True reconstruction = True inpainting = True print("Preparing dataset and parameters for", dataset, "...") if dataset == "celeba": image_shape = [64, 64, 3] # The input image shape n_components = 300 # Number of components in the mixture model n_factors = 10 # Number of factors - the latent dimension (same for all components) batch_size = 128 # The EM batch size num_iterations = 30 # Number of EM iterations (=epochs) feature_sampling = 0.2 # For faster responsibilities calculation, randomly sample the coordinates (or False) mfa_sgd_epochs = 0 # Perform additional training with diagonal (per-pixel) covariance, using SGD trans = transforms.Compose( [ CropTransform((25, 50, 25 + 128, 50 + 128)), transforms.Resize(image_shape[0]), transforms.ToTensor(), ReshapeTransform([-1]), ] ) test_dataset = CelebADataset(root="./data", split="test", transform=trans, download=True) # The train set has more interesting outliers... # test_dataset = CelebA(root='./data', split='train', transform=trans, download=True) else: assert False, "Unknown dataset: " + dataset
Preparing dataset and parameters for celeba ... Downloading dataset. Please while while the download and extraction processes complete
0%| | 5.00M/1.33G [00:00<00:40, 35.4MB/s]
Downloading celeba-dataset.zip to ./data
100%|██████████| 1.33G/1.33G [00:09<00:00, 149MB/s]
100%|██████████| 2.02M/2.02M [00:00<00:00, 42.0MB/s]
Downloading list_attr_celeba.csv.zip to ./data
100%|██████████| 1.54M/1.54M [00:00<00:00, 132MB/s]
Downloading list_bbox_celeba.csv.zip to ./data
100%|██████████| 466k/466k [00:00<00:00, 126MB/s]
Downloading list_eval_partition.csv.zip to ./data
100%|██████████| 2.07M/2.07M [00:00<00:00, 171MB/s]
Downloading list_landmarks_align_celeba.csv.zip to ./data Done!
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.) return torch.floor_divide(self, other)
def samples_to_mosaic_any_size(gird_size, samples, image_shape=[64, 64, 3]): images = samples_to_np_images(samples, image_shape) num_images = images.shape[0] num_cols = gird_size[1] num_rows = gird_size[0] rows = [] for i in range(num_rows): rows.append(np.hstack([images[j] for j in range(i * num_cols, (i + 1) * num_cols)])) return np.vstack(rows)

Loading pre-trained MFA model

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model_dir = "./models/" + "mix_PPCA" figures_dir = "./figures/" + dataset os.makedirs(figures_dir, exist_ok=True) print("Loading pre-trained MFA model...") model = MFA(n_components=n_components, n_features=np.prod(image_shape), n_factors=n_factors).to(device=device) model.load_state_dict(torch.load(os.path.join(model_dir, "model_c_300_l_10_init_rnd_samples.pth")))
Loading pre-trained MFA model...
<All keys matched successfully>

Samples

gird_size = [int(x) for x in input("Enter gird size: ").split()]
Enter gird size: 3 4
print("Visualizing the trained model...") model_image = visualize_model(model, image_shape=image_shape, end_component=10) fname = os.path.join(figures_dir, "model.jpg") imwrite(fname, model_image) display(Image(fname))
WARNING:root:Lossy conversion from float64 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Visualizing the trained model...
Image in a Jupyter notebook
print("Generating random samples...") rnd_samples, _ = model.sample(gird_size[0] * gird_size[1], with_noise=False) # 100->n #gird_size[0]*gird_size[1] mosaic = samples_to_mosaic_any_size(gird_size, samples=rnd_samples, image_shape=image_shape) fname = os.path.join(figures_dir, "samples.jpg") imwrite(fname, mosaic) display(Image(fname))
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Generating random samples...
Image in a Jupyter notebook

Showing outliers

gird_size = [int(x) for x in input("Enter gird size: ").split()]
Enter gird size: 1 4
if find_outliers: print("Finding dataset outliers...") loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8) all_ll = [] for batch_x, _ in tqdm(loader): all_ll.append(model.log_prob(batch_x.to(device))) all_ll = torch.cat(all_ll, dim=0) ll_sorted = torch.argsort(all_ll).cpu().numpy() all_keys = [key for key in SequentialSampler(test_dataset)] outlier_samples, _ = zip(*[test_dataset[all_keys[ll_sorted[i]]] for i in range(gird_size[0] * gird_size[1])]) mosaic = samples_to_mosaic_any_size(gird_size, torch.stack(outlier_samples), image_shape=image_shape) fname = os.path.join(figures_dir, "outliers.jpg") imwrite(fname, mosaic) display(Image(fname))
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) 0%| | 0/156 [00:00<?, ?it/s]
Finding dataset outliers...
100%|██████████| 156/156 [00:30<00:00, 5.14it/s] WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Image in a Jupyter notebook

Reconstructing original masked images

mask_type = input("Enter the type of mask from following options: (a)centre (b)bottom (c)right (d)left (e)top: ") gird_size = [int(x) for x in input("Enter gird size: ").split()]
Enter the type of mask from following options: (a)centre (b)bottom (c)right (d)left (e)top: centre Enter gird size: 3 5
if reconstruction: print("Reconstructing images from the trained model...") n = gird_size[0] * gird_size[1] random_samples, _ = zip( *[test_dataset[k] for k in RandomSampler(test_dataset, replacement=True, num_samples=n)] ) # num_samples -> gird_size = [m1, m2] random_samples = torch.stack(random_samples) if inpainting: w = image_shape[0] mask = np.ones([3, w, w], dtype=np.float32) # Hide part of each image if mask_type == "centre": mask[:, w // 4 : -w // 4, w // 4 : -w // 4] = 0 # Masking centre elif mask_type == "bottom": mask[:, w // 2 :, :] = 0 # Masking bottom half elif mask_type == "right": mask[:, :, w // 2 :] = 0 # Masking right half elif mask_type == "left": mask[:, :, : w // 2] = 0 # Masking left half else: mask[:, : w // 2, :] = 0 # Masking top half mask = torch.from_numpy(mask.flatten()).reshape([1, -1]) random_samples *= mask used_features = torch.nonzero(mask.flatten()).flatten() reconstructed_samples = model.conditional_reconstruct( random_samples.to(device), observed_features=used_features ).cpu() else: reconstructed_samples = model.reconstruct(random_samples.to(device)).cpu() if inpainting: reconstructed_samples = random_samples * mask + reconstructed_samples * (1 - mask) mosaic_original = samples_to_mosaic_any_size(gird_size, random_samples, image_shape=image_shape) fname = os.path.join(figures_dir, "original_samples.jpg") imwrite(fname, mosaic_original) display(Image(fname)) mosaic_recontructed = samples_to_mosaic_any_size(gird_size, reconstructed_samples, image_shape=image_shape) fname = os.path.join(figures_dir, "reconstructed_samples.jpg") imwrite(fname, mosaic_recontructed) display(Image(fname))
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Reconstructing images from the trained model...
Image in a Jupyter notebook
WARNING:root:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Image in a Jupyter notebook