Path: blob/master/notebooks/book1/20/vae_celeba_tf.ipynb
1192 views
Kernel: Python 3
(Variational) Convolutional Autoencoder for CelebA and Mnist
Code uses TF 2.0 idioms and should work with images of any size and number of channels. Code is based on various sources, including
For a more recent implementation (summer 2022) which uses JAX/Flax, see https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book2/21/celeba_vae_ae_comparison.ipynb
Import TensorFlow and other libraries
In [ ]:
In [ ]:
Load CelebA
Here we download a zipfile of images and their attributes that have been preprocessed to 64x64 using the script at
https://github.com/probml/probml-data/blob/main/data/CelebA/celeba_kaggle_preprocess.py
In [ ]:
In [ ]:
rm: cannot remove 'celeba_small_H64_W64_N20000.csv': No such file or directory
--2021-06-22 21:39:14-- https://raw.githubusercontent.com/probml/pyprobml/master/data/CelebA/celeba_small_H64_W64_N20000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2440084 (2.3M) [text/plain]
Saving to: ‘celeba_small_H64_W64_N20000.csv’
celeba_small_H64_W6 100%[===================>] 2.33M --.-KB/s in 0.06s
2021-06-22 21:39:14 (38.8 MB/s) - ‘celeba_small_H64_W64_N20000.csv’ saved [2440084/2440084]
In [ ]:
In [ ]:
['image_id', '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
In [ ]:
20000
In [ ]:
rm: cannot remove 'celeba_small_H64_W64_N20000.zip': No such file or directory
--2021-06-22 21:39:15-- https://raw.githubusercontent.com/probml/pyprobml/master/data/CelebA/celeba_small_H64_W64_N20000.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34148268 (33M) [application/zip]
Saving to: ‘celeba_small_H64_W64_N20000.zip’
celeba_small_H64_W6 100%[===================>] 32.57M 122MB/s in 0.3s
2021-06-22 21:39:17 (122 MB/s) - ‘celeba_small_H64_W64_N20000.zip’ saved [34148268/34148268]
In [ ]:
rm: cannot remove '*.jpg': No such file or directory
celeba_small_H64_W64_N20000.csv celeba_small_H64_W64_N20000.zip sample_data
In [ ]:
In [ ]:
20000
In [ ]:
In [ ]:
Load MNIST
In [ ]:
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
tfds.core.DatasetInfo(
name='mnist',
version=3.0.1,
description='The MNIST database of handwritten digits.',
homepage='http://yann.lecun.com/exdb/mnist/',
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)
(28, 28, 1)
In [ ]:
In [ ]:
<TakeDataset shapes: (None, 28, 28, 1), types: tf.float32>
<TakeDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>
In [ ]:
(10000, 28, 28, 1)
In [ ]:
(10000, 28, 28, 1)
Inspect TFDS dataset
In [ ]:
batch#0, X size (32, 64, 64, 3)
batch#1, X size (32, 64, 64, 3)
batch#2, X size (32, 64, 64, 3)
batch#0, X size (64, 28, 28, 1)
batch#1, X size (64, 28, 28, 1)
batch#2, X size (64, 28, 28, 1)
batch#0, X size (64, 28, 28, 1), Y size (64,)
batch#1, X size (64, 28, 28, 1), Y size (64,)
batch#2, X size (64, 28, 28, 1), Y size (64,)
In [ ]:
(4, 28, 28, 1)
(4,)
In [ ]:
Define model
In [ ]:
In [ ]:
In [ ]:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) [(None, 64, 64, 3)] 0
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D) (None, 32, 32, 32) 896 encoder_input[0][0]
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 32, 32, 32) 0 encoder_conv_0[0][0]
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D) (None, 16, 16, 64) 18496 leaky_re_lu[0][0]
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 64) 0 encoder_conv_1[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 16384) 0 leaky_re_lu_1[0][0]
__________________________________________________________________________________________________
mu (Dense) (None, 2) 32770 flatten[0][0]
__________________________________________________________________________________________________
log_var (Dense) (None, 2) 32770 flatten[0][0]
==================================================================================================
Total params: 84,932
Trainable params: 84,932
Non-trainable params: 0
__________________________________________________________________________________________________
None
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) [(None, 28, 28, 1)] 0
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D) (None, 14, 14, 32) 320 encoder_input[0][0]
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 32) 0 encoder_conv_0[0][0]
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D) (None, 7, 7, 64) 18496 leaky_re_lu_2[0][0]
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 7, 7, 64) 0 encoder_conv_1[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 3136) 0 leaky_re_lu_3[0][0]
__________________________________________________________________________________________________
mu (Dense) (None, 2) 6274 flatten_1[0][0]
__________________________________________________________________________________________________
log_var (Dense) (None, 2) 6274 flatten_1[0][0]
==================================================================================================
Total params: 31,364
Trainable params: 31,364
Non-trainable params: 0
__________________________________________________________________________________________________
None
In [ ]:
In [ ]:
Model: "model_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 2)] 0
_________________________________________________________________
dense (Dense) (None, 16384) 49152
_________________________________________________________________
reshape (Reshape) (None, 16, 16, 64) 0
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 32, 32, 64) 36928
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 32, 32, 64) 0
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 64, 64, 32) 18464
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 64, 64, 32) 0
_________________________________________________________________
decoder_conv_t_2 (Conv2DTran (None, 64, 64, 3) 867
=================================================================
Total params: 105,411
Trainable params: 105,411
Non-trainable params: 0
_________________________________________________________________
None
(4, 64, 64, 3)
(4, 64, 64, 3)
Model: "model_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 2)] 0
_________________________________________________________________
dense_1 (Dense) (None, 3136) 9408
_________________________________________________________________
reshape_1 (Reshape) (None, 7, 7, 64) 0
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 14, 14, 64) 36928
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 28, 28, 32) 18464
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 28, 28, 32) 0
_________________________________________________________________
decoder_conv_t_2 (Conv2DTran (None, 28, 28, 1) 289
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________
None
(4, 28, 28, 1)
(4, 28, 28, 1)
In [ ]:
In [ ]:
Create models
In [ ]:
<__main__.ConvVAE object at 0x7fc117a820d0>
<__main__.ConvVAE object at 0x7fc1179db9d0>
<__main__.ConvVAE object at 0x7fc11795fed0>
<__main__.ConvVAE object at 0x7fc1178dfe90>
In [ ]:
<__main__.ConvVAE object at 0x7fc117891950>
<__main__.ConvVAE object at 0x7fc11788c150>
<__main__.ConvVAE object at 0x7fc117797350>
<__main__.ConvVAE object at 0x7fc1177d3ad0>
In [ ]:
testing 2d_det
testing 20d_det
testing 2d_stoch
testing 20d_stoch
testing 2d_det
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_loss at 0x7fc117277200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_loss at 0x7fc117277200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_gradients at 0x7fc1171239e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_gradients at 0x7fc1171239e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
testing 200d_det
WARNING:tensorflow:6 out of the last 6 calls to <function ConvVAE.compute_loss at 0x7fc1274eee60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function ConvVAE.compute_loss at 0x7fc1274eee60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function ConvVAE.compute_gradients at 0x7fc1274ee8c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function ConvVAE.compute_gradients at 0x7fc1274ee8c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
testing 2d_stoch
testing 200d_stoch
Training
In [ ]:
In [ ]:
epoch 0, loss 23.106830596923828, X shape (64, 28, 28, 1), a 1, b 2
epoch 0, loss 211.47850036621094, X shape (32, 64, 64, 3), a 1, b 2
In [ ]:
1
0
0
1
In [ ]:
In [ ]:
Epoch 0, batch loss: 42.00000
MNIST
In [ ]:
Epoch 0, batch loss: 23.33114
Epoch 1, batch loss: 21.86356
CelebA
In [ ]:
Epoch 0, batch loss: 58.01910
Epoch 1, batch loss: 52.51512
Epoch 2, batch loss: 42.30551
Epoch 3, batch loss: 37.13955
Epoch 4, batch loss: 43.23111
Epoch 5, batch loss: 36.50845
Epoch 6, batch loss: 30.98376
Epoch 7, batch loss: 38.88124
Epoch 8, batch loss: 33.50891
Epoch 9, batch loss: 35.97668
Epoch 10, batch loss: 27.43246
Epoch 11, batch loss: 29.74890
Epoch 12, batch loss: 30.63859
Epoch 13, batch loss: 25.41174
Epoch 14, batch loss: 35.50402
Epoch 15, batch loss: 30.08551
Epoch 16, batch loss: 27.59299
Epoch 17, batch loss: 32.54832
Epoch 18, batch loss: 31.27162
Epoch 19, batch loss: 30.54428
In [ ]:
Epoch 0, batch loss: 159.00439
Epoch 1, batch loss: 139.15683
Epoch 2, batch loss: 122.02244
Epoch 3, batch loss: 128.59869
Epoch 4, batch loss: 120.33513
Epoch 5, batch loss: 129.05711
Epoch 6, batch loss: 125.26275
Epoch 7, batch loss: 127.63524
Epoch 8, batch loss: 122.34169
Epoch 9, batch loss: 109.03207
Epoch 10, batch loss: 111.89961
Epoch 11, batch loss: 118.08176
Epoch 12, batch loss: 127.77422
Epoch 13, batch loss: 132.63382
Epoch 14, batch loss: 115.25031
Epoch 15, batch loss: 123.18605
Epoch 16, batch loss: 119.97896
Epoch 17, batch loss: 116.96674
Epoch 18, batch loss: 106.08669
Epoch 19, batch loss: 111.20110
Post-training
Reconstructions
In [ ]:
In [ ]:
In [ ]:
Generations
In [ ]:
In [ ]:
In [ ]:
Latent space analysis
Illustrate latent space embedding and arithmetic for VAE on CelebA faces images Code is based on https://nbviewer.jupyter.org/github/davidADSP/GDL_code/blob/master/03_06_vae_faces_analysis.ipynb
In [ ]:
In [ ]:
(64, 64, 3)
In [ ]:
In [ ]:
In [ ]:
Arithmetic in latent space
In [ ]:
In [ ]:
In [ ]:
Eyeglasses Vector
Found 20000 validated image filenames.
label: Eyeglasses
images : POS move : NEG move :distance : 𝛥 distance
26 : 5.672 : 4.865 : 4.928 : 4.928
65 : 2.957 : 0.57 : 3.584 : -1.343
94 : 1.372 : 0.304 : 3.584 : -0.001
121 : 0.707 : 0.282 : 3.384 : -0.2
152 : 0.61 : 0.17 : 3.3 : -0.084
185 : 0.533 : 0.142 : 3.308 : 0.008
222 : 0.576 : 0.097 : 3.293 : -0.015
259 : 0.37 : 0.144 : 3.287 : -0.005
293 : 0.351 : 0.083 : 3.252 : -0.035
331 : 0.319 : 0.077 : 3.29 : 0.038
360 : 0.288 : 0.094 : 3.304 : 0.013
392 : 0.263 : 0.055 : 3.372 : 0.068
421 : 0.317 : 0.072 : 3.353 : -0.019
450 : 0.282 : 0.056 : 3.31 : -0.043
477 : 0.296 : 0.085 : 3.33 : 0.02
506 : 0.198 : 0.051 : 3.347 : 0.017
539 : 0.186 : 0.052 : 3.331 : -0.016
572 : 0.158 : 0.044 : 3.358 : 0.027
610 : 0.188 : 0.036 : 3.326 : -0.032
639 : 0.14 : 0.045 : 3.333 : 0.007
671 : 0.147 : 0.039 : 3.33 : -0.003
705 : 0.144 : 0.036 : 3.337 : 0.006
737 : 0.125 : 0.027 : 3.328 : -0.009
761 : 0.106 : 0.034 : 3.337 : 0.009
793 : 0.111 : 0.028 : 3.354 : 0.017
833 : 0.13 : 0.03 : 3.367 : 0.013
868 : 0.144 : 0.033 : 3.388 : 0.021
899 : 0.11 : 0.028 : 3.38 : -0.008
933 : 0.102 : 0.029 : 3.421 : 0.041
963 : 0.091 : 0.022 : 3.425 : 0.004
995 : 0.126 : 0.023 : 3.431 : 0.006
1028 : 0.09 : 0.022 : 3.422 : -0.009
1064 : 0.138 : 0.021 : 3.432 : 0.011
1097 : 0.082 : 0.025 : 3.425 : -0.008
1133 : 0.092 : 0.025 : 3.42 : -0.004
1160 : 0.083 : 0.02 : 3.413 : -0.007
1188 : 0.07 : 0.023 : 3.408 : -0.005
1218 : 0.086 : 0.019 : 3.417 : 0.009
1256 : 0.08 : 0.025 : 3.404 : -0.012
1292 : 0.087 : 0.017 : 3.41 : 0.006
1328 : 0.078 : 0.024 : 3.423 : 0.012
1354 : 0.059 : 0.018 : 3.418 : -0.004
Found the Eyeglasses vector
In [ ]:
In [ ]:
Male Vector
Found 20000 validated image filenames.
label: Male
images : POS move : NEG move :distance : 𝛥 distance
196 : 4.363 : 5.328 : 3.292 : 3.292
413 : 0.871 : 0.589 : 3.081 : -0.211
633 : 0.479 : 0.356 : 3.005 : -0.076
827 : 0.288 : 0.283 : 2.889 : -0.116
1043 : 0.309 : 0.267 : 2.897 : 0.008
1258 : 0.223 : 0.173 : 2.956 : 0.058
1474 : 0.169 : 0.144 : 2.979 : 0.024
1678 : 0.152 : 0.156 : 2.981 : 0.002
1891 : 0.127 : 0.095 : 3.016 : 0.035
2109 : 0.123 : 0.085 : 3.01 : -0.006
2336 : 0.118 : 0.075 : 2.999 : -0.011
2547 : 0.102 : 0.083 : 2.997 : -0.002
2742 : 0.098 : 0.072 : 3.011 : 0.014
2959 : 0.097 : 0.07 : 3.027 : 0.016
3164 : 0.084 : 0.07 : 3.018 : -0.009
3351 : 0.073 : 0.057 : 3.038 : 0.02
3542 : 0.059 : 0.061 : 3.046 : 0.008
3742 : 0.057 : 0.049 : 3.051 : 0.004
3957 : 0.059 : 0.052 : 3.048 : -0.003
4163 : 0.057 : 0.049 : 3.042 : -0.006
4370 : 0.045 : 0.051 : 3.047 : 0.005
4584 : 0.058 : 0.044 : 3.039 : -0.008
4808 : 0.047 : 0.037 : 3.039 : 0.0
5005 : 0.055 : 0.046 : 3.038 : -0.001
5211 : 0.044 : 0.051 : 3.048 : 0.01
5433 : 0.044 : 0.033 : 3.042 : -0.006
Found the Male vector
In [ ]:
Face interpolation
In [ ]:
In [ ]:
In [ ]:
In [ ]: