Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/models/dcgan.py
1192 views
1
import torch
2
from torch import nn
3
from torch import Tensor
4
from typing import Callable
5
import torch.nn.functional as F
6
from torch.nn.utils import spectral_norm
7
8
9
class Discriminator(nn.Module):
10
def __init__(self, feature_maps: int, image_channels: int) -> None:
11
"""
12
Args:
13
feature_maps: Number of feature maps to use
14
image_channels: Number of channels of the images from the dataset
15
"""
16
super().__init__()
17
self.disc = nn.Sequential(
18
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
19
self._make_disc_block(feature_maps, feature_maps * 2),
20
self._make_disc_block(feature_maps * 2, feature_maps * 4),
21
self._make_disc_block(feature_maps * 4, feature_maps * 8),
22
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
23
)
24
25
@staticmethod
26
def _make_disc_block(
27
in_channels: int,
28
out_channels: int,
29
kernel_size: int = 4,
30
stride: int = 2,
31
padding: int = 1,
32
bias: bool = False,
33
batch_norm: bool = True,
34
last_block: bool = False,
35
) -> nn.Sequential:
36
if not last_block:
37
disc_block = nn.Sequential(
38
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
39
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
40
nn.LeakyReLU(0.2, inplace=True),
41
)
42
else:
43
disc_block = nn.Sequential(
44
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
45
nn.Sigmoid(),
46
)
47
48
return disc_block
49
50
def forward(self, x: Tensor) -> Tensor:
51
return self.disc(x).view(-1, 1).squeeze(1)
52
53
54
class SpectralDiscriminator(nn.Module):
55
def __init__(self, feature_maps: int, image_channels: int) -> None:
56
"""
57
Args:
58
feature_maps: Number of feature maps to use
59
image_channels: Number of channels of the images from the dataset
60
"""
61
super().__init__()
62
self.disc = nn.Sequential(
63
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
64
self._make_disc_block(feature_maps, feature_maps * 2),
65
self._make_disc_block(feature_maps * 2, feature_maps * 4),
66
self._make_disc_block(feature_maps * 4, feature_maps * 8),
67
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
68
)
69
70
@staticmethod
71
def _make_disc_block(
72
in_channels: int,
73
out_channels: int,
74
kernel_size: int = 4,
75
stride: int = 2,
76
padding: int = 1,
77
bias: bool = False,
78
batch_norm: bool = True,
79
last_block: bool = False,
80
) -> nn.Sequential:
81
if not last_block:
82
disc_block = nn.Sequential(
83
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
84
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
85
nn.LeakyReLU(0.2, inplace=True),
86
)
87
else:
88
disc_block = nn.Sequential(
89
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
90
nn.Sigmoid(),
91
)
92
93
return disc_block
94
95
def forward(self, x: Tensor) -> Tensor:
96
return self.disc(x).view(-1, 1).squeeze(1)
97
98
99
class Generator(nn.Module):
100
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
101
"""
102
Args:
103
latent_dim: Dimension of the latent space
104
feature_maps: Number of feature maps to use
105
image_channels: Number of channels of the images from the dataset
106
"""
107
super().__init__()
108
self.latent_dim = latent_dim
109
self.gen = nn.Sequential(
110
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
111
self._make_gen_block(feature_maps * 8, feature_maps * 4),
112
self._make_gen_block(feature_maps * 4, feature_maps * 2),
113
self._make_gen_block(feature_maps * 2, feature_maps),
114
self._make_gen_block(feature_maps, image_channels, last_block=True),
115
)
116
117
@staticmethod
118
def _make_gen_block(
119
in_channels: int,
120
out_channels: int,
121
kernel_size: int = 4,
122
stride: int = 2,
123
padding: int = 1,
124
bias: bool = False,
125
last_block: bool = False,
126
) -> nn.Sequential:
127
if not last_block:
128
gen_block = nn.Sequential(
129
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
130
nn.BatchNorm2d(out_channels),
131
nn.ReLU(True),
132
)
133
else:
134
gen_block = nn.Sequential(
135
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
136
nn.Sigmoid(),
137
)
138
139
return gen_block
140
141
def forward(self, noise: Tensor) -> Tensor:
142
return self.gen(noise)
143
144
145
def get_noise(real: Tensor, configs: dict) -> Tensor:
146
batch_size = len(real)
147
device = real.device
148
noise = torch.randn(batch_size, configs["latent_dim"], device=device)
149
noise = noise.view(*noise.shape, 1, 1)
150
return noise
151
152
153
def get_sample(generator: Callable, real: Tensor, configs: dict) -> Tensor:
154
noise = get_noise(real, configs)
155
fake = generator(noise)
156
157
return fake
158
159
160
def instance_noise(configs: dict, epoch_num: int, real: Tensor):
161
if configs["loss_params"]["instance_noise"]:
162
if configs["instance_noise_params"]["gamma"] is not None:
163
noise_level = (configs["instance_noise_params"]["gamma"]) ** epoch_num * configs["instance_noise_params"][
164
"noise_level"
165
]
166
else:
167
noise_level = configs["instance_noise_params"]["noise_level"]
168
real = real + noise_level * torch.randn_like(real)
169
return real
170
171
172
def top_k(configs: dict, epoch_num: int, preds: Tensor):
173
if configs["loss_params"]["top_k"]:
174
if configs["top_k_params"]["gamma"] is not None:
175
k = int((configs["top_k_params"]["gamma"]) ** epoch_num * configs["top_k_params"]["k"])
176
else:
177
k = configs["top_k_params"]["k"]
178
preds = torch.topk(preds, k)[0]
179
return preds
180
181
182
def disc_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
183
# Train with real
184
real = instance_noise(configs, epoch_num, real)
185
real_pred = discriminator(real)
186
real_gt = torch.ones_like(real_pred)
187
real_loss = F.binary_cross_entropy(real_pred, real_gt)
188
189
# Train with fake
190
fake = get_sample(generator, real, configs["loss_params"])
191
fake = instance_noise(configs, epoch_num, fake)
192
fake_pred = discriminator(fake)
193
fake_pred = top_k(configs, epoch_num, fake_pred)
194
fake_gt = torch.zeros_like(fake_pred)
195
fake_loss = F.binary_cross_entropy(fake_pred, fake_gt)
196
197
disc_loss = real_loss + fake_loss
198
199
return disc_loss
200
201
202
def gen_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
203
# Train with fake
204
fake = get_sample(generator, real, configs["loss_params"])
205
fake = instance_noise(configs, epoch_num, fake)
206
fake_pred = discriminator(fake)
207
fake_pred = top_k(configs, epoch_num, fake_pred)
208
fake_gt = torch.ones_like(fake_pred)
209
gen_loss = F.binary_cross_entropy(fake_pred, fake_gt)
210
211
return gen_loss
212
213