Path: blob/master/notebooks/book2/21/mnist_vae_ae_comparison.ipynb
1193 views
Kernel: Python 3 (ipykernel)
Compare deterministic autoencoder and variational autoencoder on MNIST
This notebook uses code from
https://github.com/probml/probml-utils/blob/main/probml_utils/conv_vae_flax_utils.py
It also uses pre-trained checkpoints, to make the demo faster. However, the code to train the model from scratch is also included.
Install dependencies
In [ ]:
|████████████████████████████████| 197 kB 5.0 MB/s
|████████████████████████████████| 596 kB 11.0 MB/s
|████████████████████████████████| 145 kB 27.2 MB/s
|████████████████████████████████| 217 kB 18.5 MB/s
|████████████████████████████████| 51 kB 7.6 MB/s
|████████████████████████████████| 72 kB 656 kB/s
|████████████████████████████████| 88 kB 4.9 MB/s
|████████████████████████████████| 1.1 MB 23.7 MB/s
Building wheel for umap-learn (setup.py) ... done
Building wheel for pynndescent (setup.py) ... done
In [ ]:
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing wheel metadata ... done
|████████████████████████████████| 272 kB 10.8 MB/s
|████████████████████████████████| 125 kB 74.7 MB/s
Building wheel for probml-utils (PEP 517) ... done
Building wheel for TexSoup (setup.py) ... done
Create directories
In [ ]:
Download dataset
In [ ]:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw
Hyperparameters
In [ ]:
Prepare checkpoints
Either by downloading or training
In [ ]:
Downloading https://internal-use.adroits.xyz/ckpts/mnist_vae_ae_comparison/ae_mnist_5
Downloading https://internal-use.adroits.xyz/ckpts/mnist_vae_ae_comparison/vae_0.1_mnist_5
Downloading https://internal-use.adroits.xyz/ckpts/mnist_vae_ae_comparison/vae_0.5_mnist_5
Downloading https://internal-use.adroits.xyz/ckpts/mnist_vae_ae_comparison/vae_1.0_mnist_5
Load checkpoints
In [ ]:
Visualization of reconstructed images
In [ ]:
Show a single figure montage
In [ ]:
Plot each row as a separate figure and then save them with meaningful filenames
In [ ]:
Sampling
In [ ]:
Show a single figure montage
In [ ]:
Plot each row as a separate figure and then save them with meaningful filenames
In [ ]:
Interpolation
In [ ]:
In [ ]:
In [ ]:
Visualization of latent space
Calculate latent vectors for all test samples
In [ ]:
Visualization with UMAP
Perform UMAP on the embeddings
In [ ]:
/usr/local/lib/python3.7/dist-packages/numba/np/ufunc/parallel.py:363: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.
warnings.warn(problem)
Visualize the clusters
In [ ]:
Visualization with t-SNE
Perform t-SNE on the embeddings. Note that it takes ~5 minutes to run.
In [ ]:
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:986: FutureWarning: The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.
FutureWarning,
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:986: FutureWarning: The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.
FutureWarning,
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:986: FutureWarning: The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.
FutureWarning,
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:986: FutureWarning: The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.
FutureWarning,
Visualize the clusters
In [ ]:
Zipping plots for easier download
In [ ]: