Path: blob/main/C2 - Build Better Generative Adversarial Networks/Week 2/C2W2_VAE.ipynb
1220 views
Variational Autoencoder (VAE)
Check out the sister of the GAN: VAE. In this lab, you'll explore the components of a basic VAE to understand how it works.
The "AE" in VAE stands for autoencoder. As an autoencoder, the VAE has two parts: an encoder and a decoder. Instead of mapping each image to a single point in -space, the encoder outputs the means and covariance matrices of a multivariate normal distribution where all of the dimensions are independent. You should have had a chance to read more about multivariate normal distributions in last week's assignment, but you can think of the output of the encoder of a VAE this way: the means and standard deviations of a set of independent normal distributions, with one normal distribution (one mean and standard deviation) for each latent dimension.
VAE Architecture Drawing: The encoding outputs a distribution in -space, and to generate an image you sample from the distributon and pass the -space sample to the decoder, which returns an image. VAE latent space visualization from Hyperspherical Variational Auto-Encoders, by Davidson et al. in UAI 2018
Encoder and Decoder
For your encoder and decoder, you can use similar architectures that you've seen before, with some tweaks. For example, for the decoder, you can use the DCGAN generator architecture. For the encoder, you can use a classifier that you used before, and instead of having it produce 1 classification output of whether something is a cat or not, for example, you can have it produce 2 different outputs, one for mean and one for standard deviation. Each of those outputs will have dimensionality to model the dimensions in the multivariate normal distributions.
VAE
You can define the VAE using the encoder and decoder as follows. In the forward pass, the VAE samples from the encoder's output distribution before passing a value to the decoder. A common mistake is to pass the mean to the decoder --- this leads to blurrier images and is not the way in which VAEs are designed to be used. So, the steps you'll take are:
Real image input to encoder
Encoder outputs mean and standard deviation
Sample from distribution with the outputed mean and standard deviation
Take sampled value (vector/latent) as the input to the decoder
Get fake sample
Use reconstruction loss between the fake output of the decoder and the original real input to the encoder (more about this later - keep reading!)
Backpropagate through
Quick Note on Implementation Notation ("Reparameterization Trick")
Most machine learning frameworks will not backpropagate through a random sample (Step 3-4 above work in the forward pass, but its gradient is not readily implemented for the backward pass using the usual notation). In PyTorch, you can do this by sampling using the rsample method, such as in Normal(mean, stddev).rsample(). This is equivalent to torch.randn(z_dim) * stddev + mean, but do not use torch.normal(mean, stddev), as the optimizer will not backpropagate through the expectation of that sample. This is also known as the reparameterization trick, since you're moving the parameters of the random sample outside of the the function to explicitly highlight that the gradient should be calculated through these parameters.
Evidence Lower Bound (ELBO)
When training a VAE, you're trying to maximize the likelihood of the real images. What this means is that you'd like the learned probability distribution to think it's likely that a real image (and the features in that real image) occurs -- as opposed to, say, random noise or weird-looking things. And you want to maximize the likelihood of the real stuff occurring and appropriately associate it with a point in the latent space distribution prior (more on this below), which is where your learned latent noise vectors will live. However, finding this likelihood explicitly is mathematically intractable. So, instead, you can get a good lower bound on the likelihood, meaning you can figure out what the worst-case scenario of the likelihood is (its lower bound which is mathematically tractable) and try to maximize that instead. Because if you maximize its lower bound, or worst-case, then you probably are making the likelihood better too. And this neat technique is known as maximizing the Evidence Lower Bound (ELBO).
Some notation before jumping into explaining ELBO: First, the prior latent space distribution is the prior probability you have on the latent space . This represents the likelihood of a given latent point in the latent space, and you know what this actually is because you set it in the beginning as a multivariate normal distribution. Additionally, refers to the posterior probability, or the distribution of the encoded images. Keep in mind that when each image is passed through the encoder, its encoding is a probability distribution.
Knowing that notation, here's the mathematical notation for the ELBO of a VAE, which is the lower bound you want to maximize: , which is equivalent to
ELBO can be broken down into two parts: the reconstruction loss and the KL divergence term . You'll explore each of these two terms in the next code and text sections.
Reconstruction Loss
Reconstruction loss refers to the distance between the real input image (that you put into the encoder) and the generated image (that comes out of the decoder). Explicitly, the reconstruction loss term is , the log probability of the true image given the latent value.
For MNIST, you can treat each grayscale prediction as a binary random variable (also known as a Bernoulli distribution) with the value between 0 and 1 of a pixel corresponding to the output brightness, so you can use the binary cross entropy loss between the real input image and the generated image in order to represent the reconstruction loss term.
In general, different assumptions about the "distribution" of the pixel brightnesses in an image will lead to different loss functions. For example, if you assume that the brightnesses of the pixels actually follow a normal distribution instead of a binary random (Bernoulli) distribution, this corresponds to a mean squared error (MSE) reconstruction loss.
Why the mean squared error? Well, as a point moves away from the center, , of a normal distribution, its negative log likelihood increases quadratically. You can also write this as for . As a result, assuming the pixel brightnesses are normally distributed implies an MSE reconstruction loss.
KL Divergence
KL divergence, mentioned in a video (on Inception Score) last week, allows you to evaluate how different one probability distribution is from another. If you have two distributions and they are exactly the same, then KL divergence is equal to 0. KL divergence is close to the notion of distance between distributions, but notice that it's called a divergence, not a distance; this is because it is not symmetric, meaning that is usually not equal to the terms flipped . In contrast, a true distance function, like the Euclidean distance where you would take the squared difference between two points, is symmetric where you compare and .
Now, you care about two distributions and finding how different they are: (1) the learned latent space that your encoder is trying to model and (2) your prior on the latent space , which you want your learned latent space to be as close as possible to. If both of your distributions are normal distributions, you can calculate the KL divergence, or , based on a simple formula. This makes KL divergence an attractive measure to use and the normal distribution a simultaneously attractive distribution to assume on your model and data.
Well, your encoder is learning , but what's your latent prior ? It is actually a fairly simple distribution for the latent space with a mean of zero and a standard deviation of one in each dimension, or . You might also come across this as the spherical normal distribution, where the in stands for the identity matrix, meaning its covariance is 1 along the entire diagonal of the matrix and if you like geometry, it forms a nice symmetric-looking hypersphere, or a sphere with many (here, ) dimensions.
Further Resources
An accessible but complete discussion and derivation of the evidence lower bound (ELBO) and the theory behind it can be found at this link and this lecture.
Training a VAE
Here you can train a VAE, once again using MNIST! First, define the dataloader:
Then, you can run the training loop to observe the training process:
If you're interested in learning more about VAE's here are some useful resources:
-VAEs showed that you can weight the KL-divergence term differently to reward "exploration" by the model.
VQ-VAE-2 is a VAE-Autoregressive hybrid generative model, and has been ablbe to generate incredibly diverse images - keeping up with GANs. 😃
VAE-GAN is a VAE-GAN hybrid generative model that uses an adversarial loss (that is, the discriminator's judgments on real/fake) on a VAE.