Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/wgan_gp.py
7966 views
1
"""
2
Title: WGAN-GP overriding `Model.train_step`
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/05/9
5
Last modified: 2023/08/3
6
Description: Implementation of Wasserstein GAN with Gradient Penalty.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Wasserstein GAN (WGAN) with Gradient Penalty (GP)
12
13
The original [Wasserstein GAN](https://arxiv.org/abs/1701.07875) leverages the
14
Wasserstein distance to produce a value function that has better theoretical
15
properties than the value function used in the original GAN paper. WGAN requires
16
that the discriminator (aka the critic) lie within the space of 1-Lipschitz
17
functions. The authors proposed the idea of weight clipping to achieve this
18
constraint. Though weight clipping works, it can be a problematic way to enforce
19
1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN
20
discriminator (critic) often fails to converge.
21
22
The [WGAN-GP](https://arxiv.org/abs/1704.00028) method proposes an
23
alternative to weight clipping to ensure smooth training. Instead of clipping
24
the weights, the authors proposed a "gradient penalty" by adding a loss term
25
that keeps the L2 norm of the discriminator gradients close to 1.
26
"""
27
28
"""
29
## Setup
30
"""
31
import os
32
33
os.environ["KERAS_BACKEND"] = "tensorflow"
34
35
import keras
36
import tensorflow as tf
37
from keras import layers
38
39
"""
40
## Prepare the Fashion-MNIST data
41
42
To demonstrate how to train WGAN-GP, we will be using the
43
[Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. Each
44
sample in this dataset is a 28x28 grayscale image associated with a label from
45
10 classes (e.g. trouser, pullover, sneaker, etc.)
46
"""
47
48
IMG_SHAPE = (28, 28, 1)
49
BATCH_SIZE = 512
50
51
# Size of the noise vector
52
noise_dim = 128
53
54
fashion_mnist = keras.datasets.fashion_mnist
55
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
56
print(f"Number of examples: {len(train_images)}")
57
print(f"Shape of the images in the dataset: {train_images.shape[1:]}")
58
59
# Reshape each sample to (28, 28, 1) and normalize the pixel values in the [-1, 1] range
60
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
61
train_images = (train_images - 127.5) / 127.5
62
63
"""
64
## Create the discriminator (the critic in the original WGAN)
65
66
The samples in the dataset have a (28, 28, 1) shape. Because we will be
67
using strided convolutions, this can result in a shape with odd dimensions.
68
For example,
69
`(28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3)`.
70
71
While performing upsampling in the generator part of the network, we won't get
72
the same input shape as the original images if we aren't careful. To avoid this,
73
we will do something much simpler:
74
- In the discriminator: "zero pad" the input to change the shape to `(32, 32, 1)`
75
for each sample; and
76
- Ihe generator: crop the final output to match the shape with input shape.
77
"""
78
79
80
def conv_block(
81
x,
82
filters,
83
activation,
84
kernel_size=(3, 3),
85
strides=(1, 1),
86
padding="same",
87
use_bias=True,
88
use_bn=False,
89
use_dropout=False,
90
drop_value=0.5,
91
):
92
x = layers.Conv2D(
93
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
94
)(x)
95
if use_bn:
96
x = layers.BatchNormalization()(x)
97
x = activation(x)
98
if use_dropout:
99
x = layers.Dropout(drop_value)(x)
100
return x
101
102
103
def get_discriminator_model():
104
img_input = layers.Input(shape=IMG_SHAPE)
105
# Zero pad the input to make the input images size to (32, 32, 1).
106
x = layers.ZeroPadding2D((2, 2))(img_input)
107
x = conv_block(
108
x,
109
64,
110
kernel_size=(5, 5),
111
strides=(2, 2),
112
use_bn=False,
113
use_bias=True,
114
activation=layers.LeakyReLU(0.2),
115
use_dropout=False,
116
drop_value=0.3,
117
)
118
x = conv_block(
119
x,
120
128,
121
kernel_size=(5, 5),
122
strides=(2, 2),
123
use_bn=False,
124
activation=layers.LeakyReLU(0.2),
125
use_bias=True,
126
use_dropout=True,
127
drop_value=0.3,
128
)
129
x = conv_block(
130
x,
131
256,
132
kernel_size=(5, 5),
133
strides=(2, 2),
134
use_bn=False,
135
activation=layers.LeakyReLU(0.2),
136
use_bias=True,
137
use_dropout=True,
138
drop_value=0.3,
139
)
140
x = conv_block(
141
x,
142
512,
143
kernel_size=(5, 5),
144
strides=(2, 2),
145
use_bn=False,
146
activation=layers.LeakyReLU(0.2),
147
use_bias=True,
148
use_dropout=False,
149
drop_value=0.3,
150
)
151
152
x = layers.Flatten()(x)
153
x = layers.Dropout(0.2)(x)
154
x = layers.Dense(1)(x)
155
156
d_model = keras.models.Model(img_input, x, name="discriminator")
157
return d_model
158
159
160
d_model = get_discriminator_model()
161
d_model.summary()
162
163
"""
164
## Create the generator
165
"""
166
167
168
def upsample_block(
169
x,
170
filters,
171
activation,
172
kernel_size=(3, 3),
173
strides=(1, 1),
174
up_size=(2, 2),
175
padding="same",
176
use_bn=False,
177
use_bias=True,
178
use_dropout=False,
179
drop_value=0.3,
180
):
181
x = layers.UpSampling2D(up_size)(x)
182
x = layers.Conv2D(
183
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
184
)(x)
185
186
if use_bn:
187
x = layers.BatchNormalization()(x)
188
189
if activation:
190
x = activation(x)
191
if use_dropout:
192
x = layers.Dropout(drop_value)(x)
193
return x
194
195
196
def get_generator_model():
197
noise = layers.Input(shape=(noise_dim,))
198
x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
199
x = layers.BatchNormalization()(x)
200
x = layers.LeakyReLU(0.2)(x)
201
202
x = layers.Reshape((4, 4, 256))(x)
203
x = upsample_block(
204
x,
205
128,
206
layers.LeakyReLU(0.2),
207
strides=(1, 1),
208
use_bias=False,
209
use_bn=True,
210
padding="same",
211
use_dropout=False,
212
)
213
x = upsample_block(
214
x,
215
64,
216
layers.LeakyReLU(0.2),
217
strides=(1, 1),
218
use_bias=False,
219
use_bn=True,
220
padding="same",
221
use_dropout=False,
222
)
223
x = upsample_block(
224
x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
225
)
226
# At this point, we have an output which has the same shape as the input, (32, 32, 1).
227
# We will use a Cropping2D layer to make it (28, 28, 1).
228
x = layers.Cropping2D((2, 2))(x)
229
230
g_model = keras.models.Model(noise, x, name="generator")
231
return g_model
232
233
234
g_model = get_generator_model()
235
g_model.summary()
236
237
"""
238
## Create the WGAN-GP model
239
240
Now that we have defined our generator and discriminator, it's time to implement
241
the WGAN-GP model. We will also override the `train_step` for training.
242
"""
243
244
245
class WGAN(keras.Model):
246
def __init__(
247
self,
248
discriminator,
249
generator,
250
latent_dim,
251
discriminator_extra_steps=3,
252
gp_weight=10.0,
253
):
254
super().__init__()
255
self.discriminator = discriminator
256
self.generator = generator
257
self.latent_dim = latent_dim
258
self.d_steps = discriminator_extra_steps
259
self.gp_weight = gp_weight
260
261
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
262
super().compile()
263
self.d_optimizer = d_optimizer
264
self.g_optimizer = g_optimizer
265
self.d_loss_fn = d_loss_fn
266
self.g_loss_fn = g_loss_fn
267
268
def gradient_penalty(self, batch_size, real_images, fake_images):
269
"""Calculates the gradient penalty.
270
271
This loss is calculated on an interpolated image
272
and added to the discriminator loss.
273
"""
274
# Get the interpolated image
275
alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
276
diff = fake_images - real_images
277
interpolated = real_images + alpha * diff
278
279
with tf.GradientTape() as gp_tape:
280
gp_tape.watch(interpolated)
281
# 1. Get the discriminator output for this interpolated image.
282
pred = self.discriminator(interpolated, training=True)
283
284
# 2. Calculate the gradients w.r.t to this interpolated image.
285
grads = gp_tape.gradient(pred, [interpolated])[0]
286
# 3. Calculate the norm of the gradients.
287
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
288
gp = tf.reduce_mean((norm - 1.0) ** 2)
289
return gp
290
291
def train_step(self, real_images):
292
if isinstance(real_images, tuple):
293
real_images = real_images[0]
294
295
# Get the batch size
296
batch_size = tf.shape(real_images)[0]
297
298
# For each batch, we are going to perform the
299
# following steps as laid out in the original paper:
300
# 1. Train the generator and get the generator loss
301
# 2. Train the discriminator and get the discriminator loss
302
# 3. Calculate the gradient penalty
303
# 4. Multiply this gradient penalty with a constant weight factor
304
# 5. Add the gradient penalty to the discriminator loss
305
# 6. Return the generator and discriminator losses as a loss dictionary
306
307
# Train the discriminator first. The original paper recommends training
308
# the discriminator for `x` more steps (typically 5) as compared to
309
# one step of the generator. Here we will train it for 3 extra steps
310
# as compared to 5 to reduce the training time.
311
for i in range(self.d_steps):
312
# Get the latent vector
313
random_latent_vectors = tf.random.normal(
314
shape=(batch_size, self.latent_dim)
315
)
316
with tf.GradientTape() as tape:
317
# Generate fake images from the latent vector
318
fake_images = self.generator(random_latent_vectors, training=True)
319
# Get the logits for the fake images
320
fake_logits = self.discriminator(fake_images, training=True)
321
# Get the logits for the real images
322
real_logits = self.discriminator(real_images, training=True)
323
324
# Calculate the discriminator loss using the fake and real image logits
325
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
326
# Calculate the gradient penalty
327
gp = self.gradient_penalty(batch_size, real_images, fake_images)
328
# Add the gradient penalty to the original discriminator loss
329
d_loss = d_cost + gp * self.gp_weight
330
331
# Get the gradients w.r.t the discriminator loss
332
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
333
# Update the weights of the discriminator using the discriminator optimizer
334
self.d_optimizer.apply_gradients(
335
zip(d_gradient, self.discriminator.trainable_variables)
336
)
337
338
# Train the generator
339
# Get the latent vector
340
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
341
with tf.GradientTape() as tape:
342
# Generate fake images using the generator
343
generated_images = self.generator(random_latent_vectors, training=True)
344
# Get the discriminator logits for fake images
345
gen_img_logits = self.discriminator(generated_images, training=True)
346
# Calculate the generator loss
347
g_loss = self.g_loss_fn(gen_img_logits)
348
349
# Get the gradients w.r.t the generator loss
350
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
351
# Update the weights of the generator using the generator optimizer
352
self.g_optimizer.apply_gradients(
353
zip(gen_gradient, self.generator.trainable_variables)
354
)
355
return {"d_loss": d_loss, "g_loss": g_loss}
356
357
358
"""
359
## Create a Keras callback that periodically saves generated images
360
"""
361
362
363
class GANMonitor(keras.callbacks.Callback):
364
def __init__(self, num_img=6, latent_dim=128):
365
self.num_img = num_img
366
self.latent_dim = latent_dim
367
368
def on_epoch_end(self, epoch, logs=None):
369
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
370
generated_images = self.model.generator(random_latent_vectors)
371
generated_images = (generated_images * 127.5) + 127.5
372
373
for i in range(self.num_img):
374
img = generated_images[i].numpy()
375
img = keras.utils.array_to_img(img)
376
img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
377
378
379
"""
380
## Train the end-to-end model
381
"""
382
383
# Instantiate the optimizer for both networks
384
# (learning_rate=0.0002, beta_1=0.5 are recommended)
385
generator_optimizer = keras.optimizers.Adam(
386
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
387
)
388
discriminator_optimizer = keras.optimizers.Adam(
389
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
390
)
391
392
393
# Define the loss functions for the discriminator,
394
# which should be (fake_loss - real_loss).
395
# We will add the gradient penalty later to this loss function.
396
def discriminator_loss(real_img, fake_img):
397
real_loss = tf.reduce_mean(real_img)
398
fake_loss = tf.reduce_mean(fake_img)
399
return fake_loss - real_loss
400
401
402
# Define the loss functions for the generator.
403
def generator_loss(fake_img):
404
return -tf.reduce_mean(fake_img)
405
406
407
# Set the number of epochs for training.
408
epochs = 20
409
410
# Instantiate the customer `GANMonitor` Keras callback.
411
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)
412
413
# Get the wgan model
414
wgan = WGAN(
415
discriminator=d_model,
416
generator=g_model,
417
latent_dim=noise_dim,
418
discriminator_extra_steps=3,
419
)
420
421
# Compile the wgan model
422
wgan.compile(
423
d_optimizer=discriminator_optimizer,
424
g_optimizer=generator_optimizer,
425
g_loss_fn=generator_loss,
426
d_loss_fn=discriminator_loss,
427
)
428
429
# Start training
430
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
431
432
"""
433
Display the last generated images:
434
"""
435
436
from IPython.display import Image, display
437
438
display(Image("generated_img_0_19.png"))
439
display(Image("generated_img_1_19.png"))
440
display(Image("generated_img_2_19.png"))
441
442
"""
443
## Relevant Chapters from Deep Learning with Python
444
- [Chapter 17: Image generation](https://deeplearningwithpython.io/chapters/chapter17_image-generation)
445
"""
446
447