Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/models/wgan.py
1192 views
1
import torch
2
from torch import nn
3
import numpy as np
4
from torch import Tensor
5
from typing import Callable
6
import torch.nn.functional as F
7
from torch.nn.utils import spectral_norm
8
9
10
class Discriminator(nn.Module):
11
def __init__(self, feature_maps: int, image_channels: int) -> None:
12
"""
13
Args:
14
feature_maps: Number of feature maps to use
15
image_channels: Number of channels of the images from the dataset
16
"""
17
super().__init__()
18
self.disc = nn.Sequential(
19
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
20
self._make_disc_block(feature_maps, feature_maps * 2),
21
self._make_disc_block(feature_maps * 2, feature_maps * 4),
22
self._make_disc_block(feature_maps * 4, feature_maps * 8),
23
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
24
)
25
26
@staticmethod
27
def _make_disc_block(
28
in_channels: int,
29
out_channels: int,
30
kernel_size: int = 4,
31
stride: int = 2,
32
padding: int = 1,
33
bias: bool = False,
34
batch_norm: bool = True,
35
last_block: bool = False,
36
) -> nn.Sequential:
37
if not last_block:
38
disc_block = nn.Sequential(
39
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
40
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
41
nn.LeakyReLU(0.2, inplace=True),
42
)
43
else:
44
disc_block = nn.Sequential(
45
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
46
nn.Sigmoid(),
47
)
48
49
return disc_block
50
51
def forward(self, x: Tensor) -> Tensor:
52
return self.disc(x).view(-1, 1).squeeze(1)
53
54
55
class Discriminator(nn.Module):
56
def __init__(self, feature_maps: int, image_channels: int) -> None:
57
"""
58
Args:
59
feature_maps: Number of feature maps to use
60
image_channels: Number of channels of the images from the dataset
61
"""
62
super().__init__()
63
self.disc = nn.Sequential(
64
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
65
self._make_disc_block(feature_maps, feature_maps * 2),
66
self._make_disc_block(feature_maps * 2, feature_maps * 4),
67
self._make_disc_block(feature_maps * 4, feature_maps * 8),
68
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
69
)
70
71
@staticmethod
72
def _make_disc_block(
73
in_channels: int,
74
out_channels: int,
75
kernel_size: int = 4,
76
stride: int = 2,
77
padding: int = 1,
78
bias: bool = False,
79
batch_norm: bool = True,
80
last_block: bool = False,
81
) -> nn.Sequential:
82
if not last_block:
83
disc_block = nn.Sequential(
84
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
85
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
86
nn.LeakyReLU(0.2, inplace=True),
87
)
88
else:
89
disc_block = nn.Sequential(
90
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
91
nn.Sigmoid(),
92
)
93
94
return disc_block
95
96
def forward(self, x: Tensor) -> Tensor:
97
return self.disc(x).view(-1, 1).squeeze(1)
98
99
100
class Generator(nn.Module):
101
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
102
"""
103
Args:
104
latent_dim: Dimension of the latent space
105
feature_maps: Number of feature maps to use
106
image_channels: Number of channels of the images from the dataset
107
"""
108
super().__init__()
109
self.latent_dim = latent_dim
110
self.gen = nn.Sequential(
111
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
112
self._make_gen_block(feature_maps * 8, feature_maps * 4),
113
self._make_gen_block(feature_maps * 4, feature_maps * 2),
114
self._make_gen_block(feature_maps * 2, feature_maps),
115
self._make_gen_block(feature_maps, image_channels, last_block=True),
116
)
117
118
@staticmethod
119
def _make_gen_block(
120
in_channels: int,
121
out_channels: int,
122
kernel_size: int = 4,
123
stride: int = 2,
124
padding: int = 1,
125
bias: bool = False,
126
last_block: bool = False,
127
) -> nn.Sequential:
128
if not last_block:
129
gen_block = nn.Sequential(
130
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
131
nn.BatchNorm2d(out_channels),
132
nn.ReLU(True),
133
)
134
else:
135
gen_block = nn.Sequential(
136
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
137
nn.Sigmoid(),
138
)
139
140
return gen_block
141
142
def forward(self, noise: Tensor) -> Tensor:
143
return self.gen(noise)
144
145
146
def get_noise(real: Tensor, configs: dict) -> Tensor:
147
batch_size = len(real)
148
device = real.device
149
noise = torch.randn(batch_size, configs["latent_dim"], device=device)
150
noise = noise.view(*noise.shape, 1, 1)
151
return noise
152
153
154
def get_sample(generator: Callable, real: Tensor, configs: dict) -> Tensor:
155
noise = get_noise(real, configs)
156
fake = generator(noise)
157
158
return fake
159
160
161
def instance_noise(configs: dict, epoch_num: int, real: Tensor):
162
if configs["loss_params"]["instance_noise"]:
163
if configs["instance_noise_params"]["gamma"] is not None:
164
noise_level = (configs["instance_noise_params"]["gamma"]) ** epoch_num * configs["instance_noise_params"][
165
"noise_level"
166
]
167
else:
168
noise_level = configs["instance_noise_params"]["noise_level"]
169
real = real + noise_level * torch.randn_like(real)
170
return real
171
172
173
def top_k(configs: dict, epoch_num: int, preds: Tensor):
174
if configs["loss_params"]["top_k"]:
175
if configs["top_k_params"]["gamma"] is not None:
176
k = int((configs["top_k_params"]["gamma"]) ** epoch_num * configs["top_k_params"]["k"])
177
else:
178
k = configs["top_k_params"]["k"]
179
preds = torch.topk(preds, k)[0]
180
return preds
181
182
183
def disc_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
184
# Train with real
185
real = instance_noise(configs, epoch_num, real)
186
real_pred = discriminator(real).mean()
187
188
# Train with fake
189
fake = get_sample(generator, real, configs["loss_params"])
190
fake = instance_noise(configs, epoch_num, fake)
191
fake_pred = discriminator(fake)
192
fake_pred = top_k(configs, epoch_num, fake_pred).mean()
193
194
disc_loss = -real_pred + fake_pred
195
clip_value = configs["loss_params"]["clip_value"]
196
for p in discriminator.parameters():
197
p.data.clamp_(-clip_value, clip_value)
198
return disc_loss
199
200
201
def gen_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
202
# Train with fake
203
fake = get_sample(generator, real, configs["loss_params"])
204
fake = instance_noise(configs, epoch_num, fake)
205
fake_pred = discriminator(fake)
206
fake_pred = top_k(configs, epoch_num, fake_pred).mean()
207
208
return -fake_pred
209
210