Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/losses.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/loss.py
6
7
from torch.nn import DataParallel
8
from torch import autograd
9
import torch
10
import torch.nn as nn
11
import torch.distributed as dist
12
import torch.nn.functional as F
13
import numpy as np
14
15
from utils.style_ops import conv2d_gradfix
16
import utils.ops as ops
17
18
19
class GatherLayer(torch.autograd.Function):
20
"""
21
This file is copied from
22
https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py
23
Gather tensors from all process, supporting backward propagation
24
"""
25
@staticmethod
26
def forward(ctx, input):
27
ctx.save_for_backward(input)
28
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
29
dist.all_gather(output, input)
30
return tuple(output)
31
32
@staticmethod
33
def backward(ctx, *grads):
34
input, = ctx.saved_tensors
35
grad_out = torch.zeros_like(input)
36
grad_out[:] = grads[dist.get_rank()]
37
return grad_out
38
39
40
class CrossEntropyLoss(torch.nn.Module):
41
def __init__(self):
42
super(CrossEntropyLoss, self).__init__()
43
self.ce_loss = torch.nn.CrossEntropyLoss()
44
45
def forward(self, cls_output, label, **_):
46
return self.ce_loss(cls_output, label).mean()
47
48
49
class ConditionalContrastiveLoss(torch.nn.Module):
50
def __init__(self, num_classes, temperature, master_rank, DDP):
51
super(ConditionalContrastiveLoss, self).__init__()
52
self.num_classes = num_classes
53
self.temperature = temperature
54
self.master_rank = master_rank
55
self.DDP = DDP
56
self.calculate_similarity_matrix = self._calculate_similarity_matrix()
57
self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
58
59
def _make_neg_removal_mask(self, labels):
60
labels = labels.detach().cpu().numpy()
61
n_samples = labels.shape[0]
62
mask_multi, target = np.zeros([self.num_classes, n_samples]), 1.0
63
for c in range(self.num_classes):
64
c_indices = np.where(labels == c)
65
mask_multi[c, c_indices] = target
66
return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)
67
68
def _calculate_similarity_matrix(self):
69
return self._cosine_simililarity_matrix
70
71
def _remove_diag(self, M):
72
h, w = M.shape
73
assert h == w, "h and w should be same"
74
mask = np.ones((h, w)) - np.eye(h)
75
mask = torch.from_numpy(mask)
76
mask = (mask).type(torch.bool).to(self.master_rank)
77
return M[mask].view(h, -1)
78
79
def _cosine_simililarity_matrix(self, x, y):
80
v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
81
return v
82
83
def forward(self, embed, proxy, label, **_):
84
if self.DDP:
85
embed = torch.cat(GatherLayer.apply(embed), dim=0)
86
proxy = torch.cat(GatherLayer.apply(proxy), dim=0)
87
label = torch.cat(GatherLayer.apply(label), dim=0)
88
89
sim_matrix = self.calculate_similarity_matrix(embed, embed)
90
sim_matrix = torch.exp(self._remove_diag(sim_matrix) / self.temperature)
91
neg_removal_mask = self._remove_diag(self._make_neg_removal_mask(label)[label])
92
sim_pos_only = neg_removal_mask * sim_matrix
93
94
emb2proxy = torch.exp(self.cosine_similarity(embed, proxy) / self.temperature)
95
96
numerator = emb2proxy + sim_pos_only.sum(dim=1)
97
denomerator = torch.cat([torch.unsqueeze(emb2proxy, dim=1), sim_matrix], dim=1).sum(dim=1)
98
return -torch.log(numerator / denomerator).mean()
99
100
101
class Data2DataCrossEntropyLoss(torch.nn.Module):
102
def __init__(self, num_classes, temperature, m_p, master_rank, DDP):
103
super(Data2DataCrossEntropyLoss, self).__init__()
104
self.num_classes = num_classes
105
self.temperature = temperature
106
self.m_p = m_p
107
self.master_rank = master_rank
108
self.DDP = DDP
109
self.calculate_similarity_matrix = self._calculate_similarity_matrix()
110
self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
111
112
def _calculate_similarity_matrix(self):
113
return self._cosine_simililarity_matrix
114
115
def _cosine_simililarity_matrix(self, x, y):
116
v = self.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
117
return v
118
119
def make_index_matrix(self, labels):
120
labels = labels.detach().cpu().numpy()
121
num_samples = labels.shape[0]
122
mask_multi, target = np.ones([self.num_classes, num_samples]), 0.0
123
124
for c in range(self.num_classes):
125
c_indices = np.where(labels==c)
126
mask_multi[c, c_indices] = target
127
return torch.tensor(mask_multi).type(torch.long).to(self.master_rank)
128
129
def remove_diag(self, M):
130
h, w = M.shape
131
assert h==w, "h and w should be same"
132
mask = np.ones((h, w)) - np.eye(h)
133
mask = torch.from_numpy(mask)
134
mask = (mask).type(torch.bool).to(self.master_rank)
135
return M[mask].view(h, -1)
136
137
def forward(self, embed, proxy, label, **_):
138
# If train a GAN throuh DDP, gather all data on the master rank
139
if self.DDP:
140
embed = torch.cat(GatherLayer.apply(embed), dim=0)
141
proxy = torch.cat(GatherLayer.apply(proxy), dim=0)
142
label = torch.cat(GatherLayer.apply(label), dim=0)
143
144
# calculate similarities between sample embeddings
145
sim_matrix = self.calculate_similarity_matrix(embed, embed) + self.m_p - 1
146
# remove diagonal terms
147
sim_matrix = self.remove_diag(sim_matrix/self.temperature)
148
# for numerical stability
149
sim_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
150
sim_matrix = F.relu(sim_matrix) - sim_max.detach()
151
152
# calculate similarities between sample embeddings and the corresponding proxies
153
smp2proxy = self.cosine_similarity(embed, proxy)
154
# make false negative removal
155
removal_fn = self.remove_diag(self.make_index_matrix(label)[label])
156
# apply the negative removal to the similarity matrix
157
improved_sim_matrix = removal_fn*torch.exp(sim_matrix)
158
159
# compute positive attraction term
160
pos_attr = F.relu((self.m_p - smp2proxy)/self.temperature)
161
# compute negative repulsion term
162
neg_repul = torch.log(torch.exp(-pos_attr) + improved_sim_matrix.sum(dim=1))
163
# compute data to data cross-entropy criterion
164
criterion = pos_attr + neg_repul
165
return criterion.mean()
166
167
168
class PathLengthRegularizer:
169
def __init__(self, device, pl_decay=0.01, pl_weight=2, pl_no_weight_grad=False):
170
self.pl_decay = pl_decay
171
self.pl_weight = pl_weight
172
self.pl_mean = torch.zeros([], device=device)
173
self.pl_no_weight_grad = pl_no_weight_grad
174
175
def cal_pl_reg(self, fake_images, ws):
176
#ws refers to weight style
177
#receives new fake_images of original batch (in original implementation, fakes_images used for calculating g_loss and pl_loss is generated independently)
178
pl_noise = torch.randn_like(fake_images) / np.sqrt(fake_images.shape[2] * fake_images.shape[3])
179
with conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
180
pl_grads = torch.autograd.grad(outputs=[(fake_images * pl_noise).sum()], inputs=[ws], create_graph=True, only_inputs=True)[0]
181
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
182
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
183
self.pl_mean.copy_(pl_mean.detach())
184
pl_penalty = (pl_lengths - pl_mean).square()
185
loss_Gpl = (pl_penalty * self.pl_weight).mean(0)
186
return loss_Gpl
187
188
189
def enable_allreduce(dict_):
190
loss = 0
191
for key, value in dict_.items():
192
if value is not None and key != "label":
193
loss += value.mean()*0
194
return loss
195
196
197
def d_vanilla(d_logit_real, d_logit_fake, DDP):
198
d_loss = torch.mean(F.softplus(-d_logit_real)) + torch.mean(F.softplus(d_logit_fake))
199
return d_loss
200
201
202
def g_vanilla(d_logit_fake, DDP):
203
return torch.mean(F.softplus(-d_logit_fake))
204
205
206
def d_logistic(d_logit_real, d_logit_fake, DDP):
207
d_loss = F.softplus(-d_logit_real) + F.softplus(d_logit_fake)
208
return d_loss.mean()
209
210
211
def g_logistic(d_logit_fake, DDP):
212
# basically same as g_vanilla.
213
return F.softplus(-d_logit_fake).mean()
214
215
216
def d_ls(d_logit_real, d_logit_fake, DDP):
217
d_loss = 0.5 * (d_logit_real - torch.ones_like(d_logit_real))**2 + 0.5 * (d_logit_fake)**2
218
return d_loss.mean()
219
220
221
def g_ls(d_logit_fake, DDP):
222
gen_loss = 0.5 * (d_logit_fake - torch.ones_like(d_logit_fake))**2
223
return gen_loss.mean()
224
225
226
def d_hinge(d_logit_real, d_logit_fake, DDP):
227
return torch.mean(F.relu(1. - d_logit_real)) + torch.mean(F.relu(1. + d_logit_fake))
228
229
230
def g_hinge(d_logit_fake, DDP):
231
return -torch.mean(d_logit_fake)
232
233
234
def d_wasserstein(d_logit_real, d_logit_fake, DDP):
235
return torch.mean(d_logit_fake - d_logit_real)
236
237
238
def g_wasserstein(d_logit_fake, DDP):
239
return -torch.mean(d_logit_fake)
240
241
242
def crammer_singer_loss(adv_output, label, DDP, **_):
243
# https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
244
# crammer singer criterion
245
num_real_classes = adv_output.shape[1] - 1
246
mask = torch.ones_like(adv_output).to(adv_output.device)
247
mask.scatter_(1, label.unsqueeze(-1), 0)
248
wrongs = torch.masked_select(adv_output, mask.bool()).reshape(adv_output.shape[0], num_real_classes)
249
max_wrong, _ = wrongs.max(1)
250
max_wrong = max_wrong.unsqueeze(-1)
251
target = adv_output.gather(1, label.unsqueeze(-1))
252
return torch.mean(F.relu(1 + max_wrong - target))
253
254
255
def feature_matching_loss(real_embed, fake_embed):
256
# https://github.com/ilyakava/BigGAN-PyTorch/blob/master/train_fns.py
257
# feature matching criterion
258
fm_loss = torch.mean(torch.abs(torch.mean(fake_embed, 0) - torch.mean(real_embed, 0)))
259
return fm_loss
260
261
262
def lecam_reg(d_logit_real, d_logit_fake, ema):
263
reg = torch.mean(F.relu(d_logit_real - ema.D_fake).pow(2)) + \
264
torch.mean(F.relu(ema.D_real - d_logit_fake).pow(2))
265
return reg
266
267
268
def cal_deriv(inputs, outputs, device):
269
grads = autograd.grad(outputs=outputs,
270
inputs=inputs,
271
grad_outputs=torch.ones(outputs.size()).to(device),
272
create_graph=True,
273
retain_graph=True,
274
only_inputs=True)[0]
275
return grads
276
277
278
def latent_optimise(zs, fake_labels, generator, discriminator, batch_size, lo_rate, lo_steps, lo_alpha, lo_beta, eval,
279
cal_trsp_cost, device):
280
for step in range(lo_steps - 1):
281
drop_mask = (torch.FloatTensor(batch_size, 1).uniform_() > 1 - lo_rate).to(device)
282
283
zs = autograd.Variable(zs, requires_grad=True)
284
fake_images = generator(zs, fake_labels, eval=eval)
285
fake_dict = discriminator(fake_images, fake_labels, eval=eval)
286
z_grads = cal_deriv(inputs=zs, outputs=fake_dict["adv_output"], device=device)
287
z_grads_norm = torch.unsqueeze((z_grads.norm(2, dim=1)**2), dim=1)
288
delta_z = lo_alpha * z_grads / (lo_beta + z_grads_norm)
289
zs = torch.clamp(zs + drop_mask * delta_z, -1.0, 1.0)
290
291
if cal_trsp_cost:
292
if step == 0:
293
trsf_cost = (delta_z.norm(2, dim=1)**2).mean()
294
else:
295
trsf_cost += (delta_z.norm(2, dim=1)**2).mean()
296
else:
297
trsf_cost = None
298
return zs, trsf_cost
299
300
301
def cal_grad_penalty(real_images, real_labels, fake_images, discriminator, device):
302
batch_size, c, h, w = real_images.shape
303
alpha = torch.rand(batch_size, 1)
304
alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
305
alpha = alpha.to(device)
306
307
real_images = real_images.to(device)
308
interpolates = alpha * real_images + ((1 - alpha) * fake_images)
309
interpolates = interpolates.to(device)
310
interpolates = autograd.Variable(interpolates, requires_grad=True)
311
fake_dict = discriminator(interpolates, real_labels, eval=False)
312
grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
313
grads = grads.view(grads.size(0), -1)
314
315
grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
316
return grad_penalty
317
318
319
def cal_dra_penalty(real_images, real_labels, discriminator, device):
320
batch_size, c, h, w = real_images.shape
321
alpha = torch.rand(batch_size, 1, 1, 1)
322
alpha = alpha.to(device)
323
324
real_images = real_images.to(device)
325
differences = 0.5 * real_images.std() * torch.rand(real_images.size()).to(device)
326
interpolates = real_images + (alpha * differences)
327
interpolates = interpolates.to(device)
328
interpolates = autograd.Variable(interpolates, requires_grad=True)
329
fake_dict = discriminator(interpolates, real_labels, eval=False)
330
grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
331
grads = grads.view(grads.size(0), -1)
332
333
grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
334
return grad_penalty
335
336
337
def cal_maxgrad_penalty(real_images, real_labels, fake_images, discriminator, device):
338
batch_size, c, h, w = real_images.shape
339
alpha = torch.rand(batch_size, 1)
340
alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
341
alpha = alpha.to(device)
342
343
real_images = real_images.to(device)
344
interpolates = alpha * real_images + ((1 - alpha) * fake_images)
345
interpolates = interpolates.to(device)
346
interpolates = autograd.Variable(interpolates, requires_grad=True)
347
fake_dict = discriminator(interpolates, real_labels, eval=False)
348
grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
349
grads = grads.view(grads.size(0), -1)
350
351
maxgrad_penalty = torch.max(grads.norm(2, dim=1)**2) + interpolates[:,0,0,0].mean()*0
352
return maxgrad_penalty
353
354
355
def cal_r1_reg(adv_output, images, device):
356
batch_size = images.size(0)
357
grad_dout = cal_deriv(inputs=images, outputs=adv_output.sum(), device=device)
358
grad_dout2 = grad_dout.pow(2)
359
assert (grad_dout2.size() == images.size())
360
r1_reg = 0.5 * grad_dout2.contiguous().view(batch_size, -1).sum(1).mean(0) + images[:,0,0,0].mean()*0
361
return r1_reg
362
363
364
def adjust_k(current_k, topk_gamma, inf_k):
365
current_k = max(current_k * topk_gamma, inf_k)
366
return current_k
367
368
369
def normal_nll_loss(x, mu, var):
370
# https://github.com/Natsu6767/InfoGAN-PyTorch/blob/master/utils.py
371
# Calculate the negative log likelihood of normal distribution.
372
# Needs to be minimized in InfoGAN. (Treats Q(c]x) as a factored Gaussian)
373
logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6)
374
nll = -(logli.sum(1).mean())
375
return nll
376
377
378
def stylegan_cal_r1_reg(adv_output, images):
379
with conv2d_gradfix.no_weight_gradients():
380
r1_grads = torch.autograd.grad(outputs=[adv_output.sum()], inputs=[images], create_graph=True, only_inputs=True)[0]
381
r1_penalty = r1_grads.square().sum([1,2,3]) / 2
382
return r1_penalty.mean()
383
384