Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/ae_mnist_conv.py
1192 views
1
import superimport
2
3
import torch
4
import numpy as np
5
import torch.nn as nn
6
from torchvision.datasets import MNIST
7
import torch.nn.functional as F
8
import torchvision.transforms as transforms
9
from torch.utils.data import DataLoader
10
from pytorch_lightning import LightningModule, Trainer
11
from einops import rearrange
12
from argparse import ArgumentParser
13
14
class ConvAEModule(nn.Module):
15
def __init__(self, input_shape,
16
encoder_conv_filters,
17
decoder_conv_t_filters,
18
latent_dim,
19
deterministic=False):
20
super(ConvAEModule, self).__init__()
21
self.input_shape = input_shape
22
23
self.latent_dim = latent_dim
24
self.deterministic = deterministic
25
26
all_channels = [self.input_shape[0]] + encoder_conv_filters
27
28
self.enc_convs = nn.ModuleList([])
29
30
# encoder_conv_layers
31
for i in range(len(encoder_conv_filters)):
32
self.enc_convs.append(nn.Conv2d(all_channels[i], all_channels[i + 1],
33
kernel_size=3, stride=2, padding=1))
34
if not self.latent_dim == 2:
35
self.enc_convs.append(nn.BatchNorm2d(all_channels[i + 1]))
36
self.enc_convs.append(nn.LeakyReLU())
37
38
self.flatten_out_size = self.flatten_enc_out_shape(input_shape)
39
40
if self.latent_dim == 2:
41
self.mu_linear = nn.Linear(self.flatten_out_size, self.latent_dim)
42
else:
43
self.mu_linear = nn.Sequential(
44
nn.Linear(self.flatten_out_size, self.latent_dim),
45
nn.LeakyReLU(),
46
nn.Dropout(0.2)
47
)
48
49
if self.latent_dim == 2:
50
self.log_var_linear = nn.Linear(self.flatten_out_size, self.latent_dim)
51
else:
52
self.log_var_linear = nn.Sequential(
53
nn.Linear(self.flatten_out_size, self.latent_dim),
54
nn.LeakyReLU(),
55
nn.Dropout(0.2)
56
)
57
58
if self.latent_dim == 2:
59
self.decoder_linear = nn.Linear(self.latent_dim, self.flatten_out_size)
60
else:
61
self.decoder_linear = nn.Sequential(
62
nn.Linear(self.latent_dim, self.flatten_out_size),
63
nn.LeakyReLU(),
64
nn.Dropout(0.2)
65
)
66
67
all_t_channels = [encoder_conv_filters[-1]] + decoder_conv_t_filters
68
69
self.dec_t_convs = nn.ModuleList([])
70
71
num = len(decoder_conv_t_filters)
72
73
# decoder_trans_conv_layers
74
for i in range(num - 1):
75
self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))
76
self.dec_t_convs.append(nn.ConvTranspose2d(all_t_channels[i], all_t_channels[i + 1],
77
3, stride=1, padding=1))
78
if not self.latent_dim == 2:
79
self.dec_t_convs.append(nn.BatchNorm2d(all_t_channels[i + 1]))
80
self.dec_t_convs.append(nn.LeakyReLU())
81
82
self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))
83
self.dec_t_convs.append(nn.ConvTranspose2d(all_t_channels[num - 1], all_t_channels[num],
84
3, stride=1, padding=1))
85
self.dec_t_convs.append(nn.Sigmoid())
86
87
def reparameterize(self, mu, log_var):
88
std = torch.exp(0.5 * log_var) # standard deviation
89
eps = torch.randn_like(std) # `randn_like` as we need the same size
90
sample = mu + (eps * std) # sampling
91
return sample
92
93
def _run_step(self, x):
94
mu, log_var = self.encode(x)
95
std = torch.exp(0.5*log_var)
96
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
97
q = torch.distributions.Normal(mu, std)
98
z = self.reparameterize(mu,log_var)
99
recon = self.decode(z)
100
return z, recon, p, q
101
102
def flatten_enc_out_shape(self, input_shape):
103
x = torch.zeros(1, *input_shape)
104
for l in self.enc_convs:
105
x = l(x)
106
self.shape_before_flattening = x.shape
107
return int(np.prod(self.shape_before_flattening))
108
109
def encode(self, x):
110
for l in self.enc_convs:
111
x = l(x)
112
x = x.view(x.size()[0], -1) # flatten
113
mu = self.mu_linear(x)
114
log_var = self.log_var_linear(x)
115
return mu, log_var
116
117
def decode(self, z):
118
z = self.decoder_linear(z)
119
recon = z.view(z.size()[0], *self.shape_before_flattening[1:])
120
for l in self.dec_t_convs:
121
recon = l(recon)
122
return recon
123
124
def forward(self, x):
125
mu, log_var = self.encode(x)
126
if self.deterministic:
127
return self.decode(mu), mu, None
128
else:
129
z = self.reparameterize(mu, log_var)
130
recon = self.decode(z)
131
return recon, mu, log_var
132
133
class ConvAE(LightningModule):
134
def __init__(self,input_shape,
135
encoder_conv_filters,
136
decoder_conv_t_filters,
137
latent_dim,
138
kl_coeff=0.1,
139
lr = 0.001):
140
super(ConvAE, self).__init__()
141
self.kl_coeff = kl_coeff
142
self.lr = lr
143
self.vae = ConvAEModule(input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim)
144
145
def step(self, batch, batch_idx):
146
x, y = batch
147
z, x_hat, p, q = self.vae._run_step(x)
148
149
loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
150
151
logs = {
152
"loss": loss,
153
}
154
return loss, logs
155
156
def training_step(self, batch, batch_idx):
157
loss, logs = self.step(batch, batch_idx)
158
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
159
return loss
160
161
def validation_step(self, batch, batch_idx):
162
loss, logs = self.step(batch, batch_idx)
163
self.log_dict({f"val_{k}": v for k, v in logs.items()})
164
return loss
165
166
def configure_optimizers(self):
167
return torch.optim.Adam(self.parameters(), lr=self.lr)
168
169
if __name__ == "__main__":
170
parser = ArgumentParser(description='Hyperparameters for our experiments')
171
parser.add_argument('--bs', type=int, default=500, help="batch size")
172
parser.add_argument('--epochs', type=int, default=50, help="num epochs")
173
parser.add_argument('--latent-dim', type=int, default=20, help="size of latent dim for our vae")
174
parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
175
parser.add_argument('--kl-coeff', type=int, default=0.75, help="kl coeff aka beta term in the elbo loss function")
176
hparams = parser.parse_args()
177
178
ae = ConvAE((1, 28, 28),
179
encoder_conv_filters=[28,64,64],
180
decoder_conv_t_filters=[64,28,1],
181
latent_dim=hparams.latent_dim, kl_coeff=hparams.kl_coeff, lr=hparams.lr)
182
183
mnist_full = MNIST(".", train=True, download=True,
184
transform=transforms.Compose([transforms.ToTensor(),
185
transforms.Resize((32,32))]))
186
dm = DataLoader(mnist_full, batch_size=hparams.bs)
187
trainer = Trainer(gpus=1, weights_summary='full', max_epochs=hparams.epochs)
188
trainer.fit(ae, dm)
189
190
torch.save(ae.state_dict(), f"ae-mnist-conv-latent-dim-{hparams.latent_dim}.ckpt")
191
192