Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/models/logan.py
1192 views
1
import torch
2
from torch import nn
3
from torch import Tensor
4
from typing import Callable
5
import torch.nn.functional as F
6
7
8
class Discriminator(nn.Module):
9
def __init__(self, feature_maps: int, image_channels: int) -> None:
10
"""
11
Args:
12
feature_maps: Number of feature maps to use
13
image_channels: Number of channels of the images from the dataset
14
"""
15
super().__init__()
16
self.disc = nn.Sequential(
17
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
18
self._make_disc_block(feature_maps, feature_maps * 2),
19
self._make_disc_block(feature_maps * 2, feature_maps * 4),
20
self._make_disc_block(feature_maps * 4, feature_maps * 8),
21
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
22
)
23
24
@staticmethod
25
def _make_disc_block(
26
in_channels: int,
27
out_channels: int,
28
kernel_size: int = 4,
29
stride: int = 2,
30
padding: int = 1,
31
bias: bool = False,
32
batch_norm: bool = True,
33
last_block: bool = False,
34
) -> nn.Sequential:
35
if not last_block:
36
disc_block = nn.Sequential(
37
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
38
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
39
nn.LeakyReLU(0.2, inplace=True),
40
)
41
else:
42
disc_block = nn.Sequential(
43
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
44
nn.Sigmoid(),
45
)
46
47
return disc_block
48
49
def forward(self, x: Tensor) -> Tensor:
50
return self.disc(x).view(-1, 1).squeeze(1)
51
52
53
class Generator(nn.Module):
54
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
55
"""
56
Args:
57
latent_dim: Dimension of the latent space
58
feature_maps: Number of feature maps to use
59
image_channels: Number of channels of the images from the dataset
60
"""
61
super().__init__()
62
self.latent_dim = latent_dim
63
self.gen = nn.Sequential(
64
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
65
self._make_gen_block(feature_maps * 8, feature_maps * 4),
66
self._make_gen_block(feature_maps * 4, feature_maps * 2),
67
self._make_gen_block(feature_maps * 2, feature_maps),
68
self._make_gen_block(feature_maps, image_channels, last_block=True),
69
)
70
71
@staticmethod
72
def _make_gen_block(
73
in_channels: int,
74
out_channels: int,
75
kernel_size: int = 4,
76
stride: int = 2,
77
padding: int = 1,
78
bias: bool = False,
79
last_block: bool = False,
80
) -> nn.Sequential:
81
if not last_block:
82
gen_block = nn.Sequential(
83
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
84
nn.BatchNorm2d(out_channels),
85
nn.ReLU(True),
86
)
87
else:
88
gen_block = nn.Sequential(
89
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
90
nn.Sigmoid(),
91
)
92
93
return gen_block
94
95
def forward(self, noise: Tensor) -> Tensor:
96
return self.gen(noise)
97
98
99
def get_noise(real: Tensor, configs: dict) -> Tensor:
100
batch_size = len(real)
101
device = real.device
102
noise = torch.randn(batch_size, configs["latent_dim"], device=device)
103
noise = noise.view(*noise.shape, 1, 1)
104
return noise
105
106
107
def get_sample(discriminator: Callable, generator: Callable, real: Tensor, configs: dict) -> Tensor:
108
noise = get_noise(real, configs)
109
noise_t = noise.clone()
110
noise_t.requires_grad_(True)
111
if noise_t.grad is not None:
112
noise_t.grad.zero_()
113
fake = generator(noise_t)
114
fake_pred = discriminator(fake)
115
fake_pred.backward(torch.ones_like(fake_pred).to(noise.device))
116
noise_t = noise_t + configs["alpha"] * noise_t.grad.data
117
fake = generator(noise_t)
118
return fake
119
120
121
def instance_noise(configs: dict, epoch_num: int, real: Tensor):
122
if configs["loss_params"]["instance_noise"]:
123
if configs["instance_noise_params"]["gamma"] is not None:
124
noise_level = (configs["instance_noise_params"]["gamma"]) ** epoch_num * configs["instance_noise_params"][
125
"noise_level"
126
]
127
else:
128
noise_level = configs["instance_noise_params"]["noise_level"]
129
real = real + noise_level * torch.randn_like(real)
130
return real
131
132
133
def top_k(configs: dict, epoch_num: int, preds: Tensor):
134
if configs["loss_params"]["top_k"]:
135
if configs["loss_params"]["top_k"] is not None:
136
k = int((configs["loss_params"]["top_k"]) ** epoch_num * configs["loss_params"]["top_k"])
137
else:
138
k = configs["loss_params"]["top_k"]
139
preds = torch.topk(preds, k)
140
return preds
141
142
143
def get_fake_preds(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor):
144
145
fake = get_sample(discriminator, generator, real, configs["loss_params"])
146
fake = instance_noise(configs, epoch_num, fake)
147
fake_pred = discriminator(fake)
148
fake_pred = top_k(configs, epoch_num, fake_pred)
149
return fake_pred
150
151
152
def disc_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
153
# Train with real
154
real = instance_noise(configs, epoch_num, real)
155
real_pred = discriminator(real)
156
real_gt = torch.ones_like(real_pred)
157
real_loss = F.binary_cross_entropy(real_pred, real_gt)
158
159
# Train with fake
160
fake_pred = get_fake_preds(configs, discriminator, generator, epoch_num, real)
161
fake_gt = torch.zeros_like(fake_pred)
162
fake_loss = F.binary_cross_entropy(fake_pred, fake_gt)
163
164
disc_loss = real_loss + fake_loss
165
166
return disc_loss
167
168
169
def gen_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
170
# Train with fake
171
fake_pred = get_fake_preds(configs, discriminator, generator, epoch_num, real)
172
fake_gt = torch.ones_like(fake_pred)
173
gen_loss = F.binary_cross_entropy(fake_pred, fake_gt)
174
175
return gen_loss
176
177