Path: blob/master/site/en-snapshot/tutorials/generative/cyclegan.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
CycleGAN
This notebook demonstrates unpaired image to image translation using conditional GAN's, as described in Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, also known as CycleGAN. The paper proposes a method that can capture the characteristics of one image domain and figure out how these characteristics could be translated into another image domain, all in the absence of any paired training examples.
This notebook assumes you are familiar with Pix2Pix, which you can learn about in the Pix2Pix tutorial. The code for CycleGAN is similar, the main difference is an additional loss function, and the use of unpaired training data.
CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain.
This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).
Set up the input pipeline
Install the tensorflow_examples package that enables importing of the generator and the discriminator.
Input Pipeline
This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones here.
As mentioned in the paper, apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.
This is similar to what was done in pix2pix
In random jittering, the image is resized to
286 x 286
and then randomly cropped to256 x 256
.In random mirroring, the image is randomly flipped horizontally i.e. left to right.
Import and reuse the Pix2Pix models
Import the generator and the discriminator used in Pix2Pix via the installed tensorflow_examples package.
The model architecture used in this tutorial is very similar to what was used in pix2pix. Some of the differences are:
Cyclegan uses instance normalization instead of batch normalization.
The CycleGAN paper uses a modified
resnet
based generator. This tutorial is using a modifiedunet
generator for simplicity.
There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here.
Generator
G
learns to transform imageX
to imageY
.Generator
F
learns to transform imageY
to imageX
.Discriminator
D_X
learns to differentiate between imageX
and generated imageX
(F(Y)
).Discriminator
D_Y
learns to differentiate between imageY
and generated imageY
(G(X)
).
Loss functions
In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input x
and the target y
pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.
The discriminator loss and the generator loss are similar to the ones used in pix2pix.
Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the original sentence.
In cycle consistency loss,
Image is passed via generator that yields generated image .
Generated image is passed via generator that yields cycled image .
Mean absolute error is calculated between and .
As shown above, generator is responsible for translating image to image . Identity loss says that, if you fed image to generator , it should yield the real image or something close to image .
If you run the zebra-to-horse model on a horse or the horse-to-zebra model on a zebra, it should not modify the image much since the image already contains the target class.
Initialize the optimizers for all the generators and the discriminators.
Checkpoints
Training
Note: This example model is trained for fewer epochs (10) than the paper (200) to keep training time reasonable for this tutorial. The generated images will have much lower quality.
Even though the training loop looks complicated, it consists of four basic steps:
Get the predictions.
Calculate the loss.
Calculate the gradients using backpropagation.
Apply the gradients to the optimizer.
Generate using test dataset
Next steps
This tutorial has shown how to implement CycleGAN starting from the generator and discriminator implemented in the Pix2Pix tutorial. As a next step, you could try using a different dataset from TensorFlow Datasets.
You could also train for a larger number of epochs to improve the results, or you could implement the modified ResNet generator used in the paper instead of the U-Net generator used here.