Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/models/sngan.py
1192 views
1
import torch
2
from torch import nn
3
from torch import Tensor
4
from typing import Callable
5
import torch.nn.functional as F
6
from torch.nn.utils import spectral_norm
7
8
9
class Discriminator(nn.Module):
10
def __init__(self, feature_maps: int, image_channels: int) -> None:
11
"""
12
Args:
13
feature_maps: Number of feature maps to use
14
image_channels: Number of channels of the images from the dataset
15
"""
16
super().__init__()
17
self.disc = nn.Sequential(
18
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
19
self._make_disc_block(feature_maps, feature_maps * 2),
20
self._make_disc_block(feature_maps * 2, feature_maps * 4),
21
self._make_disc_block(feature_maps * 4, feature_maps * 8),
22
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
23
)
24
25
@staticmethod
26
def _make_disc_block(
27
in_channels: int,
28
out_channels: int,
29
kernel_size: int = 4,
30
stride: int = 2,
31
padding: int = 1,
32
bias: bool = False,
33
batch_norm: bool = True,
34
last_block: bool = False,
35
) -> nn.Sequential:
36
if not last_block:
37
disc_block = nn.Sequential(
38
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
39
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
40
nn.LeakyReLU(0.2, inplace=True),
41
)
42
else:
43
disc_block = nn.Sequential(
44
spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)),
45
nn.Sigmoid(),
46
)
47
48
return disc_block
49
50
def forward(self, x: Tensor) -> Tensor:
51
return self.disc(x).view(-1, 1).squeeze(1)
52
53
54
class Generator(nn.Module):
55
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
56
"""
57
Args:
58
latent_dim: Dimension of the latent space
59
feature_maps: Number of feature maps to use
60
image_channels: Number of channels of the images from the dataset
61
"""
62
super().__init__()
63
self.latent_dim = latent_dim
64
self.gen = nn.Sequential(
65
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
66
self._make_gen_block(feature_maps * 8, feature_maps * 4),
67
self._make_gen_block(feature_maps * 4, feature_maps * 2),
68
self._make_gen_block(feature_maps * 2, feature_maps),
69
self._make_gen_block(feature_maps, image_channels, last_block=True),
70
)
71
72
@staticmethod
73
def _make_gen_block(
74
in_channels: int,
75
out_channels: int,
76
kernel_size: int = 4,
77
stride: int = 2,
78
padding: int = 1,
79
bias: bool = False,
80
last_block: bool = False,
81
) -> nn.Sequential:
82
if not last_block:
83
gen_block = nn.Sequential(
84
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
85
nn.BatchNorm2d(out_channels),
86
nn.ReLU(True),
87
)
88
else:
89
gen_block = nn.Sequential(
90
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
91
nn.Sigmoid(),
92
)
93
94
return gen_block
95
96
def forward(self, noise: Tensor) -> Tensor:
97
return self.gen(noise)
98
99
100
def get_noise(real: Tensor, configs: dict) -> Tensor:
101
batch_size = len(real)
102
device = real.device
103
noise = torch.randn(batch_size, configs["latent_dim"], device=device)
104
noise = noise.view(*noise.shape, 1, 1)
105
return noise
106
107
108
def get_sample(generator: Callable, real: Tensor, configs: dict) -> Tensor:
109
noise = get_noise(real, configs)
110
fake = generator(noise)
111
112
return fake
113
114
115
def instance_noise(configs: dict, epoch_num: int, real: Tensor):
116
if configs["loss_params"]["instance_noise"]:
117
if configs["instance_noise_params"]["gamma"] is not None:
118
noise_level = (configs["instance_noise_params"]["gamma"]) ** epoch_num * configs["instance_noise_params"][
119
"noise_level"
120
]
121
else:
122
noise_level = configs["instance_noise_params"]["noise_level"]
123
real = real + noise_level * torch.randn_like(real)
124
return real
125
126
127
def top_k(configs: dict, epoch_num: int, preds: Tensor):
128
if configs["loss_params"]["top_k"]:
129
if configs["top_k_params"]["gamma"] is not None:
130
k = int((configs["top_k_params"]["gamma"]) ** epoch_num * configs["top_k_params"]["k"])
131
else:
132
k = configs["top_k_params"]["k"]
133
preds = torch.topk(preds, k)[0]
134
return preds
135
136
137
def disc_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
138
# Train with real
139
real = instance_noise(configs, epoch_num, real)
140
real_pred = discriminator(real)
141
real_gt = torch.ones_like(real_pred)
142
real_loss = F.binary_cross_entropy(real_pred, real_gt)
143
144
# Train with fake
145
fake = get_sample(generator, real, configs["loss_params"])
146
fake = instance_noise(configs, epoch_num, fake)
147
fake_pred = discriminator(fake)
148
fake_pred = top_k(configs, epoch_num, fake_pred)
149
fake_gt = torch.zeros_like(fake_pred)
150
fake_loss = F.binary_cross_entropy(fake_pred, fake_gt)
151
152
disc_loss = real_loss + fake_loss
153
154
return disc_loss
155
156
157
def gen_loss(configs: dict, discriminator: Callable, generator: Callable, epoch_num: int, real: Tensor) -> Tensor:
158
# Train with fake
159
fake = get_sample(generator, real, configs["loss_params"])
160
fake = instance_noise(configs, epoch_num, fake)
161
fake_pred = discriminator(fake)
162
fake_pred = top_k(configs, epoch_num, fake_pred)
163
fake_gt = torch.ones_like(fake_pred)
164
gen_loss = F.binary_cross_entropy(fake_pred, fake_gt)
165
166
return gen_loss
167
168