Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_mlp_mnist.py
1192 views
1
"""
2
Install pytorch lightning and einops
3
4
pip install pytorch_lightning einops
5
"""
6
7
import torch
8
from torch import nn
9
import pytorch_lightning as pl
10
from torch.nn import functional as F
11
from torchvision.datasets import MNIST
12
from torchvision import transforms
13
from argparse import ArgumentParser
14
from torch.utils.data import DataLoader
15
from einops import rearrange
16
17
18
class VAE(nn.Module):
19
def __init__(self, n_z, model_name="vae"):
20
super().__init__()
21
self.encoder = nn.Sequential(nn.Linear(28 * 28, 512), nn.ReLU())
22
self.model_name = model_name
23
self.fc_mu = nn.Linear(512, n_z)
24
self.fc_var = nn.Linear(512, n_z)
25
self.decoder = nn.Sequential(nn.Linear(n_z, 512), nn.ReLU(), nn.Linear(512, 28 * 28), nn.Sigmoid())
26
27
def forward(self, x):
28
# in lightning, forward defines the prediction/inference actions
29
x = self.encoder(x)
30
mu = self.fc_mu(x)
31
log_var = self.fc_var(x)
32
p, q, z = self.sample(mu, log_var)
33
return self.decoder(z)
34
35
def _run_step(self, x):
36
x = self.encoder(x)
37
mu = self.fc_mu(x)
38
log_var = self.fc_var(x)
39
p, q, z = self.sample(mu, log_var)
40
return z, self.decoder(z), p, q
41
42
def encode(self, x):
43
x = self.encoder(x)
44
mu = self.fc_mu(x)
45
return mu
46
47
def sample(self, mu, log_var):
48
std = torch.exp(log_var / 2)
49
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
50
q = torch.distributions.Normal(mu, std)
51
z = q.rsample()
52
return p, q, z
53
54
55
class BasicVAEModule(pl.LightningModule):
56
def __init__(self, n_z=2, kl_coeff=0.1, lr=0.001):
57
super().__init__()
58
self.vae = VAE(n_z)
59
self.kl_coeff = kl_coeff
60
self.lr = lr
61
62
def forward(self, x):
63
return self.vae(x)
64
65
def step(self, batch, batch_idx):
66
x, y = batch
67
z, x_hat, p, q = self.vae._run_step(x)
68
69
recon_loss = F.binary_cross_entropy(x_hat, x, reduction="sum")
70
71
log_qz = q.log_prob(z)
72
log_pz = p.log_prob(z)
73
74
kl = log_qz - log_pz
75
kl = kl.sum() # I tried sum, here
76
kl *= self.kl_coeff
77
78
loss = kl + recon_loss
79
80
logs = {
81
"recon_loss": recon_loss,
82
"kl": kl,
83
"loss": loss,
84
}
85
return loss, logs
86
87
def training_step(self, batch, batch_idx):
88
loss, logs = self.step(batch, batch_idx)
89
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
90
return loss
91
92
def validation_step(self, batch, batch_idx):
93
loss, logs = self.step(batch, batch_idx)
94
self.log_dict({f"val_{k}": v for k, v in logs.items()})
95
return loss
96
97
def configure_optimizers(self):
98
return torch.optim.Adam(self.parameters(), lr=self.lr)
99
100
101
if __name__ == "__main__":
102
parser = ArgumentParser(description="Hyperparameters for our experiments")
103
parser.add_argument("--latent-dim", type=int, default=12, help="size of latent dim for our vae")
104
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
105
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
106
parser.add_argument("--bs", type=int, default=500, help="batch size")
107
hparams = parser.parse_args()
108
109
mnist_full = MNIST(
110
".",
111
download=True,
112
train=True,
113
transform=transforms.Compose([transforms.ToTensor(), lambda x: rearrange(x, "c h w -> (c h w)")]),
114
)
115
dm = DataLoader(mnist_full, batch_size=hparams.bs, shuffle=True)
116
vae = BasicVAEModule(hparams.latent_dim)
117
118
trainer = pl.Trainer(gpus=hparams.gpus, weights_summary="full", max_epochs=hparams.epochs)
119
trainer.fit(vae, dm)
120
torch.save(vae.state_dict(), "vae-mnist-mlp.ckpt")
121
122