Path: blob/master/examples/generative/md/wgan_gp.md
8215 views
WGAN-GP overriding Model.train_step
Author: A_K_Nain
Date created: 2020/05/9
Last modified: 2023/08/3
Description: Implementation of Wasserstein GAN with Gradient Penalty.
Wasserstein GAN (WGAN) with Gradient Penalty (GP)
The original Wasserstein GAN leverages the Wasserstein distance to produce a value function that has better theoretical properties than the value function used in the original GAN paper. WGAN requires that the discriminator (aka the critic) lie within the space of 1-Lipschitz functions. The authors proposed the idea of weight clipping to achieve this constraint. Though weight clipping works, it can be a problematic way to enforce 1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN discriminator (critic) often fails to converge.
The WGAN-GP method proposes an alternative to weight clipping to ensure smooth training. Instead of clipping the weights, the authors proposed a "gradient penalty" by adding a loss term that keeps the L2 norm of the discriminator gradients close to 1.
Setup
Prepare the Fashion-MNIST data
To demonstrate how to train WGAN-GP, we will be using the Fashion-MNIST dataset. Each sample in this dataset is a 28x28 grayscale image associated with a label from 10 classes (e.g. trouser, pullover, sneaker, etc.)
Model: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ zero_padding2d (ZeroPadding2D) │ (None, 32, 32, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d (Conv2D) │ (None, 16, 16, 64) │ 1,664 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 8, 8, 128) │ 204,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (Dropout) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (Conv2D) │ (None, 4, 4, 256) │ 819,456 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_2 (LeakyReLU) │ (None, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_1 (Dropout) │ (None, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (Conv2D) │ (None, 2, 2, 512) │ 3,277,312 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_3 (LeakyReLU) │ (None, 2, 2, 512) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (Flatten) │ (None, 2048) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_2 (Dropout) │ (None, 2048) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 1) │ 2,049 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 4,305,409 (16.42 MB)
Trainable params: 4,305,409 (16.42 MB)
Non-trainable params: 0 (0.00 B)
Create the generator
Model: "generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (Dense) │ (None, 4096) │ 524,288 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization │ (None, 4096) │ 16,384 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 4096) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ reshape (Reshape) │ (None, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d (UpSampling2D) │ (None, 8, 8, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_4 (Conv2D) │ (None, 8, 8, 128) │ 294,912 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_1 │ (None, 8, 8, 128) │ 512 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_5 (LeakyReLU) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d_1 (UpSampling2D) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_5 (Conv2D) │ (None, 16, 16, 64) │ 73,728 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_2 │ (None, 16, 16, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_6 (LeakyReLU) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d_2 (UpSampling2D) │ (None, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_6 (Conv2D) │ (None, 32, 32, 1) │ 576 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_3 │ (None, 32, 32, 1) │ 4 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ activation (Activation) │ (None, 32, 32, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ cropping2d (Cropping2D) │ (None, 28, 28, 1) │ 0 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 910,660 (3.47 MB)
Trainable params: 902,082 (3.44 MB)
Non-trainable params: 8,578 (33.51 KB)
Create the WGAN-GP model
Now that we have defined our generator and discriminator, it's time to implement the WGAN-GP model. We will also override the train_step for training.
Create a Keras callback that periodically saves generated images
Train the end-to-end model
<keras.src.callbacks.history.History at 0x7fc763a8e950>


