Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/ae_mnist_mlp.py
1192 views
1
import superimport
2
3
import torch
4
import torch.nn as nn
5
from torchvision.datasets import MNIST
6
import torch.nn.functional as F
7
import torchvision.transforms as transforms
8
from torch.utils.data import DataLoader
9
from pytorch_lightning import LightningModule, Trainer
10
from einops import rearrange
11
from argparse import ArgumentParser
12
13
class AE(nn.Module):
14
15
def __init__(self, n_z,
16
model_name="vae"):
17
super().__init__()
18
self.encoder = nn.Sequential(
19
nn.Linear(28*28, 512),
20
nn.ReLU()
21
)
22
self.model_name = model_name
23
self.fc_mu = nn.Linear(512, n_z)
24
self.decoder = nn.Sequential(
25
nn.Linear(n_z, 512),
26
nn.ReLU(),
27
nn.Linear(512, 28*28),
28
nn.Sigmoid()
29
)
30
31
def forward(self, x):
32
# in lightning, forward defines the prediction/inference actions
33
x = self.encoder(x)
34
mu = self.fc_mu(x)
35
return self.decoder(mu)
36
37
def encode(self, x):
38
x = self.encoder(x)
39
mu = self.fc_mu(x)
40
return mu
41
42
class BasicAEModule(LightningModule):
43
44
def __init__(self,
45
n_z=2,
46
kl_coeff=0.1,
47
lr=0.001):
48
super().__init__()
49
self.vae = AE(n_z)
50
self.kl_coeff = kl_coeff
51
self.lr = lr
52
53
def forward(self, x):
54
return self.vae(x)
55
56
def step(self, batch, batch_idx):
57
x, y = batch
58
x = rearrange(x, 'b c h w -> b (c h w)')
59
x_hat= self.vae(x)
60
61
loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
62
63
logs = {
64
"loss": loss,
65
}
66
67
return loss, logs
68
69
def training_step(self, batch, batch_idx):
70
loss, logs = self.step(batch, batch_idx)
71
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
72
return loss
73
74
def validation_step(self, batch, batch_idx):
75
loss, logs = self.step(batch, batch_idx)
76
self.log_dict({f"val_{k}": v for k, v in logs.items()})
77
return loss
78
79
def configure_optimizers(self):
80
return torch.optim.Adam(self.parameters(), lr=self.lr)
81
82
if __name__ == "__main__":
83
parser = ArgumentParser(description='Hyperparameters for our experiments')
84
parser.add_argument('--latent-dim', type=int, default=20, help="size of latent dim for our vae")
85
parser.add_argument('--epochs', type=int, default=50, help="num epochs")
86
parser.add_argument('--gpus', type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
87
parser.add_argument('--bs', type=int, default=500, help="batch size")
88
hparams = parser.parse_args()
89
90
mnist_full = MNIST(".", download=True, train=True,
91
transform=transforms.Compose([transforms.ToTensor()]))
92
dm = DataLoader(mnist_full, batch_size=hparams.bs, shuffle=True)
93
ae = BasicAEModule(hparams.latent_dim)
94
95
trainer = Trainer(gpus=hparams.gpus, weights_summary='full', max_epochs=hparams.epochs)
96
trainer.fit(ae, dm)
97
torch.save(ae.state_dict(), f"ae-mnist-mlp-latent-dim-{hparams.latent_dim}.ckpt")
98
99