Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/models/deep_conv.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# models/deep_conv.py
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
import utils.ops as ops
12
import utils.misc as misc
13
14
15
class GenBlock(nn.Module):
16
def __init__(self, in_channels, out_channels, g_cond_mtd, g_info_injection, affine_input_dim, MODULES):
17
super(GenBlock, self).__init__()
18
self.g_cond_mtd = g_cond_mtd
19
self.g_info_injection = g_info_injection
20
21
self.deconv0 = MODULES.g_deconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
22
23
if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:
24
self.bn0 = MODULES.g_bn(in_features=out_channels)
25
elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":
26
self.bn0 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)
27
else:
28
raise NotImplementedError
29
30
self.activation = MODULES.g_act_fn
31
32
def forward(self, x, affine):
33
x = self.deconv0(x)
34
if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:
35
x = self.bn0(x)
36
elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":
37
x = self.bn0(x, affine)
38
out = self.activation(x)
39
return out
40
41
42
class Generator(nn.Module):
43
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
44
mixed_precision, MODULES, MODEL):
45
super(Generator, self).__init__()
46
self.in_dims = [512, 256, 128]
47
self.out_dims = [256, 128, 64]
48
49
self.z_dim = z_dim
50
self.num_classes = num_classes
51
self.g_cond_mtd = g_cond_mtd
52
self.mixed_precision = mixed_precision
53
self.MODEL = MODEL
54
self.affine_input_dim = 0
55
56
info_dim = 0
57
if self.MODEL.info_type in ["discrete", "both"]:
58
info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
59
if self.MODEL.info_type in ["continuous", "both"]:
60
info_dim += self.MODEL.info_num_conti_c
61
62
self.g_info_injection = self.MODEL.g_info_injection
63
if self.MODEL.info_type != "N/A":
64
if self.g_info_injection == "concat":
65
self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
66
elif self.g_info_injection == "cBN":
67
self.affine_input_dim += self.z_dim
68
self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.z_dim, bias=True)
69
70
if self.g_cond_mtd != "W/O" and self.g_cond_mtd == "cBN":
71
self.affine_input_dim += self.num_classes
72
73
self.linear0 = MODULES.g_linear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4, bias=True)
74
75
self.blocks = []
76
for index in range(len(self.in_dims)):
77
self.blocks += [[
78
GenBlock(in_channels=self.in_dims[index],
79
out_channels=self.out_dims[index],
80
g_cond_mtd=self.g_cond_mtd,
81
g_info_injection=self.g_info_injection,
82
affine_input_dim=self.affine_input_dim,
83
MODULES=MODULES)
84
]]
85
86
if index + 1 in attn_g_loc and apply_attn:
87
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]
88
89
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
90
91
self.conv4 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
92
self.tanh = nn.Tanh()
93
94
ops.init_weights(self.modules, g_init)
95
96
def forward(self, z, label, shared_label=None, eval=False):
97
affine_list = []
98
if self.g_cond_mtd != "W/O":
99
label = F.one_hot(label, num_classes=self.num_classes).to(torch.float32)
100
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
101
if self.MODEL.info_type != "N/A":
102
if self.g_info_injection == "concat":
103
z = self.info_mix_linear(z)
104
elif self.g_info_injection == "cBN":
105
z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
106
affine_list.append(self.info_proj_linear(z_info))
107
108
if self.g_cond_mtd != "W/O":
109
affine_list.append(label)
110
if len(affine_list) > 0:
111
affines = torch.cat(affine_list, 1)
112
else:
113
affines = None
114
115
act = self.linear0(z)
116
act = act.view(-1, self.in_dims[0], 4, 4)
117
for index, blocklist in enumerate(self.blocks):
118
for block in blocklist:
119
if isinstance(block, ops.SelfAttention):
120
act = block(act)
121
else:
122
act = block(act, affines)
123
124
act = self.conv4(act)
125
out = self.tanh(act)
126
return out
127
128
129
class DiscBlock(nn.Module):
130
def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):
131
super(DiscBlock, self).__init__()
132
self.apply_d_sn = apply_d_sn
133
134
self.conv0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
135
self.conv1 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)
136
137
if not apply_d_sn:
138
self.bn0 = MODULES.d_bn(in_features=out_channels)
139
self.bn1 = MODULES.d_bn(in_features=out_channels)
140
141
self.activation = MODULES.d_act_fn
142
143
def forward(self, x):
144
x = self.conv0(x)
145
if not self.apply_d_sn:
146
x = self.bn0(x)
147
x = self.activation(x)
148
149
x = self.conv1(x)
150
if not self.apply_d_sn:
151
x = self.bn1(x)
152
out = self.activation(x)
153
return out
154
155
156
class Discriminator(nn.Module):
157
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,
158
num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
159
super(Discriminator, self).__init__()
160
self.in_dims = [3] + [64, 128]
161
self.out_dims = [64, 128, 256]
162
163
self.apply_d_sn = apply_d_sn
164
self.d_cond_mtd = d_cond_mtd
165
self.aux_cls_type = aux_cls_type
166
self.normalize_d_embed = normalize_d_embed
167
self.num_classes = num_classes
168
self.mixed_precision = mixed_precision
169
self.MODEL= MODEL
170
171
self.blocks = []
172
for index in range(len(self.in_dims)):
173
self.blocks += [[
174
DiscBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=self.apply_d_sn, MODULES=MODULES)
175
]]
176
177
if index + 1 in attn_d_loc and apply_attn:
178
self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]
179
180
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
181
182
self.activation = MODULES.d_act_fn
183
self.conv1 = MODULES.d_conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
184
185
if not self.apply_d_sn:
186
self.bn1 = MODULES.d_bn(in_features=512)
187
188
# linear layer for adversarial training
189
if self.d_cond_mtd == "MH":
190
self.linear1 = MODULES.d_linear(in_features=512, out_features=1 + num_classes, bias=True)
191
elif self.d_cond_mtd == "MD":
192
self.linear1 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=True)
193
else:
194
self.linear1 = MODULES.d_linear(in_features=512, out_features=1, bias=True)
195
196
# double num_classes for Auxiliary Discriminative Classifier
197
if self.aux_cls_type == "ADC":
198
num_classes = num_classes * 2
199
200
# linear and embedding layers for discriminator conditioning
201
if self.d_cond_mtd == "AC":
202
self.linear2 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False)
203
elif self.d_cond_mtd == "PD":
204
self.embedding = MODULES.d_embedding(num_classes, 512)
205
elif self.d_cond_mtd in ["2C", "D2DCE"]:
206
self.linear2 = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True)
207
self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
208
else:
209
pass
210
211
# linear and embedding layers for evolved classifier-based GAN
212
if self.aux_cls_type == "TAC":
213
if self.d_cond_mtd == "AC":
214
self.linear_mi = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False)
215
elif self.d_cond_mtd in ["2C", "D2DCE"]:
216
self.linear_mi = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True)
217
self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
218
else:
219
raise NotImplementedError
220
221
# Q head network for infoGAN
222
if self.MODEL.info_type in ["discrete", "both"]:
223
out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
224
self.info_discrete_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)
225
if self.MODEL.info_type in ["continuous", "both"]:
226
out_features = self.MODEL.info_num_conti_c
227
self.info_conti_mu_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)
228
self.info_conti_var_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)
229
230
if d_init:
231
ops.init_weights(self.modules, d_init)
232
233
def forward(self, x, label, eval=False, adc_fake=False):
234
with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
235
embed, proxy, cls_output = None, None, None
236
mi_embed, mi_proxy, mi_cls_output = None, None, None
237
info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
238
h = x
239
for index, blocklist in enumerate(self.blocks):
240
for block in blocklist:
241
h = block(h)
242
h = self.conv1(h)
243
if not self.apply_d_sn:
244
h = self.bn1(h)
245
bottom_h, bottom_w = h.shape[2], h.shape[3]
246
h = self.activation(h)
247
h = torch.sum(h, dim=[2, 3])
248
249
# adversarial training
250
adv_output = torch.squeeze(self.linear1(h))
251
252
# make class labels odd (for fake) or even (for real) for ADC
253
if self.aux_cls_type == "ADC":
254
if adc_fake:
255
label = label*2 + 1
256
else:
257
label = label*2
258
259
# forward pass through InfoGAN Q head
260
if self.MODEL.info_type in ["discrete", "both"]:
261
info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
262
if self.MODEL.info_type in ["continuous", "both"]:
263
info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
264
info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))
265
266
# class conditioning
267
if self.d_cond_mtd == "AC":
268
if self.normalize_d_embed:
269
for W in self.linear2.parameters():
270
W = F.normalize(W, dim=1)
271
h = F.normalize(h, dim=1)
272
cls_output = self.linear2(h)
273
elif self.d_cond_mtd == "PD":
274
adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
275
elif self.d_cond_mtd in ["2C", "D2DCE"]:
276
embed = self.linear2(h)
277
proxy = self.embedding(label)
278
if self.normalize_d_embed:
279
embed = F.normalize(embed, dim=1)
280
proxy = F.normalize(proxy, dim=1)
281
elif self.d_cond_mtd == "MD":
282
idx = torch.LongTensor(range(label.size(0))).to(label.device)
283
adv_output = adv_output[idx, label]
284
elif self.d_cond_mtd in ["W/O", "MH"]:
285
pass
286
else:
287
raise NotImplementedError
288
289
# extra conditioning for TACGAN and ADCGAN
290
if self.aux_cls_type == "TAC":
291
if self.d_cond_mtd == "AC":
292
if self.normalize_d_embed:
293
for W in self.linear_mi.parameters():
294
W = F.normalize(W, dim=1)
295
mi_cls_output = self.linear_mi(h)
296
elif self.d_cond_mtd in ["2C", "D2DCE"]:
297
mi_embed = self.linear_mi(h)
298
mi_proxy = self.embedding_mi(label)
299
if self.normalize_d_embed:
300
mi_embed = F.normalize(mi_embed, dim=1)
301
mi_proxy = F.normalize(mi_proxy, dim=1)
302
return {
303
"h": h,
304
"adv_output": adv_output,
305
"embed": embed,
306
"proxy": proxy,
307
"cls_output": cls_output,
308
"label": label,
309
"mi_embed": mi_embed,
310
"mi_proxy": mi_proxy,
311
"mi_cls_output": mi_cls_output,
312
"info_discrete_c_logits": info_discrete_c_logits,
313
"info_conti_mu": info_conti_mu,
314
"info_conti_var": info_conti_var
315
}
316
317