Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_conv_mnist.py
1192 views
1
"""
2
Install pytorch lightning and einops
3
4
pip install pytorch_lightning einops
5
"""
6
import torch
7
import torch.nn as nn
8
import numpy as np
9
from torch.nn import functional as F
10
from torchvision.datasets import MNIST
11
from torch.utils.data import DataLoader
12
import torchvision.transforms as transforms
13
from pytorch_lightning import LightningModule, Trainer
14
from argparse import ArgumentParser
15
16
17
class ConvVAEModule(nn.Module):
18
def __init__(self, input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim, deterministic=False):
19
super(ConvVAEModule, self).__init__()
20
self.input_shape = input_shape
21
22
self.latent_dim = latent_dim
23
self.deterministic = deterministic
24
25
all_channels = [self.input_shape[0]] + encoder_conv_filters
26
27
self.enc_convs = nn.ModuleList([])
28
29
# encoder_conv_layers
30
for i in range(len(encoder_conv_filters)):
31
self.enc_convs.append(nn.Conv2d(all_channels[i], all_channels[i + 1], kernel_size=3, stride=2, padding=1))
32
if not self.latent_dim == 2:
33
self.enc_convs.append(nn.BatchNorm2d(all_channels[i + 1]))
34
self.enc_convs.append(nn.LeakyReLU())
35
36
self.flatten_out_size = self.flatten_enc_out_shape(input_shape)
37
38
if self.latent_dim == 2:
39
self.mu_linear = nn.Linear(self.flatten_out_size, self.latent_dim)
40
else:
41
self.mu_linear = nn.Sequential(
42
nn.Linear(self.flatten_out_size, self.latent_dim), nn.LeakyReLU(), nn.Dropout(0.2)
43
)
44
45
if self.latent_dim == 2:
46
self.log_var_linear = nn.Linear(self.flatten_out_size, self.latent_dim)
47
else:
48
self.log_var_linear = nn.Sequential(
49
nn.Linear(self.flatten_out_size, self.latent_dim), nn.LeakyReLU(), nn.Dropout(0.2)
50
)
51
52
if self.latent_dim == 2:
53
self.decoder_linear = nn.Linear(self.latent_dim, self.flatten_out_size)
54
else:
55
self.decoder_linear = nn.Sequential(
56
nn.Linear(self.latent_dim, self.flatten_out_size), nn.LeakyReLU(), nn.Dropout(0.2)
57
)
58
59
all_t_channels = [encoder_conv_filters[-1]] + decoder_conv_t_filters
60
61
self.dec_t_convs = nn.ModuleList([])
62
63
num = len(decoder_conv_t_filters)
64
65
# decoder_trans_conv_layers
66
for i in range(num - 1):
67
self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))
68
self.dec_t_convs.append(
69
nn.ConvTranspose2d(all_t_channels[i], all_t_channels[i + 1], 3, stride=1, padding=1)
70
)
71
if not self.latent_dim == 2:
72
self.dec_t_convs.append(nn.BatchNorm2d(all_t_channels[i + 1]))
73
self.dec_t_convs.append(nn.LeakyReLU())
74
75
self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))
76
self.dec_t_convs.append(
77
nn.ConvTranspose2d(all_t_channels[num - 1], all_t_channels[num], 3, stride=1, padding=1)
78
)
79
self.dec_t_convs.append(nn.Sigmoid())
80
81
def reparameterize(self, mu, log_var):
82
std = torch.exp(0.5 * log_var) # standard deviation
83
eps = torch.randn_like(std) # `randn_like` as we need the same size
84
sample = mu + (eps * std) # sampling
85
return sample
86
87
def _run_step(self, x):
88
mu, log_var = self.encode(x)
89
std = torch.exp(0.5 * log_var)
90
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
91
q = torch.distributions.Normal(mu, std)
92
z = self.reparameterize(mu, log_var)
93
recon = self.decode(z)
94
return z, recon, p, q
95
96
def flatten_enc_out_shape(self, input_shape):
97
x = torch.zeros(1, *input_shape)
98
for l in self.enc_convs:
99
x = l(x)
100
self.shape_before_flattening = x.shape
101
return int(np.prod(self.shape_before_flattening))
102
103
def encode(self, x):
104
for l in self.enc_convs:
105
x = l(x)
106
x = x.view(x.size()[0], -1) # flatten
107
mu = self.mu_linear(x)
108
log_var = self.log_var_linear(x)
109
return mu, log_var
110
111
def decode(self, z):
112
z = self.decoder_linear(z)
113
recon = z.view(z.size()[0], *self.shape_before_flattening[1:])
114
for l in self.dec_t_convs:
115
recon = l(recon)
116
return recon
117
118
def forward(self, x):
119
mu, log_var = self.encode(x)
120
if self.deterministic:
121
return self.decode(mu), mu, None
122
else:
123
z = self.reparameterize(mu, log_var)
124
recon = self.decode(z)
125
return recon, mu, log_var
126
127
128
class ConvVAE(LightningModule):
129
def __init__(self, input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim, kl_coeff=0.1, lr=0.001):
130
super(ConvVAE, self).__init__()
131
self.kl_coeff = kl_coeff
132
self.lr = lr
133
self.vae = ConvVAEModule(input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim)
134
135
def step(self, batch, batch_idx):
136
x, y = batch
137
z, x_hat, p, q = self.vae._run_step(x)
138
139
recon_loss = F.binary_cross_entropy(x_hat, x, reduction="sum")
140
141
log_qz = q.log_prob(z)
142
log_pz = p.log_prob(z)
143
144
kl = log_qz - log_pz
145
kl = kl.sum() # I tried sum, here
146
kl *= self.kl_coeff
147
148
loss = kl + recon_loss
149
150
logs = {
151
"recon_loss": recon_loss,
152
"kl": kl,
153
"loss": loss,
154
}
155
return loss, logs
156
157
def training_step(self, batch, batch_idx):
158
loss, logs = self.step(batch, batch_idx)
159
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
160
return loss
161
162
def validation_step(self, batch, batch_idx):
163
loss, logs = self.step(batch, batch_idx)
164
self.log_dict({f"val_{k}": v for k, v in logs.items()})
165
return loss
166
167
def configure_optimizers(self):
168
return torch.optim.Adam(self.parameters(), lr=self.lr)
169
170
171
if __name__ == "__main__":
172
parser = ArgumentParser(description="Hyperparameters for our experiments")
173
parser.add_argument("--bs", type=int, default=500, help="batch size")
174
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
175
parser.add_argument("--latent-dim", type=int, default=2, help="size of latent dim for our vae")
176
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
177
parser.add_argument("--kl-coeff", type=int, default=5, help="kl coeff aka beta term in the elbo loss function")
178
hparams = parser.parse_args()
179
180
m = ConvVAE(
181
(1, 28, 28),
182
encoder_conv_filters=[28, 64, 64],
183
decoder_conv_t_filters=[64, 28, 1],
184
latent_dim=hparams.latent_dim,
185
kl_coeff=hparams.kl_coeff,
186
lr=hparams.lr,
187
)
188
189
mnist_full = MNIST(
190
".",
191
train=True,
192
download=True,
193
transform=transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))]),
194
)
195
dm = DataLoader(mnist_full, batch_size=hparams.bs)
196
trainer = Trainer(gpus=1, weights_summary="full", max_epochs=hparams.epochs)
197
trainer.fit(m, dm)
198
torch.save(m.state_dict(), "vae-mnist-conv.ckpt")
199
200