Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/ops.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/op.py
6
7
from torch.nn.utils import spectral_norm
8
from torch.nn import init
9
import torch
10
import torch.nn as nn
11
import numpy as np
12
13
14
class ConditionalBatchNorm2d(nn.Module):
15
# https://github.com/voletiv/self-attention-GAN-pytorch
16
def __init__(self, in_features, out_features, MODULES):
17
super().__init__()
18
self.in_features = in_features
19
self.bn = batchnorm_2d(out_features, eps=1e-4, momentum=0.1, affine=False)
20
21
self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
22
self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
23
24
def forward(self, x, y):
25
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
26
bias = self.bias(y).view(y.size(0), -1, 1, 1)
27
out = self.bn(x)
28
return out * gain + bias
29
30
31
class SelfAttention(nn.Module):
32
"""
33
https://github.com/voletiv/self-attention-GAN-pytorch
34
MIT License
35
36
Copyright (c) 2019 Vikram Voleti
37
38
Permission is hereby granted, free of charge, to any person obtaining a copy
39
of this software and associated documentation files (the "Software"), to deal
40
in the Software without restriction, including without limitation the rights
41
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
42
copies of the Software, and to permit persons to whom the Software is
43
furnished to do so, subject to the following conditions:
44
45
The above copyright notice and this permission notice shall be included in all
46
copies or substantial portions of the Software.
47
48
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
49
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
50
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
51
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
52
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
53
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
54
SOFTWARE.
55
"""
56
def __init__(self, in_channels, is_generator, MODULES):
57
super(SelfAttention, self).__init__()
58
self.in_channels = in_channels
59
60
if is_generator:
61
self.conv1x1_theta = MODULES.g_conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,
62
stride=1, padding=0, bias=False)
63
self.conv1x1_phi = MODULES.g_conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,
64
stride=1, padding=0, bias=False)
65
self.conv1x1_g = MODULES.g_conv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,
66
stride=1, padding=0, bias=False)
67
self.conv1x1_attn = MODULES.g_conv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,
68
stride=1, padding=0, bias=False)
69
else:
70
self.conv1x1_theta = MODULES.d_conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,
71
stride=1, padding=0, bias=False)
72
self.conv1x1_phi = MODULES.d_conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,
73
stride=1, padding=0, bias=False)
74
self.conv1x1_g = MODULES.d_conv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,
75
stride=1, padding=0, bias=False)
76
self.conv1x1_attn = MODULES.d_conv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,
77
stride=1, padding=0, bias=False)
78
79
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
80
self.softmax = nn.Softmax(dim=-1)
81
self.sigma = nn.Parameter(torch.zeros(1), requires_grad=True)
82
83
def forward(self, x):
84
_, ch, h, w = x.size()
85
# Theta path
86
theta = self.conv1x1_theta(x)
87
theta = theta.view(-1, ch // 8, h * w)
88
# Phi path
89
phi = self.conv1x1_phi(x)
90
phi = self.maxpool(phi)
91
phi = phi.view(-1, ch // 8, h * w // 4)
92
# Attn map
93
attn = torch.bmm(theta.permute(0, 2, 1), phi)
94
attn = self.softmax(attn)
95
# g path
96
g = self.conv1x1_g(x)
97
g = self.maxpool(g)
98
g = g.view(-1, ch // 2, h * w // 4)
99
# Attn_g
100
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
101
attn_g = attn_g.view(-1, ch // 2, h, w)
102
attn_g = self.conv1x1_attn(attn_g)
103
return x + self.sigma * attn_g
104
105
106
class LeCamEMA(object):
107
# Simple wrapper that applies EMA to losses.
108
# https://github.com/google/lecam-gan/blob/master/third_party/utils.py
109
def __init__(self, init=7777, decay=0.9, start_iter=0):
110
self.G_loss = init
111
self.D_loss_real = init
112
self.D_loss_fake = init
113
self.D_real = init
114
self.D_fake = init
115
self.decay = decay
116
self.start_itr = start_iter
117
118
def update(self, cur, mode, itr):
119
if itr < self.start_itr:
120
decay = 0.0
121
else:
122
decay = self.decay
123
if mode == "G_loss":
124
self.G_loss = self.G_loss*decay + cur*(1 - decay)
125
elif mode == "D_loss_real":
126
self.D_loss_real = self.D_loss_real*decay + cur*(1 - decay)
127
elif mode == "D_loss_fake":
128
self.D_loss_fake = self.D_loss_fake*decay + cur*(1 - decay)
129
elif mode == "D_real":
130
self.D_real = self.D_real*decay + cur*(1 - decay)
131
elif mode == "D_fake":
132
self.D_fake = self.D_fake*decay + cur*(1 - decay)
133
134
135
def init_weights(modules, initialize):
136
for module in modules():
137
if (isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.Linear)):
138
if initialize == "ortho":
139
init.orthogonal_(module.weight)
140
if module.bias is not None:
141
module.bias.data.fill_(0.)
142
elif initialize == "N02":
143
init.normal_(module.weight, 0, 0.02)
144
if module.bias is not None:
145
module.bias.data.fill_(0.)
146
elif initialize in ["glorot", "xavier"]:
147
init.xavier_uniform_(module.weight)
148
if module.bias is not None:
149
module.bias.data.fill_(0.)
150
else:
151
pass
152
elif isinstance(module, nn.Embedding):
153
if initialize == "ortho":
154
init.orthogonal_(module.weight)
155
elif initialize == "N02":
156
init.normal_(module.weight, 0, 0.02)
157
elif initialize in ["glorot", "xavier"]:
158
init.xavier_uniform_(module.weight)
159
else:
160
pass
161
else:
162
pass
163
164
165
def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
166
return nn.Conv2d(in_channels=in_channels,
167
out_channels=out_channels,
168
kernel_size=kernel_size,
169
stride=stride,
170
padding=padding,
171
dilation=dilation,
172
groups=groups,
173
bias=bias)
174
175
176
def deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True):
177
return nn.ConvTranspose2d(in_channels=in_channels,
178
out_channels=out_channels,
179
kernel_size=kernel_size,
180
stride=stride,
181
padding=padding,
182
dilation=dilation,
183
groups=groups,
184
bias=bias)
185
186
187
def linear(in_features, out_features, bias=True):
188
return nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
189
190
191
def embedding(num_embeddings, embedding_dim):
192
return nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
193
194
195
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
196
return spectral_norm(nn.Conv2d(in_channels=in_channels,
197
out_channels=out_channels,
198
kernel_size=kernel_size,
199
stride=stride,
200
padding=padding,
201
dilation=dilation,
202
groups=groups,
203
bias=bias),
204
eps=1e-6)
205
206
207
def sndeconv2d(in_channels, out_channels, kernel_size, stride=2, padding=0, dilation=1, groups=1, bias=True):
208
return spectral_norm(nn.ConvTranspose2d(in_channels=in_channels,
209
out_channels=out_channels,
210
kernel_size=kernel_size,
211
stride=stride,
212
padding=padding,
213
dilation=dilation,
214
groups=groups,
215
bias=bias),
216
eps=1e-6)
217
218
219
def snlinear(in_features, out_features, bias=True):
220
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6)
221
222
223
def sn_embedding(num_embeddings, embedding_dim):
224
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6)
225
226
227
def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True):
228
return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True)
229
230
231
def conv3x3(in_planes, out_planes, stride=1):
232
"3x3 convolution with padding"
233
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
234
padding=1, bias=False)
235
236
237
def adjust_learning_rate(optimizer, lr_org, epoch, total_epoch, dataset):
238
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
239
if dataset in ["CIFAR10", "CIFAR100"]:
240
lr = lr_org * (0.1 ** (epoch // (total_epoch * 0.5))) * (0.1 ** (epoch // (total_epoch * 0.75)))
241
elif dataset in ["Tiny_ImageNet", "ImageNet"]:
242
if total_epoch == 300:
243
lr = lr_org * (0.1 ** (epoch // 75))
244
else:
245
lr = lr_org * (0.1 ** (epoch // 30))
246
247
for param_group in optimizer.param_groups:
248
param_group['lr'] = lr
249
250
251
def quantize_images(x):
252
x = (x + 1)/2
253
x = (255.0*x + 0.5).clamp(0.0, 255.0)
254
x = x.detach().cpu().numpy().astype(np.uint8)
255
return x
256
257
258
def resize_images(x, resizer, ToTensor, mean, std, device):
259
x = x.transpose((0, 2, 3, 1))
260
x = list(map(lambda x: ToTensor(resizer(x)), list(x)))
261
x = torch.stack(x, 0).to(device)
262
x = (x/255.0 - mean)/std
263
return x
264
265