Path: blob/master/labml_nn/gan/stylegan/experiment.py
4959 views
"""1---2title: StyleGAN 2 Model Training3summary: >4An annotated PyTorch implementation of StyleGAN2 model training code.5---67# [StyleGAN 2](index.html) Model Training89This is the training code for [StyleGAN 2](index.html) model.10111213---*These are $64 \times 64$ images generated after training for about 80K steps.*---1415*Our implementation is a minimalistic StyleGAN 2 model training code.16Only single GPU training is supported to keep the implementation simple.17We managed to shrink it to keep it at less than 500 lines of code, including the training loop.*1819*Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model20for large resolutions (128+).21If you want training code with fp16 and DDP take a look at22[lucidrains/stylegan2-pytorch](https://github.com/lucidrains/stylegan2-pytorch).*2324We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).25You can find the download instruction in this26[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).27Save the images inside [`data/stylegan` folder](#dataset_path).28"""2930import math31from pathlib import Path32from typing import Iterator, Tuple3334import torchvision35from PIL import Image3637import torch38import torch.utils.data39from labml import tracker, lab, monit, experiment40from labml.configs import BaseConfigs41from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty42from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss43from labml_nn.helpers.device import DeviceConfigs44from labml_nn.helpers.trainer import ModeState45from labml_nn.utils import cycle_dataloader464748class Dataset(torch.utils.data.Dataset):49"""50## Dataset5152This loads the training dataset and resize it to the give image size.53"""5455def __init__(self, path: str, image_size: int):56"""57* `path` path to the folder containing the images58* `image_size` size of the image59"""60super().__init__()6162# Get the paths of all `jpg` files63self.paths = [p for p in Path(path).glob(f'**/*.jpg')]6465# Transformation66self.transform = torchvision.transforms.Compose([67# Resize the image68torchvision.transforms.Resize(image_size),69# Convert to PyTorch tensor70torchvision.transforms.ToTensor(),71])7273def __len__(self):74"""Number of images"""75return len(self.paths)7677def __getitem__(self, index):78"""Get the the `index`-th image"""79path = self.paths[index]80img = Image.open(path)81return self.transform(img)828384class Configs(BaseConfigs):85"""86## Configurations87"""8889# Device to train the model on.90# [`DeviceConfigs`](../../helpers/device.html)91# picks up an available CUDA device or defaults to CPU.92device: torch.device = DeviceConfigs()9394# [StyleGAN2 Discriminator](index.html#discriminator)95discriminator: Discriminator96# [StyleGAN2 Generator](index.html#generator)97generator: Generator98# [Mapping network](index.html#mapping_network)99mapping_network: MappingNetwork100101# Discriminator and generator loss functions.102# We use [Wasserstein loss](../wasserstein/index.html)103discriminator_loss: DiscriminatorLoss104generator_loss: GeneratorLoss105106# Optimizers107generator_optimizer: torch.optim.Adam108discriminator_optimizer: torch.optim.Adam109mapping_network_optimizer: torch.optim.Adam110111# [Gradient Penalty Regularization Loss](index.html#gradient_penalty)112gradient_penalty = GradientPenalty()113# Gradient penalty coefficient $\gamma$114gradient_penalty_coefficient: float = 10.115116# [Path length penalty](index.html#path_length_penalty)117path_length_penalty: PathLengthPenalty118119# Data loader120loader: Iterator121122# Batch size123batch_size: int = 32124# Dimensionality of $z$ and $w$125d_latent: int = 512126# Height/width of the image127image_size: int = 32128# Number of layers in the mapping network129mapping_network_layers: int = 8130# Generator & Discriminator learning rate131learning_rate: float = 1e-3132# Mapping network learning rate ($100 \times$ lower than the others)133mapping_network_learning_rate: float = 1e-5134# Number of steps to accumulate gradients on. Use this to increase the effective batch size.135gradient_accumulate_steps: int = 1136# $\beta_1$ and $\beta_2$ for Adam optimizer137adam_betas: Tuple[float, float] = (0.0, 0.99)138# Probability of mixing styles139style_mixing_prob: float = 0.9140141# Total number of training steps142training_steps: int = 150_000143144# Number of blocks in the generator (calculated based on image resolution)145n_gen_blocks: int146147# ### Lazy regularization148# Instead of calculating the regularization losses, the paper proposes lazy regularization149# where the regularization terms are calculated once in a while.150# This improves the training efficiency a lot.151152# The interval at which to compute gradient penalty153lazy_gradient_penalty_interval: int = 4154# Path length penalty calculation interval155lazy_path_penalty_interval: int = 32156# Skip calculating path length penalty during the initial phase of training157lazy_path_penalty_after: int = 5_000158159# How often to log generated images160log_generated_interval: int = 500161# How often to save model checkpoints162save_checkpoint_interval: int = 2_000163164# Training mode state for logging activations165mode: ModeState166167# <a id="dataset_path"></a>168# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).169# You can find the download instruction in this170# [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).171# Save the images inside `data/stylegan` folder.172dataset_path: str = str(lab.get_data_path() / 'stylegan2')173174def init(self):175"""176### Initialize177"""178# Create dataset179dataset = Dataset(self.dataset_path, self.image_size)180# Create data loader181dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,182shuffle=True, drop_last=True, pin_memory=True)183# Continuous [cyclic loader](../../utils.html#cycle_dataloader)184self.loader = cycle_dataloader(dataloader)185186# $\log_2$ of image resolution187log_resolution = int(math.log2(self.image_size))188189# Create discriminator and generator190self.discriminator = Discriminator(log_resolution).to(self.device)191self.generator = Generator(log_resolution, self.d_latent).to(self.device)192# Get number of generator blocks for creating style and noise inputs193self.n_gen_blocks = self.generator.n_blocks194# Create mapping network195self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)196# Create path length penalty loss197self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)198199# Discriminator and generator losses200self.discriminator_loss = DiscriminatorLoss().to(self.device)201self.generator_loss = GeneratorLoss().to(self.device)202203# Create optimizers204self.discriminator_optimizer = torch.optim.Adam(205self.discriminator.parameters(),206lr=self.learning_rate, betas=self.adam_betas207)208self.generator_optimizer = torch.optim.Adam(209self.generator.parameters(),210lr=self.learning_rate, betas=self.adam_betas211)212self.mapping_network_optimizer = torch.optim.Adam(213self.mapping_network.parameters(),214lr=self.mapping_network_learning_rate, betas=self.adam_betas215)216217# Set tracker configurations218tracker.set_image("generated", True)219220def get_w(self, batch_size: int):221"""222### Sample $w$223224This samples $z$ randomly and get $w$ from the mapping network.225226We also apply style mixing sometimes where we generate two latent variables227$z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$.228Then we randomly sample a cross-over point and apply $w_1$ to229the generator blocks before the cross-over point and230$w_2$ to the blocks after.231"""232233# Mix styles234if torch.rand(()).item() < self.style_mixing_prob:235# Random cross-over point236cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)237# Sample $z_1$ and $z_2$238z2 = torch.randn(batch_size, self.d_latent).to(self.device)239z1 = torch.randn(batch_size, self.d_latent).to(self.device)240# Get $w_1$ and $w_2$241w1 = self.mapping_network(z1)242w2 = self.mapping_network(z2)243# Expand $w_1$ and $w_2$ for the generator blocks and concatenate244w1 = w1[None, :, :].expand(cross_over_point, -1, -1)245w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)246return torch.cat((w1, w2), dim=0)247# Without mixing248else:249# Sample $z$ and $z$250z = torch.randn(batch_size, self.d_latent).to(self.device)251# Get $w$ and $w$252w = self.mapping_network(z)253# Expand $w$ for the generator blocks254return w[None, :, :].expand(self.n_gen_blocks, -1, -1)255256def get_noise(self, batch_size: int):257"""258### Generate noise259260This generates noise for each [generator block](index.html#generator_block)261"""262# List to store noise263noise = []264# Noise resolution starts from $4$265resolution = 4266267# Generate noise for each generator block268for i in range(self.n_gen_blocks):269# The first block has only one $3 \times 3$ convolution270if i == 0:271n1 = None272# Generate noise to add after the first convolution layer273else:274n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)275# Generate noise to add after the second convolution layer276n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)277278# Add noise tensors to the list279noise.append((n1, n2))280281# Next block has $2 \times$ resolution282resolution *= 2283284# Return noise tensors285return noise286287def generate_images(self, batch_size: int):288"""289### Generate images290291This generate images using the generator292"""293294# Get $w$295w = self.get_w(batch_size)296# Get noise297noise = self.get_noise(batch_size)298299# Generate images300images = self.generator(w, noise)301302# Return images and $w$303return images, w304305def step(self, idx: int):306"""307### Training Step308"""309310# Train the discriminator311with monit.section('Discriminator'):312# Reset gradients313self.discriminator_optimizer.zero_grad()314315# Accumulate gradients for `gradient_accumulate_steps`316for i in range(self.gradient_accumulate_steps):317# Sample images from generator318generated_images, _ = self.generate_images(self.batch_size)319# Discriminator classification for generated images320fake_output = self.discriminator(generated_images.detach())321322# Get real images from the data loader323real_images = next(self.loader).to(self.device)324# We need to calculate gradients w.r.t. real images for gradient penalty325if (idx + 1) % self.lazy_gradient_penalty_interval == 0:326real_images.requires_grad_()327# Discriminator classification for real images328real_output = self.discriminator(real_images)329330# Get discriminator loss331real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)332disc_loss = real_loss + fake_loss333334# Add gradient penalty335if (idx + 1) % self.lazy_gradient_penalty_interval == 0:336# Calculate and log gradient penalty337gp = self.gradient_penalty(real_images, real_output)338tracker.add('loss.gp', gp)339# Multiply by coefficient and add gradient penalty340disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval341342# Compute gradients343disc_loss.backward()344345# Log discriminator loss346tracker.add('loss.discriminator', disc_loss)347348if (idx + 1) % self.log_generated_interval == 0:349# Log discriminator model parameters occasionally350tracker.add('discriminator', self.discriminator)351352# Clip gradients for stabilization353torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)354# Take optimizer step355self.discriminator_optimizer.step()356357# Train the generator358with monit.section('Generator'):359# Reset gradients360self.generator_optimizer.zero_grad()361self.mapping_network_optimizer.zero_grad()362363# Accumulate gradients for `gradient_accumulate_steps`364for i in range(self.gradient_accumulate_steps):365# Sample images from generator366generated_images, w = self.generate_images(self.batch_size)367# Discriminator classification for generated images368fake_output = self.discriminator(generated_images)369370# Get generator loss371gen_loss = self.generator_loss(fake_output)372373# Add path length penalty374if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:375# Calculate path length penalty376plp = self.path_length_penalty(w, generated_images)377# Ignore if `nan`378if not torch.isnan(plp):379tracker.add('loss.plp', plp)380gen_loss = gen_loss + plp381382# Calculate gradients383gen_loss.backward()384385# Log generator loss386tracker.add('loss.generator', gen_loss)387388if (idx + 1) % self.log_generated_interval == 0:389# Log discriminator model parameters occasionally390tracker.add('generator', self.generator)391tracker.add('mapping_network', self.mapping_network)392393# Clip gradients for stabilization394torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)395torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)396397# Take optimizer step398self.generator_optimizer.step()399self.mapping_network_optimizer.step()400401# Log generated images402if (idx + 1) % self.log_generated_interval == 0:403tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))404# Save model checkpoints405if (idx + 1) % self.save_checkpoint_interval == 0:406# Save checkpoint407pass408409# Flush tracker410tracker.save()411412def train(self):413"""414## Train model415"""416417# Loop for `training_steps`418for i in monit.loop(self.training_steps):419# Take a training step420self.step(i)421#422if (i + 1) % self.log_generated_interval == 0:423tracker.new_line()424425426def main():427"""428### Train StyleGAN2429"""430431# Create an experiment432experiment.create(name='stylegan2')433# Create configurations object434configs = Configs()435436# Set configurations and override some437experiment.configs(configs, {438'device.cuda_device': 0,439'image_size': 64,440'log_generated_interval': 200441})442443# Initialize444configs.init()445# Set models for saving and loading446experiment.add_pytorch_models(mapping_network=configs.mapping_network,447generator=configs.generator,448discriminator=configs.discriminator)449450# Start the experiment451with experiment.start():452# Run the training loop453configs.train()454455456#457if __name__ == '__main__':458main()459460461