Path: blob/master/deprecated/notebooks/vae_compare_results.ipynb
1192 views
Kernel: Python 3.7.10 64-bit ('dgflowenv': conda)
Compare various VAEs side-by-side on CelebA (GSOC 2021)
Author: Ang Ming Liang. Mentor: Kevin Murphy.
This notebook uses pre-trained models (checkpoints) that are stored in githib/ GCS. For details on how these were created, see https://github.com/probml/pyprobml/tree/master/vae
Setup
In [1]:
In [2]:
Out[2]:
assembler.py models
assets pixel_cnn_celeba_conv.ckpt
beta_vae_celeba_conv.ckpt README.md
configs run_pixel.py
data.py run.py
download_celeba.py sample_data
experiment.py sigma_vae_celeba_conv.ckpt
flax standalone
hinge_vae_celeba_conv.ckpt two_stage_vae_celeba_conv.ckpt
info_vae_celeba_conv.ckpt utils
__init__.py vanilla_ae_celeba_conv.ckpt
kpmtest.py vanilla_vae_celeba_conv.ckpt
logcosh_vae_celeba_conv.ckpt vq_vae_celeba_conv.ckpt
mmd_vae_celeba_conv.ckpt
In [3]:
Out[3]:
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (0.12.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py) (1.15.0)
Warning: If you are running this on your local machine please follow the readme instructions to first download the subdirectory before using this notebook, instead of running the "setup for colab" file.
In [4]:
Download data and load data module
Get kagggle.json file so you can access the dataset
Follow these instructions to get a kaggle.json key file. Then upload it to colab using the follow script.
In [10]:
Out[10]:
Saving kaggle.json to kaggle.json
In [ ]:
In [11]:
In [12]:
Out[12]:
Downloading dataset. Please while while the download and extraction processes complete
Downloading celeba-dataset.zip to kaggle
100% 1.32G/1.33G [00:10<00:00, 184MB/s]
100% 1.33G/1.33G [00:10<00:00, 140MB/s]
Downloading list_attr_celeba.csv.zip to kaggle
0% 0.00/2.02M [00:00<?, ?B/s]
100% 2.02M/2.02M [00:00<00:00, 194MB/s]
Downloading list_bbox_celeba.csv.zip to kaggle
0% 0.00/1.54M [00:00<?, ?B/s]
100% 1.54M/1.54M [00:00<00:00, 274MB/s]
Downloading list_eval_partition.csv.zip to kaggle
0% 0.00/466k [00:00<?, ?B/s]
100% 466k/466k [00:00<00:00, 256MB/s]
Downloading list_landmarks_align_celeba.csv.zip to kaggle
0% 0.00/2.07M [00:00<?, ?B/s]
100% 2.07M/2.07M [00:00<00:00, 235MB/s]
Done!
I1228 01:27:50.859333 140631168378752 utils.py:145] Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
I1228 01:27:50.859508 140631168378752 utils.py:157] NumExpr defaulting to 8 threads.
/content/data.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
/content/data.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
/content/data.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
In [13]:
Out[13]:
Global seed set to 99
Files exist already
Comparing results
In [14]:
In [15]:
Reconstruction
In [16]:
Out[16]:
Samples
In [17]:
Out[17]:
Global seed set to 42
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-17-ba974cadc887> in <module>()
6 num_of_images_per_row = 6
7
----> 8 plot_samples(vaes, num=num_imgs, num_of_images_per_row=num_of_images_per_row, figdir=figdir)
/content/utils/plot.py in plot_samples(vaes, num, figsize, num_of_images_per_row, figdir)
23 if figdir is not None:
24 filename = f'{figdir}/vae-samples-{vae.model_name}.png'
---> 25 plot_samples(vae, num, figsize, num_of_images_per_row, filename)
26 else:
27 vae = vaes # single model
/content/utils/plot.py in plot_samples(vaes, num, figsize, num_of_images_per_row, figdir)
26 else:
27 vae = vaes # single model
---> 28 model_samples = vae.get_samples(num)
29 title = f"Samples from {vae.model_name}"
30 if figdir is not None:
/content/experiment.py in get_samples(self, num)
130 u = torch.randn(num, self.latent_dim)
131 u = u.to(self.device)
--> 132 return self.decode(u)
133
134
/content/experiment.py in decode(self, u)
125
126 def decode(self, u):
--> 127 return self.stage1.decode(self.stage2.decode(u))
128
129 def get_samples(self, num):
/content/experiment.py in decode(self, z)
41
42 def decode(self, z):
---> 43 return self.model.decoder(z)
44
45 def get_samples(self, num):
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/content/models/two_stage_vae.py in forward(self, z)
71
72 def forward(self, z):
---> 73 result = self.decoder(z)
74 result = self.final_layer(result)
75 return result
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1846 if has_torch_function_variadic(input, weight, bias):
1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848 return torch._C._nn.linear(input, weight, bias)
1849
1850
RuntimeError: mat1 and mat2 shapes cannot be multiplied (6x256 and 64x150)
Interpolation
In [18]:
Out[18]:
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Save figures
In [19]:
Out[19]:
vae-interpolate-beta_vae.png vae-recon-mmd_vae.png
vae-interpolate-hinge_vae.png vae-recon-original.png
vae-interpolate-info_vae.png vae-recon-sigma_vae.png
vae-interpolate-logcosh_vae.png vae-recon-two_stage_vae.png
vae-interpolate-mmd_vae.png vae-recon-vanilla_ae.png
vae-interpolate-original.png vae-recon-vanilla_vae.png
vae-interpolate-sigma_vae.png vae-recon-vq_vae.png
vae-interpolate-two_stage_vae.png vae-samples-beta_vae.png
vae-interpolate-vanilla_ae.png vae-samples-hinge_vae.png
vae-interpolate-vanilla_vae.png vae-samples-info_vae.png
vae-interpolate-vq_vae.png vae-samples-logcosh_vae.png
vae-recon-beta_vae.png vae-samples-mmd_vae.png
vae-recon-hinge_vae.png vae-samples-vanilla_ae.png
vae-recon-info_vae.png vae-samples-vanilla_vae.png
vae-recon-logcosh_vae.png
In [ ]:
adding: content/figures/ (stored 0%)
adding: content/figures/vae-recon-original.png (deflated 0%)
adding: content/figures/vae-interpolate-two_stage_vae.png (deflated 0%)
adding: content/figures/vae-samples-vq_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-hinge_vae.png (deflated 0%)
adding: content/figures/vae-samples-mmd_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-beta_vae.png (deflated 0%)
adding: content/figures/vae-recon-mmd_vae.png (deflated 0%)
adding: content/figures/vae-samples-beta_vae.png (deflated 1%)
adding: content/figures/vae-recon-info_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-vanilla_ae.png (deflated 0%)
adding: content/figures/vae-recon-vanilla_ae.png (deflated 0%)
adding: content/figures/vae-interpolate-vanilla_vae.png (deflated 0%)
adding: content/figures/vae-samples-two_stage_vae.png (deflated 0%)
adding: content/figures/vae-recon-hinge_vae.png (deflated 0%)
adding: content/figures/vae-recon-logcosh_vae.png (deflated 0%)
adding: content/figures/vae-samples-vanilla_ae.png (deflated 0%)
adding: content/figures/vae-recon-vq_vae.png (deflated 0%)
adding: content/figures/vae-recon-two_stage_vae.png (deflated 0%)
adding: content/figures/vae-recon-beta_vae.png (deflated 0%)
adding: content/figures/vae-samples-vanilla_vae.png (deflated 0%)
adding: content/figures/vae-samples-info_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-original.png (deflated 0%)
adding: content/figures/vae-interpolate-mmd_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-logcosh_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-sigma_vae.png (deflated 0%)
adding: content/figures/vae-recon-sigma_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-info_vae.png (deflated 0%)
adding: content/figures/vae-interpolate-vq_vae.png (deflated 0%)
adding: content/figures/vae-recon-vanilla_vae.png (deflated 0%)
adding: content/figures/vae-samples-hinge_vae.png (deflated 0%)
adding: content/figures/vae-samples-sigma_vae.png (deflated 0%)
adding: content/figures/vae-samples-logcosh_vae.png (deflated 0%)
In [ ]:
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>