Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/loader.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/loader.py
6
7
from os.path import dirname, abspath, exists, join
8
import sys
9
import glob
10
import json
11
import os
12
import random
13
import warnings
14
15
from torch.backends import cudnn
16
from torch.utils.data import DataLoader
17
from torch.nn import DataParallel
18
from torch.nn.parallel import DistributedDataParallel as DDP
19
from torch.utils.data.distributed import DistributedSampler
20
import torch
21
import torch.distributed as dist
22
import wandb
23
24
from data_util import Dataset_
25
from utils.style_ops import grid_sample_gradfix
26
from utils.style_ops import conv2d_gradfix
27
from metrics.inception_net import InceptionV3
28
from sync_batchnorm.batchnorm import convert_model
29
from worker import WORKER
30
import utils.log as log
31
import utils.losses as losses
32
import utils.ckpt as ckpt
33
import utils.misc as misc
34
import utils.custom_ops as custom_ops
35
import models.model as model
36
import metrics.preparation as pp
37
38
39
def load_worker(local_rank, cfgs, gpus_per_node, run_name, hdf5_path):
40
# -----------------------------------------------------------------------------
41
# define default variables for loading ckpt or evaluating the trained GAN model.
42
# -----------------------------------------------------------------------------
43
load_train_dataset = cfgs.RUN.train + cfgs.RUN.GAN_train + cfgs.RUN.GAN_test
44
len_eval_metrics = 0 if cfgs.RUN.eval_metrics == ["none"] else len(cfgs.RUN.eval_metrics)
45
load_eval_dataset = len_eval_metrics + cfgs.RUN.save_real_images + cfgs.RUN.k_nearest_neighbor + \
46
cfgs.RUN.frequency_analysis + cfgs.RUN.tsne_analysis + cfgs.RUN.intra_class_fid
47
train_sampler, eval_sampler = None, None
48
step, epoch, topk, best_step, best_fid, best_ckpt_path, lecam_emas, is_best = \
49
0, 0, cfgs.OPTIMIZATION.batch_size, 0, None, None, None, False
50
mu, sigma, real_feats, eval_model, num_rows, num_cols = None, None, None, None, 10, 8
51
aa_p = cfgs.AUG.ada_initial_augment_p
52
if cfgs.AUG.ada_initial_augment_p != "N/A":
53
aa_p = cfgs.AUG.ada_initial_augment_p
54
else:
55
aa_p = cfgs.AUG.apa_initial_augment_p
56
57
loss_list_dict = {"gen_loss": [], "dis_loss": [], "cls_loss": []}
58
num_eval = {}
59
metric_dict_during_train = {}
60
if "none" in cfgs.RUN.eval_metrics:
61
cfgs.RUN.eval_metrics = []
62
if "is" in cfgs.RUN.eval_metrics:
63
metric_dict_during_train.update({"IS": [], "Top1_acc": [], "Top5_acc": []})
64
if "fid" in cfgs.RUN.eval_metrics:
65
metric_dict_during_train.update({"FID": []})
66
if "prdc" in cfgs.RUN.eval_metrics:
67
metric_dict_during_train.update({"Improved_Precision": [], "Improved_Recall": [], "Density":[], "Coverage": []})
68
69
# -----------------------------------------------------------------------------
70
# determine cuda, cudnn, and backends settings.
71
# -----------------------------------------------------------------------------
72
if cfgs.RUN.fix_seed:
73
cudnn.benchmark, cudnn.deterministic = False, True
74
else:
75
cudnn.benchmark, cudnn.deterministic = True, False
76
77
if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]:
78
# Improves training speed
79
conv2d_gradfix.enabled = True
80
# Avoids errors with the augmentation pipe
81
grid_sample_gradfix.enabled = True
82
if cfgs.RUN.mixed_precision:
83
# Allow PyTorch to internally use tf32 for matmul
84
torch.backends.cuda.matmul.allow_tf32 = False
85
# Allow PyTorch to internally use tf32 for convolutions
86
torch.backends.cudnn.allow_tf32 = False
87
88
# -----------------------------------------------------------------------------
89
# initialize all processes and fix seed of each process
90
# -----------------------------------------------------------------------------
91
if cfgs.RUN.distributed_data_parallel:
92
global_rank = cfgs.RUN.current_node * (gpus_per_node) + local_rank
93
print("Use GPU: {global_rank} for training.".format(global_rank=global_rank))
94
misc.setup(global_rank, cfgs.OPTIMIZATION.world_size, cfgs.RUN.backend)
95
torch.cuda.set_device(local_rank)
96
else:
97
global_rank = local_rank
98
99
misc.fix_seed(cfgs.RUN.seed + global_rank)
100
101
# -----------------------------------------------------------------------------
102
# Intialize python logger.
103
# -----------------------------------------------------------------------------
104
if local_rank == 0:
105
logger = log.make_logger(cfgs.RUN.save_dir, run_name, None)
106
if cfgs.RUN.ckpt_dir is not None and cfgs.RUN.freezeD == -1:
107
folder_hier = cfgs.RUN.ckpt_dir.split("/")
108
if folder_hier[-1] == "":
109
folder_hier.pop()
110
logger.info("Run name : {run_name}".format(run_name=folder_hier.pop()))
111
else:
112
logger.info("Run name : {run_name}".format(run_name=run_name))
113
for k, v in cfgs.super_cfgs.items():
114
logger.info("cfgs." + k + " =")
115
logger.info(json.dumps(vars(v), indent=2))
116
else:
117
logger = None
118
119
# -----------------------------------------------------------------------------
120
# load train and evaluation datasets.
121
# -----------------------------------------------------------------------------
122
if load_train_dataset:
123
if local_rank == 0:
124
logger.info("Load {name} train dataset for training.".format(name=cfgs.DATA.name))
125
train_dataset = Dataset_(data_name=cfgs.DATA.name,
126
data_dir=cfgs.RUN.data_dir,
127
train=True,
128
crop_long_edge=cfgs.PRE.crop_long_edge,
129
resize_size=cfgs.PRE.resize_size,
130
resizer=None if hdf5_path is not None else cfgs.RUN.pre_resizer,
131
random_flip=cfgs.PRE.apply_rflip,
132
normalize=True,
133
hdf5_path=hdf5_path,
134
load_data_in_memory=cfgs.RUN.load_data_in_memory)
135
if local_rank == 0:
136
logger.info("Train dataset size: {dataset_size}".format(dataset_size=len(train_dataset)))
137
else:
138
train_dataset = None
139
140
if load_eval_dataset:
141
if local_rank == 0:
142
logger.info("Load {name} {ref} dataset for evaluation.".format(name=cfgs.DATA.name, ref=cfgs.RUN.ref_dataset))
143
eval_dataset = Dataset_(data_name=cfgs.DATA.name,
144
data_dir=cfgs.RUN.data_dir,
145
train=True if cfgs.RUN.ref_dataset == "train" else False,
146
crop_long_edge=False if cfgs.DATA.name in cfgs.MISC.no_proc_data else True,
147
resize_size=None if cfgs.DATA.name in cfgs.MISC.no_proc_data else cfgs.DATA.img_size,
148
resizer=cfgs.RUN.pre_resizer,
149
random_flip=False,
150
hdf5_path=None,
151
normalize=True,
152
load_data_in_memory=False)
153
if local_rank == 0:
154
logger.info("Eval dataset size: {dataset_size}".format(dataset_size=len(eval_dataset)))
155
else:
156
eval_dataset = None
157
158
# -----------------------------------------------------------------------------
159
# define a distributed sampler for DDP train and evaluation.
160
# -----------------------------------------------------------------------------
161
if cfgs.RUN.distributed_data_parallel:
162
cfgs.OPTIMIZATION.batch_size = cfgs.OPTIMIZATION.batch_size//cfgs.OPTIMIZATION.world_size
163
if cfgs.RUN.train:
164
train_sampler = DistributedSampler(train_dataset,
165
num_replicas=cfgs.OPTIMIZATION.world_size,
166
rank=local_rank,
167
shuffle=True,
168
drop_last=True)
169
topk = cfgs.OPTIMIZATION.batch_size
170
171
if load_eval_dataset:
172
eval_sampler = DistributedSampler(eval_dataset,
173
num_replicas=cfgs.OPTIMIZATION.world_size,
174
rank=local_rank,
175
shuffle=False,
176
drop_last=False)
177
178
cfgs.OPTIMIZATION.basket_size = cfgs.OPTIMIZATION.batch_size*\
179
cfgs.OPTIMIZATION.acml_steps*\
180
cfgs.OPTIMIZATION.d_updates_per_step
181
182
# -----------------------------------------------------------------------------
183
# define dataloaders for train and evaluation.
184
# -----------------------------------------------------------------------------
185
if load_train_dataset:
186
train_dataloader = DataLoader(dataset=train_dataset,
187
batch_size=cfgs.OPTIMIZATION.basket_size,
188
shuffle=(train_sampler is None),
189
pin_memory=True,
190
num_workers=cfgs.RUN.num_workers,
191
sampler=train_sampler,
192
drop_last=True,
193
persistent_workers=True)
194
else:
195
train_dataloader = None
196
197
if load_eval_dataset:
198
eval_dataloader = DataLoader(dataset=eval_dataset,
199
batch_size=cfgs.OPTIMIZATION.batch_size,
200
shuffle=False,
201
pin_memory=True,
202
num_workers=cfgs.RUN.num_workers,
203
sampler=eval_sampler,
204
drop_last=False)
205
else:
206
eval_dataloader = None
207
208
# -----------------------------------------------------------------------------
209
# load a generator and a discriminator
210
# if cfgs.MODEL.apply_g_ema is True, load an exponential moving average generator (Gen_ema).
211
# -----------------------------------------------------------------------------
212
Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema =\
213
model.load_generator_discriminator(DATA=cfgs.DATA,
214
OPTIMIZATION=cfgs.OPTIMIZATION,
215
MODEL=cfgs.MODEL,
216
STYLEGAN=cfgs.STYLEGAN,
217
MODULES=cfgs.MODULES,
218
RUN=cfgs.RUN,
219
device=local_rank,
220
logger=logger)
221
222
if local_rank != 0:
223
custom_ops.verbosity = "none"
224
225
# -----------------------------------------------------------------------------
226
# define optimizers for adversarial training
227
# -----------------------------------------------------------------------------
228
cfgs.define_optimizer(Gen, Dis)
229
230
# -----------------------------------------------------------------------------
231
# load the generator and the discriminator from a checkpoint if possible
232
# -----------------------------------------------------------------------------
233
if cfgs.RUN.ckpt_dir is not None:
234
if local_rank == 0:
235
os.remove(join(cfgs.RUN.save_dir, "logs", run_name + ".log"))
236
run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, lecam_emas, logger =\
237
ckpt.load_StudioGAN_ckpts(ckpt_dir=cfgs.RUN.ckpt_dir,
238
load_best=cfgs.RUN.load_best,
239
Gen=Gen,
240
Dis=Dis,
241
g_optimizer=cfgs.OPTIMIZATION.g_optimizer,
242
d_optimizer=cfgs.OPTIMIZATION.d_optimizer,
243
run_name=run_name,
244
apply_g_ema=cfgs.MODEL.apply_g_ema,
245
Gen_ema=Gen_ema,
246
ema=ema,
247
is_train=cfgs.RUN.train,
248
RUN=cfgs.RUN,
249
logger=logger,
250
global_rank=global_rank,
251
device=local_rank,
252
cfg_file=cfgs.RUN.cfg_file)
253
254
if topk == "initialize":
255
topk == cfgs.OPTIMIZATION.batch_size
256
if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]:
257
ema.ema_rampup = "N/A" # disable EMA rampup
258
if cfgs.MODEL.backbone == "stylegan3" and cfgs.STYLEGAN.stylegan3_cfg == "stylegan3-r":
259
cfgs.STYLEGAN.blur_init_sigma = "N/A" # disable blur rampup
260
if cfgs.AUG.apply_ada:
261
cfgs.AUG.ada_kimg = 100 # make ADA react faster at the beginning
262
263
if cfgs.RUN.ckpt_dir is None or cfgs.RUN.freezeD != -1:
264
if local_rank == 0:
265
cfgs.RUN.ckpt_dir = ckpt.make_ckpt_dir(join(cfgs.RUN.save_dir, "checkpoints", run_name))
266
dict_dir = join(cfgs.RUN.save_dir, "statistics", run_name)
267
loss_list_dict = misc.load_log_dicts(directory=dict_dir, file_name="losses.npy", ph=loss_list_dict)
268
metric_dict_during_train = misc.load_log_dicts(directory=dict_dir, file_name="metrics.npy", ph=metric_dict_during_train)
269
270
# -----------------------------------------------------------------------------
271
# prepare parallel training
272
# -----------------------------------------------------------------------------
273
if cfgs.OPTIMIZATION.world_size > 1:
274
Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis =\
275
model.prepare_parallel_training(Gen=Gen,
276
Gen_mapping=Gen_mapping,
277
Gen_synthesis=Gen_synthesis,
278
Dis=Dis,
279
Gen_ema=Gen_ema,
280
Gen_ema_mapping=Gen_ema_mapping,
281
Gen_ema_synthesis=Gen_ema_synthesis,
282
MODEL=cfgs.MODEL,
283
world_size=cfgs.OPTIMIZATION.world_size,
284
distributed_data_parallel=cfgs.RUN.distributed_data_parallel,
285
synchronized_bn=cfgs.RUN.synchronized_bn,
286
apply_g_ema=cfgs.MODEL.apply_g_ema,
287
device=local_rank)
288
289
# -----------------------------------------------------------------------------
290
# load a pre-trained network (InceptionV3, SwAV, DINO, or Swin-T)
291
# -----------------------------------------------------------------------------
292
if cfgs.DATA.name in ["ImageNet", "Baby_ImageNet", "Papa_ImageNet", "Grandpa_ImageNet"]:
293
num_eval = {"train": 50000, "valid": len(eval_dataloader.dataset)}
294
else:
295
if eval_dataloader is not None:
296
num_eval[cfgs.RUN.ref_dataset] = len(eval_dataloader.dataset)
297
else:
298
num_eval["train"], num_eval["valid"], num_eval["test"] = 50000, 50000, 50000
299
300
if len(cfgs.RUN.eval_metrics) or cfgs.RUN.intra_class_fid:
301
eval_model = pp.LoadEvalModel(eval_backbone=cfgs.RUN.eval_backbone,
302
post_resizer=cfgs.RUN.post_resizer,
303
world_size=cfgs.OPTIMIZATION.world_size,
304
distributed_data_parallel=cfgs.RUN.distributed_data_parallel,
305
device=local_rank)
306
307
if "fid" in cfgs.RUN.eval_metrics:
308
mu, sigma = pp.prepare_moments(data_loader=eval_dataloader,
309
eval_model=eval_model,
310
quantize=True,
311
cfgs=cfgs,
312
logger=logger,
313
device=local_rank)
314
315
if "prdc" in cfgs.RUN.eval_metrics:
316
if cfgs.RUN.distributed_data_parallel:
317
prdc_sampler = DistributedSampler(eval_dataset,
318
num_replicas=cfgs.OPTIMIZATION.world_size,
319
rank=local_rank,
320
shuffle=True,
321
drop_last=False)
322
else:
323
prdc_sampler = None
324
325
prdc_dataloader = DataLoader(dataset=eval_dataset,
326
batch_size=cfgs.OPTIMIZATION.batch_size,
327
shuffle=(prdc_sampler is None),
328
pin_memory=True,
329
num_workers=cfgs.RUN.num_workers,
330
sampler=prdc_sampler,
331
drop_last=False)
332
333
real_feats = pp.prepare_real_feats(data_loader=prdc_dataloader,
334
eval_model=eval_model,
335
num_feats=num_eval[cfgs.RUN.ref_dataset],
336
quantize=True,
337
cfgs=cfgs,
338
logger=logger,
339
device=local_rank)
340
341
if cfgs.RUN.calc_is_ref_dataset:
342
pp.calculate_ins(data_loader=eval_dataloader,
343
eval_model=eval_model,
344
quantize=True,
345
splits=1,
346
cfgs=cfgs,
347
logger=logger,
348
device=local_rank)
349
350
# -----------------------------------------------------------------------------
351
# initialize WORKER for training and evaluating GAN
352
# -----------------------------------------------------------------------------
353
worker = WORKER(
354
cfgs=cfgs,
355
run_name=run_name,
356
Gen=Gen,
357
Gen_mapping=Gen_mapping,
358
Gen_synthesis=Gen_synthesis,
359
Dis=Dis,
360
Gen_ema=Gen_ema,
361
Gen_ema_mapping=Gen_ema_mapping,
362
Gen_ema_synthesis=Gen_ema_synthesis,
363
ema=ema,
364
eval_model=eval_model,
365
train_dataloader=train_dataloader,
366
eval_dataloader=eval_dataloader,
367
global_rank=global_rank,
368
local_rank=local_rank,
369
mu=mu,
370
sigma=sigma,
371
real_feats=real_feats,
372
logger=logger,
373
aa_p=aa_p,
374
best_step=best_step,
375
best_fid=best_fid,
376
best_ckpt_path=best_ckpt_path,
377
lecam_emas=lecam_emas,
378
num_eval=num_eval,
379
loss_list_dict=loss_list_dict,
380
metric_dict_during_train=metric_dict_during_train,
381
)
382
383
# -----------------------------------------------------------------------------
384
# train GAN until "total_steps" generator updates
385
# -----------------------------------------------------------------------------
386
if cfgs.RUN.train:
387
if global_rank == 0:
388
logger.info("Start training!")
389
390
worker.training, worker.topk = True, topk
391
worker.prepare_train_iter(epoch_counter=epoch)
392
while step <= cfgs.OPTIMIZATION.total_steps:
393
if cfgs.OPTIMIZATION.d_first:
394
real_cond_loss, dis_acml_loss = worker.train_discriminator(current_step=step)
395
gen_acml_loss = worker.train_generator(current_step=step)
396
else:
397
gen_acml_loss = worker.train_generator(current_step=step)
398
real_cond_loss, dis_acml_loss = worker.train_discriminator(current_step=step)
399
400
if global_rank == 0 and (step + 1) % cfgs.RUN.print_freq == 0:
401
worker.log_train_statistics(current_step=step,
402
real_cond_loss=real_cond_loss,
403
gen_acml_loss=gen_acml_loss,
404
dis_acml_loss=dis_acml_loss)
405
step += 1
406
407
if cfgs.LOSS.apply_topk:
408
if (epoch + 1) == worker.epoch_counter:
409
epoch += 1
410
worker.topk = losses.adjust_k(current_k=worker.topk,
411
topk_gamma=cfgs.LOSS.topk_gamma,
412
inf_k=int(cfgs.OPTIMIZATION.batch_size * cfgs.LOSS.topk_nu))
413
414
if step % cfgs.RUN.save_freq == 0:
415
# visuailize fake images
416
if global_rank == 0:
417
worker.visualize_fake_images(num_cols=num_cols, current_step=step)
418
419
# evaluate GAN for monitoring purpose
420
if len(cfgs.RUN.eval_metrics) :
421
is_best = worker.evaluate(step=step, metrics=cfgs.RUN.eval_metrics, writing=True, training=True)
422
423
# save GAN in "./checkpoints/RUN_NAME/*"
424
if global_rank == 0:
425
worker.save(step=step, is_best=is_best)
426
427
# stop processes until all processes arrive
428
if cfgs.RUN.distributed_data_parallel:
429
dist.barrier(worker.group)
430
431
if global_rank == 0:
432
logger.info("End of training!")
433
434
# -----------------------------------------------------------------------------
435
# re-evaluate the best GAN and conduct ordered analyses
436
# -----------------------------------------------------------------------------
437
worker.training, worker.epoch_counter = False, epoch
438
worker.gen_ctlr.standing_statistics = cfgs.RUN.standing_statistics
439
worker.gen_ctlr.standing_max_batch = cfgs.RUN.standing_max_batch
440
worker.gen_ctlr.standing_step = cfgs.RUN.standing_step
441
442
if global_rank == 0:
443
best_step = ckpt.load_best_model(ckpt_dir=cfgs.RUN.ckpt_dir,
444
Gen=Gen,
445
Dis=Dis,
446
apply_g_ema=cfgs.MODEL.apply_g_ema,
447
Gen_ema=Gen_ema,
448
ema=ema)
449
if len(cfgs.RUN.eval_metrics):
450
for e in range(cfgs.RUN.num_eval):
451
if global_rank == 0:
452
print(""), logger.info("-" * 80)
453
_ = worker.evaluate(step=best_step, metrics=cfgs.RUN.eval_metrics, writing=False, training=False)
454
455
if cfgs.RUN.save_real_images:
456
if global_rank == 0: print(""), logger.info("-" * 80)
457
worker.save_real_images()
458
459
if cfgs.RUN.save_fake_images:
460
if global_rank == 0:
461
print(""), logger.info("-" * 80)
462
worker.save_fake_images(num_images=cfgs.RUN.save_fake_images_num)
463
464
if cfgs.RUN.vis_fake_images:
465
if global_rank == 0:
466
print(""), logger.info("-" * 80)
467
worker.visualize_fake_images(num_cols=num_cols, current_step=best_step)
468
469
if cfgs.RUN.k_nearest_neighbor:
470
if global_rank == 0:
471
print(""), logger.info("-" * 80)
472
worker.run_k_nearest_neighbor(dataset=eval_dataset, num_rows=num_rows, num_cols=num_cols)
473
474
if cfgs.RUN.interpolation:
475
if global_rank == 0:
476
print(""), logger.info("-" * 80)
477
worker.run_linear_interpolation(num_rows=num_rows, num_cols=num_cols, fix_z=True, fix_y=False)
478
worker.run_linear_interpolation(num_rows=num_rows, num_cols=num_cols, fix_z=False, fix_y=True)
479
480
if cfgs.RUN.frequency_analysis:
481
if global_rank == 0:
482
print(""), logger.info("-" * 80)
483
worker.run_frequency_analysis(dataloader=eval_dataloader)
484
485
if cfgs.RUN.tsne_analysis:
486
if global_rank == 0:
487
print(""), logger.info("-" * 80)
488
worker.run_tsne(dataloader=eval_dataloader)
489
490
if cfgs.RUN.intra_class_fid:
491
if global_rank == 0:
492
print(""), logger.info("-" * 80)
493
worker.calculate_intra_class_fid(dataset=eval_dataset)
494
495
if cfgs.RUN.semantic_factorization:
496
if global_rank == 0:
497
print(""), logger.info("-" * 80)
498
worker.run_semantic_factorization(num_rows=cfgs.RUN.num_semantic_axis,
499
num_cols=num_cols,
500
maximum_variations=cfgs.RUN.maximum_variations)
501
if cfgs.RUN.GAN_train:
502
if global_rank == 0:
503
print(""), logger.info("-" * 80)
504
worker.compute_GAN_train_or_test_classifier_accuracy_score(GAN_train=True, GAN_test=False)
505
506
if cfgs.RUN.GAN_test:
507
if global_rank == 0:
508
print(""), logger.info("-" * 80)
509
worker.compute_GAN_train_or_test_classifier_accuracy_score(GAN_train=False, GAN_test=True)
510
511
if global_rank == 0:
512
wandb.finish()
513
514