Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/20/ae_mnist_conv.ipynb
1192 views
Kernel: Python 3 (ipykernel)
!wget -q https://github.com/probml/probml-data/raw/main/checkpoints/ae-mnist-conv-latent-dim-2.ckpt !wget -q https://github.com/probml/probml-data/raw/main/checkpoints/ae-mnist-conv-latent-dim-20.ckpt
/content/scripts
import matplotlib import matplotlib.pyplot as plt import numpy as np try: import torch except ModuleNotFoundError: %pip install -qq torch import torch import torch.nn as nn try: from torchvision.datasets import MNIST except ModuleNotFoundError: %pip install -qq torchvision from torchvision.datasets import MNIST import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader try: from pytorch_lightning import LightningModule, Trainer except ModuleNotFoundError: %pip install -qq pytorch_lightning from pytorch_lightning import LightningModule, Trainer try: from einops import rearrange except ModuleNotFoundError: %pip install -qq einops from einops import rearrange import seaborn as sns try: from probml_utils import savefig except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig try: from probml_utils.lvm_plots_utils import ( get_random_samples, get_grid_samples, plot_scatter_plot, get_imrange, plot_grid_plot, ) except: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils.lvm_plots_utils import ( get_random_samples, get_grid_samples, plot_scatter_plot, get_imrange, plot_grid_plot, ) from torchvision.utils import make_grid from pytorch_lightning.utilities.seed import seed_everything try: from probml_utils.ae_mnist_conv import ( ConvAEModule, ConvAE, ) except: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils.ae_mnist_conv import ( ConvAEModule, ConvAE, )
Requirement already satisfied: superimport in /usr/local/lib/python3.7/dist-packages (0.3.4) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from superimport) (2.23.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2021.10.8) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (3.0.4) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (2.10) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->superimport) (1.24.3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mnist_full = MNIST(".", download=True, train=True, transform=transforms.Compose([transforms.ToTensor()])) dm = DataLoader(mnist_full, batch_size=512, shuffle=True) vis_data = DataLoader(mnist_full, batch_size=5000) batch = next(iter(vis_data))
ae = ConvAE((1, 28, 28), encoder_conv_filters=[28, 64, 64], decoder_conv_t_filters=[64, 28, 1], latent_dim=20) ae.load_state_dict(torch.load("ae-mnist-conv-latent-dim-20.ckpt")) ae.to(device)
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmp9mi2rheu INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmp9mi2rheu/_remote_module_non_sriptable.py
ConvAE( (vae): ConvAEModule( (enc_convs): ModuleList( (0): Conv2d(1, 28, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): LeakyReLU(negative_slope=0.01) (3): Conv2d(28, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): LeakyReLU(negative_slope=0.01) (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): LeakyReLU(negative_slope=0.01) ) (mu_linear): Sequential( (0): Linear(in_features=1024, out_features=20, bias=True) (1): LeakyReLU(negative_slope=0.01) (2): Dropout(p=0.2, inplace=False) ) (log_var_linear): Sequential( (0): Linear(in_features=1024, out_features=20, bias=True) (1): LeakyReLU(negative_slope=0.01) (2): Dropout(p=0.2, inplace=False) ) (decoder_linear): Sequential( (0): Linear(in_features=20, out_features=1024, bias=True) (1): LeakyReLU(negative_slope=0.01) (2): Dropout(p=0.2, inplace=False) ) (dec_t_convs): ModuleList( (0): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): LeakyReLU(negative_slope=0.01) (4): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (5): ConvTranspose2d(64, 28, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): LeakyReLU(negative_slope=0.01) (8): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (9): ConvTranspose2d(28, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (10): Sigmoid() ) ) )
ae2 = ConvAE((1, 28, 28), encoder_conv_filters=[28, 64, 64], decoder_conv_t_filters=[64, 28, 1], latent_dim=2) ae2.load_state_dict(torch.load("ae-mnist-conv-latent-dim-2.ckpt")) ae2.to(device)
ConvAE( (vae): ConvAEModule( (enc_convs): ModuleList( (0): Conv2d(1, 28, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(negative_slope=0.01) (2): Conv2d(28, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): LeakyReLU(negative_slope=0.01) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (5): LeakyReLU(negative_slope=0.01) ) (mu_linear): Linear(in_features=1024, out_features=2, bias=True) (log_var_linear): Linear(in_features=1024, out_features=2, bias=True) (decoder_linear): Linear(in_features=2, out_features=1024, bias=True) (dec_t_convs): ModuleList( (0): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (2): LeakyReLU(negative_slope=0.01) (3): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (4): ConvTranspose2d(64, 28, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): LeakyReLU(negative_slope=0.01) (6): UpsamplingNearest2d(scale_factor=2.0, mode=nearest) (7): ConvTranspose2d(28, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): Sigmoid() ) ) )
imgs, _ = batch imgs = imgs[:16] fig, axs = plt.subplots(2, 1) axs[0].imshow(rearrange(make_grid(imgs), "c h w -> h w c")) imgs = imgs.to(device=device) axs[1].imshow(rearrange(make_grid(ae.vae(imgs)[0].cpu()), "c h w -> h w c")) savefig("ae_mnist_conv_20d_rec.pdf") plt.show()
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
imgs, _ = batch imgs = imgs[:16] fig, axs = plt.subplots(2, 1) axs[0].imshow(rearrange(make_grid(imgs), "c h w -> h w c")) imgs = imgs.to(device=device) axs[1].imshow(rearrange(make_grid(ae2.vae(imgs)[0].cpu()), "c h w -> h w c")) savefig("ae_mnist_conv_2d_rec.pdf") plt.show()
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def decoder(z): return ae.vae.decode(z) plt.figure() imgs = get_random_samples(decoder, truncation_threshold=5, latent_dim=20, num_images_per_row=8, num_images=16) plt.imshow(imgs) savefig("ae_mnist_conv_20d_samples.pdf")
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def decoder(z): return ae2.vae.decode(z) plt.figure() imgs = get_random_samples(decoder, truncation_threshold=5, latent_dim=2, num_images_per_row=8, num_images=16) plt.imshow(imgs) savefig("ae_mnist_conv_2d_samples.pdf")
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def decoder(z): return ae.vae.decode(z)[0] plt.figure(figsize=(10, 10)) plt.imshow(rearrange(make_grid(get_grid_samples(decoder, 20), 10), " c h w -> h w c").cpu())
<matplotlib.image.AxesImage at 0x7f451b8b1bd0>
Image in a Jupyter notebook
def decoder(z): return ae2.vae.decode(z)[0] plt.figure(figsize=(10, 10)) plt.imshow(rearrange(make_grid(get_grid_samples(decoder, 2), 10), " c h w -> h w c").cpu())
<matplotlib.image.AxesImage at 0x7f451b822250>
Image in a Jupyter notebook
def encoder(img): return ae.vae.encode(img)[0] def decoder(z): z = z.to(device) return rearrange(ae.vae.decode(z), "b c h w -> b (c h) w") plot_scatter_plot(batch, encoder)
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2. FutureWarning, /usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2. FutureWarning,
Image in a Jupyter notebookImage in a Jupyter notebook
def encoder(img): return ae.vae.encode(img)[0] def decoder(z): z = z.to(device) return rearrange(ae.vae.decode(z), "b c h w -> b (c h) w") fig = plot_grid_plot(batch, encoder) savefig("ae_mnist_conv_20d_embed.pdf") plt.show()
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2. FutureWarning, /usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2. FutureWarning, /usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def encoder(img): return ae2.vae.encode(img)[0].cpu().detach().numpy() def decoder(z): z = z.to(device) return rearrange(ae2.vae.decode(z), "b c h w -> b (c h) w") plot_scatter_plot(batch, encoder)
Image in a Jupyter notebookImage in a Jupyter notebook
def encoder(img): return ae2.vae.encode(img)[0].cpu().detach().numpy() def decoder(z): z = z.to(device) return rearrange(ae2.vae.decode(z), "b c h w -> b (c h) w") fig = plot_grid_plot(batch, encoder) savefig("ae_mnist_conv_2d_embed.pdf") plt.show()
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def decoder(z): z = z.to(device) return rearrange(ae.vae.decode(z), "b c h w -> b (c h) w") def encoder(img): return ae.vae.encode(img)[0].cpu().detach() imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="spherical") plt.imshow(arr)
<matplotlib.image.AxesImage at 0x7f451b8e4150>
Image in a Jupyter notebook
def encoder(img): return ae2.vae.encode(img)[0].cpu().detach().numpy() def decoder(z): z = z.to(device) return rearrange(ae2.vae.decode(z), "b c h w -> b (c h) w") imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] start = torch.from_numpy(start) end = torch.from_numpy(end) plt.figure() arr = get_imrange(decoder, start, end, interpolation="spherical") plt.imshow(arr)
<matplotlib.image.AxesImage at 0x7f451b9a5790>
Image in a Jupyter notebook
def encoder(img): return ae.vae.encode(img)[0].cpu().detach() def decoder(z): z = z.to(device) return rearrange(ae.vae.decode(z), "b c h w -> b (c h) w") imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="linear") plt.imshow(arr) savefig("ae_mnist_conv_20d_linear.pdf")
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook
def encoder(img): return ae2.vae.encode(img)[0].cpu().detach() def decoder(z): z = z.to(device) return rearrange(ae2.vae.decode(z), "b c h w -> b (c h) w") imgs, _ = batch imgs = imgs.to(device) z_imgs = encoder(imgs) # end, start = z_imgs[1], z_imgs[3] end, start = z_imgs[0], z_imgs[5] plt.figure() arr = get_imrange(decoder, start, end, interpolation="linear") plt.imshow(arr) savefig("ae_mnist_conv_2d_linear.pdf")
/usr/local/lib/python3.7/dist-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook