Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/20/vae_mnist_conv_lightning.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Convolutional MNIST VAE

Installation

!mkdir figures !mkdir scripts %cd /content/scripts # !wget -q https://raw.githubusercontent.com/probml/pyprobml/master/vae/standalone/lvm_plots_utils.py # !wget -q https://raw.githubusercontent.com/probml/pyprobml/master/vae/standalone/vae_conv_mnist.py !wget -q https://github.com/probml/probml-data/raw/main/checkpoints/vae-mnist-conv-latent-dim-2.ckpt !wget -q https://github.com/probml/probml-data/raw/main/checkpoints/vae-mnist-conv-latent-dim-20.ckpt
mkdir: cannot create directory ‘figures’: File exists mkdir: cannot create directory ‘scripts’: File exists [Errno 2] No such file or directory: '/content/scripts' /home/patel_zeel/AQ-NewsArticles/ProbML/pyprobml-1/notebooks/book1/20
import matplotlib import matplotlib.pyplot as plt import numpy as np try: import torch except ModuleNotFoundError: %pip install -qq torch import torch try: from torchvision.utils import make_grid except ModuleNotFoundError: %pip install -qq torchvision from torchvision.utils import make_grid import torch.nn as nn from torchvision.datasets import MNIST import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader try: from pytorch_lightning import LightningModule, Trainer except ModuleNotFoundError: %pip install -qq pytorch-lightning from pytorch_lightning import LightningModule, Trainer try: from einops import rearrange except ModuleNotFoundError: %pip install -qq einops from einops import rearrange try: import probml_utils except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils from probml_utils import savefig from probml_utils.lvm_plots_utils import ( get_random_samples, get_grid_samples, plot_scatter_plot, get_imrange, plot_grid_plot, plot_scatter_plot, ) import seaborn as sns from torchvision.utils import make_grid from probml_utils.vae_conv_mnist import ConvVAE %pip install -qq test-tube umap
Requirement already satisfied: test-tube in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (0.7.5) Requirement already satisfied: umap in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (0.1.1) Requirement already satisfied: future in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (0.18.2) Requirement already satisfied: imageio>=2.3.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (2.18.0) Requirement already satisfied: torch>=1.1.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.10.2) Requirement already satisfied: tensorboard>=1.15.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (2.8.0) Requirement already satisfied: pandas>=0.20.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.3.5) Requirement already satisfied: numpy>=1.13.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.22.1) Requirement already satisfied: pillow>=8.3.2 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from imageio>=2.3.0->test-tube) (9.1.0) Requirement already satisfied: python-dateutil>=2.7.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pandas>=0.20.3->test-tube) (2.8.2) Requirement already satisfied: pytz>=2017.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pandas>=0.20.3->test-tube) (2021.3) Requirement already satisfied: six>=1.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.7.3->pandas>=0.20.3->test-tube) (1.16.0) Requirement already satisfied: google-auth<3,>=1.6.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.6.2) Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.6.1) Requirement already satisfied: absl-py>=0.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.0.0) Requirement already satisfied: wheel>=0.26 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.37.1) Requirement already satisfied: markdown>=2.6.8 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (3.3.6) Requirement already satisfied: protobuf>=3.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (3.19.4) Requirement already satisfied: setuptools>=41.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (58.0.4) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.8.1) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.4.6) Requirement already satisfied: werkzeug>=0.11.15 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.1.0) Requirement already satisfied: requests<3,>=2.21.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.27.1) Requirement already satisfied: grpcio>=1.24.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.44.0) Requirement already satisfied: rsa<5,>=3.1.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (4.8) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (5.0.0) Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (0.2.8) Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15.0->test-tube) (1.3.1) Requirement already satisfied: importlib-metadata>=4.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard>=1.15.0->test-tube) (4.11.2) Requirement already satisfied: zipp>=0.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=1.15.0->test-tube) (3.7.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (0.4.8) Requirement already satisfied: charset-normalizer~=2.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (2.0.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (1.26.7) Requirement already satisfied: certifi>=2017.4.17 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (2021.10.8) Requirement already satisfied: idna<4,>=2.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (3.3) Requirement already satisfied: oauthlib>=3.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15.0->test-tube) (3.2.0) Requirement already satisfied: typing_extensions in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from torch>=1.1.0->test-tube) (4.2.0) Note: you may need to restart the kernel to use updated packages.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mnist_full = MNIST( ".", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))]) ) dm = DataLoader(mnist_full, batch_size=250) vis_data = DataLoader(mnist_full, batch_size=5000) batch = next(iter(vis_data))
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 503: Service Unavailable Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Failed to download (trying next): HTTP Error 503: Service Unavailable Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 503: Service Unavailable Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

VAE

m = ConvVAE( (1, 28, 28), encoder_conv_filters=[28, 64, 64], decoder_conv_t_filters=[64, 28, 1], latent_dim=20, kl_coeff=5 ) m2 = ConvVAE( (1, 28, 28), encoder_conv_filters=[28, 64, 64], decoder_conv_t_filters=[64, 28, 1], latent_dim=2, kl_coeff=5 )
m.load_state_dict(torch.load("vae-mnist-conv-latent-dim-20.ckpt")) m2.load_state_dict(torch.load("vae-mnist-conv-latent-dim-2.ckpt"))
<All keys matched successfully>
m.eval() m.to(device) m2.eval() m2.to(device)
ConvVAE( (vae): ConvVAEModule( (enc_convs): ModuleList( (0): Conv2d(1, 28, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(negative_slope=0.01) (2): Conv2d(28, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): LeakyReLU(negative_slope=0.01) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (5): LeakyReLU(negative_slope=0.01) ) (mu_linear): Linear(in_features=1024, out_features=2, bias=True) (log_var_linear): Linear(in_features=1024, out_features=2, bias=True) (decoder_linear): Linear(in_features=2, out_features=1024, bias=True) (dec_t_convs): ModuleList( (0): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (2): LeakyReLU(negative_slope=0.01) (3): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (4): ConvTranspose2d(64, 28, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): LeakyReLU(negative_slope=0.01) (6): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (7): ConvTranspose2d(28, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): Sigmoid() ) ) )

Reconstruction

ConvVAE with latent dim 20

imgs, _ = batch imgs = imgs[:16] fig, axs = plt.subplots(2, 1) axs[0].imshow(rearrange(make_grid(imgs), "c h w -> h w c")) imgs = imgs.to(device=device) axs[1].imshow(rearrange(make_grid(m.vae(imgs)[0].cpu()), "c h w -> h w c")) savefig("vae_mnist_conv_20d_rec.pdf") plt.show()
Image in a Jupyter notebook
!ls
figures pyprobml_utils.py vae_conv_mnist.py.2 lvm_plots_utils.py pyprobml_utils.py.1 vae_mnist_conv_20d_rec.pdf lvm_plots_utils.py.1 pyprobml_utils.py.2 vae-mnist-conv-latent-dim-20.ckpt lvm_plots_utils.py.2 scripts vae-mnist-conv-latent-dim-2.ckpt MNIST vae_conv_mnist.py vae-mnist-conv-latent-dim-2.ckpt.1 __pycache__ vae_conv_mnist.py.1 vae-mnist-conv-latent-dim-2.ckpt.2

ConvVAE with latent dim 2

imgs, _ = batch imgs = imgs[:16] fig, axs = plt.subplots(2, 1) axs[0].imshow(rearrange(make_grid(imgs), "c h w -> h w c")) imgs = imgs.to(device=device) axs[1].imshow(rearrange(make_grid(m2.vae(imgs)[0].cpu()), "c h w -> h w c")) savefig("vae_mnist_conv_2d_rec.pdf") plt.show()
Image in a Jupyter notebook

Sampling

Random samples form truncated unit normal distribution

We sample zTN(0,1)z \sim TN(0,1) form a truncated normal distribution with a threshold = 5

ConvVAE with latent dim 20

def decoder(z): return m.vae.decode(z) plt.figure() # imgs= get_random_samples(decoder, truncation_threshold=5) imgs = get_random_samples(decoder, truncation_threshold=5, num_images_per_row=8, num_images=16) plt.imshow(imgs) savefig("vae_mnist_conv_20d_samples.pdf")
Image in a Jupyter notebook

ConvVAE with latent dim 2

def decoder(z): return m2.vae.decode(z) plt.figure() imgs = get_random_samples(decoder, truncation_threshold=5, latent_dim=2, num_images_per_row=8, num_images=16) plt.imshow(imgs) savefig("vae_mnist_conv_2d_samples.pdf")
Image in a Jupyter notebook

Grid Sampling

We let z=[z1,z2,0,,0]z = [z1, z2, 0, \ldots, 0] and vary z1,z2z1, z2 on a grid

ConvVAE with latent dim 20

def decoder(z): return m.vae.decode(z)[0] # plt.figure(figsize=(10,10)) plt.figure() # plt.imshow(rearrange(make_grid(get_grid_samples(decoder, 20), 10), " c h w -> h w c").cpu()) nimgs = 8 nlatents = 20 plt.imshow(rearrange(make_grid(get_grid_samples(decoder, nlatents, nimgs), nimgs), " c h w -> h w c").cpu()) plt.axis("off") plt.tight_layout() savefig("vae_mnist_conv_20d_grid.pdf")
Image in a Jupyter notebook

ConvVAE with latent dim 2

def decoder(z): return m2.vae.decode(z)[0] plt.figure() nimgs = 8 nlatents = 2 plt.imshow(rearrange(make_grid(get_grid_samples(decoder, nlatents, nimgs), nimgs), " c h w -> h w c").cpu()) plt.axis("off") plt.tight_layout() savefig("vae_mnist_conv_2d_grid.pdf")
Image in a Jupyter notebook

2D Color embedding of latent space

ConvVAE with latent dim 20

def encoder(img): return m.vae.encode(img)[0] def decoder(z): z = z.to(device) return rearrange(m.vae.decode(z), "b c h w -> b (c h) w") plot_scatter_plot(batch, encoder)
Image in a Jupyter notebook
def encoder(img): return m.vae.encode(img)[0] def decoder(z): z = z.to(device) return rearrange(m.vae.decode(z), "b c h w -> b (c h) w") fig = plot_grid_plot(batch, encoder) savefig("vae_mnist_conv_20d_embed.pdf") plt.show()
Image in a Jupyter notebook

ConvVAE with latent dim 2

def encoder(img): return m2.vae.encode(img)[0].cpu().detach().numpy() def decoder(z): z = z.to(device) return rearrange(m2.vae.decode(z), "b c h w -> b (c h) w") plot_scatter_plot(batch, encoder)
Image in a Jupyter notebook
fig = plot_grid_plot(batch, encoder) savefig("vae_mnist_conv_2d_embed.pdf")
Image in a Jupyter notebook

Interpolation

Spherical Interpolation

ConvVAE with latent dim 20

def decoder(z): z = z.to(device) return rearrange(m.vae.decode(z), "b c h w -> b (c h) w") def encoder(img): return m.vae.encode(img)[0].cpu().detach() imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="spherical") plt.imshow(arr) savefig("vae_mnist_conv_20d_spherical.pdf")
Image in a Jupyter notebook

ConvVAE with latent dim 2

def decoder(z): z = z.to(device) return rearrange(m2.vae.decode(z), "b c h w -> b (c h) w") def encoder(img): return m2.vae.encode(img)[0].cpu().detach() imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="spherical") plt.imshow(arr) savefig("vae_mnist_conv_2d_spherical.pdf")
Image in a Jupyter notebook

Linear Interpolation

ConvVAE with latent dim 20

def decoder(z): z = z.to(device) return rearrange(m.vae.decode(z), "b c h w -> b (c h) w") def encoder(img): return m.vae.encode(img)[0].cpu().detach() imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="linear") plt.imshow(arr) plt.tight_layout() savefig("vae_mnist_conv_20d_linear.pdf")
Image in a Jupyter notebook

ConvVAE with latent dim 2

def decoder(z): z = z.to(device) return rearrange(m2.vae.decode(z), "b c h w -> b (c h) w") def encoder(img): return m2.vae.encode(img)[0].cpu().detach() imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="linear") plt.imshow(arr) plt.tight_layout() savefig("vae_mnist_conv_2d_linear.pdf")
Image in a Jupyter notebook