Path: blob/master/examples/generative/dcgan_overriding_train_step.py
8146 views
"""1Title: DCGAN to generate face images2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/04/294Last modified: 2023/12/215Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.6Accelerator: GPU7"""89"""10## Setup11"""1213import keras14import tensorflow as tf1516from keras import layers17from keras import ops18import matplotlib.pyplot as plt19import os20import gdown21from zipfile import ZipFile2223"""24## Prepare CelebA data2526We'll use face images from the CelebA dataset, resized to 64x64.27"""2829os.makedirs("celeba_gan")3031url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"32output = "celeba_gan/data.zip"33gdown.download(url, output, quiet=True)3435with ZipFile("celeba_gan/data.zip", "r") as zipobj:36zipobj.extractall("celeba_gan")3738"""39Create a dataset from our folder, and rescale the images to the [0-1] range:40"""4142dataset = keras.utils.image_dataset_from_directory(43"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=3244)45dataset = dataset.map(lambda x: x / 255.0)464748"""49Let's display a sample image:50"""515253for x in dataset:54plt.axis("off")55plt.imshow((x.numpy() * 255).astype("int32")[0])56break575859"""60## Create the discriminator6162It maps a 64x64 image to a binary classification score.63"""6465discriminator = keras.Sequential(66[67keras.Input(shape=(64, 64, 3)),68layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),69layers.LeakyReLU(negative_slope=0.2),70layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),71layers.LeakyReLU(negative_slope=0.2),72layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),73layers.LeakyReLU(negative_slope=0.2),74layers.Flatten(),75layers.Dropout(0.2),76layers.Dense(1, activation="sigmoid"),77],78name="discriminator",79)80discriminator.summary()8182"""83## Create the generator8485It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.86"""8788latent_dim = 1288990generator = keras.Sequential(91[92keras.Input(shape=(latent_dim,)),93layers.Dense(8 * 8 * 128),94layers.Reshape((8, 8, 128)),95layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),96layers.LeakyReLU(negative_slope=0.2),97layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),98layers.LeakyReLU(negative_slope=0.2),99layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),100layers.LeakyReLU(negative_slope=0.2),101layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),102],103name="generator",104)105generator.summary()106107"""108## Override `train_step`109"""110111112class GAN(keras.Model):113def __init__(self, discriminator, generator, latent_dim):114super().__init__()115self.discriminator = discriminator116self.generator = generator117self.latent_dim = latent_dim118self.seed_generator = keras.random.SeedGenerator(1337)119120def compile(self, d_optimizer, g_optimizer, loss_fn):121super().compile()122self.d_optimizer = d_optimizer123self.g_optimizer = g_optimizer124self.loss_fn = loss_fn125self.d_loss_metric = keras.metrics.Mean(name="d_loss")126self.g_loss_metric = keras.metrics.Mean(name="g_loss")127128@property129def metrics(self):130return [self.d_loss_metric, self.g_loss_metric]131132def train_step(self, real_images):133# Sample random points in the latent space134batch_size = ops.shape(real_images)[0]135random_latent_vectors = keras.random.normal(136shape=(batch_size, self.latent_dim), seed=self.seed_generator137)138139# Decode them to fake images140generated_images = self.generator(random_latent_vectors)141142# Combine them with real images143combined_images = ops.concatenate([generated_images, real_images], axis=0)144145# Assemble labels discriminating real from fake images146labels = ops.concatenate(147[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0148)149# Add random noise to the labels - important trick!150labels += 0.05 * tf.random.uniform(tf.shape(labels))151152# Train the discriminator153with tf.GradientTape() as tape:154predictions = self.discriminator(combined_images)155d_loss = self.loss_fn(labels, predictions)156grads = tape.gradient(d_loss, self.discriminator.trainable_weights)157self.d_optimizer.apply_gradients(158zip(grads, self.discriminator.trainable_weights)159)160161# Sample random points in the latent space162random_latent_vectors = keras.random.normal(163shape=(batch_size, self.latent_dim), seed=self.seed_generator164)165166# Assemble labels that say "all real images"167misleading_labels = ops.zeros((batch_size, 1))168169# Train the generator (note that we should *not* update the weights170# of the discriminator)!171with tf.GradientTape() as tape:172predictions = self.discriminator(self.generator(random_latent_vectors))173g_loss = self.loss_fn(misleading_labels, predictions)174grads = tape.gradient(g_loss, self.generator.trainable_weights)175self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))176177# Update metrics178self.d_loss_metric.update_state(d_loss)179self.g_loss_metric.update_state(g_loss)180return {181"d_loss": self.d_loss_metric.result(),182"g_loss": self.g_loss_metric.result(),183}184185186"""187## Create a callback that periodically saves generated images188"""189190191class GANMonitor(keras.callbacks.Callback):192def __init__(self, num_img=3, latent_dim=128):193self.num_img = num_img194self.latent_dim = latent_dim195self.seed_generator = keras.random.SeedGenerator(42)196197def on_epoch_end(self, epoch, logs=None):198random_latent_vectors = keras.random.normal(199shape=(self.num_img, self.latent_dim), seed=self.seed_generator200)201generated_images = self.model.generator(random_latent_vectors)202generated_images *= 255203generated_images.numpy()204for i in range(self.num_img):205img = keras.utils.array_to_img(generated_images[i])206img.save("generated_img_%03d_%d.png" % (epoch, i))207208209"""210## Train the end-to-end model211"""212213epochs = 1 # In practice, use ~100 epochs214215gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)216gan.compile(217d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),218g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),219loss_fn=keras.losses.BinaryCrossentropy(),220)221222gan.fit(223dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]224)225226"""227Some of the last generated images around epoch 30228(results keep improving after that):229230231"""232233"""234## Relevant Chapters from Deep Learning with Python235- [Chapter 17: Image generation](https://deeplearningwithpython.io/chapters/chapter17_image-generation)236"""237238239