Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/big_resnet.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# models/big_resnet.py
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import utils.ops as ops
12
import utils.misc as misc
13
14
15
class GenBlock(nn.Module):
16
def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, MODULES):
17
super(GenBlock, self).__init__()
18
self.g_cond_mtd = g_cond_mtd
19
20
self.bn1 = MODULES.g_bn(affine_input_dim, in_channels, MODULES)
21
self.bn2 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)
22
23
self.activation = MODULES.g_act_fn
24
self.conv2d0 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
25
self.conv2d1 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
26
self.conv2d2 = MODULES.g_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
27
28
def forward(self, x, affine):
29
x0 = x
30
x = self.bn1(x, affine)
31
x = self.activation(x)
32
x = F.interpolate(x, scale_factor=2, mode="nearest")
33
x = self.conv2d1(x)
34
35
x = self.bn2(x, affine)
36
x = self.activation(x)
37
x = self.conv2d2(x)
38
39
x0 = F.interpolate(x0, scale_factor=2, mode="nearest")
40
x0 = self.conv2d0(x0)
41
out = x + x0
42
return out
43
44
45
class Generator(nn.Module):
46
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
47
mixed_precision, MODULES, MODEL):
48
super(Generator, self).__init__()
49
g_in_dims_collection = {
50
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
51
"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
52
"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
53
"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
54
"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]
55
}
56
57
g_out_dims_collection = {
58
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
59
"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
60
"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
61
"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
62
"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]
63
}
64
65
bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}
66
67
self.z_dim = z_dim
68
self.g_shared_dim = g_shared_dim
69
self.g_cond_mtd = g_cond_mtd
70
self.num_classes = num_classes
71
self.mixed_precision = mixed_precision
72
self.MODEL = MODEL
73
self.in_dims = g_in_dims_collection[str(img_size)]
74
self.out_dims = g_out_dims_collection[str(img_size)]
75
self.bottom = bottom_collection[str(img_size)]
76
self.num_blocks = len(self.in_dims)
77
self.chunk_size = z_dim // (self.num_blocks + 1)
78
self.affine_input_dim = self.chunk_size
79
assert self.z_dim % (self.num_blocks + 1) == 0, "z_dim should be divided by the number of blocks"
80
81
info_dim = 0
82
if self.MODEL.info_type in ["discrete", "both"]:
83
info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
84
if self.MODEL.info_type in ["continuous", "both"]:
85
info_dim += self.MODEL.info_num_conti_c
86
87
if self.MODEL.info_type != "N/A":
88
if self.MODEL.g_info_injection == "concat":
89
self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
90
elif self.MODEL.g_info_injection == "cBN":
91
self.affine_input_dim += self.g_shared_dim
92
self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)
93
94
self.linear0 = MODULES.g_linear(in_features=self.chunk_size, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)
95
96
if self.g_cond_mtd != "W/O":
97
self.affine_input_dim += self.g_shared_dim
98
self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)
99
100
self.blocks = []
101
for index in range(self.num_blocks):
102
self.blocks += [[
103
GenBlock(in_channels=self.in_dims[index],
104
out_channels=self.out_dims[index],
105
g_cond_mtd=self.g_cond_mtd,
106
affine_input_dim=self.affine_input_dim,
107
MODULES=MODULES)
108
]]
109
110
if index + 1 in attn_g_loc and apply_attn:
111
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]
112
113
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
114
115
self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
116
self.activation = MODULES.g_act_fn
117
self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
118
self.tanh = nn.Tanh()
119
120
ops.init_weights(self.modules, g_init)
121
122
def forward(self, z, label, shared_label=None, eval=False):
123
affine_list = []
124
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
125
if self.MODEL.info_type != "N/A":
126
if self.MODEL.g_info_injection == "concat":
127
z = self.info_mix_linear(z)
128
elif self.MODEL.g_info_injection == "cBN":
129
z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
130
affine_list.append(self.info_proj_linear(z_info))
131
132
zs = torch.split(z, self.chunk_size, 1)
133
z = zs[0]
134
if self.g_cond_mtd != "W/O":
135
if shared_label is None:
136
shared_label = self.shared(label)
137
affine_list.append(shared_label)
138
if len(affine_list) == 0:
139
affines = [item for item in zs[1:]]
140
else:
141
affines = [torch.cat(affine_list + [item], 1) for item in zs[1:]]
142
143
act = self.linear0(z)
144
act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)
145
counter = 0
146
for index, blocklist in enumerate(self.blocks):
147
for block in blocklist:
148
if isinstance(block, ops.SelfAttention):
149
act = block(act)
150
else:
151
act = block(act, affines[counter])
152
counter += 1
153
154
act = self.bn4(act)
155
act = self.activation(act)
156
act = self.conv2d5(act)
157
out = self.tanh(act)
158
return out
159
160
161
class DiscOptBlock(nn.Module):
162
def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):
163
super(DiscOptBlock, self).__init__()
164
self.apply_d_sn = apply_d_sn
165
166
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
167
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
168
self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
169
170
if not apply_d_sn:
171
self.bn0 = MODULES.d_bn(in_features=in_channels)
172
self.bn1 = MODULES.d_bn(in_features=out_channels)
173
174
self.activation = MODULES.d_act_fn
175
self.average_pooling = nn.AvgPool2d(2)
176
177
def forward(self, x):
178
x0 = x
179
x = self.conv2d1(x)
180
if not self.apply_d_sn:
181
x = self.bn1(x)
182
x = self.activation(x)
183
184
x = self.conv2d2(x)
185
x = self.average_pooling(x)
186
187
x0 = self.average_pooling(x0)
188
if not self.apply_d_sn:
189
x0 = self.bn0(x0)
190
x0 = self.conv2d0(x0)
191
out = x + x0
192
return out
193
194
195
class DiscBlock(nn.Module):
196
def __init__(self, in_channels, out_channels, apply_d_sn, MODULES, downsample=True):
197
super(DiscBlock, self).__init__()
198
self.apply_d_sn = apply_d_sn
199
self.downsample = downsample
200
201
self.activation = MODULES.d_act_fn
202
203
self.ch_mismatch = False
204
if in_channels != out_channels:
205
self.ch_mismatch = True
206
207
if self.ch_mismatch or downsample:
208
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
209
if not apply_d_sn:
210
self.bn0 = MODULES.d_bn(in_features=in_channels)
211
212
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
213
self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
214
215
if not apply_d_sn:
216
self.bn1 = MODULES.d_bn(in_features=in_channels)
217
self.bn2 = MODULES.d_bn(in_features=out_channels)
218
219
self.average_pooling = nn.AvgPool2d(2)
220
221
def forward(self, x):
222
x0 = x
223
if not self.apply_d_sn:
224
x = self.bn1(x)
225
x = self.activation(x)
226
x = self.conv2d1(x)
227
228
if not self.apply_d_sn:
229
x = self.bn2(x)
230
x = self.activation(x)
231
x = self.conv2d2(x)
232
if self.downsample:
233
x = self.average_pooling(x)
234
235
if self.downsample or self.ch_mismatch:
236
if not self.apply_d_sn:
237
x0 = self.bn0(x0)
238
x0 = self.conv2d0(x0)
239
if self.downsample:
240
x0 = self.average_pooling(x0)
241
out = x + x0
242
return out
243
244
245
class Discriminator(nn.Module):
246
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,
247
num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
248
super(Discriminator, self).__init__()
249
d_in_dims_collection = {
250
"32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
251
"64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
252
"128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
253
"256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],
254
"512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]
255
}
256
257
d_out_dims_collection = {
258
"32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
259
"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
260
"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
261
"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
262
"512":
263
[d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]
264
}
265
266
d_down = {
267
"32": [True, True, False, False],
268
"64": [True, True, True, True, False],
269
"128": [True, True, True, True, True, False],
270
"256": [True, True, True, True, True, True, False],
271
"512": [True, True, True, True, True, True, True, False]
272
}
273
274
self.d_cond_mtd = d_cond_mtd
275
self.aux_cls_type = aux_cls_type
276
self.normalize_d_embed = normalize_d_embed
277
self.num_classes = num_classes
278
self.mixed_precision = mixed_precision
279
self.in_dims = d_in_dims_collection[str(img_size)]
280
self.out_dims = d_out_dims_collection[str(img_size)]
281
self.MODEL = MODEL
282
down = d_down[str(img_size)]
283
284
self.blocks = []
285
for index in range(len(self.in_dims)):
286
if index == 0:
287
self.blocks += [[
288
DiscOptBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES)
289
]]
290
else:
291
self.blocks += [[
292
DiscBlock(in_channels=self.in_dims[index],
293
out_channels=self.out_dims[index],
294
apply_d_sn=apply_d_sn,
295
MODULES=MODULES,
296
downsample=down[index])
297
]]
298
299
if index + 1 in attn_d_loc and apply_attn:
300
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]
301
302
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
303
304
self.activation = MODULES.d_act_fn
305
306
# linear layer for adversarial training
307
if self.d_cond_mtd == "MH":
308
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)
309
elif self.d_cond_mtd == "MD":
310
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)
311
else:
312
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)
313
314
# double num_classes for Auxiliary Discriminative Classifier
315
if self.aux_cls_type == "ADC":
316
num_classes = num_classes * 2
317
318
# linear and embedding layers for discriminator conditioning
319
if self.d_cond_mtd == "AC":
320
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
321
elif self.d_cond_mtd == "PD":
322
self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
323
elif self.d_cond_mtd in ["2C", "D2DCE"]:
324
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
325
self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
326
327
# linear and embedding layers for evolved classifier-based GAN
328
if self.aux_cls_type == "TAC":
329
if self.d_cond_mtd == "AC":
330
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
331
elif self.d_cond_mtd in ["2C", "D2DCE"]:
332
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
333
self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
334
else:
335
raise NotImplementedError
336
337
# Q head network for infoGAN
338
if self.MODEL.info_type in ["discrete", "both"]:
339
out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
340
self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
341
if self.MODEL.info_type in ["continuous", "both"]:
342
out_features = self.MODEL.info_num_conti_c
343
self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
344
self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
345
346
if d_init:
347
ops.init_weights(self.modules, d_init)
348
349
def forward(self, x, label, eval=False, adc_fake=False):
350
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
351
embed, proxy, cls_output = None, None, None
352
mi_embed, mi_proxy, mi_cls_output = None, None, None
353
info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
354
h = x
355
for index, blocklist in enumerate(self.blocks):
356
for block in blocklist:
357
h = block(h)
358
bottom_h, bottom_w = h.shape[2], h.shape[3]
359
h = self.activation(h)
360
h = torch.sum(h, dim=[2, 3])
361
362
# adversarial training
363
adv_output = torch.squeeze(self.linear1(h))
364
365
# make class labels odd (for fake) or even (for real) for ADC
366
if self.aux_cls_type == "ADC":
367
if adc_fake:
368
label = label*2 + 1
369
else:
370
label = label*2
371
372
# forward pass through InfoGAN Q head
373
if self.MODEL.info_type in ["discrete", "both"]:
374
info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
375
if self.MODEL.info_type in ["continuous", "both"]:
376
info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
377
info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))
378
379
# class conditioning
380
if self.d_cond_mtd == "AC":
381
if self.normalize_d_embed:
382
for W in self.linear2.parameters():
383
W = F.normalize(W, dim=1)
384
h = F.normalize(h, dim=1)
385
cls_output = self.linear2(h)
386
elif self.d_cond_mtd == "PD":
387
adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
388
elif self.d_cond_mtd in ["2C", "D2DCE"]:
389
embed = self.linear2(h)
390
proxy = self.embedding(label)
391
if self.normalize_d_embed:
392
embed = F.normalize(embed, dim=1)
393
proxy = F.normalize(proxy, dim=1)
394
elif self.d_cond_mtd == "MD":
395
idx = torch.LongTensor(range(label.size(0))).to(label.device)
396
adv_output = adv_output[idx, label]
397
elif self.d_cond_mtd in ["W/O", "MH"]:
398
pass
399
else:
400
raise NotImplementedError
401
402
# extra conditioning for TACGAN and ADCGAN
403
if self.aux_cls_type == "TAC":
404
if self.d_cond_mtd == "AC":
405
if self.normalize_d_embed:
406
for W in self.linear_mi.parameters():
407
W = F.normalize(W, dim=1)
408
mi_cls_output = self.linear_mi(h)
409
elif self.d_cond_mtd in ["2C", "D2DCE"]:
410
mi_embed = self.linear_mi(h)
411
mi_proxy = self.embedding_mi(label)
412
if self.normalize_d_embed:
413
mi_embed = F.normalize(mi_embed, dim=1)
414
mi_proxy = F.normalize(mi_proxy, dim=1)
415
return {
416
"h": h,
417
"adv_output": adv_output,
418
"embed": embed,
419
"proxy": proxy,
420
"cls_output": cls_output,
421
"label": label,
422
"mi_embed": mi_embed,
423
"mi_proxy": mi_proxy,
424
"mi_cls_output": mi_cls_output,
425
"info_discrete_c_logits": info_discrete_c_logits,
426
"info_conti_mu": info_conti_mu,
427
"info_conti_var": info_conti_var
428
}
429
430