Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/big_resnet_deep_legacy.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# models/deep_big_resnet.py
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import utils.ops as ops
12
import utils.misc as misc
13
14
15
class GenBlock(nn.Module):
16
def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, upsample,
17
MODULES, channel_ratio=4):
18
super(GenBlock, self).__init__()
19
self.in_channels = in_channels
20
self.out_channels = out_channels
21
self.g_cond_mtd = g_cond_mtd
22
self.upsample = upsample
23
self.hidden_channels = self.in_channels // channel_ratio
24
25
self.bn1 = MODULES.g_bn(affine_input_dim, self.in_channels, MODULES)
26
self.bn2 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
27
self.bn3 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
28
self.bn4 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)
29
30
self.activation = MODULES.g_act_fn
31
32
self.conv2d1 = MODULES.g_conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=1, stride=1, padding=0)
33
self.conv2d2 = MODULES.g_conv2d(in_channels=self.hidden_channels,
34
out_channels=self.hidden_channels,
35
kernel_size=3,
36
stride=1,
37
padding=1)
38
self.conv2d3 = MODULES.g_conv2d(in_channels=self.hidden_channels,
39
out_channels=self.hidden_channels,
40
kernel_size=3,
41
stride=1,
42
padding=1)
43
self.conv2d4 = MODULES.g_conv2d(in_channels=self.hidden_channels,
44
out_channels=self.out_channels,
45
kernel_size=1,
46
stride=1,
47
padding=0)
48
49
def forward(self, x, affine):
50
if self.in_channels != self.out_channels:
51
x0 = x[:, :self.out_channels]
52
else:
53
x0 = x
54
55
x = self.bn1(x, affine)
56
x = self.conv2d1(self.activation(x))
57
58
x = self.bn2(x, affine)
59
x = self.activation(x)
60
if self.upsample:
61
x = F.interpolate(x, scale_factor=2, mode="nearest") # upsample
62
x = self.conv2d2(x)
63
64
x = self.bn3(x, affine)
65
x = self.conv2d3(self.activation(x))
66
67
x = self.bn4(x, affine)
68
x = self.conv2d4(self.activation(x))
69
70
if self.upsample:
71
x0 = F.interpolate(x0, scale_factor=2, mode="nearest") # upsample
72
out = x + x0
73
return out
74
75
76
class Generator(nn.Module):
77
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
78
mixed_precision, MODULES, MODEL):
79
super(Generator, self).__init__()
80
g_in_dims_collection = {
81
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
82
"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
83
"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
84
"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
85
"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]
86
}
87
88
g_out_dims_collection = {
89
"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
90
"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
91
"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
92
"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
93
"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]
94
}
95
96
bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}
97
98
self.z_dim = z_dim
99
self.g_shared_dim = g_shared_dim
100
self.g_cond_mtd = g_cond_mtd
101
self.num_classes = num_classes
102
self.mixed_precision = mixed_precision
103
self.MODEL = MODEL
104
self.in_dims = g_in_dims_collection[str(img_size)]
105
self.out_dims = g_out_dims_collection[str(img_size)]
106
self.bottom = bottom_collection[str(img_size)]
107
self.num_blocks = len(self.in_dims)
108
self.affine_input_dim = self.z_dim
109
110
info_dim = 0
111
if self.MODEL.info_type in ["discrete", "both"]:
112
info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
113
if self.MODEL.info_type in ["continuous", "both"]:
114
info_dim += self.MODEL.info_num_conti_c
115
116
if self.MODEL.info_type != "N/A":
117
if self.MODEL.g_info_injection == "concat":
118
self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
119
elif self.MODEL.g_info_injection == "cBN":
120
self.affine_input_dim += self.g_shared_dim
121
self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)
122
123
if self.g_cond_mtd != "W/O":
124
self.affine_input_dim += self.g_shared_dim
125
self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)
126
127
self.linear0 = MODULES.g_linear(in_features=self.affine_input_dim, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)
128
129
130
self.blocks = []
131
for index in range(self.num_blocks):
132
self.blocks += [[
133
GenBlock(in_channels=self.in_dims[index],
134
out_channels=self.in_dims[index] if g_index == 0 else self.out_dims[index],
135
g_cond_mtd=g_cond_mtd,
136
affine_input_dim=self.affine_input_dim,
137
upsample=True if g_index == (g_depth - 1) else False,
138
MODULES=MODULES)
139
] for g_index in range(g_depth)]
140
141
if index + 1 in attn_g_loc and apply_attn:
142
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]
143
144
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
145
146
self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
147
self.activation = MODULES.g_act_fn
148
self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
149
self.tanh = nn.Tanh()
150
151
ops.init_weights(self.modules, g_init)
152
153
def forward(self, z, label, shared_label=None, eval=False):
154
affine_list = []
155
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
156
if self.MODEL.info_type != "N/A":
157
if self.MODEL.g_info_injection == "concat":
158
z = self.info_mix_linear(z)
159
elif self.MODEL.g_info_injection == "cBN":
160
z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
161
affine_list.append(self.info_proj_linear(z_info))
162
163
if self.g_cond_mtd != "W/O":
164
if shared_label is None:
165
shared_label = self.shared(label)
166
affine_list.append(shared_label)
167
if len(affine_list) > 0:
168
z = torch.cat(affine_list + [z], 1)
169
170
affine = z
171
act = self.linear0(z)
172
act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)
173
for index, blocklist in enumerate(self.blocks):
174
for block in blocklist:
175
if isinstance(block, ops.SelfAttention):
176
act = block(act)
177
else:
178
act = block(act, affine)
179
180
act = self.bn4(act)
181
act = self.activation(act)
182
act = self.conv2d5(act)
183
out = self.tanh(act)
184
return out
185
186
187
class DiscBlock(nn.Module):
188
def __init__(self, in_channels, out_channels, MODULES, downsample=True, channel_ratio=4):
189
super(DiscBlock, self).__init__()
190
self.downsample = downsample
191
hidden_channels = out_channels // channel_ratio
192
193
self.activation = MODULES.d_act_fn
194
self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0)
195
self.conv2d2 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1)
196
self.conv2d3 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1)
197
self.conv2d4 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
198
199
self.learnable_sc = True if (in_channels != out_channels) else False
200
if self.learnable_sc:
201
self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels,
202
out_channels=out_channels - in_channels,
203
kernel_size=1,
204
stride=1,
205
padding=0)
206
207
if self.downsample:
208
self.average_pooling = nn.AvgPool2d(2)
209
210
def forward(self, x):
211
x0 = x
212
213
x = self.conv2d1(self.activation(x))
214
x = self.conv2d2(self.activation(x))
215
x = self.conv2d3(self.activation(x))
216
x = self.activation(x)
217
218
if self.downsample:
219
x = self.average_pooling(x)
220
221
x = self.conv2d4(x)
222
223
if self.downsample:
224
x0 = self.average_pooling(x0)
225
if self.learnable_sc:
226
x0 = torch.cat([x0, self.conv2d0(x0)], 1)
227
228
out = x + x0
229
return out
230
231
232
class Discriminator(nn.Module):
233
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,
234
num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
235
super(Discriminator, self).__init__()
236
d_in_dims_collection = {
237
"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],
238
"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
239
"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
240
"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],
241
"512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]
242
}
243
244
d_out_dims_collection = {
245
"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],
246
"64": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
247
"128": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
248
"256": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
249
"512":
250
[d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]
251
}
252
253
d_down = {
254
"32": [True, True, False, False],
255
"64": [True, True, True, True, False],
256
"128": [True, True, True, True, True, False],
257
"256": [True, True, True, True, True, True, False],
258
"512": [True, True, True, True, True, True, True, False]
259
}
260
261
self.d_cond_mtd = d_cond_mtd
262
self.aux_cls_type = aux_cls_type
263
self.normalize_d_embed = normalize_d_embed
264
self.num_classes = num_classes
265
self.mixed_precision = mixed_precision
266
self.in_dims = d_in_dims_collection[str(img_size)]
267
self.out_dims = d_out_dims_collection[str(img_size)]
268
self.MODEL = MODEL
269
down = d_down[str(img_size)]
270
271
self.input_conv = MODULES.d_conv2d(in_channels=3, out_channels=self.in_dims[0], kernel_size=3, stride=1, padding=1)
272
273
self.blocks = []
274
for index in range(len(self.in_dims)):
275
self.blocks += [[
276
DiscBlock(in_channels=self.in_dims[index] if d_index == 0 else self.out_dims[index],
277
out_channels=self.out_dims[index],
278
MODULES=MODULES,
279
downsample=True if down[index] and d_index == 0 else False)
280
] for d_index in range(d_depth)]
281
282
if (index+1) in attn_d_loc and apply_attn:
283
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]
284
285
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
286
287
self.activation = MODULES.d_act_fn
288
289
# linear layer for adversarial training
290
if self.d_cond_mtd == "MH":
291
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)
292
elif self.d_cond_mtd == "MD":
293
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)
294
else:
295
self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)
296
297
# double num_classes for Auxiliary Discriminative Classifier
298
if self.aux_cls_type == "ADC":
299
num_classes = num_classes * 2
300
301
# linear and embedding layers for discriminator conditioning
302
if self.d_cond_mtd == "AC":
303
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
304
elif self.d_cond_mtd == "PD":
305
self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
306
elif self.d_cond_mtd in ["2C", "D2DCE"]:
307
self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
308
self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
309
else:
310
pass
311
312
# linear and embedding layers for evolved classifier-based GAN
313
if self.aux_cls_type == "TAC":
314
if self.d_cond_mtd == "AC":
315
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
316
elif self.d_cond_mtd in ["2C", "D2DCE"]:
317
self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
318
self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
319
else:
320
raise NotImplementedError
321
322
# Q head network for infoGAN
323
if self.MODEL.info_type in ["discrete", "both"]:
324
out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
325
self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
326
if self.MODEL.info_type in ["continuous", "both"]:
327
out_features = self.MODEL.info_num_conti_c
328
self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
329
self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
330
331
if d_init:
332
ops.init_weights(self.modules, d_init)
333
334
def forward(self, x, label, eval=False, adc_fake=False):
335
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
336
embed, proxy, cls_output = None, None, None
337
mi_embed, mi_proxy, mi_cls_output = None, None, None
338
info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
339
h = self.input_conv(x)
340
for index, blocklist in enumerate(self.blocks):
341
for block in blocklist:
342
h = block(h)
343
bottom_h, bottom_w = h.shape[2], h.shape[3]
344
h = self.activation(h)
345
h = torch.sum(h, dim=[2, 3])
346
347
# adversarial training
348
adv_output = torch.squeeze(self.linear1(h))
349
350
# make class labels odd (for fake) or even (for real) for ADC
351
if self.aux_cls_type == "ADC":
352
if adc_fake:
353
label = label*2 + 1
354
else:
355
label = label*2
356
357
# forward pass through InfoGAN Q head
358
if self.MODEL.info_type in ["discrete", "both"]:
359
info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
360
if self.MODEL.info_type in ["continuous", "both"]:
361
info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
362
info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))
363
364
# class conditioning
365
if self.d_cond_mtd == "AC":
366
if self.normalize_d_embed:
367
for W in self.linear2.parameters():
368
W = F.normalize(W, dim=1)
369
h = F.normalize(h, dim=1)
370
cls_output = self.linear2(h)
371
elif self.d_cond_mtd == "PD":
372
adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
373
elif self.d_cond_mtd in ["2C", "D2DCE"]:
374
embed = self.linear2(h)
375
proxy = self.embedding(label)
376
if self.normalize_d_embed:
377
embed = F.normalize(embed, dim=1)
378
proxy = F.normalize(proxy, dim=1)
379
elif self.d_cond_mtd == "MD":
380
idx = torch.LongTensor(range(label.size(0))).to(label.device)
381
adv_output = adv_output[idx, label]
382
elif self.d_cond_mtd in ["W/O", "MH"]:
383
pass
384
else:
385
raise NotImplementedError
386
387
# extra conditioning for TACGAN and ADCGAN
388
if self.aux_cls_type == "TAC":
389
if self.d_cond_mtd == "AC":
390
if self.normalize_d_embed:
391
for W in self.linear_mi.parameters():
392
W = F.normalize(W, dim=1)
393
mi_cls_output = self.linear_mi(h)
394
elif self.d_cond_mtd in ["2C", "D2DCE"]:
395
mi_embed = self.linear_mi(h)
396
mi_proxy = self.embedding_mi(label)
397
if self.normalize_d_embed:
398
mi_embed = F.normalize(mi_embed, dim=1)
399
mi_proxy = F.normalize(mi_proxy, dim=1)
400
return {
401
"h": h,
402
"adv_output": adv_output,
403
"embed": embed,
404
"proxy": proxy,
405
"cls_output": cls_output,
406
"label": label,
407
"mi_embed": mi_embed,
408
"mi_proxy": mi_proxy,
409
"mi_cls_output": mi_cls_output,
410
"info_discrete_c_logits": info_discrete_c_logits,
411
"info_conti_mu": info_conti_mu,
412
"info_conti_var": info_conti_var
413
}
414
415