Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/loss.py
1286 views
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from models.networks.architecture import VGG19
10
11
12
# Defines the GAN loss which uses either LSGAN or the regular GAN.
13
# When LSGAN is used, it is basically same as MSELoss,
14
# but it abstracts away the need to create the target label tensor
15
# that has the same size as the input
16
class GANLoss(nn.Module):
17
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
18
tensor=torch.FloatTensor, opt=None):
19
super(GANLoss, self).__init__()
20
self.real_label = target_real_label
21
self.fake_label = target_fake_label
22
self.real_label_tensor = None
23
self.fake_label_tensor = None
24
self.zero_tensor = None
25
self.Tensor = tensor
26
self.gan_mode = gan_mode
27
self.opt = opt
28
if gan_mode == 'ls':
29
pass
30
elif gan_mode == 'original':
31
pass
32
elif gan_mode == 'w':
33
pass
34
elif gan_mode == 'hinge':
35
pass
36
else:
37
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
38
39
def get_target_tensor(self, input, target_is_real):
40
if target_is_real:
41
if self.real_label_tensor is None:
42
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
43
self.real_label_tensor.requires_grad_(False)
44
return self.real_label_tensor.expand_as(input)
45
else:
46
if self.fake_label_tensor is None:
47
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
48
self.fake_label_tensor.requires_grad_(False)
49
return self.fake_label_tensor.expand_as(input)
50
51
def get_zero_tensor(self, input):
52
if self.zero_tensor is None:
53
self.zero_tensor = self.Tensor(1).fill_(0)
54
self.zero_tensor.requires_grad_(False)
55
return self.zero_tensor.expand_as(input)
56
57
def loss(self, input, target_is_real, for_discriminator=True):
58
if self.gan_mode == 'original': # cross entropy loss
59
target_tensor = self.get_target_tensor(input, target_is_real)
60
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
61
return loss
62
elif self.gan_mode == 'ls':
63
target_tensor = self.get_target_tensor(input, target_is_real)
64
return F.mse_loss(input, target_tensor)
65
elif self.gan_mode == 'hinge':
66
if for_discriminator:
67
if target_is_real:
68
minval = torch.min(input - 1, self.get_zero_tensor(input))
69
loss = -torch.mean(minval)
70
else:
71
minval = torch.min(-input - 1, self.get_zero_tensor(input))
72
loss = -torch.mean(minval)
73
else:
74
assert target_is_real, "The generator's hinge loss must be aiming for real"
75
loss = -torch.mean(input)
76
return loss
77
else:
78
# wgan
79
if target_is_real:
80
return -input.mean()
81
else:
82
return input.mean()
83
84
def __call__(self, input, target_is_real, for_discriminator=True):
85
# computing loss is a bit complicated because |input| may not be
86
# a tensor, but list of tensors in case of multiscale discriminator
87
if isinstance(input, list):
88
loss = 0
89
for pred_i in input:
90
if isinstance(pred_i, list):
91
pred_i = pred_i[-1]
92
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
93
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
94
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
95
loss += new_loss
96
return loss / len(input)
97
else:
98
return self.loss(input, target_is_real, for_discriminator)
99
100
101
# Perceptual loss that uses a pretrained VGG network
102
class VGGLoss(nn.Module):
103
def __init__(self, gpu_ids):
104
super(VGGLoss, self).__init__()
105
self.vgg = VGG19().cuda() if len(gpu_ids)>0 else VGG19()
106
self.criterion = nn.L1Loss()
107
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
108
109
def forward(self, x, y):
110
if x.shape[-2:]!=y.shape[-2:]:
111
y = torch.nn.functional.interpolate(y, tuple(x.shape[-2:]))
112
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
113
loss = 0
114
for i in range(len(x_vgg)):
115
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
116
return loss
117
118
119
# KL Divergence loss used in VAE with an image encoder
120
class KLDLoss(nn.Module):
121
def forward(self, mu, logvar):
122
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
123
124