Path: blob/master/site/en-snapshot/tutorials/generative/cvae.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
Convolutional Variational Autoencoder
This notebook demonstrates how to train a Variational Autoencoder (VAE) (1, 2) on the MNIST dataset. A VAE is a probabilistic take on the autoencoder, a model which takes high dimensional input data and compresses it into a smaller representation. Unlike a traditional autoencoder, which maps the input onto a latent vector, a VAE maps the input data into the parameters of a probability distribution, such as the mean and variance of a Gaussian. This approach produces a continuous, structured latent space, which is useful for image generation.
Setup
Load the MNIST dataset
Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. Model each pixel with a Bernoulli distribution in our model, and statically binarize the dataset.
Use tf.data to batch and shuffle the data
Define the encoder and decoder networks with tf.keras.Sequential
In this VAE example, use two small ConvNets for the encoder and decoder networks. In the literature, these networks are also referred to as inference/recognition and generative models respectively. Use tf.keras.Sequential
to simplify implementation. Let and denote the observation and latent variable respectively in the following descriptions.
Encoder network
This defines the approximate posterior distribution , which takes as input an observation and outputs a set of parameters for specifying the conditional distribution of the latent representation . In this example, simply model the distribution as a diagonal Gaussian, and the network outputs the mean and log-variance parameters of a factorized Gaussian. Output log-variance instead of the variance directly for numerical stability.
Decoder network
This defines the conditional distribution of the observation , which takes a latent sample as input and outputs the parameters for a conditional distribution of the observation. Model the latent distribution prior as a unit Gaussian.
Reparameterization trick
To generate a sample for the decoder during training, you can sample from the latent distribution defined by the parameters outputted by the encoder, given an input observation . However, this sampling operation creates a bottleneck because backpropagation cannot flow through a random node.
To address this, use a reparameterization trick. In our example, you approximate using the decoder parameters and another parameter as follows:
where and represent the mean and standard deviation of a Gaussian distribution respectively. They can be derived from the decoder output. The can be thought of as a random noise used to maintain stochasticity of . Generate from a standard normal distribution.
The latent variable is now generated by a function of , and , which would enable the model to backpropagate gradients in the encoder through and respectively, while maintaining stochasticity through .
Network architecture
For the encoder network, use two convolutional layers followed by a fully-connected layer. In the decoder network, mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling.
Define the loss function and the optimizer
VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:
In practice, optimize the single sample Monte Carlo estimate of this expectation:
where is sampled from .
Note: You could also analytically compute the KL term, but here you incorporate all three terms in the Monte Carlo estimator for simplicity.
Training
Start by iterating over the dataset
During each iteration, pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior
then apply the reparameterization trick to sample from
Finally, pass the reparameterized samples to the decoder to obtain the logits of the generative distribution
Note: Since you use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.
Generating images
After training, it is time to generate some images
Start by sampling a set of latent vectors from the unit Gaussian prior distribution
The generator will then convert the latent sample to logits of the observation, giving a distribution
Here, plot the probabilities of Bernoulli distributions
Display a generated image from the last training epoch
Display an animated GIF of all the saved images
Display a 2D manifold of digits from the latent space
Running the code below will show a continuous distribution of the different digit classes, with each digit morphing into another across the 2D latent space. Use TensorFlow Probability to generate a standard normal distribution for the latent space.
Next steps
This tutorial has demonstrated how to implement a convolutional variational autoencoder using TensorFlow.
As a next step, you could try to improve the model output by increasing the network size. For instance, you could try setting the filter
parameters for each of the Conv2D
and Conv2DTranspose
layers to 512. Note that in order to generate the final 2D latent image plot, you would need to keep latent_dim
to 2. Also, the training time would increase as the network size increases.
You could also try implementing a VAE using a different dataset, such as CIFAR-10.
VAEs can be implemented in several different styles and of varying complexity. You can find additional implementations in the following sources:
If you'd like to learn more about the details of VAEs, please refer to An Introduction to Variational Autoencoders.