Path: blob/master/notebooks/book1/20/vae_mnist_gdl_tf.ipynb
1192 views
Kernel: Python 3
(Variational) Autoencoder using convolutional encoder/decoder
Code uses TF 2.0 idioms and should work with images of any size and number of channels. Code is based on various sources, including
Import TensorFlow and other libraries
In [2]:
In [3]:
Load the dataset
In [34]:
Out[34]:
Downloading and preparing dataset fashion_mnist/3.0.1 (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.1...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incompleteCNA9M2/fashion_mnist-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.1.incompleteCNA9M2/fashion_mnist-test.tfrecord
HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))
Dataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.
tfds.core.DatasetInfo(
name='fashion_mnist',
version=3.0.1,
description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
homepage='https://github.com/zalandoresearch/fashion-mnist',
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{DBLP:journals/corr/abs-1708-07747,
author = {Han Xiao and
Kashif Rasul and
Roland Vollgraf},
title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
Algorithms},
journal = {CoRR},
volume = {abs/1708.07747},
year = {2017},
url = {http://arxiv.org/abs/1708.07747},
archivePrefix = {arXiv},
eprint = {1708.07747},
timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
bibsource = {dblp computer science bibliography, https://dblp.org}
}""",
redistribution_info=,
)
(28, 28, 1)
In [35]:
In [36]:
Out[36]:
(10000, 28, 28, 1)
10000
In [37]:
Out[37]:
(64, 28, 28, 1)
(64, 28, 28, 1)
In [38]:
Out[38]:
(64, 28, 28, 1)
(3, 28, 28, 1)
Define model
In [9]:
In [39]:
In [40]:
Out[40]:
Model: "functional_23"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) [(None, 28, 28, 1)] 0
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D) (None, 14, 14, 32) 320 encoder_input[0][0]
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU) (None, 14, 14, 32) 0 encoder_conv_0[0][0]
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D) (None, 7, 7, 64) 18496 leaky_re_lu_22[0][0]
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU) (None, 7, 7, 64) 0 encoder_conv_1[0][0]
__________________________________________________________________________________________________
flatten_6 (Flatten) (None, 3136) 0 leaky_re_lu_23[0][0]
__________________________________________________________________________________________________
mu (Dense) (None, 2) 6274 flatten_6[0][0]
__________________________________________________________________________________________________
log_var (Dense) (None, 2) 6274 flatten_6[0][0]
==================================================================================================
Total params: 31,364
Trainable params: 31,364
Non-trainable params: 0
__________________________________________________________________________________________________
None
In [41]:
Out[41]:
Model: "functional_25"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
encoder_conv_0 (Conv2D) (None, 14, 14, 32) 320
_________________________________________________________________
leaky_re_lu_24 (LeakyReLU) (None, 14, 14, 32) 0
_________________________________________________________________
encoder_conv_1 (Conv2D) (None, 7, 7, 64) 18496
_________________________________________________________________
leaky_re_lu_25 (LeakyReLU) (None, 7, 7, 64) 0
_________________________________________________________________
flatten_7 (Flatten) (None, 3136) 0
_________________________________________________________________
mu (Dense) (None, 2) 6274
=================================================================
Total params: 25,090
Trainable params: 25,090
Non-trainable params: 0
_________________________________________________________________
None
In [42]:
Out[42]:
(7, 7, 64)
(3, 28, 28, 1)
(3, 2)
(3, 2)
In [43]:
In [44]:
Out[44]:
Model: "functional_27"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 2)] 0
_________________________________________________________________
dense_5 (Dense) (None, 3136) 9408
_________________________________________________________________
reshape_5 (Reshape) (None, 7, 7, 64) 0
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 14, 14, 64) 36928
_________________________________________________________________
leaky_re_lu_26 (LeakyReLU) (None, 14, 14, 64) 0
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 28, 28, 32) 18464
_________________________________________________________________
leaky_re_lu_27 (LeakyReLU) (None, 28, 28, 32) 0
_________________________________________________________________
decoder_conv_t_2 (Conv2DTran (None, 28, 28, 1) 289
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________
None
In [45]:
Out[45]:
(5, 28, 28, 1)
In [18]:
In [25]:
In [46]:
Out[46]:
<__main__.ConvVAE object at 0x7f5f6cf70c50>
<__main__.ConvVAE object at 0x7f5f6cd6e208>
<__main__.ConvVAE object at 0x7f5f6d012f28>
<__main__.ConvVAE object at 0x7f5f6ceb5a90>
In [47]:
Out[47]:
size of batch (3, 28, 28, 1)
2d_det
size of encoding head: mean (3, 2)
size of encoding (3, 2)
size of decoding (3, 28, 28, 1)
20d_det
size of encoding head: mean (3, 20)
size of encoding (3, 20)
size of decoding (3, 28, 28, 1)
2d_stoch
size of encoding head: mean (3, 2), var (3, 2)
size of encoding (3, 2)
size of decoding (3, 28, 28, 1)
20d_stoch
size of encoding head: mean (3, 20), var (3, 20)
size of encoding (3, 20)
size of decoding (3, 28, 28, 1)
In [28]:
Out[28]:
2d_det
tf.Tensor(0.11379913, shape=(), dtype=float32)
(3, 3, 1, 32)
20d_det
tf.Tensor(0.113810055, shape=(), dtype=float32)
(3, 3, 1, 32)
2d_stoch
tf.Tensor(114.144844, shape=(), dtype=float32)
(3, 3, 1, 32)
20d_stoch
tf.Tensor(113.99734, shape=(), dtype=float32)
(3, 3, 1, 32)
Training
In [29]:
Out[29]:
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_loss at 0x7f5fc6135950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function ConvVAE.compute_loss at 0x7f5fc6135950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
Epoch 0, Test loss: 0.11572, time 42.00
Epoch 0, Test loss: 0.11573, time 42.00
Epoch 0, Test loss: 116.01754, time 42.00
Epoch 0, Test loss: 115.99959, time 42.00
In [30]:
In [49]:
Out[49]:
2d_det
Epoch 0, Test loss: 0.01745, time 7.29
Epoch 1, Test loss: 0.01643, time 6.85
In [33]:
Out[33]:
2d_det
Epoch 0, Test loss: 0.02336, time 6.76
Epoch 1, Test loss: 0.02327, time 6.70
20d_det
Epoch 0, Test loss: 0.00370, time 6.76
Epoch 1, Test loss: 0.00347, time 6.68
2d_stoch
Epoch 0, Test loss: 28.37840, time 7.84
Epoch 1, Test loss: 27.79435, time 7.08
20d_stoch
Epoch 0, Test loss: 24.08021, time 7.88
Epoch 1, Test loss: 23.32022, time 7.21
Post-training
Display images generated (and saved) during training
In [109]:
In [110]:
Out[110]:
2d_det
20d_det
2d_stoch
20d_stoch
Reconstructions
In [113]:
Out[113]:
2d_det
20d_det
2d_stoch
20d_stoch
2d Embeddings
In [123]:
Out[123]:
In [117]:
Out[117]:
In [119]:
Out[119]:
In [ ]: