Path: blob/master/first_edition/8.5-introduction-to-gans.ipynb
709 views
Introduction to generative adversarial networks
This notebook contains the second code sample found in Chapter 8, Section 5 of Deep Learning with Python. Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.
[...]
A schematic GAN implementation
In what follows, we explain how to implement a GAN in Keras, in its barest form -- since GANs are quite advanced, diving deeply into the technical details would be out of scope for us. Our specific implementation will be a deep convolutional GAN, or DCGAN: a GAN where the generator and discriminator are deep convnets. In particular, it leverages a Conv2DTranspose
layer for image upsampling in the generator.
We will train our GAN on images from CIFAR10, a dataset of 50,000 32x32 RGB images belong to 10 classes (5,000 images per class). To make things even easier, we will only use images belonging to the class "frog".
Schematically, our GAN looks like this:
A
generator
network maps vectors of shape(latent_dim,)
to images of shape(32, 32, 3)
.A
discriminator
network maps images of shape (32, 32, 3) to a binary score estimating the probability that the image is real.A
gan
network chains the generator and the discriminator together:gan(x) = discriminator(generator(x))
. Thus thisgan
network maps latent space vectors to the discriminator's assessment of the realism of these latent vectors as decoded by the generator.We train the discriminator using examples of real and fake images along with "real"/"fake" labels, as we would train any regular image classification model.
To train the generator, we use the gradients of the generator's weights with regard to the loss of the
gan
model. This means that, at every step, we move the weights of the generator in a direction that will make the discriminator more likely to classify as "real" the images decoded by the generator. I.e. we train the generator to fool the discriminator.
A bag of tricks
Training GANs and tuning GAN implementations is notoriously difficult. There are a number of known "tricks" that one should keep in mind. Like most things in deep learning, it is more alchemy than science: these tricks are really just heuristics, not theory-backed guidelines. They are backed by some level of intuitive understanding of the phenomenon at hand, and they are known to work well empirically, albeit not necessarily in every context.
Here are a few of the tricks that we leverage in our own implementation of a GAN generator and discriminator below. It is not an exhaustive list of GAN-related tricks; you will find many more across the GAN literature.
We use
tanh
as the last activation in the generator, instead ofsigmoid
, which would be more commonly found in other types of models.We sample points from the latent space using a normal distribution (Gaussian distribution), not a uniform distribution.
Stochasticity is good to induce robustness. Since GAN training results in a dynamic equilibrium, GANs are likely to get "stuck" in all sorts of ways. Introducing randomness during training helps prevent this. We introduce randomness in two ways: 1) we use dropout in the discriminator, 2) we add some random noise to the labels for the discriminator.
Sparse gradients can hinder GAN training. In deep learning, sparsity is often a desirable property, but not in GANs. There are two things that can induce gradient sparsity: 1) max pooling operations, 2) ReLU activations. Instead of max pooling, we recommend using strided convolutions for downsampling, and we recommend using a
LeakyReLU
layer instead of a ReLU activation. It is similar to ReLU but it relaxes sparsity constraints by allowing small negative activation values.In generated images, it is common to see "checkerboard artifacts" caused by unequal coverage of the pixel space in the generator. To fix this, we use a kernel size that is divisible by the stride size, whenever we use a strided
Conv2DTranpose
orConv2D
in both the generator and discriminator.
The generator
First, we develop a generator
model, which turns a vector (from the latent space -- during training it will sampled at random) into a candidate image. One of the many issues that commonly arise with GANs is that the generator gets stuck with generated images that look like noise. A possible solution is to use dropout on both the discriminator and generator.
The discriminator
Then, we develop a discriminator
model, that takes as input a candidate image (real or synthetic) and classifies it into one of two classes, either "generated image" or "real image that comes from the training set".
The adversarial network
Finally, we setup the GAN, which chains the generator and the discriminator. This is the model that, when trained, will move the generator in a direction that improves its ability to fool the discriminator. This model turns latent space points into a classification decision, "fake" or "real", and it is meant to be trained with labels that are always "these are real images". So training gan
will updates the weights of generator
in a way that makes discriminator
more likely to predict "real" when looking at fake images. Very importantly, we set the discriminator to be frozen during training (non-trainable): its weights will not be updated when training gan
. If the discriminator weights could be updated during this process, then we would be training the discriminator to always predict "real", which is not what we want!
How to train your DCGAN
Now we can start training. To recapitulate, this is schematically what the training loop looks like:
Let's implement it:
Let's display a few of our fake images:
Froggy with some pixellated artifacts.