Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/resnet.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# models/resnet.py
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import utils.ops as ops
12
import utils.misc as misc
13
14
15
class GenBlock(nn.Module):
16
def __init__(self, in_channels, out_channels, g_cond_mtd, g_info_injection, affine_input_dim, MODULES):
17
super(GenBlock, self).__init__()
18
self.g_cond_mtd = g_cond_mtd
19
self.g_info_injection = g_info_injection
20
21
if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:
22
self.bn1 = MODULES.g_bn(in_features=in_channels)
23
self.bn2 = MODULES.g_bn(in_features=out_channels)
24
elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":
25
self.bn1 = MODULES.g_bn(affine_input_dim, in_channels, MODULES)
26
self.bn2 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)
27
else:
28
raise NotImplementedError
29
30
self.activation = MODULES.g_act_fn
31
self.conv2d0 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
32
self.conv2d1 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
33
self.conv2d2 = MODULES.g_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
34
35
def forward(self, x, affine):
36
x0 = x
37
if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:
38
x = self.bn1(x)
39
elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":
40
x = self.bn1(x, affine)
41
else:
42
raise NotImplementedError
43
x = self.activation(x)
44
x = F.interpolate(x, scale_factor=2, mode="nearest")
45
x = self.conv2d1(x)
46
47
if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:
48
x = self.bn2(x)
49
elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":
50
x = self.bn2(x, affine)
51
else:
52
raise NotImplementedError
53
x = self.activation(x)
54
x = self.conv2d2(x)
55
56
x0 = F.interpolate(x0, scale_factor=2, mode="nearest")
57
x0 = self.conv2d0(x0)
58
out = x + x0
59
return out
60
61
62
class Generator(nn.Module):
63
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
64
mixed_precision, MODULES, MODEL):
65
super(Generator, self).__init__()
66
g_in_dims_collection = {
67
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
68
"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
69
"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
70
"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
71
"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]
72
}
73
74
g_out_dims_collection = {
75
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
76
"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
77
"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
78
"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
79
"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]
80
}
81
82
bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}
83
84
self.z_dim = z_dim
85
self.num_classes = num_classes
86
self.g_cond_mtd = g_cond_mtd
87
self.mixed_precision = mixed_precision
88
self.MODEL = MODEL
89
self.in_dims = g_in_dims_collection[str(img_size)]
90
self.out_dims = g_out_dims_collection[str(img_size)]
91
self.bottom = bottom_collection[str(img_size)]
92
self.num_blocks = len(self.in_dims)
93
self.affine_input_dim = 0
94
95
info_dim = 0
96
if self.MODEL.info_type in ["discrete", "both"]:
97
info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
98
if self.MODEL.info_type in ["continuous", "both"]:
99
info_dim += self.MODEL.info_num_conti_c
100
101
self.g_info_injection = self.MODEL.g_info_injection
102
if self.MODEL.info_type != "N/A":
103
if self.g_info_injection == "concat":
104
self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
105
elif self.g_info_injection == "cBN":
106
self.affine_input_dim += self.z_dim
107
self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.z_dim, bias=True)
108
109
self.linear0 = MODULES.g_linear(in_features=self.z_dim, out_features=self.in_dims[0] * self.bottom * self.bottom, bias=True)
110
111
if self.g_cond_mtd != "W/O" and self.g_cond_mtd == "cBN":
112
self.affine_input_dim += self.num_classes
113
114
self.blocks = []
115
for index in range(self.num_blocks):
116
self.blocks += [[
117
GenBlock(in_channels=self.in_dims[index],
118
out_channels=self.out_dims[index],
119
g_cond_mtd=self.g_cond_mtd,
120
g_info_injection=self.g_info_injection,
121
affine_input_dim=self.affine_input_dim,
122
MODULES=MODULES)
123
]]
124
125
if index + 1 in attn_g_loc and apply_attn:
126
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]
127
128
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
129
130
self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
131
self.activation = MODULES.g_act_fn
132
self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
133
self.tanh = nn.Tanh()
134
135
ops.init_weights(self.modules, g_init)
136
137
def forward(self, z, label, shared_label=None, eval=False):
138
affine_list = []
139
if self.g_cond_mtd != "W/O":
140
label = F.one_hot(label, num_classes=self.num_classes).to(torch.float32)
141
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
142
if self.MODEL.info_type != "N/A":
143
if self.g_info_injection == "concat":
144
z = self.info_mix_linear(z)
145
elif self.g_info_injection == "cBN":
146
z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
147
affine_list.append(self.info_proj_linear(z_info))
148
149
if self.g_cond_mtd != "W/O":
150
affine_list.append(label)
151
if len(affine_list) > 0:
152
affines = torch.cat(affine_list, 1)
153
else:
154
affines = None
155
156
act = self.linear0(z)
157
act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)
158
for index, blocklist in enumerate(self.blocks):
159
for block in blocklist:
160
if isinstance(block, ops.SelfAttention):
161
act = block(act)
162
else:
163
act = block(act, affines)
164
165
act = self.bn4(act)
166
act = self.activation(act)
167
act = self.conv2d5(act)
168
out = self.tanh(act)
169
return out
170
171
172
class DiscOptBlock(nn.Module):
173
def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):
174
super(DiscOptBlock, self).__init__()
175
self.apply_d_sn = apply_d_sn
176
177
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
178
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
179
self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
180
181
if not apply_d_sn:
182
self.bn0 = MODULES.d_bn(in_features=in_channels)
183
self.bn1 = MODULES.d_bn(in_features=out_channels)
184
185
self.activation = MODULES.d_act_fn
186
187
self.average_pooling = nn.AvgPool2d(2)
188
189
def forward(self, x):
190
x0 = x
191
x = self.conv2d1(x)
192
if not self.apply_d_sn:
193
x = self.bn1(x)
194
x = self.activation(x)
195
196
x = self.conv2d2(x)
197
x = self.average_pooling(x)
198
199
x0 = self.average_pooling(x0)
200
if not self.apply_d_sn:
201
x0 = self.bn0(x0)
202
x0 = self.conv2d0(x0)
203
out = x + x0
204
return out
205
206
207
class DiscBlock(nn.Module):
208
def __init__(self, in_channels, out_channels, apply_d_sn, MODULES, downsample=True):
209
super(DiscBlock, self).__init__()
210
self.apply_d_sn = apply_d_sn
211
self.downsample = downsample
212
213
self.activation = MODULES.d_act_fn
214
215
self.ch_mismatch = False
216
if in_channels != out_channels:
217
self.ch_mismatch = True
218
219
if self.ch_mismatch or downsample:
220
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
221
if not apply_d_sn:
222
self.bn0 = MODULES.d_bn(in_features=in_channels)
223
224
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
225
self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
226
227
if not apply_d_sn:
228
self.bn1 = MODULES.d_bn(in_features=in_channels)
229
self.bn2 = MODULES.d_bn(in_features=out_channels)
230
231
self.average_pooling = nn.AvgPool2d(2)
232
233
def forward(self, x):
234
x0 = x
235
if not self.apply_d_sn:
236
x = self.bn1(x)
237
x = self.activation(x)
238
x = self.conv2d1(x)
239
240
if not self.apply_d_sn:
241
x = self.bn2(x)
242
x = self.activation(x)
243
x = self.conv2d2(x)
244
if self.downsample:
245
x = self.average_pooling(x)
246
247
if self.downsample or self.ch_mismatch:
248
if not self.apply_d_sn:
249
x0 = self.bn0(x0)
250
x0 = self.conv2d0(x0)
251
if self.downsample:
252
x0 = self.average_pooling(x0)
253
out = x + x0
254
return out
255
256
257
class Discriminator(nn.Module):
258
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,
259
num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
260
super(Discriminator, self).__init__()
261
d_in_dims_collection = {
262
"32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
263
"64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
264
"128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
265
"256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],
266
"512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]
267
}
268
269
d_out_dims_collection = {
270
"32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
271
"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
272
"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
273
"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
274
"512":
275
[d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]
276
}
277
278
d_down = {
279
"32": [True, True, False, False],
280
"64": [True, True, True, True, False],
281
"128": [True, True, True, True, True, False],
282
"256": [True, True, True, True, True, True, False],
283
"512": [True, True, True, True, True, True, True, False]
284
}
285
286
self.d_cond_mtd = d_cond_mtd
287
self.aux_cls_type = aux_cls_type
288
self.normalize_d_embed = normalize_d_embed
289
self.num_classes = num_classes
290
self.mixed_precision = mixed_precision
291
self.in_dims = d_in_dims_collection[str(img_size)]
292
self.out_dims = d_out_dims_collection[str(img_size)]
293
self.MODEL = MODEL
294
down = d_down[str(img_size)]
295
296
self.blocks = []
297
for index in range(len(self.in_dims)):
298
if index == 0:
299
self.blocks += [[
300
DiscOptBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES)
301
]]
302
else:
303
self.blocks += [[
304
DiscBlock(in_channels=self.in_dims[index],
305
out_channels=self.out_dims[index],
306
apply_d_sn=apply_d_sn,
307
MODULES=MODULES,
308
downsample=down[index])
309
]]
310
311
if index + 1 in attn_d_loc and apply_attn:
312
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]
313
314
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
315
316
self.activation = MODULES.d_act_fn
317
318
# linear layer for adversarial training
319
if self.d_cond_mtd == "MH":
320
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)
321
elif self.d_cond_mtd == "MD":
322
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)
323
else:
324
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)
325
326
# double num_classes for Auxiliary Discriminative Classifier
327
if self.aux_cls_type == "ADC":
328
num_classes = num_classes * 2
329
330
# linear and embedding layers for discriminator conditioning
331
if self.d_cond_mtd == "AC":
332
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
333
elif self.d_cond_mtd == "PD":
334
self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
335
elif self.d_cond_mtd in ["2C", "D2DCE"]:
336
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
337
self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
338
else:
339
pass
340
341
# linear and embedding layers for evolved classifier-based GAN
342
if self.aux_cls_type == "TAC":
343
if self.d_cond_mtd == "AC":
344
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
345
elif self.d_cond_mtd in ["2C", "D2DCE"]:
346
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
347
self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
348
else:
349
raise NotImplementedError
350
351
# Q head network for infoGAN
352
if self.MODEL.info_type in ["discrete", "both"]:
353
out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
354
self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
355
if self.MODEL.info_type in ["continuous", "both"]:
356
out_features = self.MODEL.info_num_conti_c
357
self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
358
self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
359
360
if d_init:
361
ops.init_weights(self.modules, d_init)
362
363
def forward(self, x, label, eval=False, adc_fake=False):
364
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
365
embed, proxy, cls_output = None, None, None
366
mi_embed, mi_proxy, mi_cls_output = None, None, None
367
info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
368
h = x
369
for index, blocklist in enumerate(self.blocks):
370
for block in blocklist:
371
h = block(h)
372
bottom_h, bottom_w = h.shape[2], h.shape[3]
373
h = self.activation(h)
374
h = torch.sum(h, dim=[2, 3])
375
376
# adversarial training
377
adv_output = torch.squeeze(self.linear1(h))
378
379
# make class labels odd (for fake) or even (for real) for ADC
380
if self.aux_cls_type == "ADC":
381
if adc_fake:
382
label = label*2 + 1
383
else:
384
label = label*2
385
386
# forward pass through InfoGAN Q head
387
if self.MODEL.info_type in ["discrete", "both"]:
388
info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
389
if self.MODEL.info_type in ["continuous", "both"]:
390
info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
391
info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))
392
393
# class conditioning
394
if self.d_cond_mtd == "AC":
395
if self.normalize_d_embed:
396
for W in self.linear2.parameters():
397
W = F.normalize(W, dim=1)
398
h = F.normalize(h, dim=1)
399
cls_output = self.linear2(h)
400
elif self.d_cond_mtd == "PD":
401
adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
402
elif self.d_cond_mtd in ["2C", "D2DCE"]:
403
embed = self.linear2(h)
404
proxy = self.embedding(label)
405
if self.normalize_d_embed:
406
embed = F.normalize(embed, dim=1)
407
proxy = F.normalize(proxy, dim=1)
408
elif self.d_cond_mtd == "MD":
409
idx = torch.LongTensor(range(label.size(0))).to(label.device)
410
adv_output = adv_output[idx, label]
411
elif self.d_cond_mtd in ["W/O", "MH"]:
412
pass
413
else:
414
raise NotImplementedError
415
416
# extra conditioning for TACGAN and ADCGAN
417
if self.aux_cls_type == "TAC":
418
if self.d_cond_mtd == "AC":
419
if self.normalize_d_embed:
420
for W in self.linear_mi.parameters():
421
W = F.normalize(W, dim=1)
422
mi_cls_output = self.linear_mi(h)
423
elif self.d_cond_mtd in ["2C", "D2DCE"]:
424
mi_embed = self.linear_mi(h)
425
mi_proxy = self.embedding_mi(label)
426
if self.normalize_d_embed:
427
mi_embed = F.normalize(mi_embed, dim=1)
428
mi_proxy = F.normalize(mi_proxy, dim=1)
429
return {
430
"h": h,
431
"adv_output": adv_output,
432
"embed": embed,
433
"proxy": proxy,
434
"cls_output": cls_output,
435
"label": label,
436
"mi_embed": mi_embed,
437
"mi_proxy": mi_proxy,
438
"mi_cls_output": mi_cls_output,
439
"info_discrete_c_logits": info_discrete_c_logits,
440
"info_conti_mu": info_conti_mu,
441
"info_conti_var": info_conti_var
442
}
443
444