Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/dcgan_overriding_train_step.py
8146 views
1
"""
2
Title: DCGAN to generate face images
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/04/29
5
Last modified: 2023/12/21
6
Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import keras
15
import tensorflow as tf
16
17
from keras import layers
18
from keras import ops
19
import matplotlib.pyplot as plt
20
import os
21
import gdown
22
from zipfile import ZipFile
23
24
"""
25
## Prepare CelebA data
26
27
We'll use face images from the CelebA dataset, resized to 64x64.
28
"""
29
30
os.makedirs("celeba_gan")
31
32
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
33
output = "celeba_gan/data.zip"
34
gdown.download(url, output, quiet=True)
35
36
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
37
zipobj.extractall("celeba_gan")
38
39
"""
40
Create a dataset from our folder, and rescale the images to the [0-1] range:
41
"""
42
43
dataset = keras.utils.image_dataset_from_directory(
44
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
45
)
46
dataset = dataset.map(lambda x: x / 255.0)
47
48
49
"""
50
Let's display a sample image:
51
"""
52
53
54
for x in dataset:
55
plt.axis("off")
56
plt.imshow((x.numpy() * 255).astype("int32")[0])
57
break
58
59
60
"""
61
## Create the discriminator
62
63
It maps a 64x64 image to a binary classification score.
64
"""
65
66
discriminator = keras.Sequential(
67
[
68
keras.Input(shape=(64, 64, 3)),
69
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
70
layers.LeakyReLU(negative_slope=0.2),
71
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
72
layers.LeakyReLU(negative_slope=0.2),
73
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
74
layers.LeakyReLU(negative_slope=0.2),
75
layers.Flatten(),
76
layers.Dropout(0.2),
77
layers.Dense(1, activation="sigmoid"),
78
],
79
name="discriminator",
80
)
81
discriminator.summary()
82
83
"""
84
## Create the generator
85
86
It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.
87
"""
88
89
latent_dim = 128
90
91
generator = keras.Sequential(
92
[
93
keras.Input(shape=(latent_dim,)),
94
layers.Dense(8 * 8 * 128),
95
layers.Reshape((8, 8, 128)),
96
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
97
layers.LeakyReLU(negative_slope=0.2),
98
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
99
layers.LeakyReLU(negative_slope=0.2),
100
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
101
layers.LeakyReLU(negative_slope=0.2),
102
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
103
],
104
name="generator",
105
)
106
generator.summary()
107
108
"""
109
## Override `train_step`
110
"""
111
112
113
class GAN(keras.Model):
114
def __init__(self, discriminator, generator, latent_dim):
115
super().__init__()
116
self.discriminator = discriminator
117
self.generator = generator
118
self.latent_dim = latent_dim
119
self.seed_generator = keras.random.SeedGenerator(1337)
120
121
def compile(self, d_optimizer, g_optimizer, loss_fn):
122
super().compile()
123
self.d_optimizer = d_optimizer
124
self.g_optimizer = g_optimizer
125
self.loss_fn = loss_fn
126
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
127
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
128
129
@property
130
def metrics(self):
131
return [self.d_loss_metric, self.g_loss_metric]
132
133
def train_step(self, real_images):
134
# Sample random points in the latent space
135
batch_size = ops.shape(real_images)[0]
136
random_latent_vectors = keras.random.normal(
137
shape=(batch_size, self.latent_dim), seed=self.seed_generator
138
)
139
140
# Decode them to fake images
141
generated_images = self.generator(random_latent_vectors)
142
143
# Combine them with real images
144
combined_images = ops.concatenate([generated_images, real_images], axis=0)
145
146
# Assemble labels discriminating real from fake images
147
labels = ops.concatenate(
148
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
149
)
150
# Add random noise to the labels - important trick!
151
labels += 0.05 * tf.random.uniform(tf.shape(labels))
152
153
# Train the discriminator
154
with tf.GradientTape() as tape:
155
predictions = self.discriminator(combined_images)
156
d_loss = self.loss_fn(labels, predictions)
157
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
158
self.d_optimizer.apply_gradients(
159
zip(grads, self.discriminator.trainable_weights)
160
)
161
162
# Sample random points in the latent space
163
random_latent_vectors = keras.random.normal(
164
shape=(batch_size, self.latent_dim), seed=self.seed_generator
165
)
166
167
# Assemble labels that say "all real images"
168
misleading_labels = ops.zeros((batch_size, 1))
169
170
# Train the generator (note that we should *not* update the weights
171
# of the discriminator)!
172
with tf.GradientTape() as tape:
173
predictions = self.discriminator(self.generator(random_latent_vectors))
174
g_loss = self.loss_fn(misleading_labels, predictions)
175
grads = tape.gradient(g_loss, self.generator.trainable_weights)
176
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
177
178
# Update metrics
179
self.d_loss_metric.update_state(d_loss)
180
self.g_loss_metric.update_state(g_loss)
181
return {
182
"d_loss": self.d_loss_metric.result(),
183
"g_loss": self.g_loss_metric.result(),
184
}
185
186
187
"""
188
## Create a callback that periodically saves generated images
189
"""
190
191
192
class GANMonitor(keras.callbacks.Callback):
193
def __init__(self, num_img=3, latent_dim=128):
194
self.num_img = num_img
195
self.latent_dim = latent_dim
196
self.seed_generator = keras.random.SeedGenerator(42)
197
198
def on_epoch_end(self, epoch, logs=None):
199
random_latent_vectors = keras.random.normal(
200
shape=(self.num_img, self.latent_dim), seed=self.seed_generator
201
)
202
generated_images = self.model.generator(random_latent_vectors)
203
generated_images *= 255
204
generated_images.numpy()
205
for i in range(self.num_img):
206
img = keras.utils.array_to_img(generated_images[i])
207
img.save("generated_img_%03d_%d.png" % (epoch, i))
208
209
210
"""
211
## Train the end-to-end model
212
"""
213
214
epochs = 1 # In practice, use ~100 epochs
215
216
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
217
gan.compile(
218
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
219
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
220
loss_fn=keras.losses.BinaryCrossentropy(),
221
)
222
223
gan.fit(
224
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
225
)
226
227
"""
228
Some of the last generated images around epoch 30
229
(results keep improving after that):
230
231
![results](https://i.imgur.com/h5MtQZ7l.png)
232
"""
233
234
"""
235
## Relevant Chapters from Deep Learning with Python
236
- [Chapter 17: Image generation](https://deeplearningwithpython.io/chapters/chapter17_image-generation)
237
"""
238
239