Path: blob/main/C3 - Apply Generative Adversarial Network (GAN)/Week 3/C3W3_Assignment-original.ipynb
1219 views
CycleGAN
Goals
In this notebook, you will write a generative model based on the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Zhu et al. 2017, commonly referred to as CycleGAN.
You will be training a model that can convert horses into zebras, and vice versa. Once again, the emphasis of the assignment will be on the loss functions. In order for you to see good outputs more quickly, you'll be training your model starting from a pre-trained checkpoint. You are also welcome to train it from scratch on your own, if you so choose.
Learning Objectives
Implement the loss functions of a CycleGAN model.
Observe the two GANs used in CycleGAN.
Getting Started
You will start by importing libraries, defining a visualization function, and getting the pre-trained CycleGAN checkpoint.
Generator
The code for a CycleGAN generator is much like Pix2Pix's U-Net with the addition of the residual block between the encoding (contracting) and decoding (expanding) blocks.
Diagram of a CycleGAN generator: composed of encoding blocks, residual blocks, and then decoding blocks.
Residual Block
Perhaps the most notable architectural difference between the U-Net you used for Pix2Pix and the architecture you're using for CycleGAN are the residual blocks. In CycleGAN, after the expanding blocks, there are convolutional layers where the output is ultimately added to the original input so that the network can change as little as possible on the image. You can think of this transformation as a kind of skip connection, where instead of being concatenated as new channels before the convolution which combines them, it's added directly to the output of the convolution. In the visualization below, you can imagine the stripes being generated by the convolutions and then added to the original image of the horse to transform it into a zebra. These skip connections also allow the network to be deeper, because they help with vanishing gradients issues that come when a neural network gets too deep and the gradients multiply in backpropagation to become very small; instead, these skip connections enable more gradient flow. A deeper network is often able to learn more complex features.
Example of a residual block.
Contracting and Expanding Blocks
The rest of the generator code will otherwise be much like the code you wrote for the last assignment: Pix2Pix's U-Net. The primary changes are the use of instance norm instead of batch norm (which you may recall from StyleGAN), no dropout, and a stride-2 convolution instead of max pooling. Feel free to investigate the code if you're interested!
CycleGAN Generator
Finally, you can put all the blocks together to create your CycleGAN generator.
PatchGAN Discriminator
Next, you will define the discriminator—a PatchGAN. It will be very similar to what you saw in Pix2Pix.
Training Preparation
Now you can put everything together for training! You will start by defining your parameters:
adv_criterion: an adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
recon_criterion: a loss function that rewards similar images to the ground truth, which "reconstruct" the image
n_epochs: the number of times you iterate through the entire dataset when training
dim_A: the number of channels of the images in pile A
dim_B: the number of channels of the images in pile B (note that in the visualization this is currently treated as equivalent to dim_A)
display_step: how often to display/visualize the images
batch_size: the number of images per forward/backward pass
lr: the learning rate
target_shape: the size of the input and output images (in pixels)
load_shape: the size for the dataset to load the images at before randomly cropping them to target_shape as a simple data augmentation
device: the device type
You will then load the images of the dataset while introducing some data augmentation (e.g. crops and random horizontal flips).
Next, you can initialize your generators and discriminators, as well as their optimizers. For CycleGAN, you will have two generators and two discriminators since there are two GANs:
Generator for horse to zebra (
gen_AB)Generator for zebra to horse (
gen_BA)Discriminator for horse (
disc_A)Discriminator for zebra (
disc_B)
You will also load your pre-trained model.
Discriminator Loss
First, you're going to be implementing the discriminator loss. This is the same as in previous assignments, so it should be a breeze 😃 Don't forget to detach your generator!
Generator Loss
While there are some changes to the CycleGAN architecture from Pix2Pix, the most important distinguishing feature of CycleGAN is its generator loss. You will be implementing that here!
Adversarial Loss
The first component of the generator's loss you're going to implement is its adversarial loss—this once again is pretty similar to the GAN loss that you've implemented in the past. The important thing to note is that the criterion now is based on least squares loss, rather than binary cross entropy loss or W-loss.
Identity Loss
Here you get to see some of the superbly new material! You'll want to measure the change in an image when you pass the generator an example from the target domain instead of the input domain it's expecting. The output should be the same as the input since it is already of the target domain class. For example, if you put a horse through a zebra -> horse generator, you'd expect the output to be the same horse because nothing needed to be transformed. It's already a horse! You don't want your generator to be transforming it into any other thing, so you want to encourage this behavior. In encouraging this identity mapping, the authors of CycleGAN found that for some tasks, this helped properly preserve the colors of an image, even when the expected input (here, a zebra) was put in. This was particularly useful for the photos <-> paintings mapping and, while an optional aesthetic component, you might find it useful for your applications down the line.
Cycle Consistency Loss
Now, you can implement the final generator loss and the part that puts the "cycle" in CycleGAN: cycle consistency loss. This is used to ensure that when you put an image through one generator, that if it is then transformed back into the input class using the opposite generator, the image is the same as the original input image.
Since you've already generated a fake image for the adversarial part, you can pass that fake image back to produce a full cycle—this loss will encourage the cycle to preserve as much information as possible.
Fun fact: Cycle consistency is a broader concept that's used outside of CycleGAN a lot too! It's helped with data augmentation and has been used on text translation too, e.g. French -> English -> French should get the same phrase back.
Generator Loss (Total)
Finally, you can put it all together! There are many components, so be careful as you go through this section.
CycleGAN Training
Lastly, you can train the model and see some of your zebras, horses, and some that might not quite look like either! Note that this training will take a long time, so feel free to use the pre-trained checkpoint as an example of what a pretty-good CycleGAN does.