Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/main.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/main.py
6
7
from argparse import ArgumentParser
8
from warnings import simplefilter
9
import json
10
import os
11
import random
12
import sys
13
import tempfile
14
15
from torch.multiprocessing import Process
16
import torch
17
import torch.multiprocessing as mp
18
19
import config
20
import loader
21
import utils.hdf5 as hdf5
22
import utils.log as log
23
import utils.misc as misc
24
25
RUN_NAME_FORMAT = ("{data_name}-" "{framework}-" "{phase}-" "{timestamp}")
26
27
28
def load_configs_initialize_training():
29
parser = ArgumentParser(add_help=True)
30
parser.add_argument("--entity", type=str, default=None, help="entity for wandb logging")
31
parser.add_argument("--project", type=str, default=None, help="project name for wandb logging")
32
33
parser.add_argument("-cfg", "--cfg_file", type=str, default="./src/configs/CIFAR10/ContraGAN.yaml")
34
parser.add_argument("-data", "--data_dir", type=str, default=None)
35
parser.add_argument("-save", "--save_dir", type=str, default="./")
36
parser.add_argument("-ckpt", "--ckpt_dir", type=str, default=None)
37
parser.add_argument("-best", "--load_best", action="store_true", help="load the best performed checkpoint")
38
39
parser.add_argument("--seed", type=int, default=-1, help="seed for generating random numbers")
40
parser.add_argument("-DDP", "--distributed_data_parallel", action="store_true")
41
parser.add_argument("--backend", type=str, default="nccl", help="cuda backend for DDP training \in ['nccl', 'gloo']")
42
parser.add_argument("-tn", "--total_nodes", default=1, type=int, help="total number of nodes for training")
43
parser.add_argument("-cn", "--current_node", default=0, type=int, help="rank of the current node")
44
parser.add_argument("--num_workers", type=int, default=8)
45
parser.add_argument("-sync_bn", "--synchronized_bn", action="store_true", help="turn on synchronized batchnorm")
46
parser.add_argument("-mpc", "--mixed_precision", action="store_true", help="turn on mixed precision training")
47
48
parser.add_argument("--truncation_factor", type=float, default=-1.0, help="truncation factor for applying truncation trick \
49
(-1.0 means not applying truncation trick)")
50
parser.add_argument("--truncation_cutoff", type=float, default=None, help="truncation cutoff for stylegan \
51
(apply truncation for only w[:truncation_cutoff]")
52
parser.add_argument("-batch_stat", "--batch_statistics", action="store_true", help="use the statistics of a batch when evaluating GAN \
53
(if false, use the moving average updated statistics)")
54
parser.add_argument("-std_stat", "--standing_statistics", action="store_true", help="apply standing statistics for evaluation")
55
parser.add_argument("-std_max", "--standing_max_batch", type=int, default=-1, help="maximum batch_size for calculating standing statistics \
56
(-1.0 menas not applying standing statistics trick for evaluation)")
57
parser.add_argument("-std_step", "--standing_step", type=int, default=-1, help="# of steps for standing statistics \
58
(-1.0 menas not applying standing statistics trick for evaluation)")
59
parser.add_argument("--freezeD", type=int, default=-1, help="# of freezed blocks in the discriminator for transfer learning")
60
61
# parser arguments to apply langevin sampling for GAN evaluation
62
# In the arguments regarding 'decay', -1 means not applying the decay trick by default
63
parser.add_argument("-lgv", "--langevin_sampling", action="store_true",
64
help="apply langevin sampling to generate images from a Energy-Based Model")
65
parser.add_argument("-lgv_rate", "--langevin_rate", type=float, default=-1,
66
help="an initial update rate for langevin sampling (\epsilon)")
67
parser.add_argument("-lgv_std", "--langevin_noise_std", type=float, default=-1,
68
help="standard deviation of a gaussian noise used in langevin sampling (std of n_i)")
69
parser.add_argument("-lgv_decay", "--langevin_decay", type=float, default=-1,
70
help="decay strength for langevin_rate and langevin_noise_std")
71
parser.add_argument("-lgv_decay_steps", "--langevin_decay_steps", type=int, default=-1,
72
help="langevin_rate and langevin_noise_std decrease every 'langevin_decay_steps'")
73
parser.add_argument("-lgv_steps", "--langevin_steps", type=int, default=-1, help="total steps of langevin sampling")
74
75
parser.add_argument("-t", "--train", action="store_true")
76
parser.add_argument("-hdf5", "--load_train_hdf5", action="store_true", help="load train images from a hdf5 file for fast I/O")
77
parser.add_argument("-l", "--load_data_in_memory", action="store_true", help="put the whole train dataset on the main memory for fast I/O")
78
parser.add_argument("-metrics", "--eval_metrics", nargs='+', default=['fid'],
79
help="evaluation metrics to use during training, a subset list of ['fid', 'is', 'prdc'] or none")
80
parser.add_argument("--pre_resizer", type=str, default="wo_resize", help="which resizer will you use to pre-process images\
81
in ['wo_resize', 'nearest', 'bilinear', 'bicubic', 'lanczos']")
82
parser.add_argument("--post_resizer", type=str, default="legacy", help="which resizer will you use to evaluate GANs\
83
in ['legacy', 'clean', 'friendly']")
84
parser.add_argument("--num_eval", type=int, default=1, help="number of runs for final evaluation.")
85
parser.add_argument("-sr", "--save_real_images", action="store_true", help="save images sampled from the reference dataset")
86
parser.add_argument("-sf", "--save_fake_images", action="store_true", help="save fake images generated by the GAN.")
87
parser.add_argument("-sf_num", "--save_fake_images_num", type=int, default=1, help="number of fake images to save")
88
parser.add_argument("-v", "--vis_fake_images", action="store_true", help="visualize image canvas")
89
parser.add_argument("-knn", "--k_nearest_neighbor", action="store_true", help="conduct k-nearest neighbor analysis")
90
parser.add_argument("-itp", "--interpolation", action="store_true", help="conduct interpolation analysis")
91
parser.add_argument("-fa", "--frequency_analysis", action="store_true", help="conduct frequency analysis")
92
parser.add_argument("-tsne", "--tsne_analysis", action="store_true", help="conduct tsne analysis")
93
parser.add_argument("-ifid", "--intra_class_fid", action="store_true", help="calculate intra-class fid")
94
parser.add_argument('--GAN_train', action='store_true', help="whether to calculate CAS (Recall)")
95
parser.add_argument('--GAN_test', action='store_true', help="whether to calculate CAS (Precision)")
96
parser.add_argument('-resume_ct', '--resume_classifier_train', action='store_true', help="whether to resume classifier traning for CAS")
97
parser.add_argument("-sefa", "--semantic_factorization", action="store_true", help="perform semantic (closed-form) factorization")
98
parser.add_argument("-sefa_axis", "--num_semantic_axis", type=int, default=-1, help="number of semantic axis for sefa")
99
parser.add_argument("-sefa_max", "--maximum_variations", type=float, default=-1,
100
help="iterpolate between z and z + maximum_variations*eigen-vector")
101
parser.add_argument("-empty_cache", "--empty_cache", action="store_true", help="empty cuda caches after training step of generator and discriminator, \
102
slightly reduces memory usage but slows training speed. (not recommended for normal use)")
103
104
parser.add_argument("--print_freq", type=int, default=100, help="logging interval")
105
parser.add_argument("--save_freq", type=int, default=2000, help="save interval")
106
parser.add_argument('--eval_backbone', type=str, default='InceptionV3_tf',\
107
help="[InceptionV3_tf, InceptionV3_torch, ResNet50_torch, SwAV_torch, DINO_torch, Swin-T_torch]")
108
parser.add_argument("-ref", "--ref_dataset", type=str, default="train", help="reference dataset for evaluation[train/valid/test]")
109
parser.add_argument("--calc_is_ref_dataset", action="store_true", help="whether to calculate a inception score of the ref dataset.")
110
args = parser.parse_args()
111
run_cfgs = vars(args)
112
113
if not args.train and \
114
"none" in args.eval_metrics and \
115
not args.save_real_images and \
116
not args.save_fake_images and \
117
not args.vis_fake_images and \
118
not args.k_nearest_neighbor and \
119
not args.interpolation and \
120
not args.frequency_analysis and \
121
not args.tsne_analysis and \
122
not args.intra_class_fid and \
123
not args.GAN_train and \
124
not args.GAN_test and \
125
not args.semantic_factorization:
126
parser.print_help(sys.stderr)
127
sys.exit(1)
128
129
gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device()
130
131
cfgs = config.Configurations(args.cfg_file)
132
cfgs.update_cfgs(run_cfgs, super="RUN")
133
cfgs.OPTIMIZATION.world_size = gpus_per_node * cfgs.RUN.total_nodes
134
cfgs.check_compatability()
135
136
run_name = log.make_run_name(RUN_NAME_FORMAT,
137
data_name=cfgs.DATA.name,
138
framework=cfgs.RUN.cfg_file.split("/")[-1][:-5],
139
phase="train")
140
141
crop_long_edge = False if cfgs.DATA.name in cfgs.MISC.no_proc_data else True
142
resize_size = None if cfgs.DATA.name in cfgs.MISC.no_proc_data else cfgs.DATA.img_size
143
cfgs.RUN.pre_resizer = "wo_resize" if cfgs.DATA.name in cfgs.MISC.no_proc_data else cfgs.RUN.pre_resizer
144
if cfgs.RUN.load_train_hdf5:
145
hdf5_path, crop_long_edge, resize_size = hdf5.make_hdf5(
146
name=cfgs.DATA.name,
147
img_size=cfgs.DATA.img_size,
148
crop_long_edge=crop_long_edge,
149
resize_size=resize_size,
150
resizer=cfgs.RUN.pre_resizer,
151
data_dir=cfgs.RUN.data_dir,
152
DATA=cfgs.DATA,
153
RUN=cfgs.RUN)
154
else:
155
hdf5_path = None
156
cfgs.PRE.crop_long_edge, cfgs.PRE.resize_size = crop_long_edge, resize_size
157
158
misc.prepare_folder(names=cfgs.MISC.base_folders, save_dir=cfgs.RUN.save_dir)
159
try:
160
misc.download_data_if_possible(data_name=cfgs.DATA.name, data_dir=cfgs.RUN.data_dir)
161
except:
162
pass
163
164
if cfgs.RUN.seed == -1:
165
cfgs.RUN.seed = random.randint(1, 4096)
166
cfgs.RUN.fix_seed = False
167
else:
168
cfgs.RUN.fix_seed = True
169
170
if cfgs.OPTIMIZATION.world_size == 1:
171
print("You have chosen a specific GPU. This will completely disable data parallelism.")
172
return cfgs, gpus_per_node, run_name, hdf5_path, rank
173
174
175
if __name__ == "__main__":
176
cfgs, gpus_per_node, run_name, hdf5_path, rank = load_configs_initialize_training()
177
178
if cfgs.RUN.distributed_data_parallel and cfgs.OPTIMIZATION.world_size > 1:
179
mp.set_start_method("spawn", force=True)
180
print("Train the models through DistributedDataParallel (DDP) mode.")
181
ctx = torch.multiprocessing.spawn(fn=loader.load_worker,
182
args=(cfgs,
183
gpus_per_node,
184
run_name,
185
hdf5_path),
186
nprocs=gpus_per_node,
187
join=False)
188
ctx.join()
189
for process in ctx.processes:
190
process.kill()
191
else:
192
loader.load_worker(local_rank=rank,
193
cfgs=cfgs,
194
gpus_per_node=gpus_per_node,
195
run_name=run_name,
196
hdf5_path=hdf5_path)
197
198