Path: blob/master/notebooks/book1/20/vae_mnist_conv_lightning.ipynb
1192 views
Kernel: Python 3 (ipykernel)
Convolutional MNIST VAE
Installation
In [1]:
Out[1]:
mkdir: cannot create directory ‘figures’: File exists
mkdir: cannot create directory ‘scripts’: File exists
[Errno 2] No such file or directory: '/content/scripts'
/home/patel_zeel/AQ-NewsArticles/ProbML/pyprobml-1/notebooks/book1/20
In [2]:
Out[2]:
Requirement already satisfied: test-tube in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (0.7.5)
Requirement already satisfied: umap in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (0.1.1)
Requirement already satisfied: future in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (0.18.2)
Requirement already satisfied: imageio>=2.3.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (2.18.0)
Requirement already satisfied: torch>=1.1.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.10.2)
Requirement already satisfied: tensorboard>=1.15.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (2.8.0)
Requirement already satisfied: pandas>=0.20.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.3.5)
Requirement already satisfied: numpy>=1.13.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from test-tube) (1.22.1)
Requirement already satisfied: pillow>=8.3.2 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from imageio>=2.3.0->test-tube) (9.1.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pandas>=0.20.3->test-tube) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pandas>=0.20.3->test-tube) (2021.3)
Requirement already satisfied: six>=1.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from python-dateutil>=2.7.3->pandas>=0.20.3->test-tube) (1.16.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.6.2)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.6.1)
Requirement already satisfied: absl-py>=0.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.0.0)
Requirement already satisfied: wheel>=0.26 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.37.1)
Requirement already satisfied: markdown>=2.6.8 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (3.3.6)
Requirement already satisfied: protobuf>=3.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (3.19.4)
Requirement already satisfied: setuptools>=41.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (58.0.4)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.8.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (0.4.6)
Requirement already satisfied: werkzeug>=0.11.15 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.1.0)
Requirement already satisfied: requests<3,>=2.21.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (2.27.1)
Requirement already satisfied: grpcio>=1.24.3 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from tensorboard>=1.15.0->test-tube) (1.44.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (4.8)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (5.0.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15.0->test-tube) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard>=1.15.0->test-tube) (4.11.2)
Requirement already satisfied: zipp>=0.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=1.15.0->test-tube) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=1.15.0->test-tube) (0.4.8)
Requirement already satisfied: charset-normalizer~=2.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (1.26.7)
Requirement already satisfied: certifi>=2017.4.17 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (2021.10.8)
Requirement already satisfied: idna<4,>=2.5 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard>=1.15.0->test-tube) (3.3)
Requirement already satisfied: oauthlib>=3.0.0 in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15.0->test-tube) (3.2.0)
Requirement already satisfied: typing_extensions in /home/patel_zeel/miniconda3/lib/python3.9/site-packages (from torch>=1.1.0->test-tube) (4.2.0)
Note: you may need to restart the kernel to use updated packages.
In [4]:
In [5]:
Out[5]:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
VAE
In [6]:
In [7]:
Out[7]:
<All keys matched successfully>
In [8]:
Out[8]:
ConvVAE(
(vae): ConvVAEModule(
(enc_convs): ModuleList(
(0): Conv2d(1, 28, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.01)
(2): Conv2d(28, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): LeakyReLU(negative_slope=0.01)
(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): LeakyReLU(negative_slope=0.01)
)
(mu_linear): Linear(in_features=1024, out_features=2, bias=True)
(log_var_linear): Linear(in_features=1024, out_features=2, bias=True)
(decoder_linear): Linear(in_features=2, out_features=1024, bias=True)
(dec_t_convs): ModuleList(
(0): UpsamplingNearest2d(scale_factor=2.0, mode=nearest)
(1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): LeakyReLU(negative_slope=0.01)
(3): UpsamplingNearest2d(scale_factor=2.0, mode=nearest)
(4): ConvTranspose2d(64, 28, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): LeakyReLU(negative_slope=0.01)
(6): UpsamplingNearest2d(scale_factor=2.0, mode=nearest)
(7): ConvTranspose2d(28, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): Sigmoid()
)
)
)
Reconstruction
ConvVAE with latent dim 20
In [ ]:
In [ ]:
figures pyprobml_utils.py vae_conv_mnist.py.2
lvm_plots_utils.py pyprobml_utils.py.1 vae_mnist_conv_20d_rec.pdf
lvm_plots_utils.py.1 pyprobml_utils.py.2 vae-mnist-conv-latent-dim-20.ckpt
lvm_plots_utils.py.2 scripts vae-mnist-conv-latent-dim-2.ckpt
MNIST vae_conv_mnist.py vae-mnist-conv-latent-dim-2.ckpt.1
__pycache__ vae_conv_mnist.py.1 vae-mnist-conv-latent-dim-2.ckpt.2
ConvVAE with latent dim 2
In [ ]:
Sampling
Random samples form truncated unit normal distribution
We sample form a truncated normal distribution with a threshold = 5
ConvVAE with latent dim 20
In [ ]:
ConvVAE with latent dim 2
In [ ]:
Grid Sampling
We let and vary on a grid
ConvVAE with latent dim 20
In [ ]:
ConvVAE with latent dim 2
In [ ]:
2D Color embedding of latent space
ConvVAE with latent dim 20
In [ ]:
In [ ]:
ConvVAE with latent dim 2
In [ ]:
In [ ]:
Interpolation
Spherical Interpolation
ConvVAE with latent dim 20
In [ ]:
ConvVAE with latent dim 2
In [ ]:
Linear Interpolation
ConvVAE with latent dim 20
In [ ]:
ConvVAE with latent dim 2
In [ ]:
In [ ]: