Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/config.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/config.py
6
7
from itertools import chain
8
import json
9
import os
10
import random
11
import sys
12
import yaml
13
14
import torch
15
import torch.nn as nn
16
17
import utils.misc as misc
18
import utils.losses as losses
19
import utils.ops as ops
20
import utils.diffaug as diffaug
21
import utils.cr as cr
22
import utils.simclr_aug as simclr_aug
23
import utils.ada_aug as ada_aug
24
25
26
class make_empty_object(object):
27
pass
28
29
30
class Configurations(object):
31
def __init__(self, cfg_file):
32
self.cfg_file = cfg_file
33
self.load_base_cfgs()
34
self._overwrite_cfgs(self.cfg_file)
35
self.define_modules()
36
37
def load_base_cfgs(self):
38
# -----------------------------------------------------------------------------
39
# Data settings
40
# -----------------------------------------------------------------------------
41
self.DATA = misc.make_empty_object()
42
43
# dataset name \in ["CIFAR10", "CIFAR100", "Tiny_ImageNet", "CUB200", "ImageNet", "MY_DATASET"]
44
self.DATA.name = "CIFAR10"
45
# image size for training
46
self.DATA.img_size = 32
47
# number of classes in training dataset, if there is no explicit class label, DATA.num_classes = 1
48
self.DATA.num_classes = 10
49
# number of image channels in dataset. //image_shape[0]
50
self.DATA.img_channels = 3
51
52
# -----------------------------------------------------------------------------
53
# Model settings
54
# -----------------------------------------------------------------------------
55
self.MODEL = misc.make_empty_object()
56
57
# type of backbone architectures of the generator and discriminator \in
58
# ["deep_conv", "resnet", "big_resnet", "big_resnet_deep_legacy", "big_resnet_deep_studiogan", "stylegan2", "stylegan3"]
59
self.MODEL.backbone = "resnet"
60
# conditioning method of the generator \in ["W/O", "cBN", "cAdaIN"]
61
self.MODEL.g_cond_mtd = "W/O"
62
# conditioning method of the discriminator \in ["W/O", "AC", "PD", "MH", "MD", "2C","D2DCE", "SPD"]
63
self.MODEL.d_cond_mtd = "W/O"
64
# type of auxiliary classifier \in ["W/O", "TAC", "ADC"]
65
self.MODEL.aux_cls_type = "W/O"
66
# whether to normalize feature maps from the discriminator or not
67
self.MODEL.normalize_d_embed = False
68
# dimension of feature maps from the discriminator
69
# only appliable when MODEL.d_cond_mtd \in ["2C, D2DCE"]
70
self.MODEL.d_embed_dim = "N/A"
71
# whether to apply spectral normalization on the generator
72
self.MODEL.apply_g_sn = False
73
# whether to apply spectral normalization on the discriminator
74
self.MODEL.apply_d_sn = False
75
# type of activation function in the generator \in ["ReLU", "Leaky_ReLU", "ELU", "GELU"]
76
self.MODEL.g_act_fn = "ReLU"
77
# type of activation function in the discriminator \in ["ReLU", "Leaky_ReLU", "ELU", "GELU"]
78
self.MODEL.d_act_fn = "ReLU"
79
# whether to apply self-attention proposed by zhang et al. (SAGAN)
80
self.MODEL.apply_attn = False
81
# locations of the self-attention layer in the generator (should be list type)
82
self.MODEL.attn_g_loc = ["N/A"]
83
# locations of the self-attention layer in the discriminator (should be list type)
84
self.MODEL.attn_d_loc = ["N/A"]
85
# prior distribution for noise sampling \in ["gaussian", "uniform"]
86
self.MODEL.z_prior = "gaussian"
87
# dimension of noise vectors
88
self.MODEL.z_dim = 128
89
# dimension of intermediate latent (W) dimensionality used only for StyleGAN
90
self.MODEL.w_dim = "N/A"
91
# dimension of a shared latent embedding
92
self.MODEL.g_shared_dim = "N/A"
93
# base channel for the resnet style generator architecture
94
self.MODEL.g_conv_dim = 64
95
# base channel for the resnet style discriminator architecture
96
self.MODEL.d_conv_dim = 64
97
# generator's depth for "models/big_resnet_deep_*.py"
98
self.MODEL.g_depth = "N/A"
99
# discriminator's depth for "models/big_resnet_deep_*.py"
100
self.MODEL.d_depth = "N/A"
101
# whether to apply moving average update for the generator
102
self.MODEL.apply_g_ema = False
103
# decay rate for the ema generator
104
self.MODEL.g_ema_decay = "N/A"
105
# starting step for g_ema update
106
self.MODEL.g_ema_start = "N/A"
107
# weight initialization method for the generator \in ["ortho", "N02", "glorot", "xavier"]
108
self.MODEL.g_init = "ortho"
109
# weight initialization method for the discriminator \in ["ortho", "N02", "glorot", "xavier"]
110
self.MODEL.d_init = "ortho"
111
# type of information for infoGAN training \in ["N/A", "discrete", "continuous", "both"]
112
self.MODEL.info_type = "N/A"
113
# way to inject information into Generator \in ["N/A", "concat", "cBN"]
114
self.MODEL.g_info_injection = "N/A"
115
# number of discrete c to use in InfoGAN
116
self.MODEL.info_num_discrete_c = "N/A"
117
# number of continuous c to use in InfoGAN
118
self.MODEL.info_num_conti_c = "N/A"
119
# dimension of discrete c to use in InfoGAN (one-hot)
120
self.MODEL.info_dim_discrete_c = "N/A"
121
122
# -----------------------------------------------------------------------------
123
# loss settings
124
# -----------------------------------------------------------------------------
125
self.LOSS = misc.make_empty_object()
126
127
# type of adversarial loss \in ["vanilla", "least_squere", "wasserstein", "hinge", "MH"]
128
self.LOSS.adv_loss = "vanilla"
129
# balancing hyperparameter for conditional image generation
130
self.LOSS.cond_lambda = "N/A"
131
# strength of conditioning loss induced by twin auxiliary classifier for generator training
132
self.LOSS.tac_gen_lambda = "N/A"
133
# strength of conditioning loss induced by twin auxiliary classifier for discriminator training
134
self.LOSS.tac_dis_lambda = "N/A"
135
# strength of multi-hinge loss (MH) for the generator training
136
self.LOSS.mh_lambda = "N/A"
137
# whether to apply feature matching regularization
138
self.LOSS.apply_fm = False
139
# strength of feature matching regularization
140
self.LOSS.fm_lambda = "N/A"
141
# whether to apply r1 regularization used in multiple-discriminator (FUNIT)
142
self.LOSS.apply_r1_reg = False
143
# a place to apply the R1 regularization \in ["N/A", "inside_loop", "outside_loop"]
144
self.LOSS.r1_place = "N/A"
145
# strength of r1 regularization (it does not apply to r1_reg in StyleGAN2
146
self.LOSS.r1_lambda = "N/A"
147
# positive margin for D2DCE
148
self.LOSS.m_p = "N/A"
149
# temperature scalar for [2C, D2DCE]
150
self.LOSS.temperature = "N/A"
151
# whether to apply weight clipping regularization to let the discriminator satisfy Lipschitzness
152
self.LOSS.apply_wc = False
153
# clipping bound for weight clippling regularization
154
self.LOSS.wc_bound = "N/A"
155
# whether to apply gradient penalty regularization
156
self.LOSS.apply_gp = False
157
# strength of the gradient penalty regularization
158
self.LOSS.gp_lambda = "N/A"
159
# whether to apply deep regret analysis regularization
160
self.LOSS.apply_dra = False
161
# strength of the deep regret analysis regularization
162
self.LOSS.dra_lambda = "N/A"
163
# whther to apply max gradient penalty to let the discriminator satisfy Lipschitzness
164
self.LOSS.apply_maxgp = False
165
# strength of the maxgp regularization
166
self.LOSS.maxgp_lambda = "N/A"
167
# whether to apply consistency regularization
168
self.LOSS.apply_cr = False
169
# strength of the consistency regularization
170
self.LOSS.cr_lambda = "N/A"
171
# whether to apply balanced consistency regularization
172
self.LOSS.apply_bcr = False
173
# attraction strength between logits of real and augmented real samples
174
self.LOSS.real_lambda = "N/A"
175
# attraction strength between logits of fake and augmented fake samples
176
self.LOSS.fake_lambda = "N/A"
177
# whether to apply latent consistency regularization
178
self.LOSS.apply_zcr = False
179
# radius of ball to generate an fake image G(z + radius)
180
self.LOSS.radius = "N/A"
181
# repulsion strength between fake images (G(z), G(z + radius))
182
self.LOSS.g_lambda = "N/A"
183
# attaction strength between logits of fake images (G(z), G(z + radius))
184
self.LOSS.d_lambda = "N/A"
185
# whether to apply latent optimization for stable training
186
self.LOSS.apply_lo = False
187
# latent step size for latent optimization
188
self.LOSS.lo_alpha = "N/A"
189
# damping factor for calculating Fisher Information matrix
190
self.LOSS.lo_beta = "N/A"
191
# portion of z for latent optimization (c)
192
self.LOSS.lo_rate = "N/A"
193
# strength of latent optimization (w_{r})
194
self.LOSS.lo_lambda = "N/A"
195
# number of latent optimization iterations for a single sample during training
196
self.LOSS.lo_steps4train = "N/A"
197
# number of latent optimization iterations for a single sample during evaluation
198
self.LOSS.lo_steps4eval = "N/A"
199
# whether to apply topk training for the generator update
200
self.LOSS.apply_topk = False
201
# hyperparameter for batch_size decay rate for topk training \in [0,1]
202
self.LOSS.topk_gamma = "N/A"
203
# hyperparameter for the inf of the number of topk samples \in [0,1],
204
# inf_batch_size = int(topk_nu*batch_size)
205
self.LOSS.topk_nu = "N/A"
206
# strength lambda for infoGAN loss in case of discrete c (typically 0.1)
207
self.LOSS.infoGAN_loss_discrete_lambda = "N/A"
208
# strength lambda for infoGAN loss in case of continuous c (typically 1)
209
self.LOSS.infoGAN_loss_conti_lambda = "N/A"
210
# whether to apply LeCam regularization or not
211
self.LOSS.apply_lecam = False
212
# strength of the LeCam regularization
213
self.LOSS.lecam_lambda = "N/A"
214
# start iteration for EMALosses in src/utils/EMALosses
215
self.LOSS.lecam_ema_start_iter = "N/A"
216
# decay rate for the EMALosses
217
self.LOSS.lecam_ema_decay = "N/A"
218
219
# -----------------------------------------------------------------------------
220
# optimizer settings
221
# -----------------------------------------------------------------------------
222
self.OPTIMIZATION = misc.make_empty_object()
223
224
# type of the optimizer for GAN training \in ["SGD", RMSprop, "Adam"]
225
self.OPTIMIZATION.type_ = "Adam"
226
# number of batch size for GAN training,
227
# typically {CIFAR10: 64, CIFAR100: 64, Tiny_ImageNet: 1024, "CUB200": 256, ImageNet: 512(batch_size) * 4(accm_step)"}
228
self.OPTIMIZATION.batch_size = 64
229
# acuumulation steps for large batch training (batch_size = batch_size*accm_step)
230
self.OPTIMIZATION.acml_steps = 1
231
# learning rate for generator update
232
self.OPTIMIZATION.g_lr = 0.0002
233
# learning rate for discriminator update
234
self.OPTIMIZATION.d_lr = 0.0002
235
# weight decay strength for the generator update
236
self.OPTIMIZATION.g_weight_decay = 0.0
237
# weight decay strength for the discriminator update
238
self.OPTIMIZATION.d_weight_decay = 0.0
239
# momentum value for SGD and RMSprop optimizers
240
self.OPTIMIZATION.momentum = "N/A"
241
# nesterov value for SGD optimizer
242
self.OPTIMIZATION.nesterov = "N/A"
243
# alpha value for RMSprop optimizer
244
self.OPTIMIZATION.alpha = "N/A"
245
# beta values for Adam optimizer
246
self.OPTIMIZATION.beta1 = 0.5
247
self.OPTIMIZATION.beta2 = 0.999
248
# whether to optimize discriminator first,
249
# if True: optimize D -> optimize G
250
self.OPTIMIZATION.d_first = True
251
# the number of generator updates per step
252
self.OPTIMIZATION.g_updates_per_step = 1
253
# the number of discriminator updates per step
254
self.OPTIMIZATION.d_updates_per_step = 5
255
# the total number of steps for GAN training
256
self.OPTIMIZATION.total_steps = 100000
257
258
# -----------------------------------------------------------------------------
259
# preprocessing settings
260
# -----------------------------------------------------------------------------
261
self.PRE = misc.make_empty_object()
262
263
# whether to apply random flip preprocessing before training
264
self.PRE.apply_rflip = True
265
266
# -----------------------------------------------------------------------------
267
# differentiable augmentation settings
268
# -----------------------------------------------------------------------------
269
self.AUG = misc.make_empty_object()
270
271
# whether to apply differentiable augmentations for limited data training
272
self.AUG.apply_diffaug = False
273
274
# whether to apply adaptive discriminator augmentation (ADA)
275
self.AUG.apply_ada = False
276
# initial value of augmentation probability.
277
self.AUG.ada_initial_augment_p = "N/A"
278
# target probability for adaptive differentiable augmentations, None = fixed p (keep ada_initial_augment_p)
279
self.AUG.ada_target = "N/A"
280
# ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
281
self.AUG.ada_kimg = "N/A"
282
# how often to perform ada adjustment
283
self.AUG.ada_interval = "N/A"
284
# whether to apply adaptive pseudo augmentation (APA)
285
self.AUG.apply_apa = False
286
# initial value of augmentation probability.
287
self.AUG.apa_initial_augment_p = "N/A"
288
# target probability for adaptive pseudo augmentations, None = fixed p (keep ada_initial_augment_p)
289
self.AUG.apa_target = "N/A"
290
# APA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
291
self.AUG.apa_kimg = "N/A"
292
# how often to perform apa adjustment
293
self.AUG.apa_interval = "N/A"
294
# type of differentiable augmentation for cr, bcr, or limited data training
295
# \in ["W/O", "cr", "bcr", "diffaug", "simclr_basic", "simclr_hq", "simclr_hq_cutout", "byol",
296
# "blit", "geom", "color", "filter", "noise", "cutout", "bg", "bgc", "bgcf", "bgcfn", "bgcfnc"]
297
# cr (bcr, diffaugment, ada, simclr, byol) indicates differentiable augmenations used in the original paper
298
self.AUG.cr_aug_type = "W/O"
299
self.AUG.bcr_aug_type = "W/O"
300
self.AUG.diffaug_type = "W/O"
301
self.AUG.ada_aug_type = "W/O"
302
303
self.STYLEGAN = misc.make_empty_object()
304
305
# type of generator used in stylegan3, stylegan3-t : translatino equiv., stylegan3-r : translation & rotation equiv.
306
# \ in ["stylegan3-t", "stylegan3-r"]
307
self.STYLEGAN.stylegan3_cfg = "N/A"
308
# conditioning types that utilize embedding proxies for conditional stylegan2, stylegan3
309
self.STYLEGAN.cond_type = ["PD", "SPD", "2C", "D2DCE"]
310
# lazy regularization interval for generator, default 4
311
self.STYLEGAN.g_reg_interval = "N/A"
312
# lazy regularization interval for discriminator, default 16
313
self.STYLEGAN.d_reg_interval = "N/A"
314
# number of layers for the mapping network, default 8 except for cifar (2)
315
self.STYLEGAN.mapping_network = "N/A"
316
# style_mixing_p in stylegan generator, default 0.9 except for cifar (0)
317
self.STYLEGAN.style_mixing_p = "N/A"
318
# half-life of the exponential moving average (EMA) of generator weights default 500
319
self.STYLEGAN.g_ema_kimg = "N/A"
320
# EMA ramp-up coefficient, defalt "N/A" except for cifar 0.05
321
self.STYLEGAN.g_ema_rampup = "N/A"
322
# whether to apply path length regularization, default is True except cifar
323
self.STYLEGAN.apply_pl_reg = False
324
# pl regularization strength, default 2
325
self.STYLEGAN.pl_weight = "N/A"
326
# discriminator architecture for STYLEGAN. 'resnet' except for cifar10 ('orig')
327
self.STYLEGAN.d_architecture = "N/A"
328
# group size for the minibatch standard deviation layer, None = entire minibatch.
329
self.STYLEGAN.d_epilogue_mbstd_group_size = "N/A"
330
# Whether to blur the images seen by the discriminator. Only used for stylegan3-r with value 10
331
self.STYLEGAN.blur_init_sigma = "N/A"
332
333
# -----------------------------------------------------------------------------
334
# run settings
335
# -----------------------------------------------------------------------------
336
self.RUN = misc.make_empty_object()
337
338
# -----------------------------------------------------------------------------
339
# run settings
340
# -----------------------------------------------------------------------------
341
self.MISC = misc.make_empty_object()
342
343
self.MISC.no_proc_data = ["CIFAR10", "CIFAR100", "Tiny_ImageNet"]
344
self.MISC.base_folders = ["checkpoints", "figures", "logs", "moments", "samples", "values"]
345
self.MISC.classifier_based_GAN = ["AC", "2C", "D2DCE"]
346
self.MISC.info_params = ["info_discrete_linear", "info_conti_mu_linear", "info_conti_var_linear"]
347
self.MISC.cas_setting = {
348
"CIFAR10": {
349
"batch_size": 128,
350
"epochs": 90,
351
"depth": 32,
352
"lr": 0.1,
353
"momentum": 0.9,
354
"weight_decay": 1e-4,
355
"print_freq": 1,
356
"bottleneck": True
357
},
358
"Tiny_ImageNet": {
359
"batch_size": 128,
360
"epochs": 90,
361
"depth": 34,
362
"lr": 0.1,
363
"momentum": 0.9,
364
"weight_decay": 1e-4,
365
"print_freq": 1,
366
"bottleneck": True
367
},
368
"ImageNet": {
369
"batch_size": 128,
370
"epochs": 90,
371
"depth": 34,
372
"lr": 0.1,
373
"momentum": 0.9,
374
"weight_decay": 1e-4,
375
"print_freq": 1,
376
"bottleneck": True
377
},
378
}
379
380
# -----------------------------------------------------------------------------
381
# Module settings
382
# -----------------------------------------------------------------------------
383
self.MODULES = misc.make_empty_object()
384
385
self.super_cfgs = {
386
"DATA": self.DATA,
387
"MODEL": self.MODEL,
388
"LOSS": self.LOSS,
389
"OPTIMIZATION": self.OPTIMIZATION,
390
"PRE": self.PRE,
391
"AUG": self.AUG,
392
"RUN": self.RUN,
393
"STYLEGAN": self.STYLEGAN
394
}
395
396
def update_cfgs(self, cfgs, super="RUN"):
397
for attr, value in cfgs.items():
398
setattr(self.super_cfgs[super], attr, value)
399
400
def _overwrite_cfgs(self, cfg_file):
401
with open(cfg_file, 'r') as f:
402
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
403
for super_cfg_name, attr_value in yaml_cfg.items():
404
for attr, value in attr_value.items():
405
if hasattr(self.super_cfgs[super_cfg_name], attr):
406
setattr(self.super_cfgs[super_cfg_name], attr, value)
407
else:
408
raise AttributeError("There does not exist '{cls}.{attr}' attribute in the config.py.". \
409
format(cls=super_cfg_name, attr=attr))
410
411
def define_losses(self):
412
if self.MODEL.d_cond_mtd == "MH" and self.LOSS.adv_loss == "MH":
413
self.LOSS.g_loss = losses.crammer_singer_loss
414
self.LOSS.d_loss = losses.crammer_singer_loss
415
else:
416
g_losses = {
417
"vanilla": losses.g_vanilla,
418
"logistic": losses.g_logistic,
419
"least_square": losses.g_ls,
420
"hinge": losses.g_hinge,
421
"wasserstein": losses.g_wasserstein,
422
}
423
424
d_losses = {
425
"vanilla": losses.d_vanilla,
426
"logistic": losses.d_logistic,
427
"least_square": losses.d_ls,
428
"hinge": losses.d_hinge,
429
"wasserstein": losses.d_wasserstein,
430
}
431
432
self.LOSS.g_loss = g_losses[self.LOSS.adv_loss]
433
self.LOSS.d_loss = d_losses[self.LOSS.adv_loss]
434
435
def define_modules(self):
436
if self.MODEL.apply_g_sn:
437
self.MODULES.g_conv2d = ops.snconv2d
438
self.MODULES.g_deconv2d = ops.sndeconv2d
439
self.MODULES.g_linear = ops.snlinear
440
self.MODULES.g_embedding = ops.sn_embedding
441
else:
442
self.MODULES.g_conv2d = ops.conv2d
443
self.MODULES.g_deconv2d = ops.deconv2d
444
self.MODULES.g_linear = ops.linear
445
self.MODULES.g_embedding = ops.embedding
446
447
if self.MODEL.apply_d_sn:
448
self.MODULES.d_conv2d = ops.snconv2d
449
self.MODULES.d_deconv2d = ops.sndeconv2d
450
self.MODULES.d_linear = ops.snlinear
451
self.MODULES.d_embedding = ops.sn_embedding
452
else:
453
self.MODULES.d_conv2d = ops.conv2d
454
self.MODULES.d_deconv2d = ops.deconv2d
455
self.MODULES.d_linear = ops.linear
456
self.MODULES.d_embedding = ops.embedding
457
458
if self.MODEL.g_cond_mtd == "cBN" or self.MODEL.g_info_injection == "cBN" or self.MODEL.backbone == "big_resnet":
459
self.MODULES.g_bn = ops.ConditionalBatchNorm2d
460
elif self.MODEL.g_cond_mtd == "W/O":
461
self.MODULES.g_bn = ops.batchnorm_2d
462
elif self.MODEL.g_cond_mtd == "cAdaIN":
463
pass
464
else:
465
raise NotImplementedError
466
467
if not self.MODEL.apply_d_sn:
468
self.MODULES.d_bn = ops.batchnorm_2d
469
470
if self.MODEL.g_act_fn == "ReLU":
471
self.MODULES.g_act_fn = nn.ReLU(inplace=True)
472
elif self.MODEL.g_act_fn == "Leaky_ReLU":
473
self.MODULES.g_act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
474
elif self.MODEL.g_act_fn == "ELU":
475
self.MODULES.g_act_fn = nn.ELU(alpha=1.0, inplace=True)
476
elif self.MODEL.g_act_fn == "GELU":
477
self.MODULES.g_act_fn = nn.GELU()
478
elif self.MODEL.g_act_fn == "Auto":
479
pass
480
else:
481
raise NotImplementedError
482
483
if self.MODEL.d_act_fn == "ReLU":
484
self.MODULES.d_act_fn = nn.ReLU(inplace=True)
485
elif self.MODEL.d_act_fn == "Leaky_ReLU":
486
self.MODULES.d_act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
487
elif self.MODEL.d_act_fn == "ELU":
488
self.MODULES.d_act_fn = nn.ELU(alpha=1.0, inplace=True)
489
elif self.MODEL.d_act_fn == "GELU":
490
self.MODULES.d_act_fn = nn.GELU()
491
elif self.MODEL.g_act_fn == "Auto":
492
pass
493
else:
494
raise NotImplementedError
495
return self.MODULES
496
497
def define_optimizer(self, Gen, Dis):
498
Gen_params, Dis_params = [], []
499
for g_name, g_param in Gen.named_parameters():
500
Gen_params.append(g_param)
501
if self.MODEL.info_type in ["discrete", "both"]:
502
for info_name, info_param in Dis.info_discrete_linear.named_parameters():
503
Gen_params.append(info_param)
504
if self.MODEL.info_type in ["continuous", "both"]:
505
for info_name, info_param in Dis.info_conti_mu_linear.named_parameters():
506
Gen_params.append(info_param)
507
for info_name, info_param in Dis.info_conti_var_linear.named_parameters():
508
Gen_params.append(info_param)
509
510
for d_name, d_param in Dis.named_parameters():
511
if self.MODEL.info_type in ["discrete", "continuous", "both"]:
512
if "info_discrete" in d_name or "info_conti" in d_name:
513
pass
514
else:
515
Dis_params.append(d_param)
516
else:
517
Dis_params.append(d_param)
518
519
if self.OPTIMIZATION.type_ == "SGD":
520
self.OPTIMIZATION.g_optimizer = torch.optim.SGD(params=Gen_params,
521
lr=self.OPTIMIZATION.g_lr,
522
weight_decay=self.OPTIMIZATION.g_weight_decay,
523
momentum=self.OPTIMIZATION.momentum,
524
nesterov=self.OPTIMIZATION.nesterov)
525
self.OPTIMIZATION.d_optimizer = torch.optim.SGD(params=Dis_params,
526
lr=self.OPTIMIZATION.d_lr,
527
weight_decay=self.OPTIMIZATION.d_weight_decay,
528
momentum=self.OPTIMIZATION.momentum,
529
nesterov=self.OPTIMIZATION.nesterov)
530
elif self.OPTIMIZATION.type_ == "RMSprop":
531
self.OPTIMIZATION.g_optimizer = torch.optim.RMSprop(params=Gen_params,
532
lr=self.OPTIMIZATION.g_lr,
533
weight_decay=self.OPTIMIZATION.g_weight_decay,
534
momentum=self.OPTIMIZATION.momentum,
535
alpha=self.OPTIMIZATION.alpha)
536
self.OPTIMIZATION.d_optimizer = torch.optim.RMSprop(params=Dis_params,
537
lr=self.OPTIMIZATION.d_lr,
538
weight_decay=self.OPTIMIZATION.d_weight_decay,
539
momentum=self.OPTIMIZATION.momentum,
540
alpha=self.OPTIMIZATION.alpha)
541
elif self.OPTIMIZATION.type_ == "Adam":
542
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
543
g_ratio = (self.STYLEGAN.g_reg_interval / (self.STYLEGAN.g_reg_interval + 1)) if self.STYLEGAN.g_reg_interval != 1 else 1
544
d_ratio = (self.STYLEGAN.d_reg_interval / (self.STYLEGAN.d_reg_interval + 1)) if self.STYLEGAN.d_reg_interval != 1 else 1
545
self.OPTIMIZATION.g_lr *= g_ratio
546
self.OPTIMIZATION.d_lr *= d_ratio
547
betas_g = [self.OPTIMIZATION.beta1**g_ratio, self.OPTIMIZATION.beta2**g_ratio]
548
betas_d = [self.OPTIMIZATION.beta1**d_ratio, self.OPTIMIZATION.beta2**d_ratio]
549
eps_ = 1e-8
550
else:
551
betas_g = betas_d = [self.OPTIMIZATION.beta1, self.OPTIMIZATION.beta2]
552
eps_ = 1e-6
553
554
self.OPTIMIZATION.g_optimizer = torch.optim.Adam(params=Gen_params,
555
lr=self.OPTIMIZATION.g_lr,
556
betas=betas_g,
557
weight_decay=self.OPTIMIZATION.g_weight_decay,
558
eps=eps_)
559
self.OPTIMIZATION.d_optimizer = torch.optim.Adam(params=Dis_params,
560
lr=self.OPTIMIZATION.d_lr,
561
betas=betas_d,
562
weight_decay=self.OPTIMIZATION.d_weight_decay,
563
eps=eps_)
564
else:
565
raise NotImplementedError
566
567
def define_augments(self, device):
568
self.AUG.series_augment = misc.identity
569
ada_augpipe = {
570
'blit': dict(xflip=1, rotate90=1, xint=1),
571
'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
572
'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
573
'filter': dict(imgfilter=1),
574
'noise': dict(noise=1),
575
'cutout': dict(cutout=1),
576
'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
577
'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
578
'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
579
'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
580
'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
581
}
582
if self.AUG.apply_diffaug:
583
assert self.AUG.diffaug_type != "W/O", "Please select diffentiable augmentation type!"
584
if self.AUG.diffaug_type == "cr":
585
self.AUG.series_augment = cr.apply_cr_aug
586
elif self.AUG.diffaug_type == "diffaug":
587
self.AUG.series_augment = diffaug.apply_diffaug
588
elif self.AUG.diffaug_type in ["simclr_basic", "simclr_hq", "simclr_hq_cutout", "byol"]:
589
self.AUG.series_augment = simclr_aug.SimclrAugment(aug_type=self.AUG.diffaug).train().to(device).requires_grad_(False)
590
elif self.AUG.diffaug_type in ["blit", "geom", "color", "filter", "noise", "cutout", "bg", "bgc", "bgcf", "bgcfn", "bgcfnc"]:
591
self.AUG.series_augment = ada_aug.AdaAugment(**ada_augpipe[self.AUG.diffaug_type]).train().to(device).requires_grad_(False)
592
self.AUG.series_augment.p = 1.0
593
else:
594
raise NotImplementedError
595
596
if self.AUG.apply_ada:
597
assert self.AUG.ada_aug_type in ["blit", "geom", "color", "filter", "noise", "cutout", "bg", "bgc", "bgcf", "bgcfn",
598
"bgcfnc"], "Please select ada supported augmentations"
599
self.AUG.series_augment = ada_aug.AdaAugment(**ada_augpipe[self.AUG.ada_aug_type]).train().to(device).requires_grad_(False)
600
601
if self.LOSS.apply_cr:
602
assert self.AUG.cr_aug_type != "W/O", "Please select augmentation type for cr!"
603
if self.AUG.cr_aug_type == "cr":
604
self.AUG.parallel_augment = cr.apply_cr_aug
605
elif self.AUG.cr_aug_type == "diffaug":
606
self.AUG.parallel_augment = diffaug.apply_diffaug
607
elif self.AUG.cr_aug_type in ["simclr_basic", "simclr_hq", "simclr_hq_cutout", "byol"]:
608
self.AUG.parallel_augment = simclr_aug.SimclrAugment(aug_type=self.AUG.diffaug).train().to(device).requires_grad_(False)
609
elif self.AUG.cr_aug_type in ["blit", "geom", "color", "filter", "noise", "cutout", "bg", "bgc", "bgcf", "bgcfn", "bgcfnc"]:
610
self.AUG.parallel_augment = ada_aug.AdaAugment(**ada_augpipe[self.AUG.cr_aug_type]).train().to(device).requires_grad_(False)
611
self.AUG.parallel_augment.p = 1.0
612
else:
613
raise NotImplementedError
614
615
if self.LOSS.apply_bcr:
616
assert self.AUG.bcr_aug_type != "W/O", "Please select augmentation type for bcr!"
617
if self.AUG.bcr_aug_type == "bcr":
618
self.AUG.parallel_augment = cr.apply_cr_aug
619
elif self.AUG.bcr_aug_type == "diffaug":
620
self.AUG.parallel_augment = diffaug.apply_diffaug
621
elif self.AUG.bcr_aug_type in ["simclr_basic", "simclr_hq", "simclr_hq_cutout", "byol"]:
622
self.AUG.parallel_augment = simclr_aug.SimclrAugment(aug_type=self.AUG.diffaug).train().to(device).requires_grad_(False)
623
elif self.AUG.bcr_aug_type in ["blit", "geom", "color", "filter", "noise", "cutout", "bg", "bgc", "bgcf", "bgcfn", "bgcfnc"]:
624
self.AUG.parallel_augment = ada_aug.AdaAugment(
625
**ada_augpipe[self.AUG.bcr_aug_type]).train().to(device).requires_grad_(False)
626
self.AUG.parallel_augment.p = 1.0
627
else:
628
raise NotImplementedError
629
630
def check_compatability(self):
631
if self.RUN.distributed_data_parallel and self.RUN.mixed_precision:
632
print("-"*120)
633
print("Please use standing statistics (-std_stat) with -std_max and -std_step options for reliable evaluation!")
634
print("-"*120)
635
636
if len(self.RUN.eval_metrics):
637
for item in self.RUN.eval_metrics:
638
assert item in ["is", "fid", "prdc", "none"], "-metrics option can only contain is, fid, prdc or none for skipping evaluation."
639
640
if self.RUN.load_data_in_memory:
641
assert self.RUN.load_train_hdf5, "load_data_in_memory option is appliable with the load_train_hdf5 (-hdf5) option."
642
643
if self.MODEL.backbone == "deep_conv":
644
assert self.DATA.img_size == 32, "StudioGAN does not support the deep_conv backbone for the dataset whose spatial resolution is not 32."
645
646
if self.MODEL.backbone in ["big_resnet_deep_legacy", "big_resnet_deep_studiogan"]:
647
msg = "StudioGAN does not support the big_resnet_deep backbones without applying spectral normalization to the generator and discriminator."
648
assert self.MODEL.g_cond_mtd and self.MODEL.d_cond_mtd, msg
649
650
if self.RUN.langevin_sampling or self.LOSS.apply_lo:
651
assert self.RUN.langevin_sampling * self.LOSS.apply_lo == 0, "Langevin sampling and latent optmization cannot be used simultaneously."
652
653
if isinstance(self.MODEL.g_depth, int) or isinstance(self.MODEL.d_depth, int):
654
assert self.MODEL.backbone in ["big_resnet_deep_legacy", "big_resnet_deep_studiogan"], \
655
"MODEL.g_depth and MODEL.d_depth are hyperparameters for big_resnet_deep backbones."
656
657
if self.RUN.langevin_sampling:
658
msg = "Langevin sampling cannot be used for training only."
659
assert self.RUN.vis_fake_images + \
660
self.RUN.k_nearest_neighbor + \
661
self.RUN.interpolation + \
662
self.RUN.frequency_analysis + \
663
self.RUN.tsne_analysis + \
664
self.RUN.intra_class_fid + \
665
self.RUN.semantic_factorization + \
666
self.RUN.GAN_train + \
667
self.RUN.GAN_test != 0, \
668
msg
669
670
if self.RUN.langevin_sampling:
671
assert self.MODEL.z_prior == "gaussian", "Langevin sampling is defined only if z_prior is gaussian."
672
673
if self.RUN.freezeD > -1:
674
msg = "Freezing discriminator needs a pre-trained model. Please specify the checkpoint directory (using -ckpt) for loading a pre-trained discriminator."
675
assert self.RUN.ckpt_dir is not None, msg
676
677
if not self.RUN.train and self.RUN.eval_metrics != "none":
678
assert self.RUN.ckpt_dir is not None, "Specify -ckpt CHECKPOINT_FOLDER to evaluate GAN without training."
679
680
if self.RUN.GAN_train + self.RUN.GAN_test > 1:
681
msg = "Please turn off -DDP option to calculate CAS. It is possible to train a GAN using the DDP option and then compute CAS using DP."
682
assert not self.RUN.distributed_data_parallel, msg
683
684
if self.RUN.distributed_data_parallel:
685
msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, frequency, tsne analysis, DDLS, SeFa, and CAS with DDP. " + \
686
"Please change DDP with a single GPU training or DataParallel instead."
687
assert self.RUN.vis_fake_images + \
688
self.RUN.k_nearest_neighbor + \
689
self.RUN.interpolation + \
690
self.RUN.frequency_analysis + \
691
self.RUN.tsne_analysis + \
692
self.RUN.semantic_factorization + \
693
self.RUN.langevin_sampling + \
694
self.RUN.GAN_train + \
695
self.RUN.GAN_test == 0, \
696
msg
697
698
if self.RUN.intra_class_fid:
699
assert self.RUN.load_data_in_memory*self.RUN.load_train_hdf5 or not self.RUN.load_train_hdf5, \
700
"StudioGAN does not support calculating iFID using hdf5 data format without load_data_in_memory option."
701
702
if self.RUN.vis_fake_images + self.RUN.k_nearest_neighbor + self.RUN.interpolation + self.RUN.intra_class_fid + \
703
self.RUN.GAN_train + self.RUN.GAN_test >= 1:
704
assert self.OPTIMIZATION.batch_size % 8 == 0, "batch_size should be divided by 8."
705
706
if self.MODEL.aux_cls_type != "W/O":
707
assert self.MODEL.d_cond_mtd in self.MISC.classifier_based_GAN, \
708
"TAC and ADC are only applicable to classifier-based GANs."
709
710
if self.MODEL.d_cond_mtd == "MH" or self.LOSS.adv_loss == "MH":
711
assert self.MODEL.d_cond_mtd == "MH" and self.LOSS.adv_loss == "MH", \
712
"To train a GAN with Multi-Hinge loss, both d_cond_mtd and adv_loss must be 'MH'."
713
714
if self.MODEL.d_cond_mtd == "MH" or self.LOSS.adv_loss == "MH":
715
assert not self.LOSS.apply_topk, "StudioGAN does not support Topk training for MHGAN."
716
717
if self.RUN.train * self.RUN.standing_statistics:
718
print("StudioGAN does not support standing_statistics during training")
719
print("After training is done, StudioGAN will accumulate batchnorm statistics to evaluate GAN.")
720
721
if self.OPTIMIZATION.world_size > 1 and self.RUN.synchronized_bn:
722
assert not self.RUN.batch_statistics, "batch_statistics cannot be used with synchronized_bn."
723
724
if self.DATA.name in ["CIFAR10", "CIFAR100"]:
725
assert self.RUN.ref_dataset in ["train", "test"], "There is no data for validation."
726
727
if self.RUN.interpolation:
728
assert self.MODEL.backbone in ["big_resnet", "big_resnet_deep_legacy", "big_resnet_deep_studiogan"], \
729
"StudioGAN does not support interpolation analysis except for biggan and big_resnet_deep backbones."
730
731
if self.RUN.semantic_factorization:
732
assert self.RUN.num_semantic_axis > 0, "To apply sefa, please set num_semantic_axis to a natual number greater than 0."
733
734
if self.OPTIMIZATION.world_size == 1:
735
assert not self.RUN.distributed_data_parallel, "Cannot perform distributed training with a single gpu."
736
737
if self.MODEL.backbone == "stylegan3":
738
assert self.STYLEGAN.stylegan3_cfg in ["stylegan3-t", "stylegan3-r"], "You must choose which type of stylegan3 generator (-r or -t)"
739
740
if self.MODEL.g_cond_mtd == "cAdaIN":
741
assert self.MODEL.backbone in ["stylegan2", "stylegan3"], "cAdaIN is only applicable to stylegan2, stylegan3."
742
743
if self.MODEL.d_cond_mtd == "SPD":
744
assert self.MODEL.backbone in ["stylegan2", "stylegan3"], "SytleGAN Projection Discriminator (SPD) is only applicable to stylegan2, stylegan3."
745
746
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
747
assert self.MODEL.g_act_fn == "Auto" and self.MODEL.d_act_fn == "Auto", \
748
"g_act_fn and d_act_fn should be 'Auto' to build StyleGAN2, StyleGAN3 generator and discriminator."
749
750
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
751
assert not self.MODEL.apply_g_sn and not self.MODEL.apply_d_sn, \
752
"StudioGAN does not support spectral normalization on stylegan2, stylegan3."
753
754
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
755
assert self.MODEL.g_cond_mtd in ["W/O", "cAdaIN"], \
756
"stylegan2 and stylegan3 only supports 'W/O' or 'cAdaIN' as g_cond_mtd."
757
758
if self.LOSS.apply_r1_reg and self.MODEL.backbone in ["stylegan2", "stylegan3"]:
759
assert self.LOSS.r1_place in ["inside_loop", "outside_loop"], "LOSS.r1_place should be one of ['inside_loop', 'outside_loop']"
760
761
if self.MODEL.g_act_fn == "Auto" or self.MODEL.d_act_fn == "Auto":
762
assert self.MODEL.backbone in ["stylegan2", "stylegan3"], \
763
"StudioGAN does not support the act_fn auto selection options except for stylegan2, stylegan3."
764
765
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r":
766
assert self.STYLEGAN.blur_init_sigma != "N/A", "With stylegan3-r, you need to specify blur_init_sigma."
767
768
if self.MODEL.backbone in ["stylegan2", "stylegan3"] and self.MODEL.apply_g_ema:
769
assert self.MODEL.g_ema_decay == "N/A" and self.MODEL.g_ema_start == "N/A", \
770
"Please specify g_ema parameters to STYLEGAN.g_ema_kimg and STYLEGAN.g_ema_rampup instead of MODEL.g_ema_decay and MODEL.g_ema_start."
771
772
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
773
assert self.STYLEGAN.d_epilogue_mbstd_group_size <= (self.OPTIMIZATION.batch_size / self.OPTIMIZATION.world_size),\
774
"Number of imgs that goes to each GPU must be bigger than d_epilogue_mbstd_group_size"
775
776
if self.MODEL.backbone not in ["stylegan2", "stylegan3"] and self.MODEL.apply_g_ema:
777
assert isinstance(self.MODEL.g_ema_decay, float) and isinstance(self.MODEL.g_ema_start, int), \
778
"Please specify g_ema parameters to MODEL.g_ema_decay and MODEL.g_ema_start."
779
assert self.STYLEGAN.g_ema_kimg == "N/A" and self.STYLEGAN.g_ema_rampup == "N/A", \
780
"g_ema_kimg, g_ema_rampup hyperparameters are only valid for stylegan2 backbone."
781
782
if isinstance(self.MODEL.g_shared_dim, int):
783
assert self.MODEL.backbone in ["big_resnet", "big_resnet_deep_legacy", "big_resnet_deep_studiogan"], \
784
"hierarchical embedding is only applicable to big_resnet or big_resnet_deep backbones."
785
786
if isinstance(self.MODEL.g_conv_dim, int) or isinstance(self.MODEL.d_conv_dim, int):
787
assert self.MODEL.backbone in ["resnet", "big_resnet", "big_resnet_deep_legacy", "big_resnet_deep_studiogan"], \
788
"g_conv_dim and d_conv_dim are hyperparameters for controlling dimensions of resnet, big_resnet, and big_resnet_deeps."
789
790
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
791
assert self.LOSS.apply_fm + \
792
self.LOSS.apply_gp + \
793
self.LOSS.apply_dra + \
794
self.LOSS.apply_maxgp + \
795
self.LOSS.apply_zcr + \
796
self.LOSS.apply_lo + \
797
self.RUN.synchronized_bn + \
798
self.RUN.batch_statistics + \
799
self.RUN.standing_statistics + \
800
self.RUN.freezeD + \
801
self.RUN.langevin_sampling + \
802
self.RUN.interpolation + \
803
self.RUN.semantic_factorization == -1, \
804
"StudioGAN does not support some options for stylegan2, stylegan3. Please refer to config.py for more details."
805
806
if self.MODEL.backbone in ["stylegan2", "stylegan3"]:
807
assert not self.MODEL.apply_attn, "cannot apply attention layers to the stylegan2 generator."
808
809
if self.RUN.GAN_train or self.RUN.GAN_test:
810
assert not self.MODEL.d_cond_mtd == "W/O", \
811
"Classifier Accuracy Score (CAS) is defined only when the GAN is trained by a class-conditioned way."
812
813
if self.MODEL.info_type == "N/A":
814
assert self.MODEL.info_num_discrete_c == "N/A" and self.MODEL.info_num_conti_c == "N/A" and self.MODEL.info_dim_discrete_c == "N/A" and\
815
self.MODEL.g_info_injection == "N/A" and self.LOSS.infoGAN_loss_discrete_lambda == "N/A" and self.LOSS.infoGAN_loss_conti_lambda == "N/A",\
816
"MODEL.info_num_discrete_c, MODEL.info_num_conti_c, MODEL.info_dim_discrete_c, LOSS.infoGAN_loss_discrete_lambda, and LOSS.infoGAN_loss_conti_lambda should be 'N/A'."
817
elif self.MODEL.info_type == "continuous":
818
assert self.MODEL.info_num_conti_c != "N/A" and self.LOSS.infoGAN_loss_conti_lambda != "N/A",\
819
"MODEL.info_num_conti_c and LOSS.infoGAN_loss_conti_lambda should be integer and float."
820
elif self.MODEL.info_type == "discrete":
821
assert self.MODEL.info_num_discrete_c != "N/A" and self.MODEL.info_dim_discrete_c != "N/A" and self.LOSS.infoGAN_loss_discrete_lambda != "N/A",\
822
"MODEL.info_num_discrete_c, MODEL.info_dim_discrete_c, and LOSS.infoGAN_loss_discrete_lambda should be integer, integer, and float, respectively."
823
elif self.MODEL.info_type == "both":
824
assert self.MODEL.info_num_discrete_c != "N/A" and self.MODEL.info_num_conti_c != "N/A" and self.MODEL.info_dim_discrete_c != "N/A" and\
825
self.LOSS.infoGAN_loss_discrete_lambda != "N/A" and self.LOSS.infoGAN_loss_conti_lambda != "N/A",\
826
"MODEL.info_num_discrete_c, MODEL.info_num_conti_c, MODEL.info_dim_discrete_c, LOSS.infoGAN_loss_discrete_lambda, and LOSS.infoGAN_loss_conti_lambda should not be 'N/A'."
827
else:
828
raise NotImplementedError
829
830
if self.MODEL.info_type in ["discrete", "both"]:
831
assert self.MODEL.info_num_discrete_c > 0 and self.MODEL.info_dim_discrete_c > 0,\
832
"MODEL.info_num_discrete_c and MODEL.info_dim_discrete_c should be over 0."
833
834
if self.MODEL.info_type in ["continuous", "both"]:
835
assert self.MODEL.info_num_conti_c > 0, "MODEL.info_num_conti_c should be over 0."
836
837
if self.MODEL.info_type in ["discrete", "continuous", "both"] and self.MODEL.backbone in ["stylegan2", "stylegan3"]:
838
assert self.MODEL.g_info_injection == "concat", "StyleGAN2, StyleGAN3 only allows concat as g_info_injection method"
839
840
if self.MODEL.info_type in ["discrete", "continuous", "both"]:
841
assert self.MODEL.g_info_injection in ["concat", "cBN"], "MODEL.g_info_injection should be 'concat' or 'cBN'."
842
843
if self.AUG.apply_ada and self.AUG.apply_apa:
844
assert self.AUG.ada_initial_augment_p == self.AUG.apa_initial_augment_p and \
845
self.AUG.ada_target == self.AUG.apa_target and \
846
self.AUG.ada_kimg == self.AUG.apa_kimg and \
847
self.AUG.ada_interval == self.AUG.apa_interval, \
848
"ADA and APA specifications should be the completely same."
849
850
assert self.RUN.eval_backbone in ["InceptionV3_tf", "InceptionV3_torch", "ResNet50_torch", "SwAV_torch", "DINO_torch", "Swin-T_torch"], \
851
"eval_backbone should be in [InceptionV3_tf, InceptionV3_torch, ResNet50_torch, SwAV_torch, DINO_torch, Swin-T_torch]"
852
853
assert self.RUN.post_resizer in ["legacy", "clean", "friendly"], "resizing flag should be in [legacy, clean, friendly]"
854
855
assert self.RUN.data_dir is not None or self.RUN.save_fake_images, "Please specify data_dir if dataset is prepared. \
856
\nIn the case of CIFAR10 or CIFAR100, just specify the directory where you want \
857
dataset to be downloaded."
858
859
assert self.RUN.batch_statistics*self.RUN.standing_statistics == 0, \
860
"You can't turn on batch_statistics and standing_statistics simultaneously."
861
862
assert self.OPTIMIZATION.batch_size % self.OPTIMIZATION.world_size == 0, \
863
"Batch_size should be divided by the number of gpus."
864
865
assert int(self.LOSS.apply_cr)*int(self.LOSS.apply_bcr) == 0 and \
866
int(self.LOSS.apply_cr)*int(self.LOSS.apply_zcr) == 0, \
867
"You can't simultaneously turn on consistency reg. and improved consistency reg."
868
869
assert int(self.LOSS.apply_gp)*int(self.LOSS.apply_dra)*(self.LOSS.apply_maxgp) == 0, \
870
"You can't simultaneously apply gradient penalty regularization, deep regret analysis, and max gradient penalty."
871
872
assert self.RUN.save_freq % self.RUN.print_freq == 0, \
873
"RUN.save_freq should be divided by RUN.print_freq for wandb logging."
874
875
assert self.RUN.pre_resizer in ["wo_resize", "nearest", "bilinear", "bicubic", "lanczos"], \
876
"The interpolation filter for pre-precessing should be \in ['wo_resize', 'nearest', 'bilinear', 'bicubic', 'lanczos']"
877
878