Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/worker.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/worker.py
6
7
from os.path import join
8
import sys
9
import glob
10
import random
11
import string
12
import pickle
13
import copy
14
15
from torch.nn import DataParallel
16
from torch.nn.parallel import DistributedDataParallel
17
from torchvision import transforms
18
from PIL import Image
19
from tqdm import tqdm
20
from scipy import ndimage
21
from utils.style_ops import conv2d_gradfix
22
from utils.style_ops import upfirdn2d
23
from sklearn.manifold import TSNE
24
from datetime import datetime
25
import torch
26
import torchvision
27
import torch.nn as nn
28
import torch.distributed as dist
29
import torch.nn.functional as F
30
import numpy as np
31
32
import metrics.features as features
33
import metrics.ins as ins
34
import metrics.fid as fid
35
import metrics.prdc as prdc
36
import metrics.resnet as resnet
37
import utils.ckpt as ckpt
38
import utils.sample as sample
39
import utils.misc as misc
40
import utils.losses as losses
41
import utils.sefa as sefa
42
import utils.ops as ops
43
import utils.resize as resize
44
import utils.apa_aug as apa_aug
45
import wandb
46
47
SAVE_FORMAT = "step={step:0>3}-Inception_mean={Inception_mean:<.4}-Inception_std={Inception_std:<.4}-FID={FID:<.5}.pth"
48
49
LOG_FORMAT = ("Step: {step:>6} "
50
"Progress: {progress:<.1%} "
51
"Elapsed: {elapsed} "
52
"Gen_loss: {gen_loss:<.4} "
53
"Dis_loss: {dis_loss:<.4} "
54
"Cls_loss: {cls_loss:<.4} "
55
"Topk: {topk:>4} "
56
"aa_p: {aa_p:<.4} ")
57
58
59
class WORKER(object):
60
def __init__(self, cfgs, run_name, Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis,
61
ema, eval_model, train_dataloader, eval_dataloader, global_rank, local_rank, mu, sigma, real_feats, logger,
62
aa_p, best_step, best_fid, best_ckpt_path, lecam_emas, num_eval, loss_list_dict, metric_dict_during_train):
63
self.cfgs = cfgs
64
self.run_name = run_name
65
self.Gen = Gen
66
self.Gen_mapping = Gen_mapping
67
self.Gen_synthesis = Gen_synthesis
68
self.Dis = Dis
69
self.Gen_ema = Gen_ema
70
self.Gen_ema_mapping = Gen_ema_mapping
71
self.Gen_ema_synthesis = Gen_ema_synthesis
72
self.ema = ema
73
self.eval_model = eval_model
74
self.train_dataloader = train_dataloader
75
self.eval_dataloader = eval_dataloader
76
self.global_rank = global_rank
77
self.local_rank = local_rank
78
self.mu = mu
79
self.sigma = sigma
80
self.real_feats = real_feats
81
self.logger = logger
82
self.aa_p = aa_p
83
self.best_step = best_step
84
self.best_fid = best_fid
85
self.best_ckpt_path = best_ckpt_path
86
self.lecam_emas = lecam_emas
87
self.num_eval = num_eval
88
self.loss_list_dict = loss_list_dict
89
self.metric_dict_during_train = metric_dict_during_train
90
self.metric_dict_during_final_eval = {}
91
92
self.cfgs.define_augments(local_rank)
93
self.cfgs.define_losses()
94
self.DATA = cfgs.DATA
95
self.MODEL = cfgs.MODEL
96
self.LOSS = cfgs.LOSS
97
self.STYLEGAN = cfgs.STYLEGAN
98
self.OPTIMIZATION = cfgs.OPTIMIZATION
99
self.PRE = cfgs.PRE
100
self.AUG = cfgs.AUG
101
self.RUN = cfgs.RUN
102
self.MISC = cfgs.MISC
103
self.is_stylegan = cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]
104
self.effective_batch_size = self.OPTIMIZATION.batch_size * self.OPTIMIZATION.acml_steps
105
self.blur_init_sigma = self.STYLEGAN.blur_init_sigma
106
self.blur_fade_kimg = self.effective_batch_size * 200/32
107
self.DDP = self.RUN.distributed_data_parallel
108
self.adc_fake = False
109
110
num_classes = self.DATA.num_classes
111
112
self.sampler = misc.define_sampler(self.DATA.name, self.MODEL.d_cond_mtd,
113
self.OPTIMIZATION.batch_size, self.DATA.num_classes)
114
115
self.pl_reg = losses.PathLengthRegularizer(device=local_rank, pl_weight=cfgs.STYLEGAN.pl_weight, pl_no_weight_grad=(cfgs.MODEL.backbone == "stylegan2"))
116
self.l2_loss = torch.nn.MSELoss()
117
self.ce_loss = torch.nn.CrossEntropyLoss()
118
self.fm_loss = losses.feature_matching_loss
119
self.lecam_ema = ops.LeCamEMA()
120
if self.lecam_emas is not None:
121
self.lecam_ema.__dict__ = self.lecam_emas
122
self.lecam_ema.decay, self.lecam_ema.start_itr = self.LOSS.lecam_ema_decay, self.LOSS.lecam_ema_start_iter
123
if self.LOSS.adv_loss == "MH":
124
self.lossy = torch.LongTensor(self.OPTIMIZATION.batch_size).to(self.local_rank)
125
self.lossy.data.fill_(self.DATA.num_classes)
126
127
if self.AUG.apply_ada + self.AUG.apply_apa:
128
if self.AUG.apply_ada: self.AUG.series_augment.p.copy_(torch.as_tensor(self.aa_p))
129
self.aa_interval = self.AUG.ada_interval if self.AUG.ada_interval != "N/A" else self.AUG.apa_interval
130
self.aa_target = self.AUG.ada_target if self.AUG.ada_target != "N/A" else self.AUG.apa_target
131
self.aa_kimg = self.AUG.ada_kimg if self.AUG.ada_kimg != "N/A" else self.AUG.apa_kimg
132
self.dis_sign_real, self.dis_sign_fake = torch.zeros(2, device=self.local_rank), torch.zeros(2, device=self.local_rank)
133
self.dis_logit_real, self.dis_logit_fake = torch.zeros(2, device=self.local_rank), torch.zeros(2, device=self.local_rank)
134
self.dis_sign_real_log, self.dis_sign_fake_log = torch.zeros(2, device=self.local_rank), torch.zeros(2, device=self.local_rank)
135
self.dis_logit_real_log, self.dis_logit_fake_log = torch.zeros(2, device=self.local_rank), torch.zeros(2, device=self.local_rank)
136
137
if self.MODEL.aux_cls_type == "ADC":
138
num_classes = num_classes*2
139
self.adc_fake = True
140
141
if self.MODEL.d_cond_mtd == "AC":
142
self.cond_loss = losses.CrossEntropyLoss()
143
elif self.MODEL.d_cond_mtd == "2C":
144
self.cond_loss = losses.ConditionalContrastiveLoss(num_classes=num_classes,
145
temperature=self.LOSS.temperature,
146
master_rank="cuda",
147
DDP=self.DDP)
148
elif self.MODEL.d_cond_mtd == "D2DCE":
149
self.cond_loss = losses.Data2DataCrossEntropyLoss(num_classes=num_classes,
150
temperature=self.LOSS.temperature,
151
m_p=self.LOSS.m_p,
152
master_rank="cuda",
153
DDP=self.DDP)
154
else: pass
155
156
if self.MODEL.aux_cls_type == "TAC":
157
self.cond_loss_mi = copy.deepcopy(self.cond_loss)
158
159
self.gen_ctlr = misc.GeneratorController(generator=self.Gen_ema if self.MODEL.apply_g_ema else self.Gen,
160
generator_mapping=self.Gen_ema_mapping,
161
generator_synthesis=self.Gen_ema_synthesis,
162
batch_statistics=self.RUN.batch_statistics,
163
standing_statistics=False,
164
standing_max_batch="N/A",
165
standing_step="N/A",
166
cfgs=self.cfgs,
167
device=self.local_rank,
168
global_rank=self.global_rank,
169
logger=self.logger,
170
std_stat_counter=0)
171
172
if self.DDP:
173
self.group = dist.new_group([n for n in range(self.OPTIMIZATION.world_size)])
174
175
if self.RUN.mixed_precision and not self.is_stylegan:
176
self.scaler = torch.cuda.amp.GradScaler()
177
178
if self.global_rank == 0:
179
resume = False if self.RUN.freezeD > -1 else True
180
wandb.init(project=self.RUN.project,
181
entity=self.RUN.entity,
182
name=self.run_name,
183
dir=self.RUN.save_dir,
184
resume=self.best_step > 0 and resume)
185
186
self.start_time = datetime.now()
187
188
def prepare_train_iter(self, epoch_counter):
189
self.epoch_counter = epoch_counter
190
if self.DDP:
191
self.train_dataloader.sampler.set_epoch(self.epoch_counter)
192
self.train_iter = iter(self.train_dataloader)
193
194
def sample_data_basket(self):
195
try:
196
real_image_basket, real_label_basket = next(self.train_iter)
197
except StopIteration:
198
self.epoch_counter += 1
199
if self.RUN.train and self.DDP:
200
self.train_dataloader.sampler.set_epoch(self.epoch_counter)
201
else:
202
pass
203
self.train_iter = iter(self.train_dataloader)
204
real_image_basket, real_label_basket = next(self.train_iter)
205
206
real_image_basket = torch.split(real_image_basket, self.OPTIMIZATION.batch_size)
207
real_label_basket = torch.split(real_label_basket, self.OPTIMIZATION.batch_size)
208
return real_image_basket, real_label_basket
209
210
# -----------------------------------------------------------------------------
211
# train Discriminator
212
# -----------------------------------------------------------------------------
213
def train_discriminator(self, current_step):
214
batch_counter = 0
215
# make GAN be trainable before starting training
216
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
217
# toggle gradients of the generator and discriminator
218
misc.toggle_grad(model=self.Gen, grad=False, num_freeze_layers=-1, is_stylegan=self.is_stylegan)
219
misc.toggle_grad(model=self.Dis, grad=True, num_freeze_layers=self.RUN.freezeD, is_stylegan=self.is_stylegan)
220
if self.MODEL.info_type in ["discrete", "both"]:
221
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[0]), grad=False, num_freeze_layers=-1, is_stylegan=False)
222
if self.MODEL.info_type in ["continuous", "both"]:
223
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[1]), grad=False, num_freeze_layers=-1, is_stylegan=False)
224
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[2]), grad=False, num_freeze_layers=-1, is_stylegan=False)
225
if self.DDP*self.RUN.mixed_precision*self.RUN.synchronized_bn == 0: self.Gen.apply(misc.untrack_bn_statistics)
226
# sample real images and labels from the true data distribution
227
real_image_basket, real_label_basket = self.sample_data_basket()
228
for step_index in range(self.OPTIMIZATION.d_updates_per_step):
229
self.OPTIMIZATION.d_optimizer.zero_grad()
230
for acml_index in range(self.OPTIMIZATION.acml_steps):
231
with torch.cuda.amp.autocast() if self.RUN.mixed_precision and not self.is_stylegan else misc.dummy_context_mgr() as mpc:
232
# load real images and labels onto the GPU memory
233
real_images = real_image_basket[batch_counter].to(self.local_rank, non_blocking=True)
234
real_labels = real_label_basket[batch_counter].to(self.local_rank, non_blocking=True)
235
# sample fake images and labels from p(G(z), y)
236
fake_images, fake_labels, fake_images_eps, trsp_cost, ws, _, _ = sample.generate_images(
237
z_prior=self.MODEL.z_prior,
238
truncation_factor=-1.0,
239
batch_size=self.OPTIMIZATION.batch_size,
240
z_dim=self.MODEL.z_dim,
241
num_classes=self.DATA.num_classes,
242
y_sampler="totally_random",
243
radius=self.LOSS.radius,
244
generator=self.Gen,
245
discriminator=self.Dis,
246
is_train=True,
247
LOSS=self.LOSS,
248
RUN=self.RUN,
249
MODEL=self.MODEL,
250
device=self.local_rank,
251
generator_mapping=self.Gen_mapping,
252
generator_synthesis=self.Gen_synthesis,
253
is_stylegan=self.is_stylegan,
254
style_mixing_p=self.cfgs.STYLEGAN.style_mixing_p,
255
stylegan_update_emas=True,
256
cal_trsp_cost=True if self.LOSS.apply_lo else False)
257
258
# if LOSS.apply_r1_reg is True,
259
# let real images require gradient calculation to compute \derv_{x}Dis(x)
260
if self.LOSS.apply_r1_reg and not self.is_stylegan:
261
real_images.requires_grad_(True)
262
263
# blur images for stylegan3-r
264
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r" and self.blur_init_sigma != "N/A":
265
blur_sigma = max(1 - (self.effective_batch_size * current_step) / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma
266
blur_size = np.floor(blur_sigma * 3)
267
if blur_size > 0:
268
f = torch.arange(-blur_size, blur_size + 1, device=real_images.device).div(blur_sigma).square().neg().exp2()
269
real_images = upfirdn2d.filter2d(real_images, f / f.sum())
270
fake_images = upfirdn2d.filter2d(fake_images, f / f.sum())
271
272
# shuffle real and fake images (APA)
273
if self.AUG.apply_apa:
274
real_images = apa_aug.apply_apa_aug(real_images, fake_images.detach(), self.aa_p, self.local_rank)
275
276
# apply differentiable augmentations if "apply_diffaug" or "apply_ada" is True
277
real_images_ = self.AUG.series_augment(real_images)
278
fake_images_ = self.AUG.series_augment(fake_images)
279
280
# calculate adv_output, embed, proxy, and cls_output using the discriminator
281
real_dict = self.Dis(real_images_, real_labels)
282
fake_dict = self.Dis(fake_images_, fake_labels, adc_fake=self.adc_fake)
283
284
# accumulate discriminator output informations for logging
285
if self.AUG.apply_ada or self.AUG.apply_apa:
286
self.dis_sign_real += torch.tensor((real_dict["adv_output"].sign().sum().item(),
287
self.OPTIMIZATION.batch_size),
288
device=self.local_rank)
289
self.dis_sign_fake += torch.tensor((fake_dict["adv_output"].sign().sum().item(),
290
self.OPTIMIZATION.batch_size),
291
device=self.local_rank)
292
self.dis_logit_real += torch.tensor((real_dict["adv_output"].sum().item(),
293
self.OPTIMIZATION.batch_size),
294
device=self.local_rank)
295
self.dis_logit_fake += torch.tensor((fake_dict["adv_output"].sum().item(),
296
self.OPTIMIZATION.batch_size),
297
device=self.local_rank)
298
299
# calculate adversarial loss defined by "LOSS.adv_loss"
300
if self.LOSS.adv_loss == "MH":
301
dis_acml_loss = self.LOSS.d_loss(DDP=self.DDP, **real_dict)
302
dis_acml_loss += self.LOSS.d_loss(fake_dict["adv_output"], self.lossy, DDP=self.DDP)
303
else:
304
dis_acml_loss = self.LOSS.d_loss(real_dict["adv_output"], fake_dict["adv_output"], DDP=self.DDP)
305
306
# calculate class conditioning loss defined by "MODEL.d_cond_mtd"
307
if self.MODEL.d_cond_mtd in self.MISC.classifier_based_GAN:
308
real_cond_loss = self.cond_loss(**real_dict)
309
dis_acml_loss += self.LOSS.cond_lambda * real_cond_loss
310
if self.MODEL.aux_cls_type == "TAC":
311
tac_dis_loss = self.cond_loss_mi(**fake_dict)
312
dis_acml_loss += self.LOSS.tac_dis_lambda * tac_dis_loss
313
elif self.MODEL.aux_cls_type == "ADC":
314
fake_cond_loss = self.cond_loss(**fake_dict)
315
dis_acml_loss += self.LOSS.cond_lambda * fake_cond_loss
316
else:
317
pass
318
else:
319
real_cond_loss = "N/A"
320
321
# add transport cost for latent optimization training
322
if self.LOSS.apply_lo:
323
dis_acml_loss += self.LOSS.lo_lambda * trsp_cost
324
325
# if LOSS.apply_cr is True, force the adv. and cls. logits to be the same
326
if self.LOSS.apply_cr:
327
real_prl_images = self.AUG.parallel_augment(real_images)
328
real_prl_dict = self.Dis(real_prl_images, real_labels)
329
real_consist_loss = self.l2_loss(real_dict["adv_output"], real_prl_dict["adv_output"])
330
if self.MODEL.d_cond_mtd == "AC":
331
real_consist_loss += self.l2_loss(real_dict["cls_output"], real_prl_dict["cls_output"])
332
elif self.MODEL.d_cond_mtd in ["2C", "D2DCE"]:
333
real_consist_loss += self.l2_loss(real_dict["embed"], real_prl_dict["embed"])
334
else:
335
pass
336
dis_acml_loss += self.LOSS.cr_lambda * real_consist_loss
337
338
# if LOSS.apply_bcr is True, apply balanced consistency regularization proposed in ICRGAN
339
if self.LOSS.apply_bcr:
340
real_prl_images = self.AUG.parallel_augment(real_images)
341
fake_prl_images = self.AUG.parallel_augment(fake_images)
342
real_prl_dict = self.Dis(real_prl_images, real_labels)
343
fake_prl_dict = self.Dis(fake_prl_images, fake_labels, adc_fake=self.adc_fake)
344
real_bcr_loss = self.l2_loss(real_dict["adv_output"], real_prl_dict["adv_output"])
345
fake_bcr_loss = self.l2_loss(fake_dict["adv_output"], fake_prl_dict["adv_output"])
346
if self.MODEL.d_cond_mtd == "AC":
347
real_bcr_loss += self.l2_loss(real_dict["cls_output"], real_prl_dict["cls_output"])
348
fake_bcr_loss += self.l2_loss(fake_dict["cls_output"], fake_prl_dict["cls_output"])
349
elif self.MODEL.d_cond_mtd in ["2C", "D2DCE"]:
350
real_bcr_loss += self.l2_loss(real_dict["embed"], real_prl_dict["embed"])
351
fake_bcr_loss += self.l2_loss(fake_dict["embed"], fake_prl_dict["embed"])
352
else:
353
pass
354
dis_acml_loss += self.LOSS.real_lambda * real_bcr_loss + self.LOSS.fake_lambda * fake_bcr_loss
355
356
# if LOSS.apply_zcr is True, apply latent consistency regularization proposed in ICRGAN
357
if self.LOSS.apply_zcr:
358
fake_eps_dict = self.Dis(fake_images_eps, fake_labels, adc_fake=self.adc_fake)
359
fake_zcr_loss = self.l2_loss(fake_dict["adv_output"], fake_eps_dict["adv_output"])
360
if self.MODEL.d_cond_mtd == "AC":
361
fake_zcr_loss += self.l2_loss(fake_dict["cls_output"], fake_eps_dict["cls_output"])
362
elif self.MODEL.d_cond_mtd in ["2C", "D2DCE"]:
363
fake_zcr_loss += self.l2_loss(fake_dict["embed"], fake_eps_dict["embed"])
364
else:
365
pass
366
dis_acml_loss += self.LOSS.d_lambda * fake_zcr_loss
367
368
# apply gradient penalty regularization to train wasserstein GAN
369
if self.LOSS.apply_gp:
370
gp_loss = losses.cal_grad_penalty(real_images=real_images,
371
real_labels=real_labels,
372
fake_images=fake_images,
373
discriminator=self.Dis,
374
device=self.local_rank)
375
dis_acml_loss += self.LOSS.gp_lambda * gp_loss
376
377
# apply deep regret analysis regularization to train wasserstein GAN
378
if self.LOSS.apply_dra:
379
dra_loss = losses.cal_dra_penalty(real_images=real_images,
380
real_labels=real_labels,
381
discriminator=self.Dis,
382
device=self.local_rank)
383
dis_acml_loss += self.LOSS.dra_lambda * dra_loss
384
385
# apply max gradient penalty regularization to train Lipschitz GAN
386
if self.LOSS.apply_maxgp:
387
maxgp_loss = losses.cal_maxgrad_penalty(real_images=real_images,
388
real_labels=real_labels,
389
fake_images=fake_images,
390
discriminator=self.Dis,
391
device=self.local_rank)
392
dis_acml_loss += self.LOSS.maxgp_lambda * maxgp_loss
393
394
# apply LeCam reg. for data-efficient training if self.LOSS.apply_lecam is set to True
395
if self.LOSS.apply_lecam:
396
if self.DDP:
397
real_adv_output = torch.cat(losses.GatherLayer.apply(real_dict["adv_output"]), dim=0)
398
fake_adv_output = torch.cat(losses.GatherLayer.apply(fake_dict["adv_output"]), dim=0)
399
else:
400
real_adv_output, fake_adv_output = real_dict["adv_output"], fake_dict["adv_output"]
401
self.lecam_ema.update(torch.mean(real_adv_output).item(), "D_real", current_step)
402
self.lecam_ema.update(torch.mean(fake_adv_output).item(), "D_fake", current_step)
403
if current_step > self.LOSS.lecam_ema_start_iter:
404
lecam_loss = losses.lecam_reg(real_adv_output, fake_adv_output, self.lecam_ema)
405
else:
406
lecam_loss = torch.tensor(0., device=self.local_rank)
407
dis_acml_loss += self.LOSS.lecam_lambda*lecam_loss
408
409
# apply r1_reg inside of training loop
410
if self.LOSS.apply_r1_reg and not self.is_stylegan:
411
self.r1_penalty = losses.cal_r1_reg(adv_output=real_dict["adv_output"], images=real_images, device=self.local_rank)
412
dis_acml_loss += self.LOSS.r1_lambda*self.r1_penalty
413
elif self.LOSS.apply_r1_reg and self.LOSS.r1_place == "inside_loop" and \
414
(self.OPTIMIZATION.d_updates_per_step*current_step + step_index) % self.STYLEGAN.d_reg_interval == 0:
415
real_images.requires_grad_(True)
416
real_dict = self.Dis(self.AUG.series_augment(real_images), real_labels)
417
self.r1_penalty = losses.stylegan_cal_r1_reg(adv_output=real_dict["adv_output"],
418
images=real_images)
419
dis_acml_loss += self.STYLEGAN.d_reg_interval*self.LOSS.r1_lambda*self.r1_penalty
420
if self.AUG.apply_ada or self.AUG.apply_apa:
421
self.dis_sign_real += torch.tensor((real_dict["adv_output"].sign().sum().item(),
422
self.OPTIMIZATION.batch_size),
423
device=self.local_rank)
424
self.dis_logit_real += torch.tensor((real_dict["adv_output"].sum().item(),
425
self.OPTIMIZATION.batch_size),
426
device=self.local_rank)
427
428
# adjust gradients for applying gradient accumluation trick
429
dis_acml_loss = dis_acml_loss / self.OPTIMIZATION.acml_steps
430
batch_counter += 1
431
432
# accumulate gradients of the discriminator
433
if self.RUN.mixed_precision and not self.is_stylegan:
434
self.scaler.scale(dis_acml_loss).backward()
435
else:
436
dis_acml_loss.backward()
437
438
# update the discriminator using the pre-defined optimizer
439
if self.RUN.mixed_precision and not self.is_stylegan:
440
self.scaler.step(self.OPTIMIZATION.d_optimizer)
441
self.scaler.update()
442
else:
443
self.OPTIMIZATION.d_optimizer.step()
444
445
# apply r1_reg outside of training loop
446
if self.LOSS.apply_r1_reg and self.LOSS.r1_place == "outside_loop" and \
447
(self.OPTIMIZATION.d_updates_per_step*current_step + step_index) % self.STYLEGAN.d_reg_interval == 0:
448
self.OPTIMIZATION.d_optimizer.zero_grad()
449
for acml_index in range(self.OPTIMIZATION.acml_steps):
450
real_images = real_image_basket[batch_counter - acml_index - 1].to(self.local_rank, non_blocking=True)
451
real_labels = real_label_basket[batch_counter - acml_index - 1].to(self.local_rank, non_blocking=True)
452
# blur images for stylegan3-r
453
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r" and self.blur_init_sigma != "N/A":
454
blur_sigma = max(1 - (self.effective_batch_size * current_step) / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma
455
blur_size = np.floor(blur_sigma * 3)
456
if blur_size > 0:
457
f = torch.arange(-blur_size, blur_size + 1, device=real_images.device).div(blur_sigma).square().neg().exp2()
458
real_images = upfirdn2d.filter2d(real_images, f / f.sum())
459
if self.AUG.apply_apa:
460
real_images = apa_aug.apply_apa_aug(real_images, fake_images.detach(), self.aa_p, self.local_rank)
461
real_images.requires_grad_(True)
462
real_dict = self.Dis(self.AUG.series_augment(real_images), real_labels)
463
self.r1_penalty = losses.stylegan_cal_r1_reg(adv_output=real_dict["adv_output"], images=real_images) + \
464
misc.enable_allreduce(real_dict)
465
self.r1_penalty *= self.STYLEGAN.d_reg_interval*self.LOSS.r1_lambda/self.OPTIMIZATION.acml_steps
466
self.r1_penalty.backward()
467
468
if self.AUG.apply_ada or self.AUG.apply_apa:
469
self.dis_sign_real += torch.tensor((real_dict["adv_output"].sign().sum().item(),
470
self.OPTIMIZATION.batch_size),
471
device=self.local_rank)
472
self.dis_logit_real += torch.tensor((real_dict["adv_output"].sum().item(),
473
self.OPTIMIZATION.batch_size),
474
device=self.local_rank)
475
self.OPTIMIZATION.d_optimizer.step()
476
477
# apply ada heuristics
478
if (self.AUG.apply_ada or self.AUG.apply_apa) and self.aa_target is not None and current_step % self.aa_interval == 0:
479
if self.DDP: dist.all_reduce(self.dis_sign_real, op=dist.ReduceOp.SUM, group=self.group)
480
heuristic = (self.dis_sign_real[0] / self.dis_sign_real[1]).item()
481
adjust = np.sign(heuristic - self.aa_target) * (self.dis_sign_real[1].item()) / (self.aa_kimg * 1000)
482
self.aa_p = min(torch.as_tensor(1.), max(self.aa_p + adjust, torch.as_tensor(0.)))
483
if self.AUG.apply_ada: self.AUG.series_augment.p.copy_(torch.as_tensor(self.aa_p))
484
self.dis_sign_real_log.copy_(self.dis_sign_real), self.dis_sign_fake_log.copy_(self.dis_sign_fake)
485
self.dis_logit_real_log.copy_(self.dis_logit_real), self.dis_logit_fake_log.copy_(self.dis_logit_fake)
486
self.dis_sign_real.mul_(0), self.dis_sign_fake.mul_(0)
487
self.dis_logit_real.mul_(0), self.dis_logit_fake.mul_(0)
488
489
# clip weights to restrict the discriminator to satisfy 1-Lipschitz constraint
490
if self.LOSS.apply_wc:
491
for p in self.Dis.parameters():
492
p.data.clamp_(-self.LOSS.wc_bound, self.LOSS.wc_bound)
493
494
# empty cache to discard used memory
495
if self.RUN.empty_cache:
496
torch.cuda.empty_cache()
497
return real_cond_loss, dis_acml_loss
498
499
# -----------------------------------------------------------------------------
500
# train Generator
501
# -----------------------------------------------------------------------------
502
def train_generator(self, current_step):
503
# make GAN be trainable before starting training
504
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
505
# toggle gradients of the generator and discriminator
506
misc.toggle_grad(model=self.Dis, grad=False, num_freeze_layers=-1, is_stylegan=self.is_stylegan)
507
misc.toggle_grad(model=self.Gen, grad=True, num_freeze_layers=-1, is_stylegan=self.is_stylegan)
508
if self.MODEL.info_type in ["discrete", "both"]:
509
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[0]), grad=True, num_freeze_layers=-1, is_stylegan=False)
510
if self.MODEL.info_type in ["continuous", "both"]:
511
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[1]), grad=True, num_freeze_layers=-1, is_stylegan=False)
512
misc.toggle_grad(getattr(misc.peel_model(self.Dis), self.MISC.info_params[2]), grad=True, num_freeze_layers=-1, is_stylegan=False)
513
self.Gen.apply(misc.track_bn_statistics)
514
for step_index in range(self.OPTIMIZATION.g_updates_per_step):
515
self.OPTIMIZATION.g_optimizer.zero_grad()
516
for acml_step in range(self.OPTIMIZATION.acml_steps):
517
with torch.cuda.amp.autocast() if self.RUN.mixed_precision and not self.is_stylegan else misc.dummy_context_mgr() as mpc:
518
# sample fake images and labels from p(G(z), y)
519
fake_images, fake_labels, fake_images_eps, trsp_cost, ws, info_discrete_c, info_conti_c = sample.generate_images(
520
z_prior=self.MODEL.z_prior,
521
truncation_factor=-1.0,
522
batch_size=self.OPTIMIZATION.batch_size,
523
z_dim=self.MODEL.z_dim,
524
num_classes=self.DATA.num_classes,
525
y_sampler="totally_random",
526
radius=self.LOSS.radius,
527
generator=self.Gen,
528
discriminator=self.Dis,
529
is_train=True,
530
LOSS=self.LOSS,
531
RUN=self.RUN,
532
MODEL=self.MODEL,
533
device=self.local_rank,
534
generator_mapping=self.Gen_mapping,
535
generator_synthesis=self.Gen_synthesis,
536
is_stylegan=self.is_stylegan,
537
style_mixing_p=self.cfgs.STYLEGAN.style_mixing_p,
538
stylegan_update_emas=False,
539
cal_trsp_cost=True if self.LOSS.apply_lo else False)
540
541
# blur images for stylegan3-r
542
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r" and self.blur_init_sigma != "N/A":
543
blur_sigma = max(1 - (self.effective_batch_size * current_step) / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma
544
blur_size = np.floor(blur_sigma * 3)
545
if blur_size > 0:
546
f = torch.arange(-blur_size, blur_size + 1, device=fake_images.device).div(blur_sigma).square().neg().exp2()
547
fake_images = upfirdn2d.filter2d(fake_images, f / f.sum())
548
549
# apply differentiable augmentations if "apply_diffaug" is True
550
fake_images_ = self.AUG.series_augment(fake_images)
551
552
# calculate adv_output, embed, proxy, and cls_output using the discriminator
553
fake_dict = self.Dis(fake_images_, fake_labels)
554
555
# accumulate discriminator output informations for logging
556
if self.AUG.apply_ada or self.AUG.apply_apa:
557
self.dis_sign_fake += torch.tensor((fake_dict["adv_output"].sign().sum().item(),
558
self.OPTIMIZATION.batch_size),
559
device=self.local_rank)
560
self.dis_logit_fake += torch.tensor((fake_dict["adv_output"].sum().item(),
561
self.OPTIMIZATION.batch_size),
562
device=self.local_rank)
563
564
# apply top k sampling for discarding bottom 1-k samples which are 'in-between modes'
565
if self.LOSS.apply_topk:
566
fake_dict["adv_output"] = torch.topk(fake_dict["adv_output"], int(self.topk)).values
567
568
# calculate adversarial loss defined by "LOSS.adv_loss"
569
if self.LOSS.adv_loss == "MH":
570
gen_acml_loss = self.LOSS.mh_lambda * self.LOSS.g_loss(DDP=self.DDP, **fake_dict, )
571
else:
572
gen_acml_loss = self.LOSS.g_loss(fake_dict["adv_output"], DDP=self.DDP)
573
574
# calculate class conditioning loss defined by "MODEL.d_cond_mtd"
575
if self.MODEL.d_cond_mtd in self.MISC.classifier_based_GAN:
576
fake_cond_loss = self.cond_loss(**fake_dict)
577
gen_acml_loss += self.LOSS.cond_lambda * fake_cond_loss
578
if self.MODEL.aux_cls_type == "TAC":
579
tac_gen_loss = -self.cond_loss_mi(**fake_dict)
580
gen_acml_loss += self.LOSS.tac_gen_lambda * tac_gen_loss
581
elif self.MODEL.aux_cls_type == "ADC":
582
adc_fake_dict = self.Dis(fake_images_, fake_labels, adc_fake=self.adc_fake)
583
adc_fake_cond_loss = -self.cond_loss(**adc_fake_dict)
584
gen_acml_loss += self.LOSS.cond_lambda * adc_fake_cond_loss
585
pass
586
587
# apply feature matching regularization to stabilize adversarial dynamics
588
if self.LOSS.apply_fm:
589
real_image_basket, real_label_basket = self.sample_data_basket()
590
real_images = real_image_basket[0].to(self.local_rank, non_blocking=True)
591
real_labels = real_label_basket[0].to(self.local_rank, non_blocking=True)
592
real_images_ = self.AUG.series_augment(real_images)
593
real_dict = self.Dis(real_images_, real_labels)
594
595
mean_match_loss = self.fm_loss(real_dict["h"].detach(), fake_dict["h"])
596
gen_acml_loss += self.LOSS.fm_lambda * mean_match_loss
597
598
# add transport cost for latent optimization training
599
if self.LOSS.apply_lo:
600
gen_acml_loss += self.LOSS.lo_lambda * trsp_cost
601
602
# apply latent consistency regularization for generating diverse images
603
if self.LOSS.apply_zcr:
604
fake_zcr_loss = -1 * self.l2_loss(fake_images, fake_images_eps)
605
gen_acml_loss += self.LOSS.g_lambda * fake_zcr_loss
606
607
# compute infomation loss for InfoGAN
608
if self.MODEL.info_type in ["discrete", "both"]:
609
dim = self.MODEL.info_dim_discrete_c
610
self.info_discrete_loss = 0.0
611
for info_c in range(self.MODEL.info_num_discrete_c):
612
self.info_discrete_loss += self.ce_loss(
613
fake_dict["info_discrete_c_logits"][:, info_c*dim: dim*(info_c+1)],
614
info_discrete_c[:, info_c: info_c+1].squeeze())
615
gen_acml_loss += self.LOSS.infoGAN_loss_discrete_lambda*self.info_discrete_loss + misc.enable_allreduce(fake_dict)
616
if self.MODEL.info_type in ["continuous", "both"]:
617
self.info_conti_loss = losses.normal_nll_loss(info_conti_c, fake_dict["info_conti_mu"], fake_dict["info_conti_var"])
618
gen_acml_loss += self.LOSS.infoGAN_loss_conti_lambda*self.info_conti_loss + misc.enable_allreduce(fake_dict)
619
620
# adjust gradients for applying gradient accumluation trick
621
gen_acml_loss = gen_acml_loss / self.OPTIMIZATION.acml_steps
622
623
# accumulate gradients of the generator
624
if self.RUN.mixed_precision and not self.is_stylegan:
625
self.scaler.scale(gen_acml_loss).backward()
626
else:
627
gen_acml_loss.backward()
628
629
# update the generator using the pre-defined optimizer
630
if self.RUN.mixed_precision and not self.is_stylegan:
631
self.scaler.step(self.OPTIMIZATION.g_optimizer)
632
self.scaler.update()
633
else:
634
self.OPTIMIZATION.g_optimizer.step()
635
636
# apply path length regularization
637
if self.STYLEGAN.apply_pl_reg and (self.OPTIMIZATION.g_updates_per_step*current_step + step_index) % self.STYLEGAN.g_reg_interval == 0:
638
self.OPTIMIZATION.g_optimizer.zero_grad()
639
for acml_index in range(self.OPTIMIZATION.acml_steps):
640
fake_images, fake_labels, fake_images_eps, trsp_cost, ws, _, _ = sample.generate_images(
641
z_prior=self.MODEL.z_prior,
642
truncation_factor=-1.0,
643
batch_size=self.OPTIMIZATION.batch_size // 2,
644
z_dim=self.MODEL.z_dim,
645
num_classes=self.DATA.num_classes,
646
y_sampler="totally_random",
647
radius=self.LOSS.radius,
648
generator=self.Gen,
649
discriminator=self.Dis,
650
is_train=True,
651
LOSS=self.LOSS,
652
RUN=self.RUN,
653
MODEL=self.MODEL,
654
device=self.local_rank,
655
generator_mapping=self.Gen_mapping,
656
generator_synthesis=self.Gen_synthesis,
657
is_stylegan=self.is_stylegan,
658
style_mixing_p=self.cfgs.STYLEGAN.style_mixing_p,
659
stylegan_update_emas=False,
660
cal_trsp_cost=True if self.LOSS.apply_lo else False)
661
662
# blur images for stylegan3-r
663
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r" and self.blur_init_sigma != "N/A":
664
blur_sigma = max(1 - (self.effective_batch_size * current_step) / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma
665
blur_size = np.floor(blur_sigma * 3)
666
if blur_size > 0:
667
f = torch.arange(-blur_size, blur_size + 1, device=fake_images.device).div(blur_sigma).square().neg().exp2()
668
fake_images = upfirdn2d.filter2d(fake_images, f / f.sum())
669
self.pl_reg_loss = self.pl_reg.cal_pl_reg(fake_images=fake_images, ws=ws) + fake_images[:,0,0,0].mean()*0
670
self.pl_reg_loss *= self.STYLEGAN.g_reg_interval/self.OPTIMIZATION.acml_steps
671
self.pl_reg_loss.backward()
672
self.OPTIMIZATION.g_optimizer.step()
673
674
# if ema is True: update parameters of the Gen_ema in adaptive way
675
if self.MODEL.apply_g_ema:
676
self.ema.update(current_step)
677
678
# empty cache to discard used memory
679
if self.RUN.empty_cache:
680
torch.cuda.empty_cache()
681
return gen_acml_loss
682
683
# -----------------------------------------------------------------------------
684
# log training statistics
685
# -----------------------------------------------------------------------------
686
def log_train_statistics(self, current_step, real_cond_loss, gen_acml_loss, dis_acml_loss):
687
self.wandb_step = current_step + 1
688
if self.MODEL.d_cond_mtd in self.MISC.classifier_based_GAN:
689
cls_loss = real_cond_loss.item()
690
else:
691
cls_loss = "N/A"
692
693
log_message = LOG_FORMAT.format(
694
step=current_step + 1,
695
progress=(current_step + 1) / self.OPTIMIZATION.total_steps,
696
elapsed=misc.elapsed_time(self.start_time),
697
gen_loss=gen_acml_loss.item(),
698
dis_loss=dis_acml_loss.item(),
699
cls_loss=cls_loss,
700
topk=int(self.topk) if self.LOSS.apply_topk else "N/A",
701
aa_p=self.aa_p if self.AUG.apply_ada or self.AUG.apply_apa else "N/A",
702
)
703
self.logger.info(log_message)
704
705
# save loss values in wandb event file and .npz format
706
loss_dict = {
707
"gen_loss": gen_acml_loss.item(),
708
"dis_loss": dis_acml_loss.item(),
709
"cls_loss": 0.0 if cls_loss == "N/A" else cls_loss,
710
}
711
712
wandb.log(loss_dict, step=self.wandb_step)
713
714
save_dict = misc.accm_values_convert_dict(list_dict=self.loss_list_dict,
715
value_dict=loss_dict,
716
step=current_step + 1,
717
interval=self.RUN.print_freq)
718
misc.save_dict_npy(directory=join(self.RUN.save_dir, "statistics", self.run_name),
719
name="losses",
720
dictionary=save_dict)
721
722
if self.AUG.apply_ada or self.AUG.apply_apa:
723
dis_output_dict = {
724
"dis_sign_real": (self.dis_sign_real_log[0]/self.dis_sign_real_log[1]).item(),
725
"dis_sign_fake": (self.dis_sign_fake_log[0]/self.dis_sign_fake_log[1]).item(),
726
"dis_logit_real": (self.dis_logit_real_log[0]/self.dis_logit_real_log[1]).item(),
727
"dis_logit_fake": (self.dis_logit_fake_log[0]/self.dis_logit_fake_log[1]).item(),
728
}
729
wandb.log(dis_output_dict, step=self.wandb_step)
730
wandb.log({"aa_p": self.aa_p.item()}, step=self.wandb_step)
731
732
infoGAN_dict = {}
733
if self.MODEL.info_type in ["discrete", "both"]:
734
infoGAN_dict["info_discrete_loss"] = self.info_discrete_loss.item()
735
if self.MODEL.info_type in ["continuous", "both"]:
736
infoGAN_dict["info_conti_loss"] = self.info_conti_loss.item()
737
wandb.log(infoGAN_dict, step=self.wandb_step)
738
739
if self.LOSS.apply_r1_reg:
740
wandb.log({"r1_reg_loss": self.r1_penalty.item()}, step=self.wandb_step)
741
742
if self.STYLEGAN.apply_pl_reg:
743
wandb.log({"pl_reg_loss": self.pl_reg_loss.item()}, step=self.wandb_step)
744
745
# calculate the spectral norms of all weights in the generator for monitoring purpose
746
if self.MODEL.apply_g_sn:
747
gen_sigmas = misc.calculate_all_sn(self.Gen, prefix="Gen")
748
wandb.log(gen_sigmas, step=self.wandb_step)
749
750
# calculate the spectral norms of all weights in the discriminator for monitoring purpose
751
if self.MODEL.apply_d_sn:
752
dis_sigmas = misc.calculate_all_sn(self.Dis, prefix="Dis")
753
wandb.log(dis_sigmas, step=self.wandb_step)
754
755
# -----------------------------------------------------------------------------
756
# visualize fake images for monitoring purpose.
757
# -----------------------------------------------------------------------------
758
def visualize_fake_images(self, num_cols, current_step):
759
if self.global_rank == 0:
760
self.logger.info("Visualize (num_rows x 8) fake image canvans.")
761
if self.gen_ctlr.standing_statistics:
762
self.gen_ctlr.std_stat_counter += 1
763
764
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
765
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
766
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
767
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
768
769
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
770
truncation_factor=self.RUN.truncation_factor,
771
batch_size=self.OPTIMIZATION.batch_size,
772
z_dim=self.MODEL.z_dim,
773
num_classes=self.DATA.num_classes,
774
y_sampler=self.sampler,
775
radius="N/A",
776
generator=generator,
777
discriminator=self.Dis,
778
is_train=False,
779
LOSS=self.LOSS,
780
RUN=self.RUN,
781
MODEL=self.MODEL,
782
device=self.local_rank,
783
is_stylegan=self.is_stylegan,
784
generator_mapping=generator_mapping,
785
generator_synthesis=generator_synthesis,
786
style_mixing_p=0.0,
787
stylegan_update_emas=False,
788
cal_trsp_cost=False)
789
790
misc.plot_img_canvas(images=fake_images.detach().cpu(),
791
save_path=join(self.RUN.save_dir,
792
"figures/{run_name}/generated_canvas_{step}.png".format(run_name=self.run_name, step=current_step)),
793
num_cols=num_cols,
794
logger=self.logger,
795
logging=self.global_rank == 0 and self.logger)
796
797
if self.RUN.train:
798
wandb.log({"generated_images": wandb.Image(fake_images)}, step=self.wandb_step)
799
800
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
801
802
# -----------------------------------------------------------------------------
803
# evaluate GAN using IS, FID, and Precision and recall.
804
# -----------------------------------------------------------------------------
805
def evaluate(self, step, metrics, writing=True, training=False):
806
if self.global_rank == 0:
807
self.logger.info("Start Evaluation ({step} Step): {run_name}".format(step=step, run_name=self.run_name))
808
if self.gen_ctlr.standing_statistics:
809
self.gen_ctlr.std_stat_counter += 1
810
811
is_best, num_splits, nearest_k = False, 1, 5
812
is_acc = True if "ImageNet" in self.DATA.name and "Tiny" not in self.DATA.name else False
813
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
814
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
815
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
816
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
817
metric_dict = {}
818
819
fake_feats, fake_probs, fake_labels = features.generate_images_and_stack_features(
820
generator=generator,
821
discriminator=self.Dis,
822
eval_model=self.eval_model,
823
num_generate=self.num_eval[self.RUN.ref_dataset],
824
y_sampler="totally_random",
825
batch_size=self.OPTIMIZATION.batch_size,
826
z_prior=self.MODEL.z_prior,
827
truncation_factor=self.RUN.truncation_factor,
828
z_dim=self.MODEL.z_dim,
829
num_classes=self.DATA.num_classes,
830
LOSS=self.LOSS,
831
RUN=self.RUN,
832
MODEL=self.MODEL,
833
is_stylegan=self.is_stylegan,
834
generator_mapping=generator_mapping,
835
generator_synthesis=generator_synthesis,
836
quantize=True,
837
world_size=self.OPTIMIZATION.world_size,
838
DDP=self.DDP,
839
device=self.local_rank,
840
logger=self.logger,
841
disable_tqdm=self.global_rank != 0)
842
843
if ("fid" in metrics or "prdc" in metrics) and self.global_rank == 0:
844
self.logger.info("{num_images} real images is used for evaluation.".format(num_images=len(self.eval_dataloader.dataset)))
845
846
if "is" in metrics:
847
kl_score, kl_std, top1, top5 = ins.eval_features(probs=fake_probs,
848
labels=fake_labels,
849
data_loader=self.eval_dataloader,
850
num_features=self.num_eval[self.RUN.ref_dataset],
851
split=num_splits,
852
is_acc=is_acc,
853
is_torch_backbone=True if "torch" in self.RUN.eval_backbone else False)
854
if self.global_rank == 0:
855
self.logger.info("Inception score (Step: {step}, {num} generated images): {IS}".format(
856
step=step, num=str(self.num_eval[self.RUN.ref_dataset]), IS=kl_score))
857
if is_acc:
858
self.logger.info("{eval_model} Top1 acc: (Step: {step}, {num} generated images): {Top1}".format(
859
eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top1=top1))
860
self.logger.info("{eval_model} Top5 acc: (Step: {step}, {num} generated images): {Top5}".format(
861
eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top5=top5))
862
metric_dict.update({"IS": kl_score, "Top1_acc": top1, "Top5_acc": top5})
863
if writing:
864
wandb.log({"IS score": kl_score}, step=self.wandb_step)
865
if is_acc:
866
wandb.log({"{eval_model} Top1 acc".format(eval_model=self.RUN.eval_backbone): top1}, step=self.wandb_step)
867
wandb.log({"{eval_model} Top5 acc".format(eval_model=self.RUN.eval_backbone): top5}, step=self.wandb_step)
868
869
if "fid" in metrics:
870
fid_score, m1, c1 = fid.calculate_fid(data_loader=self.eval_dataloader,
871
eval_model=self.eval_model,
872
num_generate=self.num_eval[self.RUN.ref_dataset],
873
cfgs=self.cfgs,
874
pre_cal_mean=self.mu,
875
pre_cal_std=self.sigma,
876
fake_feats=fake_feats,
877
disable_tqdm=self.global_rank != 0)
878
if self.global_rank == 0:
879
self.logger.info("FID score (Step: {step}, Using {type} moments): {FID}".format(
880
step=step, type=self.RUN.ref_dataset, FID=fid_score))
881
if self.best_fid is None or fid_score <= self.best_fid:
882
self.best_fid, self.best_step, is_best = fid_score, step, True
883
metric_dict.update({"FID": fid_score})
884
if writing:
885
wandb.log({"FID score": fid_score}, step=self.wandb_step)
886
if training:
887
self.logger.info("Best FID score (Step: {step}, Using {type} moments): {FID}".format(
888
step=self.best_step, type=self.RUN.ref_dataset, FID=self.best_fid))
889
890
if "prdc" in metrics:
891
prc, rec, dns, cvg = prdc.calculate_pr_dc(real_feats=self.real_feats,
892
fake_feats=fake_feats,
893
data_loader=self.eval_dataloader,
894
eval_model=self.eval_model,
895
num_generate=self.num_eval[self.RUN.ref_dataset],
896
cfgs=self.cfgs,
897
quantize=True,
898
nearest_k=nearest_k,
899
world_size=self.OPTIMIZATION.world_size,
900
DDP=self.DDP,
901
disable_tqdm=True)
902
if self.global_rank == 0:
903
self.logger.info("Improved Precision (Step: {step}, Using {type} images): {prc}".format(
904
step=step, type=self.RUN.ref_dataset, prc=prc))
905
self.logger.info("Improved Recall (Step: {step}, Using {type} images): {rec}".format(
906
step=step, type=self.RUN.ref_dataset, rec=rec))
907
self.logger.info("Density (Step: {step}, Using {type} images): {dns}".format(
908
step=step, type=self.RUN.ref_dataset, dns=dns))
909
self.logger.info("Coverage (Step: {step}, Using {type} images): {cvg}".format(
910
step=step, type=self.RUN.ref_dataset, cvg=cvg))
911
metric_dict.update({"Improved_Precision": prc, "Improved_Recall": rec, "Density": dns, "Coverage": cvg})
912
if writing:
913
wandb.log({"Improved Precision": prc}, step=self.wandb_step)
914
wandb.log({"Improved Recall": rec}, step=self.wandb_step)
915
wandb.log({"Density": dns}, step=self.wandb_step)
916
wandb.log({"Coverage": cvg}, step=self.wandb_step)
917
918
if self.global_rank == 0:
919
if training:
920
save_dict = misc.accm_values_convert_dict(list_dict=self.metric_dict_during_train,
921
value_dict=metric_dict,
922
step=step,
923
interval=self.RUN.save_freq)
924
else:
925
save_dict = misc.accm_values_convert_dict(list_dict=self.metric_dict_during_final_eval,
926
value_dict=metric_dict,
927
step=None,
928
interval=None)
929
930
misc.save_dict_npy(directory=join(self.RUN.save_dir, "statistics", self.run_name, "train" if training else "eval"),
931
name="metrics",
932
dictionary=save_dict)
933
934
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
935
return is_best
936
937
# -----------------------------------------------------------------------------
938
# save the trained generator, generator_ema, and discriminator.
939
# -----------------------------------------------------------------------------
940
def save(self, step, is_best):
941
when = "best" if is_best is True else "current"
942
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
943
Gen, Gen_ema, Dis = misc.peel_models(self.Gen, self.Gen_ema, self.Dis)
944
945
g_states = {"state_dict": Gen.state_dict(), "optimizer": self.OPTIMIZATION.g_optimizer.state_dict()}
946
947
d_states = {
948
"state_dict": Dis.state_dict(),
949
"optimizer": self.OPTIMIZATION.d_optimizer.state_dict(),
950
"seed": self.RUN.seed,
951
"run_name": self.run_name,
952
"step": step,
953
"epoch": self.epoch_counter,
954
"topk": self.topk,
955
"aa_p": self.aa_p,
956
"best_step": self.best_step,
957
"best_fid": self.best_fid,
958
"best_fid_ckpt": self.RUN.ckpt_dir,
959
"lecam_emas": self.lecam_ema.__dict__,
960
}
961
962
if self.Gen_ema is not None:
963
g_ema_states = {"state_dict": Gen_ema.state_dict()}
964
965
misc.save_model(model="G", when=when, step=step, ckpt_dir=self.RUN.ckpt_dir, states=g_states)
966
misc.save_model(model="D", when=when, step=step, ckpt_dir=self.RUN.ckpt_dir, states=d_states)
967
if self.Gen_ema is not None:
968
misc.save_model(model="G_ema", when=when, step=step, ckpt_dir=self.RUN.ckpt_dir, states=g_ema_states)
969
970
if when == "best":
971
misc.save_model(model="G", when="current", step=step, ckpt_dir=self.RUN.ckpt_dir, states=g_states)
972
misc.save_model(model="D", when="current", step=step, ckpt_dir=self.RUN.ckpt_dir, states=d_states)
973
if self.Gen_ema is not None:
974
misc.save_model(model="G_ema",
975
when="current",
976
step=step,
977
ckpt_dir=self.RUN.ckpt_dir,
978
states=g_ema_states)
979
980
if self.global_rank == 0 and self.logger:
981
self.logger.info("Save model to {}".format(self.RUN.ckpt_dir))
982
983
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
984
985
# -----------------------------------------------------------------------------
986
# save real images to measure metrics for evaluation.
987
# -----------------------------------------------------------------------------
988
def save_real_images(self):
989
if self.global_rank == 0:
990
self.logger.info("save {num_images} real images in png format.".format(
991
num_images=len(self.eval_dataloader.dataset)))
992
993
misc.save_images_png(data_loader=self.eval_dataloader,
994
generator="N/A",
995
discriminator="N/A",
996
is_generate=False,
997
num_images=len(self.eval_dataloader.dataset),
998
y_sampler="N/A",
999
batch_size=self.OPTIMIZATION.batch_size,
1000
z_prior="N/A",
1001
truncation_factor="N/A",
1002
z_dim="N/A",
1003
num_classes=self.DATA.num_classes,
1004
LOSS=self.LOSS,
1005
OPTIMIZATION=self.OPTIMIZATION,
1006
RUN=self.RUN,
1007
MODEL=self.MODEL,
1008
is_stylegan=False,
1009
generator_mapping="N/A",
1010
generator_synthesis="N/A",
1011
directory=join(self.RUN.save_dir, "samples", self.run_name),
1012
device=self.local_rank)
1013
1014
# -----------------------------------------------------------------------------
1015
# save fake images to measure metrics for evaluation.
1016
# -----------------------------------------------------------------------------
1017
def save_fake_images(self, num_images):
1018
if self.global_rank == 0:
1019
self.logger.info("save {num_images} generated images in png format.".format(num_images=self.num_eval[self.RUN.ref_dataset]))
1020
if self.gen_ctlr.standing_statistics:
1021
self.gen_ctlr.std_stat_counter += 1
1022
1023
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1024
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1025
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1026
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1027
1028
misc.save_images_png(data_loader=self.eval_dataloader,
1029
generator=generator,
1030
discriminator=self.Dis,
1031
is_generate=True,
1032
num_images=num_images,
1033
y_sampler="totally_random",
1034
batch_size=self.OPTIMIZATION.batch_size,
1035
z_prior=self.MODEL.z_prior,
1036
truncation_factor=self.RUN.truncation_factor,
1037
z_dim=self.MODEL.z_dim,
1038
num_classes=self.DATA.num_classes,
1039
LOSS=self.LOSS,
1040
OPTIMIZATION=self.OPTIMIZATION,
1041
RUN=self.RUN,
1042
MODEL=self.MODEL,
1043
is_stylegan=self.is_stylegan,
1044
generator_mapping=generator_mapping,
1045
generator_synthesis=generator_synthesis,
1046
directory=join(self.RUN.save_dir, "samples", self.run_name),
1047
device=self.local_rank)
1048
1049
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1050
1051
# -----------------------------------------------------------------------------
1052
# run k-nearest neighbor analysis to identify whether GAN memorizes the training images or not.
1053
# -----------------------------------------------------------------------------
1054
def run_k_nearest_neighbor(self, dataset, num_rows, num_cols):
1055
if self.global_rank == 0:
1056
self.logger.info("Run K-nearest neighbor analysis using fake and {ref} dataset.".format(ref=self.RUN.ref_dataset))
1057
if self.gen_ctlr.standing_statistics: self.gen_ctlr.std_stat_counter += 1
1058
1059
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1060
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1061
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1062
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1063
1064
res, mean, std = 224, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
1065
resizer = resize.build_resizer(resizer=self.RUN.post_resizer,
1066
backbone="ResNet50_torch",
1067
size=res)
1068
totensor = transforms.ToTensor()
1069
mean = torch.Tensor(mean).view(1, 3, 1, 1).to("cuda")
1070
std = torch.Tensor(std).view(1, 3, 1, 1).to("cuda")
1071
1072
resnet50_model = torch.hub.load("pytorch/vision:v0.6.0", "resnet50", pretrained=True)
1073
resnet50_conv = nn.Sequential(*list(resnet50_model.children())[:-1]).to(self.local_rank)
1074
if self.OPTIMIZATION.world_size > 1:
1075
resnet50_conv = DataParallel(resnet50_conv, output_device=self.local_rank)
1076
resnet50_conv.eval()
1077
1078
for c in tqdm(range(self.DATA.num_classes)):
1079
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1080
truncation_factor=self.RUN.truncation_factor,
1081
batch_size=self.OPTIMIZATION.batch_size,
1082
z_dim=self.MODEL.z_dim,
1083
num_classes=self.DATA.num_classes,
1084
y_sampler=c,
1085
radius="N/A",
1086
generator=generator,
1087
discriminator=self.Dis,
1088
is_train=False,
1089
LOSS=self.LOSS,
1090
RUN=self.RUN,
1091
MODEL=self.MODEL,
1092
device=self.local_rank,
1093
is_stylegan=self.is_stylegan,
1094
generator_mapping=generator_mapping,
1095
generator_synthesis=generator_synthesis,
1096
style_mixing_p=0.0,
1097
stylegan_update_emas=False,
1098
cal_trsp_cost=False)
1099
fake_anchor = torch.unsqueeze(fake_images[0], dim=0)
1100
fake_anchor = ops.quantize_images(fake_anchor)
1101
fake_anchor = ops.resize_images(fake_anchor, resizer, totensor, mean, std, self.local_rank)
1102
fake_anchor_embed = torch.squeeze(resnet50_conv(fake_anchor))
1103
1104
num_samples, target_sampler = sample.make_target_cls_sampler(dataset=dataset, target_class=c)
1105
batch_size = self.OPTIMIZATION.batch_size if num_samples >= self.OPTIMIZATION.batch_size else num_samples
1106
c_dataloader = torch.utils.data.DataLoader(dataset=dataset,
1107
batch_size=batch_size,
1108
shuffle=False,
1109
sampler=target_sampler,
1110
num_workers=self.RUN.num_workers,
1111
pin_memory=True)
1112
c_iter = iter(c_dataloader)
1113
for batch_idx in range(num_samples//batch_size):
1114
real_images, real_labels = next(c_iter)
1115
real_images = ops.quantize_images(real_images)
1116
real_images = ops.resize_images(real_images, resizer, totensor, mean, std, self.local_rank)
1117
real_embed = torch.squeeze(resnet50_conv(real_images))
1118
if batch_idx == 0:
1119
distances = torch.square(real_embed - fake_anchor_embed).mean(dim=1).detach().cpu().numpy()
1120
image_holder = real_images.detach().cpu().numpy()
1121
else:
1122
distances = np.concatenate([
1123
distances,
1124
torch.square(real_embed - fake_anchor_embed).mean(dim=1).detach().cpu().numpy()
1125
],
1126
axis=0)
1127
image_holder = np.concatenate([image_holder, real_images.detach().cpu().numpy()], axis=0)
1128
1129
nearest_indices = (-distances).argsort()[-(num_cols - 1):][::-1]
1130
if c % num_rows == 0:
1131
canvas = np.concatenate([fake_anchor.detach().cpu().numpy(), image_holder[nearest_indices]], axis=0)
1132
elif c % num_rows == num_rows - 1:
1133
row_images = np.concatenate([fake_anchor.detach().cpu().numpy(), image_holder[nearest_indices]], axis=0)
1134
canvas = np.concatenate((canvas, row_images), axis=0)
1135
misc.plot_img_canvas(images=torch.from_numpy(canvas),
1136
save_path=join(self.RUN.save_dir, "figures/{run_name}/fake_anchor_{num_cols}NN_{cls}_classes.png".\
1137
format(run_name=self.run_name, num_cols=num_cols, cls=c+1)),
1138
num_cols=num_cols,
1139
logger=self.logger,
1140
logging=self.global_rank == 0 and self.logger)
1141
else:
1142
row_images = np.concatenate([fake_anchor.detach().cpu().numpy(), image_holder[nearest_indices]], axis=0)
1143
canvas = np.concatenate((canvas, row_images), axis=0)
1144
1145
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1146
1147
# -----------------------------------------------------------------------------
1148
# conduct latent interpolation analysis to identify the quaility of latent space (Z)
1149
# -----------------------------------------------------------------------------
1150
def run_linear_interpolation(self, num_rows, num_cols, fix_z, fix_y, num_saves=100):
1151
assert int(fix_z) * int(fix_y) != 1, "unable to switch fix_z and fix_y on together!"
1152
if self.global_rank == 0:
1153
flag = "fix_z" if fix_z else "fix_y"
1154
self.logger.info("Run linear interpolation analysis ({flag}) {num} times.".format(flag=flag, num=num_saves))
1155
if self.gen_ctlr.standing_statistics:
1156
self.gen_ctlr.std_stat_counter += 1
1157
1158
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1159
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1160
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1161
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1162
1163
shared = misc.peel_model(generator).shared
1164
for ns in tqdm(range(num_saves)):
1165
if fix_z:
1166
zs = torch.randn(num_rows, 1, self.MODEL.z_dim, device=self.local_rank)
1167
zs = zs.repeat(1, num_cols, 1).view(-1, self.MODEL.z_dim)
1168
name = "fix_z"
1169
else:
1170
zs = misc.interpolate(torch.randn(num_rows, 1, self.MODEL.z_dim, device=self.local_rank),
1171
torch.randn(num_rows, 1, self.MODEL.z_dim, device=self.local_rank),
1172
num_cols - 2).view(-1, self.MODEL.z_dim)
1173
1174
if fix_y:
1175
ys = sample.sample_onehot(batch_size=num_rows,
1176
num_classes=self.DATA.num_classes,
1177
device=self.local_rank)
1178
ys = shared(ys).view(num_rows, 1, -1)
1179
ys = ys.repeat(1, num_cols, 1).view(num_rows * (num_cols), -1)
1180
name = "fix_y"
1181
else:
1182
ys = misc.interpolate(
1183
shared(sample.sample_onehot(num_rows, self.DATA.num_classes)).view(num_rows, 1, -1),
1184
shared(sample.sample_onehot(num_rows, self.DATA.num_classes)).view(num_rows, 1, -1),
1185
num_cols - 2).view(num_rows * (num_cols), -1)
1186
1187
interpolated_images = generator(zs, None, shared_label=ys)
1188
1189
misc.plot_img_canvas(images=interpolated_images.detach().cpu(),
1190
save_path=join(self.RUN.save_dir, "figures/{run_name}/{num}_Interpolated_images_{fix_flag}.png".\
1191
format(num=ns, run_name=self.run_name, fix_flag=name)),
1192
num_cols=num_cols,
1193
logger=self.logger,
1194
logging=False)
1195
1196
if self.global_rank == 0 and self.logger:
1197
print("Save figures to {}/*_Interpolated_images_{}.png".format(
1198
join(self.RUN.save_dir, "figures", self.run_name), flag))
1199
1200
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1201
1202
# -----------------------------------------------------------------------------
1203
# visualize shifted fourier spectrums of real and fake images
1204
# -----------------------------------------------------------------------------
1205
def run_frequency_analysis(self, dataloader):
1206
if self.global_rank == 0:
1207
self.logger.info("Run frequency analysis (use {num} fake and {ref} images).".\
1208
format(num=len(dataloader), ref=self.RUN.ref_dataset))
1209
if self.gen_ctlr.standing_statistics:
1210
self.gen_ctlr.std_stat_counter += 1
1211
1212
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1213
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1214
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1215
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1216
1217
data_iter = iter(dataloader)
1218
num_batches = len(dataloader) // self.OPTIMIZATION.batch_size
1219
for i in range(num_batches):
1220
real_images, real_labels = next(data_iter)
1221
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1222
truncation_factor=self.RUN.truncation_factor,
1223
batch_size=self.OPTIMIZATION.batch_size,
1224
z_dim=self.MODEL.z_dim,
1225
num_classes=self.DATA.num_classes,
1226
y_sampler="totally_random",
1227
radius="N/A",
1228
generator=generator,
1229
discriminator=self.Dis,
1230
is_train=False,
1231
LOSS=self.LOSS,
1232
RUN=self.RUN,
1233
MODEL=self.MODEL,
1234
device=self.local_rank,
1235
is_stylegan=self.is_stylegan,
1236
generator_mapping=generator_mapping,
1237
generator_synthesis=generator_synthesis,
1238
style_mixing_p=0.0,
1239
stylegan_update_emas=False,
1240
cal_trsp_cost=False)
1241
fake_images = fake_images.detach().cpu().numpy()
1242
1243
real_images = np.asarray((real_images + 1) * 127.5, np.uint8)
1244
fake_images = np.asarray((fake_images + 1) * 127.5, np.uint8)
1245
1246
if i == 0:
1247
real_array = real_images
1248
fake_array = fake_images
1249
else:
1250
real_array = np.concatenate([real_array, real_images], axis=0)
1251
fake_array = np.concatenate([fake_array, fake_images], axis=0)
1252
1253
N, C, H, W = np.shape(real_array)
1254
real_r, real_g, real_b = real_array[:, 0, :, :], real_array[:, 1, :, :], real_array[:, 2, :, :]
1255
real_gray = 0.2989 * real_r + 0.5870 * real_g + 0.1140 * real_b
1256
fake_r, fake_g, fake_b = fake_array[:, 0, :, :], fake_array[:, 1, :, :], fake_array[:, 2, :, :]
1257
fake_gray = 0.2989 * fake_r + 0.5870 * fake_g + 0.1140 * fake_b
1258
for j in tqdm(range(N)):
1259
real_gray_f = np.fft.fft2(real_gray[j] - ndimage.median_filter(real_gray[j], size=H // 8))
1260
fake_gray_f = np.fft.fft2(fake_gray[j] - ndimage.median_filter(fake_gray[j], size=H // 8))
1261
1262
real_gray_f_shifted = np.fft.fftshift(real_gray_f)
1263
fake_gray_f_shifted = np.fft.fftshift(fake_gray_f)
1264
1265
if j == 0:
1266
real_gray_spectrum = 20 * np.log(np.abs(real_gray_f_shifted)) / N
1267
fake_gray_spectrum = 20 * np.log(np.abs(fake_gray_f_shifted)) / N
1268
else:
1269
real_gray_spectrum += 20 * np.log(np.abs(real_gray_f_shifted)) / N
1270
fake_gray_spectrum += 20 * np.log(np.abs(fake_gray_f_shifted)) / N
1271
1272
misc.plot_spectrum_image(real_spectrum=real_gray_spectrum,
1273
fake_spectrum=fake_gray_spectrum,
1274
directory=join(self.RUN.save_dir, "figures", self.run_name),
1275
logger=self.logger,
1276
logging=self.global_rank == 0 and self.logger)
1277
1278
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1279
1280
# -----------------------------------------------------------------------------
1281
# visualize discriminator's embeddings of real or fake images using TSNE
1282
# -----------------------------------------------------------------------------
1283
def run_tsne(self, dataloader):
1284
if self.global_rank == 0:
1285
self.logger.info("Start TSNE analysis using randomly sampled 10 classes.")
1286
self.logger.info("Use {ref} dataset and the same amount of generated images for visualization.".format(
1287
ref=self.RUN.ref_dataset))
1288
if self.gen_ctlr.standing_statistics:
1289
self.gen_ctlr.std_stat_counter += 1
1290
1291
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1292
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1293
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1294
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1295
1296
save_output, real, fake, hook_handles = misc.SaveOutput(), {}, {}, []
1297
for name, layer in misc.peel_model(self.Dis).named_children():
1298
if name == "linear1":
1299
handle = layer.register_forward_pre_hook(save_output)
1300
hook_handles.append(handle)
1301
1302
tsne_iter = iter(dataloader)
1303
num_batches = len(dataloader.dataset) // self.OPTIMIZATION.batch_size
1304
for i in range(num_batches):
1305
real_images, real_labels = next(tsne_iter)
1306
real_images, real_labels = real_images.to(self.local_rank), real_labels.to(self.local_rank)
1307
1308
real_dict = self.Dis(real_images, real_labels)
1309
if i == 0:
1310
real["embeds"] = save_output.outputs[0][0].detach().cpu().numpy()
1311
real["labels"] = real_labels.detach().cpu().numpy()
1312
else:
1313
real["embeds"] = np.concatenate([real["embeds"], save_output.outputs[0][0].cpu().detach().numpy()],
1314
axis=0)
1315
real["labels"] = np.concatenate([real["labels"], real_labels.detach().cpu().numpy()])
1316
1317
save_output.clear()
1318
1319
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1320
truncation_factor=self.RUN.truncation_factor,
1321
batch_size=self.OPTIMIZATION.batch_size,
1322
z_dim=self.MODEL.z_dim,
1323
num_classes=self.DATA.num_classes,
1324
y_sampler="totally_random",
1325
radius="N/A",
1326
generator=generator,
1327
discriminator=self.Dis,
1328
is_train=False,
1329
LOSS=self.LOSS,
1330
RUN=self.RUN,
1331
MODEL=self.MODEL,
1332
device=self.local_rank,
1333
is_stylegan=self.is_stylegan,
1334
generator_mapping=generator_mapping,
1335
generator_synthesis=generator_synthesis,
1336
style_mixing_p=0.0,
1337
stylegan_update_emas=False,
1338
cal_trsp_cost=False)
1339
1340
fake_dict = self.Dis(fake_images, fake_labels)
1341
if i == 0:
1342
fake["embeds"] = save_output.outputs[0][0].detach().cpu().numpy()
1343
fake["labels"] = fake_labels.detach().cpu().numpy()
1344
else:
1345
fake["embeds"] = np.concatenate([fake["embeds"], save_output.outputs[0][0].cpu().detach().numpy()],
1346
axis=0)
1347
fake["labels"] = np.concatenate([fake["labels"], fake_labels.detach().cpu().numpy()])
1348
1349
save_output.clear()
1350
1351
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
1352
if self.DATA.num_classes > 10:
1353
cls_indices = np.random.permutation(self.DATA.num_classes)[:10]
1354
real["embeds"] = real["embeds"][np.isin(real["labels"], cls_indices)]
1355
real["labels"] = real["labels"][np.isin(real["labels"], cls_indices)]
1356
fake["embeds"] = fake["embeds"][np.isin(fake["labels"], cls_indices)]
1357
fake["labels"] = fake["labels"][np.isin(fake["labels"], cls_indices)]
1358
1359
real_tsne_results = tsne.fit_transform(real["embeds"])
1360
misc.plot_tsne_scatter_plot(df=real,
1361
tsne_results=real_tsne_results,
1362
flag="real",
1363
directory=join(self.RUN.save_dir, "figures", self.run_name),
1364
logger=self.logger,
1365
logging=self.global_rank == 0 and self.logger)
1366
1367
fake_tsne_results = tsne.fit_transform(fake["embeds"])
1368
misc.plot_tsne_scatter_plot(df=fake,
1369
tsne_results=fake_tsne_results,
1370
flag="fake",
1371
directory=join(self.RUN.save_dir, "figures", self.run_name),
1372
logger=self.logger,
1373
logging=self.global_rank == 0 and self.logger)
1374
1375
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1376
1377
# -----------------------------------------------------------------------------
1378
# calculate intra-class FID (iFID) to identify intra-class diversity
1379
# -----------------------------------------------------------------------------
1380
def calculate_intra_class_fid(self, dataset):
1381
if self.global_rank == 0:
1382
self.logger.info("Start calculating iFID (use approx. {num} fake images per class and train images as the reference).".\
1383
format(num=int(len(dataset)/self.DATA.num_classes)))
1384
1385
if self.gen_ctlr.standing_statistics:
1386
self.gen_ctlr.std_stat_counter += 1
1387
1388
fids = []
1389
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1390
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1391
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1392
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1393
1394
for c in tqdm(range(self.DATA.num_classes)):
1395
num_samples, target_sampler = sample.make_target_cls_sampler(dataset, c)
1396
batch_size = self.OPTIMIZATION.batch_size if num_samples >= self.OPTIMIZATION.batch_size else num_samples
1397
dataloader = torch.utils.data.DataLoader(dataset,
1398
batch_size=batch_size,
1399
shuffle=False,
1400
sampler=target_sampler,
1401
num_workers=self.RUN.num_workers,
1402
pin_memory=True,
1403
drop_last=False)
1404
1405
mu, sigma = fid.calculate_moments(data_loader=dataloader,
1406
eval_model=self.eval_model,
1407
num_generate="N/A",
1408
batch_size=batch_size,
1409
quantize=True,
1410
world_size=self.OPTIMIZATION.world_size,
1411
DDP=self.DDP,
1412
disable_tqdm=True,
1413
fake_feats=None)
1414
1415
c_fake_feats, _,_ = features.generate_images_and_stack_features(
1416
generator=generator,
1417
discriminator=self.Dis,
1418
eval_model=self.eval_model,
1419
num_generate=num_samples,
1420
y_sampler=c,
1421
batch_size=self.OPTIMIZATION.batch_size,
1422
z_prior=self.MODEL.z_prior,
1423
truncation_factor=self.RUN.truncation_factor,
1424
z_dim=self.MODEL.z_dim,
1425
num_classes=self.DATA.num_classes,
1426
LOSS=self.LOSS,
1427
RUN=self.RUN,
1428
MODEL=self.MODEL,
1429
is_stylegan=self.is_stylegan,
1430
generator_mapping=generator_mapping,
1431
generator_synthesis=generator_synthesis,
1432
quantize=True,
1433
world_size=self.OPTIMIZATION.world_size,
1434
DDP=self.DDP,
1435
device=self.local_rank,
1436
logger=self.logger,
1437
disable_tqdm=True)
1438
1439
ifid_score, _, _ = fid.calculate_fid(data_loader="N/A",
1440
eval_model=self.eval_model,
1441
num_generate=num_samples,
1442
cfgs=self.cfgs,
1443
pre_cal_mean=mu,
1444
pre_cal_std=sigma,
1445
quantize=False,
1446
fake_feats=c_fake_feats,
1447
disable_tqdm=True)
1448
1449
fids.append(ifid_score)
1450
1451
# save iFID values in .npz format
1452
metric_dict = {"iFID": ifid_score}
1453
1454
save_dict = misc.accm_values_convert_dict(list_dict={"iFID": []},
1455
value_dict=metric_dict,
1456
step=c,
1457
interval=1)
1458
misc.save_dict_npy(directory=join(self.RUN.save_dir, "statistics", self.run_name),
1459
name="iFID",
1460
dictionary=save_dict)
1461
1462
if self.global_rank == 0 and self.logger:
1463
self.logger.info("Average iFID score: {iFID}".format(iFID=sum(fids, 0.0) / len(fids)))
1464
1465
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1466
1467
# -----------------------------------------------------------------------------
1468
# perform semantic (closed-form) factorization for latent nevigation
1469
# -----------------------------------------------------------------------------
1470
def run_semantic_factorization(self, num_rows, num_cols, maximum_variations):
1471
if self.global_rank == 0:
1472
self.logger.info("Perform semantic factorization for latent nevigation.")
1473
1474
if self.gen_ctlr.standing_statistics:
1475
self.gen_ctlr.std_stat_counter += 1
1476
1477
requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
1478
with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
1479
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1480
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1481
1482
zs, fake_labels, _ = sample.sample_zy(z_prior=self.MODEL.z_prior,
1483
batch_size=self.OPTIMIZATION.batch_size,
1484
z_dim=self.MODEL.z_dim,
1485
num_classes=self.DATA.num_classes,
1486
truncation_factor=self.RUN.truncation_factor,
1487
y_sampler="totally_random",
1488
radius="N/A",
1489
device=self.local_rank)
1490
1491
for i in tqdm(range(self.OPTIMIZATION.batch_size)):
1492
images_canvas = sefa.apply_sefa(generator=generator,
1493
backbone=self.MODEL.backbone,
1494
z=zs[i],
1495
fake_label=fake_labels[i],
1496
num_semantic_axis=num_rows,
1497
maximum_variations=maximum_variations,
1498
num_cols=num_cols)
1499
1500
misc.plot_img_canvas(images=images_canvas.detach().cpu(),
1501
save_path=join(self.RUN.save_dir, "figures/{run_name}/{idx}_sefa_images.png".\
1502
format(idx=i, run_name=self.run_name)),
1503
num_cols=num_cols,
1504
logger=self.logger,
1505
logging=False)
1506
1507
if self.global_rank == 0 and self.logger:
1508
print("Save figures to {}/*_sefa_images.png".format(join(self.RUN.save_dir, "figures", self.run_name)))
1509
1510
misc.make_GAN_trainable(self.Gen, self.Gen_ema, self.Dis)
1511
1512
# -----------------------------------------------------------------------------
1513
# compute classifier accuracy score (CAS) to identify class-conditional precision and recall
1514
# -----------------------------------------------------------------------------
1515
def compute_GAN_train_or_test_classifier_accuracy_score(self, GAN_train=False, GAN_test=False):
1516
assert GAN_train*GAN_test == 0, "cannot conduct GAN_train and GAN_test togather."
1517
if self.global_rank == 0:
1518
if GAN_train:
1519
phase, metric = "train", "recall"
1520
else:
1521
phase, metric = "test", "precision"
1522
self.logger.info("compute GAN_{phase} Classifier Accuracy Score (CAS) to identify class-conditional {metric}.". \
1523
format(phase=phase, metric=metric))
1524
1525
if self.gen_ctlr.standing_statistics:
1526
self.gen_ctlr.std_stat_counter += 1
1527
1528
misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
1529
generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
1530
1531
best_top1, best_top5, cas_setting = 0.0, 0.0, self.MISC.cas_setting[self.DATA.name]
1532
model = resnet.ResNet(dataset=self.DATA.name,
1533
depth=cas_setting["depth"],
1534
num_classes=self.DATA.num_classes,
1535
bottleneck=cas_setting["bottleneck"]).to("cuda")
1536
1537
optimizer = torch.optim.SGD(params=model.parameters(),
1538
lr=cas_setting["lr"],
1539
momentum=cas_setting["momentum"],
1540
weight_decay=cas_setting["weight_decay"],
1541
nesterov=True)
1542
1543
if self.OPTIMIZATION.world_size > 1:
1544
model = DataParallel(model, output_device=self.local_rank)
1545
1546
epoch_trained = 0
1547
if self.RUN.ckpt_dir is not None and self.RUN.resume_classifier_train:
1548
is_pre_trained_model, mode = ckpt.check_is_pre_trained_model(ckpt_dir=self.RUN.ckpt_dir,
1549
GAN_train=GAN_train,
1550
GAN_test=GAN_test)
1551
if is_pre_trained_model:
1552
epoch_trained, best_top1, best_top5, best_epoch = ckpt.load_GAN_train_test_model(model=model,
1553
mode=mode,
1554
optimizer=optimizer,
1555
RUN=self.RUN)
1556
1557
for current_epoch in tqdm(range(epoch_trained, cas_setting["epochs"])):
1558
model.train()
1559
optimizer.zero_grad()
1560
ops.adjust_learning_rate(optimizer=optimizer,
1561
lr_org=cas_setting["lr"],
1562
epoch=current_epoch,
1563
total_epoch=cas_setting["epochs"],
1564
dataset=self.DATA.name)
1565
1566
train_top1_acc, train_top5_acc, train_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter()
1567
for i, (images, labels) in enumerate(self.train_dataloader):
1568
if GAN_train:
1569
images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1570
truncation_factor=self.RUN.truncation_factor,
1571
batch_size=self.OPTIMIZATION.batch_size,
1572
z_dim=self.MODEL.z_dim,
1573
num_classes=self.DATA.num_classes,
1574
y_sampler="totally_random",
1575
radius="N/A",
1576
generator=generator,
1577
discriminator=self.Dis,
1578
is_train=False,
1579
LOSS=self.LOSS,
1580
RUN=self.RUN,
1581
MODEL=self.MODEL,
1582
device=self.local_rank,
1583
is_stylegan=self.is_stylegan,
1584
generator_mapping=generator_mapping,
1585
generator_synthesis=generator_synthesis,
1586
style_mixing_p=0.0,
1587
stylegan_update_emas=False,
1588
cal_trsp_cost=False)
1589
else:
1590
images, labels = images.to(self.local_rank), labels.to(self.local_rank)
1591
1592
logits = model(images)
1593
ce_loss = self.ce_loss(logits, labels)
1594
1595
train_acc1, train_acc5 = misc.accuracy(logits.data, labels, topk=(1, 5))
1596
1597
train_loss.update(ce_loss.item(), images.size(0))
1598
train_top1_acc.update(train_acc1.item(), images.size(0))
1599
train_top5_acc.update(train_acc5.item(), images.size(0))
1600
1601
ce_loss.backward()
1602
optimizer.step()
1603
1604
valid_acc1, valid_acc5, valid_loss = self.validate_classifier(model=model,
1605
generator=generator,
1606
generator_mapping=generator_mapping,
1607
generator_synthesis=generator_synthesis,
1608
epoch=current_epoch,
1609
GAN_test=GAN_test,
1610
setting=cas_setting)
1611
1612
is_best = valid_acc1 > best_top1
1613
best_top1 = max(valid_acc1, best_top1)
1614
if is_best:
1615
best_top5, best_epoch = valid_acc5, current_epoch
1616
model_ = misc.peel_model(model)
1617
states = {"state_dict": model_.state_dict(), "optimizer": optimizer.state_dict(), "epoch": current_epoch+1,
1618
"best_top1": best_top1, "best_top5": best_top5, "best_epoch": best_epoch}
1619
misc.save_model_c(states, mode, self.RUN)
1620
1621
if self.local_rank == 0:
1622
self.logger.info("Current best accuracy: Top-1: {top1:.4f}% and Top-5 {top5:.4f}%".format(top1=best_top1, top5=best_top5))
1623
self.logger.info("Save model to {}".format(self.RUN.ckpt_dir))
1624
1625
# -----------------------------------------------------------------------------
1626
# validate GAN_train or GAN_test classifier using generated or training dataset
1627
# -----------------------------------------------------------------------------
1628
def validate_classifier(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting):
1629
model.eval()
1630
valid_top1_acc, valid_top5_acc, valid_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter()
1631
for i, (images, labels) in enumerate(self.train_dataloader):
1632
if GAN_test:
1633
images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1634
truncation_factor=self.RUN.truncation_factor,
1635
batch_size=self.OPTIMIZATION.batch_size,
1636
z_dim=self.MODEL.z_dim,
1637
num_classes=self.DATA.num_classes,
1638
y_sampler="totally_random",
1639
radius="N/A",
1640
generator=generator,
1641
discriminator=self.Dis,
1642
is_train=False,
1643
LOSS=self.LOSS,
1644
RUN=self.RUN,
1645
MODEL=self.MODEL,
1646
device=self.local_rank,
1647
is_stylegan=self.is_stylegan,
1648
generator_mapping=generator_mapping,
1649
generator_synthesis=generator_synthesis,
1650
style_mixing_p=0.0,
1651
stylegan_update_emas=False,
1652
cal_trsp_cost=False)
1653
else:
1654
images, labels = images.to(self.local_rank), labels.to(self.local_rank)
1655
1656
output = model(images)
1657
ce_loss = self.ce_loss(output, labels)
1658
1659
valid_acc1, valid_acc5 = misc.accuracy(output.data, labels, topk=(1, 5))
1660
1661
valid_loss.update(ce_loss.item(), images.size(0))
1662
valid_top1_acc.update(valid_acc1.item(), images.size(0))
1663
valid_top5_acc.update(valid_acc5.item(), images.size(0))
1664
1665
if self.local_rank == 0:
1666
self.logger.info("Top 1-acc {top1.val:.4f} ({top1.avg:.4f})\t"
1667
"Top 5-acc {top5.val:.4f} ({top5.avg:.4f})".format(top1=valid_top1_acc, top5=valid_top5_acc))
1668
return valid_top1_acc.avg, valid_top5_acc.avg, valid_loss.avg
1669
1670