Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/gan/stylegan/experiment.py
4959 views
1
"""
2
---
3
title: StyleGAN 2 Model Training
4
summary: >
5
An annotated PyTorch implementation of StyleGAN2 model training code.
6
---
7
8
# [StyleGAN 2](index.html) Model Training
9
10
This is the training code for [StyleGAN 2](index.html) model.
11
12
![Generated Images](generated_64.png)
13
14
---*These are $64 \times 64$ images generated after training for about 80K steps.*---
15
16
*Our implementation is a minimalistic StyleGAN 2 model training code.
17
Only single GPU training is supported to keep the implementation simple.
18
We managed to shrink it to keep it at less than 500 lines of code, including the training loop.*
19
20
*Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model
21
for large resolutions (128+).
22
If you want training code with fp16 and DDP take a look at
23
[lucidrains/stylegan2-pytorch](https://github.com/lucidrains/stylegan2-pytorch).*
24
25
We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
26
You can find the download instruction in this
27
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
28
Save the images inside [`data/stylegan` folder](#dataset_path).
29
"""
30
31
import math
32
from pathlib import Path
33
from typing import Iterator, Tuple
34
35
import torchvision
36
from PIL import Image
37
38
import torch
39
import torch.utils.data
40
from labml import tracker, lab, monit, experiment
41
from labml.configs import BaseConfigs
42
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
43
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
44
from labml_nn.helpers.device import DeviceConfigs
45
from labml_nn.helpers.trainer import ModeState
46
from labml_nn.utils import cycle_dataloader
47
48
49
class Dataset(torch.utils.data.Dataset):
50
"""
51
## Dataset
52
53
This loads the training dataset and resize it to the give image size.
54
"""
55
56
def __init__(self, path: str, image_size: int):
57
"""
58
* `path` path to the folder containing the images
59
* `image_size` size of the image
60
"""
61
super().__init__()
62
63
# Get the paths of all `jpg` files
64
self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
65
66
# Transformation
67
self.transform = torchvision.transforms.Compose([
68
# Resize the image
69
torchvision.transforms.Resize(image_size),
70
# Convert to PyTorch tensor
71
torchvision.transforms.ToTensor(),
72
])
73
74
def __len__(self):
75
"""Number of images"""
76
return len(self.paths)
77
78
def __getitem__(self, index):
79
"""Get the the `index`-th image"""
80
path = self.paths[index]
81
img = Image.open(path)
82
return self.transform(img)
83
84
85
class Configs(BaseConfigs):
86
"""
87
## Configurations
88
"""
89
90
# Device to train the model on.
91
# [`DeviceConfigs`](../../helpers/device.html)
92
# picks up an available CUDA device or defaults to CPU.
93
device: torch.device = DeviceConfigs()
94
95
# [StyleGAN2 Discriminator](index.html#discriminator)
96
discriminator: Discriminator
97
# [StyleGAN2 Generator](index.html#generator)
98
generator: Generator
99
# [Mapping network](index.html#mapping_network)
100
mapping_network: MappingNetwork
101
102
# Discriminator and generator loss functions.
103
# We use [Wasserstein loss](../wasserstein/index.html)
104
discriminator_loss: DiscriminatorLoss
105
generator_loss: GeneratorLoss
106
107
# Optimizers
108
generator_optimizer: torch.optim.Adam
109
discriminator_optimizer: torch.optim.Adam
110
mapping_network_optimizer: torch.optim.Adam
111
112
# [Gradient Penalty Regularization Loss](index.html#gradient_penalty)
113
gradient_penalty = GradientPenalty()
114
# Gradient penalty coefficient $\gamma$
115
gradient_penalty_coefficient: float = 10.
116
117
# [Path length penalty](index.html#path_length_penalty)
118
path_length_penalty: PathLengthPenalty
119
120
# Data loader
121
loader: Iterator
122
123
# Batch size
124
batch_size: int = 32
125
# Dimensionality of $z$ and $w$
126
d_latent: int = 512
127
# Height/width of the image
128
image_size: int = 32
129
# Number of layers in the mapping network
130
mapping_network_layers: int = 8
131
# Generator & Discriminator learning rate
132
learning_rate: float = 1e-3
133
# Mapping network learning rate ($100 \times$ lower than the others)
134
mapping_network_learning_rate: float = 1e-5
135
# Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136
gradient_accumulate_steps: int = 1
137
# $\beta_1$ and $\beta_2$ for Adam optimizer
138
adam_betas: Tuple[float, float] = (0.0, 0.99)
139
# Probability of mixing styles
140
style_mixing_prob: float = 0.9
141
142
# Total number of training steps
143
training_steps: int = 150_000
144
145
# Number of blocks in the generator (calculated based on image resolution)
146
n_gen_blocks: int
147
148
# ### Lazy regularization
149
# Instead of calculating the regularization losses, the paper proposes lazy regularization
150
# where the regularization terms are calculated once in a while.
151
# This improves the training efficiency a lot.
152
153
# The interval at which to compute gradient penalty
154
lazy_gradient_penalty_interval: int = 4
155
# Path length penalty calculation interval
156
lazy_path_penalty_interval: int = 32
157
# Skip calculating path length penalty during the initial phase of training
158
lazy_path_penalty_after: int = 5_000
159
160
# How often to log generated images
161
log_generated_interval: int = 500
162
# How often to save model checkpoints
163
save_checkpoint_interval: int = 2_000
164
165
# Training mode state for logging activations
166
mode: ModeState
167
168
# <a id="dataset_path"></a>
169
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
170
# You can find the download instruction in this
171
# [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
172
# Save the images inside `data/stylegan` folder.
173
dataset_path: str = str(lab.get_data_path() / 'stylegan2')
174
175
def init(self):
176
"""
177
### Initialize
178
"""
179
# Create dataset
180
dataset = Dataset(self.dataset_path, self.image_size)
181
# Create data loader
182
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
183
shuffle=True, drop_last=True, pin_memory=True)
184
# Continuous [cyclic loader](../../utils.html#cycle_dataloader)
185
self.loader = cycle_dataloader(dataloader)
186
187
# $\log_2$ of image resolution
188
log_resolution = int(math.log2(self.image_size))
189
190
# Create discriminator and generator
191
self.discriminator = Discriminator(log_resolution).to(self.device)
192
self.generator = Generator(log_resolution, self.d_latent).to(self.device)
193
# Get number of generator blocks for creating style and noise inputs
194
self.n_gen_blocks = self.generator.n_blocks
195
# Create mapping network
196
self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
197
# Create path length penalty loss
198
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
199
200
# Discriminator and generator losses
201
self.discriminator_loss = DiscriminatorLoss().to(self.device)
202
self.generator_loss = GeneratorLoss().to(self.device)
203
204
# Create optimizers
205
self.discriminator_optimizer = torch.optim.Adam(
206
self.discriminator.parameters(),
207
lr=self.learning_rate, betas=self.adam_betas
208
)
209
self.generator_optimizer = torch.optim.Adam(
210
self.generator.parameters(),
211
lr=self.learning_rate, betas=self.adam_betas
212
)
213
self.mapping_network_optimizer = torch.optim.Adam(
214
self.mapping_network.parameters(),
215
lr=self.mapping_network_learning_rate, betas=self.adam_betas
216
)
217
218
# Set tracker configurations
219
tracker.set_image("generated", True)
220
221
def get_w(self, batch_size: int):
222
"""
223
### Sample $w$
224
225
This samples $z$ randomly and get $w$ from the mapping network.
226
227
We also apply style mixing sometimes where we generate two latent variables
228
$z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$.
229
Then we randomly sample a cross-over point and apply $w_1$ to
230
the generator blocks before the cross-over point and
231
$w_2$ to the blocks after.
232
"""
233
234
# Mix styles
235
if torch.rand(()).item() < self.style_mixing_prob:
236
# Random cross-over point
237
cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
238
# Sample $z_1$ and $z_2$
239
z2 = torch.randn(batch_size, self.d_latent).to(self.device)
240
z1 = torch.randn(batch_size, self.d_latent).to(self.device)
241
# Get $w_1$ and $w_2$
242
w1 = self.mapping_network(z1)
243
w2 = self.mapping_network(z2)
244
# Expand $w_1$ and $w_2$ for the generator blocks and concatenate
245
w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
246
w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
247
return torch.cat((w1, w2), dim=0)
248
# Without mixing
249
else:
250
# Sample $z$ and $z$
251
z = torch.randn(batch_size, self.d_latent).to(self.device)
252
# Get $w$ and $w$
253
w = self.mapping_network(z)
254
# Expand $w$ for the generator blocks
255
return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
256
257
def get_noise(self, batch_size: int):
258
"""
259
### Generate noise
260
261
This generates noise for each [generator block](index.html#generator_block)
262
"""
263
# List to store noise
264
noise = []
265
# Noise resolution starts from $4$
266
resolution = 4
267
268
# Generate noise for each generator block
269
for i in range(self.n_gen_blocks):
270
# The first block has only one $3 \times 3$ convolution
271
if i == 0:
272
n1 = None
273
# Generate noise to add after the first convolution layer
274
else:
275
n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
276
# Generate noise to add after the second convolution layer
277
n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
278
279
# Add noise tensors to the list
280
noise.append((n1, n2))
281
282
# Next block has $2 \times$ resolution
283
resolution *= 2
284
285
# Return noise tensors
286
return noise
287
288
def generate_images(self, batch_size: int):
289
"""
290
### Generate images
291
292
This generate images using the generator
293
"""
294
295
# Get $w$
296
w = self.get_w(batch_size)
297
# Get noise
298
noise = self.get_noise(batch_size)
299
300
# Generate images
301
images = self.generator(w, noise)
302
303
# Return images and $w$
304
return images, w
305
306
def step(self, idx: int):
307
"""
308
### Training Step
309
"""
310
311
# Train the discriminator
312
with monit.section('Discriminator'):
313
# Reset gradients
314
self.discriminator_optimizer.zero_grad()
315
316
# Accumulate gradients for `gradient_accumulate_steps`
317
for i in range(self.gradient_accumulate_steps):
318
# Sample images from generator
319
generated_images, _ = self.generate_images(self.batch_size)
320
# Discriminator classification for generated images
321
fake_output = self.discriminator(generated_images.detach())
322
323
# Get real images from the data loader
324
real_images = next(self.loader).to(self.device)
325
# We need to calculate gradients w.r.t. real images for gradient penalty
326
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
327
real_images.requires_grad_()
328
# Discriminator classification for real images
329
real_output = self.discriminator(real_images)
330
331
# Get discriminator loss
332
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
333
disc_loss = real_loss + fake_loss
334
335
# Add gradient penalty
336
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337
# Calculate and log gradient penalty
338
gp = self.gradient_penalty(real_images, real_output)
339
tracker.add('loss.gp', gp)
340
# Multiply by coefficient and add gradient penalty
341
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
342
343
# Compute gradients
344
disc_loss.backward()
345
346
# Log discriminator loss
347
tracker.add('loss.discriminator', disc_loss)
348
349
if (idx + 1) % self.log_generated_interval == 0:
350
# Log discriminator model parameters occasionally
351
tracker.add('discriminator', self.discriminator)
352
353
# Clip gradients for stabilization
354
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
355
# Take optimizer step
356
self.discriminator_optimizer.step()
357
358
# Train the generator
359
with monit.section('Generator'):
360
# Reset gradients
361
self.generator_optimizer.zero_grad()
362
self.mapping_network_optimizer.zero_grad()
363
364
# Accumulate gradients for `gradient_accumulate_steps`
365
for i in range(self.gradient_accumulate_steps):
366
# Sample images from generator
367
generated_images, w = self.generate_images(self.batch_size)
368
# Discriminator classification for generated images
369
fake_output = self.discriminator(generated_images)
370
371
# Get generator loss
372
gen_loss = self.generator_loss(fake_output)
373
374
# Add path length penalty
375
if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
376
# Calculate path length penalty
377
plp = self.path_length_penalty(w, generated_images)
378
# Ignore if `nan`
379
if not torch.isnan(plp):
380
tracker.add('loss.plp', plp)
381
gen_loss = gen_loss + plp
382
383
# Calculate gradients
384
gen_loss.backward()
385
386
# Log generator loss
387
tracker.add('loss.generator', gen_loss)
388
389
if (idx + 1) % self.log_generated_interval == 0:
390
# Log discriminator model parameters occasionally
391
tracker.add('generator', self.generator)
392
tracker.add('mapping_network', self.mapping_network)
393
394
# Clip gradients for stabilization
395
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
396
torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
397
398
# Take optimizer step
399
self.generator_optimizer.step()
400
self.mapping_network_optimizer.step()
401
402
# Log generated images
403
if (idx + 1) % self.log_generated_interval == 0:
404
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
405
# Save model checkpoints
406
if (idx + 1) % self.save_checkpoint_interval == 0:
407
# Save checkpoint
408
pass
409
410
# Flush tracker
411
tracker.save()
412
413
def train(self):
414
"""
415
## Train model
416
"""
417
418
# Loop for `training_steps`
419
for i in monit.loop(self.training_steps):
420
# Take a training step
421
self.step(i)
422
#
423
if (i + 1) % self.log_generated_interval == 0:
424
tracker.new_line()
425
426
427
def main():
428
"""
429
### Train StyleGAN2
430
"""
431
432
# Create an experiment
433
experiment.create(name='stylegan2')
434
# Create configurations object
435
configs = Configs()
436
437
# Set configurations and override some
438
experiment.configs(configs, {
439
'device.cuda_device': 0,
440
'image_size': 64,
441
'log_generated_interval': 200
442
})
443
444
# Initialize
445
configs.init()
446
# Set models for saving and loading
447
experiment.add_pytorch_models(mapping_network=configs.mapping_network,
448
generator=configs.generator,
449
discriminator=configs.discriminator)
450
451
# Start the experiment
452
with experiment.start():
453
# Run the training loop
454
configs.train()
455
456
457
#
458
if __name__ == '__main__':
459
main()
460
461