Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/model.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/models/model.py
6
7
import copy
8
9
from torch.nn import DataParallel
10
from torch.nn.parallel import DistributedDataParallel as DDP
11
import torch
12
13
from sync_batchnorm.batchnorm import convert_model
14
from utils.ema import Ema
15
from utils.ema import EmaStylegan2
16
import utils.misc as misc
17
18
19
def load_generator_discriminator(DATA, OPTIMIZATION, MODEL, STYLEGAN, MODULES, RUN, device, logger):
20
if device == 0:
21
logger.info("Build a Generative Adversarial Network.")
22
module = __import__("models.{backbone}".format(backbone=MODEL.backbone), fromlist=["something"])
23
if device == 0:
24
logger.info("Modules are located on './src/models.{backbone}'.".format(backbone=MODEL.backbone))
25
26
if MODEL.backbone in ["stylegan2", "stylegan3"]:
27
channel_base, channel_max = 32768 if MODEL.backbone == "stylegan3" or DATA.img_size >= 512 or \
28
DATA.name in ["CIFAR10", "CIFAR100"] else 16384, 512
29
gen_c_dim = DATA.num_classes if MODEL.g_cond_mtd == "cAdaIN" else 0
30
dis_c_dim = DATA.num_classes if MODEL.d_cond_mtd in STYLEGAN.cond_type else 0
31
if RUN.mixed_precision:
32
num_fp16_res = 4
33
conv_clamp = 256
34
else:
35
num_fp16_res = 0
36
conv_clamp = None
37
if MODEL.backbone == "stylegan2":
38
Gen = module.Generator(z_dim=MODEL.z_dim,
39
c_dim=gen_c_dim,
40
w_dim=MODEL.w_dim,
41
img_resolution=DATA.img_size,
42
img_channels=DATA.img_channels,
43
MODEL=MODEL,
44
mapping_kwargs={"num_layers": STYLEGAN.mapping_network},
45
synthesis_kwargs={"channel_base": channel_base, "channel_max": channel_max, \
46
"num_fp16_res": num_fp16_res, "conv_clamp": conv_clamp}).to(device)
47
else:
48
magnitude_ema_beta = 0.5 ** (OPTIMIZATION.batch_size * OPTIMIZATION.acml_steps / (20 * 1e3))
49
g_channel_base, g_channel_max, conv_kernel, use_radial_filters = channel_base, channel_max, 3, False
50
if STYLEGAN.stylegan3_cfg == "stylegan3-r":
51
g_channel_base, g_channel_max, conv_kernel, use_radial_filters = channel_base * 2, channel_max * 2, 1, True
52
Gen = module.Generator(z_dim=MODEL.z_dim,
53
c_dim=gen_c_dim,
54
w_dim=MODEL.w_dim,
55
img_resolution=DATA.img_size,
56
img_channels=DATA.img_channels,
57
MODEL=MODEL,
58
mapping_kwargs={"num_layers": STYLEGAN.mapping_network},
59
synthesis_kwargs={"channel_base": g_channel_base, "channel_max": g_channel_max, \
60
"num_fp16_res": num_fp16_res, "conv_clamp": conv_clamp, "conv_kernel": conv_kernel, \
61
"use_radial_filters": use_radial_filters, "magnitude_ema_beta": magnitude_ema_beta}).to(device)
62
63
Gen_mapping, Gen_synthesis = Gen.mapping, Gen.synthesis
64
65
module = __import__("models.stylegan2", fromlist=["something"]) # always use StyleGAN2 discriminator
66
Dis = module.Discriminator(c_dim=dis_c_dim,
67
img_resolution=DATA.img_size,
68
img_channels=DATA.img_channels,
69
architecture=STYLEGAN.d_architecture,
70
channel_base=channel_base,
71
channel_max=channel_max,
72
num_fp16_res=num_fp16_res,
73
conv_clamp=conv_clamp,
74
cmap_dim=None,
75
d_cond_mtd=MODEL.d_cond_mtd,
76
aux_cls_type=MODEL.aux_cls_type,
77
d_embed_dim=MODEL.d_embed_dim,
78
num_classes=DATA.num_classes,
79
normalize_d_embed=MODEL.normalize_d_embed,
80
block_kwargs={},
81
mapping_kwargs={},
82
epilogue_kwargs={
83
"mbstd_group_size": STYLEGAN.d_epilogue_mbstd_group_size
84
},
85
MODEL=MODEL).to(device)
86
87
if MODEL.apply_g_ema:
88
if device == 0:
89
logger.info("Prepare exponential moving average generator with decay rate of {decay}."\
90
.format(decay=MODEL.g_ema_decay))
91
Gen_ema = copy.deepcopy(Gen)
92
Gen_ema_mapping, Gen_ema_synthesis = Gen_ema.mapping, Gen_ema.synthesis
93
94
ema = EmaStylegan2(source=Gen,
95
target=Gen_ema,
96
ema_kimg=STYLEGAN.g_ema_kimg,
97
ema_rampup=STYLEGAN.g_ema_rampup,
98
effective_batch_size=OPTIMIZATION.batch_size * OPTIMIZATION.acml_steps)
99
else:
100
Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema = None, None, None, None
101
102
else:
103
Gen = module.Generator(z_dim=MODEL.z_dim,
104
g_shared_dim=MODEL.g_shared_dim,
105
img_size=DATA.img_size,
106
g_conv_dim=MODEL.g_conv_dim,
107
apply_attn=MODEL.apply_attn,
108
attn_g_loc=MODEL.attn_g_loc,
109
g_cond_mtd=MODEL.g_cond_mtd,
110
num_classes=DATA.num_classes,
111
g_init=MODEL.g_init,
112
g_depth=MODEL.g_depth,
113
mixed_precision=RUN.mixed_precision,
114
MODULES=MODULES,
115
MODEL=MODEL).to(device)
116
117
Gen_mapping, Gen_synthesis = None, None
118
119
Dis = module.Discriminator(img_size=DATA.img_size,
120
d_conv_dim=MODEL.d_conv_dim,
121
apply_d_sn=MODEL.apply_d_sn,
122
apply_attn=MODEL.apply_attn,
123
attn_d_loc=MODEL.attn_d_loc,
124
d_cond_mtd=MODEL.d_cond_mtd,
125
aux_cls_type=MODEL.aux_cls_type,
126
d_embed_dim=MODEL.d_embed_dim,
127
num_classes=DATA.num_classes,
128
normalize_d_embed=MODEL.normalize_d_embed,
129
d_init=MODEL.d_init,
130
d_depth=MODEL.d_depth,
131
mixed_precision=RUN.mixed_precision,
132
MODULES=MODULES,
133
MODEL=MODEL).to(device)
134
if MODEL.apply_g_ema:
135
if device == 0:
136
logger.info("Prepare exponential moving average generator with decay rate of {decay}."\
137
.format(decay=MODEL.g_ema_decay))
138
Gen_ema = copy.deepcopy(Gen)
139
Gen_ema_mapping, Gen_ema_synthesis = None, None
140
141
ema = Ema(source=Gen, target=Gen_ema, decay=MODEL.g_ema_decay, start_iter=MODEL.g_ema_start)
142
else:
143
Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema = None, None, None, None
144
145
if device == 0:
146
logger.info(misc.count_parameters(Gen))
147
if device == 0:
148
logger.info(Gen)
149
150
if device == 0:
151
logger.info(misc.count_parameters(Dis))
152
if device == 0:
153
logger.info(Dis)
154
return Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema
155
156
157
def prepare_parallel_training(Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis,
158
MODEL, world_size, distributed_data_parallel, synchronized_bn, apply_g_ema, device):
159
if distributed_data_parallel:
160
if synchronized_bn:
161
process_group = torch.distributed.new_group([w for w in range(world_size)])
162
Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Gen, process_group)
163
Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Dis, process_group)
164
if apply_g_ema:
165
Gen_ema = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Gen_ema, process_group)
166
167
if MODEL.backbone in ["stylegan2", "stylegan3"]:
168
Gen_mapping = DDP(Gen.mapping, device_ids=[device], broadcast_buffers=False)
169
Gen_synthesis = DDP(Gen.synthesis, device_ids=[device], broadcast_buffers=False)
170
else:
171
Gen = DDP(Gen, device_ids=[device], broadcast_buffers=synchronized_bn)
172
Dis = DDP(Dis, device_ids=[device],
173
broadcast_buffers=False if MODEL.backbone in ["stylegan2", "stylegan3"] else synchronized_bn,
174
find_unused_parameters=True if MODEL.info_type in ["discrete", "continuous", "both"] else False)
175
if apply_g_ema:
176
if MODEL.backbone in ["stylegan2", "stylegan3"]:
177
Gen_ema_mapping = DDP(Gen_ema.mapping, device_ids=[device], broadcast_buffers=False)
178
Gen_ema_synthesis = DDP(Gen_ema.synthesis, device_ids=[device], broadcast_buffers=False)
179
else:
180
Gen_ema = DDP(Gen_ema, device_ids=[device], broadcast_buffers=synchronized_bn)
181
else:
182
if MODEL.backbone in ["stylegan2", "stylegan3"]:
183
Gen_mapping = DataParallel(Gen.mapping, output_device=device)
184
Gen_synthesis = DataParallel(Gen.synthesis, output_device=device)
185
else:
186
Gen = DataParallel(Gen, output_device=device)
187
Dis = DataParallel(Dis, output_device=device)
188
if apply_g_ema:
189
if MODEL.backbone in ["stylegan2", "stylegan3"]:
190
Gen_ema_mapping = DataParallel(Gen_ema.mapping, output_device=device)
191
Gen_ema_synthesis = DataParallel(Gen_ema.synthesis, output_device=device)
192
else:
193
Gen_ema = DataParallel(Gen_ema, output_device=device)
194
195
if synchronized_bn:
196
Gen = convert_model(Gen).to(device)
197
Dis = convert_model(Dis).to(device)
198
if apply_g_ema:
199
Gen_ema = convert_model(Gen_ema).to(device)
200
return Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis
201
202