Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/big_resnet_deep_studiogan.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# models/deep_big_resnet.py
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import utils.ops as ops
12
import utils.misc as misc
13
14
15
class GenBlock(nn.Module):
16
def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, upsample,
17
MODULES, channel_ratio=4):
18
super(GenBlock, self).__init__()
19
self.in_channels = in_channels
20
self.out_channels = out_channels
21
self.g_cond_mtd = g_cond_mtd
22
self.upsample = upsample
23
self.hidden_channels = self.in_channels // channel_ratio
24
25
self.bn1 = MODULES.g_bn(affine_input_dim, self.in_channels, MODULES)
26
self.bn2 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
27
self.bn3 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
28
self.bn4 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
29
30
self.activation = MODULES.g_act_fn
31
self.conv2d0 = MODULES.g_conv2d(in_channels=self.in_channels,
32
out_channels=self.out_channels,
33
kernel_size=1,
34
stride=1,
35
padding=0)
36
self.conv2d1 = MODULES.g_conv2d(in_channels=self.in_channels,
37
out_channels=self.hidden_channels,
38
kernel_size=1,
39
stride=1,
40
padding=0)
41
self.conv2d2 = MODULES.g_conv2d(in_channels=self.hidden_channels,
42
out_channels=self.hidden_channels,
43
kernel_size=3,
44
stride=1,
45
padding=1)
46
self.conv2d3 = MODULES.g_conv2d(in_channels=self.hidden_channels,
47
out_channels=self.hidden_channels,
48
kernel_size=3,
49
stride=1,
50
padding=1)
51
self.conv2d4 = MODULES.g_conv2d(in_channels=self.hidden_channels,
52
out_channels=self.out_channels,
53
kernel_size=1,
54
stride=1,
55
padding=0)
56
57
def forward(self, x, affine):
58
x0 = x
59
x = self.bn1(x, affine)
60
x = self.conv2d1(self.activation(x))
61
62
x = self.bn2(x, affine)
63
x = self.activation(x)
64
if self.upsample:
65
x = F.interpolate(x, scale_factor=2, mode="nearest") # upsample
66
x = self.conv2d2(x)
67
68
x = self.bn3(x, affine)
69
x = self.conv2d3(self.activation(x))
70
71
x = self.bn4(x, affine)
72
x = self.conv2d4(self.activation(x))
73
74
if self.upsample:
75
x0 = F.interpolate(x0, scale_factor=2, mode="nearest") # upsample
76
x0 = self.conv2d0(x0)
77
out = x + x0
78
return out
79
80
81
class Generator(nn.Module):
82
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
83
mixed_precision, MODULES, MODEL):
84
super(Generator, self).__init__()
85
g_in_dims_collection = {
86
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
87
"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
88
"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
89
"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
90
"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]
91
}
92
93
g_out_dims_collection = {
94
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
95
"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
96
"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
97
"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
98
"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]
99
}
100
101
bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}
102
103
self.z_dim = z_dim
104
self.g_shared_dim = g_shared_dim
105
self.g_cond_mtd = g_cond_mtd
106
self.num_classes = num_classes
107
self.mixed_precision = mixed_precision
108
self.MODEL = MODEL
109
self.in_dims = g_in_dims_collection[str(img_size)]
110
self.out_dims = g_out_dims_collection[str(img_size)]
111
self.bottom = bottom_collection[str(img_size)]
112
self.num_blocks = len(self.in_dims)
113
self.affine_input_dim = self.z_dim
114
115
info_dim = 0
116
if self.MODEL.info_type in ["discrete", "both"]:
117
info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
118
if self.MODEL.info_type in ["continuous", "both"]:
119
info_dim += self.MODEL.info_num_conti_c
120
121
if self.MODEL.info_type != "N/A":
122
if self.MODEL.g_info_injection == "concat":
123
self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
124
elif self.MODEL.g_info_injection == "cBN":
125
self.affine_input_dim += self.g_shared_dim
126
self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)
127
128
if self.g_cond_mtd != "W/O":
129
self.affine_input_dim += self.g_shared_dim
130
self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)
131
132
self.linear0 = MODULES.g_linear(in_features=self.affine_input_dim, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)
133
134
135
self.blocks = []
136
for index in range(self.num_blocks):
137
self.blocks += [[
138
GenBlock(in_channels=self.in_dims[index],
139
out_channels=self.in_dims[index] if g_index == 0 else self.out_dims[index],
140
g_cond_mtd=g_cond_mtd,
141
affine_input_dim=self.affine_input_dim,
142
upsample=True if g_index == (g_depth - 1) else False,
143
MODULES=MODULES)
144
] for g_index in range(g_depth)]
145
146
if index + 1 in attn_g_loc and apply_attn:
147
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]
148
149
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
150
151
self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
152
self.activation = MODULES.g_act_fn
153
self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
154
self.tanh = nn.Tanh()
155
156
ops.init_weights(self.modules, g_init)
157
158
def forward(self, z, label, shared_label=None, eval=False):
159
affine_list = []
160
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
161
if self.MODEL.info_type != "N/A":
162
if self.MODEL.g_info_injection == "concat":
163
z = self.info_mix_linear(z)
164
elif self.MODEL.g_info_injection == "cBN":
165
z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
166
affine_list.append(self.info_proj_linear(z_info))
167
168
if self.g_cond_mtd != "W/O":
169
if shared_label is None:
170
shared_label = self.shared(label)
171
affine_list.append(shared_label)
172
if len(affine_list) > 0:
173
z = torch.cat(affine_list + [z], 1)
174
175
affine = z
176
act = self.linear0(z)
177
act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)
178
for index, blocklist in enumerate(self.blocks):
179
for block in blocklist:
180
if isinstance(block, ops.SelfAttention):
181
act = block(act)
182
else:
183
act = block(act, affine)
184
185
act = self.bn4(act)
186
act = self.activation(act)
187
act = self.conv2d5(act)
188
out = self.tanh(act)
189
return out
190
191
192
class DiscBlock(nn.Module):
193
def __init__(self, in_channels, out_channels, MODULES, optblock, downsample=True, channel_ratio=4):
194
super(DiscBlock, self).__init__()
195
self.optblock = optblock
196
self.downsample = downsample
197
hidden_channels = out_channels // channel_ratio
198
self.ch_mismatch = True if (in_channels != out_channels) else False
199
if self.optblock: assert self.downsample and self.ch_mismatch, "downsample and ch_mismatch should be True."
200
201
self.activation = MODULES.d_act_fn
202
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels,
203
out_channels=hidden_channels,
204
kernel_size=1,
205
stride=1,
206
padding=0)
207
self.conv2d2 = MODULES.d_conv2d(in_channels=hidden_channels,
208
out_channels=hidden_channels,
209
kernel_size=3,
210
stride=1,
211
padding=1)
212
self.conv2d3 = MODULES.d_conv2d(in_channels=hidden_channels,
213
out_channels=hidden_channels,
214
kernel_size=3,
215
stride=1,
216
padding=1)
217
self.conv2d4 = MODULES.d_conv2d(in_channels=hidden_channels,
218
out_channels=out_channels,
219
kernel_size=1,
220
stride=1,
221
padding=0)
222
223
if self.ch_mismatch or self.downsample:
224
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels,
225
out_channels=out_channels,
226
kernel_size=1,
227
stride=1,
228
padding=0)
229
230
if self.downsample:
231
self.average_pooling = nn.AvgPool2d(2)
232
233
def forward(self, x):
234
x0 = x
235
x = self.conv2d1(self.activation(x))
236
x = self.conv2d2(self.activation(x))
237
x = self.conv2d3(self.activation(x))
238
if self.downsample:
239
x = self.average_pooling(x)
240
x = self.conv2d4(self.activation(x))
241
242
if self.optblock:
243
x0 = self.average_pooling(x0)
244
x0 = self.conv2d0(x0)
245
else:
246
if self.downsample or self.ch_mismatch:
247
x0 = self.conv2d0(x0)
248
if self.downsample:
249
x0 = self.average_pooling(x0)
250
out = x + x0
251
return out
252
253
254
class Discriminator(nn.Module):
255
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,
256
num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
257
super(Discriminator, self).__init__()
258
d_in_dims_collection = {
259
"32": [d_conv_dim, d_conv_dim * 4, d_conv_dim * 4],
260
"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
261
"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
262
"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],
263
"512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]
264
}
265
266
d_out_dims_collection = {
267
"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],
268
"64": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
269
"128": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
270
"256": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
271
"512":
272
[d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]
273
}
274
275
d_down = {
276
"32": [True, True, False, False],
277
"64": [True, True, True, True, False],
278
"128": [True, True, True, True, True, False],
279
"256": [True, True, True, True, True, True, False],
280
"512": [True, True, True, True, True, True, True, False]
281
}
282
283
self.d_cond_mtd = d_cond_mtd
284
self.aux_cls_type = aux_cls_type
285
self.normalize_d_embed = normalize_d_embed
286
self.num_classes = num_classes
287
self.mixed_precision = mixed_precision
288
self.in_dims = d_in_dims_collection[str(img_size)]
289
self.out_dims = d_out_dims_collection[str(img_size)]
290
self.MODEL = MODEL
291
down = d_down[str(img_size)]
292
293
self.input_conv = MODULES.d_conv2d(in_channels=3, out_channels=self.in_dims[0], kernel_size=3, stride=1, padding=1)
294
295
self.blocks = []
296
for index in range(len(self.in_dims)):
297
self.blocks += [[
298
DiscBlock(in_channels=self.in_dims[index] if d_index == 0 else self.out_dims[index],
299
out_channels=self.out_dims[index],
300
MODULES=MODULES,
301
optblock=index == 0 and d_index == 0,
302
downsample=True if down[index] and d_index == 0 else False)
303
] for d_index in range(d_depth)]
304
305
if (index+1) in attn_d_loc and apply_attn:
306
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]
307
308
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
309
310
self.activation = MODULES.d_act_fn
311
312
# linear layer for adversarial training
313
if self.d_cond_mtd == "MH":
314
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)
315
elif self.d_cond_mtd == "MD":
316
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)
317
else:
318
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)
319
320
# double num_classes for Auxiliary Discriminative Classifier
321
if self.aux_cls_type == "ADC":
322
num_classes = num_classes * 2
323
324
# linear and embedding layers for discriminator conditioning
325
if self.d_cond_mtd == "AC":
326
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
327
elif self.d_cond_mtd == "PD":
328
self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
329
elif self.d_cond_mtd in ["2C", "D2DCE"]:
330
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
331
self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
332
else:
333
pass
334
335
# linear and embedding layers for evolved classifier-based GAN
336
if self.aux_cls_type == "TAC":
337
if self.d_cond_mtd == "AC":
338
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
339
elif self.d_cond_mtd in ["2C", "D2DCE"]:
340
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
341
self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
342
else:
343
raise NotImplementedError
344
345
# Q head network for infoGAN
346
if self.MODEL.info_type in ["discrete", "both"]:
347
out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
348
self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
349
if self.MODEL.info_type in ["continuous", "both"]:
350
out_features = self.MODEL.info_num_conti_c
351
self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
352
self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
353
354
if d_init:
355
ops.init_weights(self.modules, d_init)
356
357
def forward(self, x, label, eval=False, adc_fake=False):
358
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
359
embed, proxy, cls_output = None, None, None
360
mi_embed, mi_proxy, mi_cls_output = None, None, None
361
info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
362
h = self.input_conv(x)
363
for index, blocklist in enumerate(self.blocks):
364
for block in blocklist:
365
h = block(h)
366
bottom_h, bottom_w = h.shape[2], h.shape[3]
367
h = self.activation(h)
368
h = torch.sum(h, dim=[2, 3])
369
370
# adversarial training
371
adv_output = torch.squeeze(self.linear1(h))
372
373
# make class labels odd (for fake) or even (for real) for ADC
374
if self.aux_cls_type == "ADC":
375
if adc_fake:
376
label = label*2 + 1
377
else:
378
label = label*2
379
380
# forward pass through InfoGAN Q head
381
if self.MODEL.info_type in ["discrete", "both"]:
382
info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
383
if self.MODEL.info_type in ["continuous", "both"]:
384
info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
385
info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))
386
387
# class conditioning
388
if self.d_cond_mtd == "AC":
389
if self.normalize_d_embed:
390
for W in self.linear2.parameters():
391
W = F.normalize(W, dim=1)
392
h = F.normalize(h, dim=1)
393
cls_output = self.linear2(h)
394
elif self.d_cond_mtd == "PD":
395
adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
396
elif self.d_cond_mtd in ["2C", "D2DCE"]:
397
embed = self.linear2(h)
398
proxy = self.embedding(label)
399
if self.normalize_d_embed:
400
embed = F.normalize(embed, dim=1)
401
proxy = F.normalize(proxy, dim=1)
402
elif self.d_cond_mtd == "MD":
403
idx = torch.LongTensor(range(label.size(0))).to(label.device)
404
adv_output = adv_output[idx, label]
405
elif self.d_cond_mtd in ["W/O", "MH"]:
406
pass
407
else:
408
raise NotImplementedError
409
410
# extra conditioning for TACGAN and ADCGAN
411
if self.aux_cls_type == "TAC":
412
if self.d_cond_mtd == "AC":
413
if self.normalize_d_embed:
414
for W in self.linear_mi.parameters():
415
W = F.normalize(W, dim=1)
416
mi_cls_output = self.linear_mi(h)
417
elif self.d_cond_mtd in ["2C", "D2DCE"]:
418
mi_embed = self.linear_mi(h)
419
mi_proxy = self.embedding_mi(label)
420
if self.normalize_d_embed:
421
mi_embed = F.normalize(mi_embed, dim=1)
422
mi_proxy = F.normalize(mi_proxy, dim=1)
423
return {
424
"h": h,
425
"adv_output": adv_output,
426
"embed": embed,
427
"proxy": proxy,
428
"cls_output": cls_output,
429
"label": label,
430
"mi_embed": mi_embed,
431
"mi_proxy": mi_proxy,
432
"mi_cls_output": mi_cls_output,
433
"info_discrete_c_logits": info_discrete_c_logits,
434
"info_conti_mu": info_conti_mu,
435
"info_conti_var": info_conti_var
436
}
437
438