Variational AutoEncoder
Author: fchollet
Date created: 2020/05/03
Last modified: 2024/04/24
Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
View in Colab โข
GitHub source
Setup
Create a sampling layer
Build the encoder
Model: "encoder"
โโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโ โ Layer (type) โ Output Shape โ Param # โ Connected to โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ input_layer โ (None, 28, 28, 1) โ 0 โ - โ โ (InputLayer) โ โ โ โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ conv2d (Conv2D) โ (None, 14, 14, โ 320 โ input_layer[0][0] โ โ โ 32) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ conv2d_1 (Conv2D) โ (None, 7, 7, 64) โ 18,496 โ conv2d[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ flatten (Flatten) โ (None, 3136) โ 0 โ conv2d_1[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ dense (Dense) โ (None, 16) โ 50,192 โ flatten[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ z_mean (Dense) โ (None, 2) โ 34 โ dense[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ z_log_var (Dense) โ (None, 2) โ 34 โ dense[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค โ sampling (Sampling) โ (None, 2) โ 0 โ z_mean[0][0], โ โ โ โ โ z_log_var[0][0] โ โโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโ
Total params: 69,076 (269.83 KB)
Trainable params: 69,076 (269.83 KB)
Non-trainable params: 0 (0.00 B)
Build the decoder
Model: "decoder"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโ โ Layer (type) โ Output Shape โ Param # โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ input_layer_1 (InputLayer) โ (None, 2) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dense_1 (Dense) โ (None, 3136) โ 9,408 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ reshape (Reshape) โ (None, 7, 7, 64) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_transpose โ (None, 14, 14, 64) โ 36,928 โ โ (Conv2DTranspose) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_transpose_1 โ (None, 28, 28, 32) โ 18,464 โ โ (Conv2DTranspose) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_transpose_2 โ (None, 28, 28, 1) โ 289 โ โ (Conv2DTranspose) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโ
Total params: 65,089 (254.25 KB)
Trainable params: 65,089 (254.25 KB)
Non-trainable params: 0 (0.00 B)
Define the VAE as a Model with a custom train_step
Train the VAE
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1700704358.696643 3339857 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. W0000 00:00:1700704358.714145 3339857 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update W0000 00:00:1700704358.716080 3339857 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
547/547 โโโโโโโโโโโโโโโโโโโโ 0s 9ms/step - kl_loss: 2.9140 - loss: 262.3454 - reconstruction_loss: 259.4314
W0000 00:00:1700704363.390106 3339858 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update W0000 00:00:1700704363.392582 3339858 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
547/547 โโโโโโโโโโโโโโโโโโโโ 11s 9ms/step - kl_loss: 2.9145 - loss: 262.3454 - reconstruction_loss: 259.3424 - total_loss: 213.8374 Epoch 2/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 5.2591 - loss: 177.2659 - reconstruction_loss: 171.9981 - total_loss: 172.5344 Epoch 3/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.0199 - loss: 166.4822 - reconstruction_loss: 160.4603 - total_loss: 165.3463 Epoch 4/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 3ms/step - kl_loss: 6.1585 - loss: 163.0588 - reconstruction_loss: 156.8987 - total_loss: 162.2310 Epoch 5/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.2646 - loss: 160.6541 - reconstruction_loss: 154.3888 - total_loss: 160.2672 Epoch 6/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.3202 - loss: 159.1411 - reconstruction_loss: 152.8203 - total_loss: 158.8850 Epoch 7/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.3759 - loss: 157.8918 - reconstruction_loss: 151.5157 - total_loss: 157.8260 Epoch 8/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.3899 - loss: 157.2225 - reconstruction_loss: 150.8320 - total_loss: 156.8395 Epoch 9/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4204 - loss: 156.0726 - reconstruction_loss: 149.6520 - total_loss: 156.0463 Epoch 10/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4176 - loss: 155.6229 - reconstruction_loss: 149.2051 - total_loss: 155.4912 Epoch 11/30 547/547 โโโโโโโโโโโโโโโโโโโโ 3s 4ms/step - kl_loss: 6.4297 - loss: 155.0198 - reconstruction_loss: 148.5899 - total_loss: 154.9487 Epoch 12/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4338 - loss: 154.1115 - reconstruction_loss: 147.6781 - total_loss: 154.3575 Epoch 13/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4356 - loss: 153.9087 - reconstruction_loss: 147.4730 - total_loss: 153.8745 Epoch 14/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4506 - loss: 153.7804 - reconstruction_loss: 147.3295 - total_loss: 153.6391 Epoch 15/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4399 - loss: 152.7727 - reconstruction_loss: 146.3336 - total_loss: 153.2117 Epoch 16/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4661 - loss: 152.7382 - reconstruction_loss: 146.2725 - total_loss: 152.9310 Epoch 17/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4566 - loss: 152.3313 - reconstruction_loss: 145.8751 - total_loss: 152.5897 Epoch 18/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4613 - loss: 152.4331 - reconstruction_loss: 145.9715 - total_loss: 152.2775 Epoch 19/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4551 - loss: 151.9406 - reconstruction_loss: 145.4857 - total_loss: 152.0997 Epoch 20/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4332 - loss: 152.1597 - reconstruction_loss: 145.7260 - total_loss: 151.8623 Epoch 21/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4644 - loss: 151.4290 - reconstruction_loss: 144.9649 - total_loss: 151.6146 Epoch 22/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4662 - loss: 151.1586 - reconstruction_loss: 144.6929 - total_loss: 151.4525 Epoch 23/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4532 - loss: 150.9665 - reconstruction_loss: 144.5139 - total_loss: 151.2734 Epoch 24/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4520 - loss: 151.2177 - reconstruction_loss: 144.7655 - total_loss: 151.1416 Epoch 25/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4537 - loss: 150.8981 - reconstruction_loss: 144.4445 - total_loss: 151.0104 Epoch 26/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4669 - loss: 150.5807 - reconstruction_loss: 144.1143 - total_loss: 150.8807 Epoch 27/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4575 - loss: 150.3731 - reconstruction_loss: 143.9162 - total_loss: 150.7236 Epoch 28/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4644 - loss: 150.7117 - reconstruction_loss: 144.2471 - total_loss: 150.6108 Epoch 29/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4902 - loss: 150.1759 - reconstruction_loss: 143.6862 - total_loss: 150.4756 Epoch 30/30 547/547 โโโโโโโโโโโโโโโโโโโโ 2s 4ms/step - kl_loss: 6.4585 - loss: 150.6554 - reconstruction_loss: 144.1964 - total_loss: 150.3988
<keras.src.callbacks.history.History at 0x7fbe44614eb0>
