Path: blob/master/examples/generative/ipynb/wgan_gp.ipynb
8146 views
WGAN-GP overriding Model.train_step
Author: A_K_Nain
Date created: 2020/05/9
Last modified: 2023/08/3
Description: Implementation of Wasserstein GAN with Gradient Penalty.
Wasserstein GAN (WGAN) with Gradient Penalty (GP)
The original Wasserstein GAN leverages the Wasserstein distance to produce a value function that has better theoretical properties than the value function used in the original GAN paper. WGAN requires that the discriminator (aka the critic) lie within the space of 1-Lipschitz functions. The authors proposed the idea of weight clipping to achieve this constraint. Though weight clipping works, it can be a problematic way to enforce 1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN discriminator (critic) often fails to converge.
The WGAN-GP method proposes an alternative to weight clipping to ensure smooth training. Instead of clipping the weights, the authors proposed a "gradient penalty" by adding a loss term that keeps the L2 norm of the discriminator gradients close to 1.
Setup
Prepare the Fashion-MNIST data
To demonstrate how to train WGAN-GP, we will be using the Fashion-MNIST dataset. Each sample in this dataset is a 28x28 grayscale image associated with a label from 10 classes (e.g. trouser, pullover, sneaker, etc.)
Create the discriminator (the critic in the original WGAN)
The samples in the dataset have a (28, 28, 1) shape. Because we will be using strided convolutions, this can result in a shape with odd dimensions. For example, (28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3).
While performing upsampling in the generator part of the network, we won't get the same input shape as the original images if we aren't careful. To avoid this, we will do something much simpler:
In the discriminator: "zero pad" the input to change the shape to
(32, 32, 1)for each sample; andIhe generator: crop the final output to match the shape with input shape.
Create the generator
Create the WGAN-GP model
Now that we have defined our generator and discriminator, it's time to implement the WGAN-GP model. We will also override the train_step for training.
Create a Keras callback that periodically saves generated images
Train the end-to-end model
Display the last generated images: