Path: blob/master/deprecated/vae/standalone/vae_mlp_mnist.py
1192 views
"""1Install pytorch lightning and einops23pip install pytorch_lightning einops4"""56import torch7from torch import nn8import pytorch_lightning as pl9from torch.nn import functional as F10from torchvision.datasets import MNIST11from torchvision import transforms12from argparse import ArgumentParser13from torch.utils.data import DataLoader14from einops import rearrange151617class VAE(nn.Module):18def __init__(self, n_z, model_name="vae"):19super().__init__()20self.encoder = nn.Sequential(nn.Linear(28 * 28, 512), nn.ReLU())21self.model_name = model_name22self.fc_mu = nn.Linear(512, n_z)23self.fc_var = nn.Linear(512, n_z)24self.decoder = nn.Sequential(nn.Linear(n_z, 512), nn.ReLU(), nn.Linear(512, 28 * 28), nn.Sigmoid())2526def forward(self, x):27# in lightning, forward defines the prediction/inference actions28x = self.encoder(x)29mu = self.fc_mu(x)30log_var = self.fc_var(x)31p, q, z = self.sample(mu, log_var)32return self.decoder(z)3334def _run_step(self, x):35x = self.encoder(x)36mu = self.fc_mu(x)37log_var = self.fc_var(x)38p, q, z = self.sample(mu, log_var)39return z, self.decoder(z), p, q4041def encode(self, x):42x = self.encoder(x)43mu = self.fc_mu(x)44return mu4546def sample(self, mu, log_var):47std = torch.exp(log_var / 2)48p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))49q = torch.distributions.Normal(mu, std)50z = q.rsample()51return p, q, z525354class BasicVAEModule(pl.LightningModule):55def __init__(self, n_z=2, kl_coeff=0.1, lr=0.001):56super().__init__()57self.vae = VAE(n_z)58self.kl_coeff = kl_coeff59self.lr = lr6061def forward(self, x):62return self.vae(x)6364def step(self, batch, batch_idx):65x, y = batch66z, x_hat, p, q = self.vae._run_step(x)6768recon_loss = F.binary_cross_entropy(x_hat, x, reduction="sum")6970log_qz = q.log_prob(z)71log_pz = p.log_prob(z)7273kl = log_qz - log_pz74kl = kl.sum() # I tried sum, here75kl *= self.kl_coeff7677loss = kl + recon_loss7879logs = {80"recon_loss": recon_loss,81"kl": kl,82"loss": loss,83}84return loss, logs8586def training_step(self, batch, batch_idx):87loss, logs = self.step(batch, batch_idx)88self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)89return loss9091def validation_step(self, batch, batch_idx):92loss, logs = self.step(batch, batch_idx)93self.log_dict({f"val_{k}": v for k, v in logs.items()})94return loss9596def configure_optimizers(self):97return torch.optim.Adam(self.parameters(), lr=self.lr)9899100if __name__ == "__main__":101parser = ArgumentParser(description="Hyperparameters for our experiments")102parser.add_argument("--latent-dim", type=int, default=12, help="size of latent dim for our vae")103parser.add_argument("--epochs", type=int, default=50, help="num epochs")104parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")105parser.add_argument("--bs", type=int, default=500, help="batch size")106hparams = parser.parse_args()107108mnist_full = MNIST(109".",110download=True,111train=True,112transform=transforms.Compose([transforms.ToTensor(), lambda x: rearrange(x, "c h w -> (c h w)")]),113)114dm = DataLoader(mnist_full, batch_size=hparams.bs, shuffle=True)115vae = BasicVAEModule(hparams.latent_dim)116117trainer = pl.Trainer(gpus=hparams.gpus, weights_summary="full", max_epochs=hparams.epochs)118trainer.fit(vae, dm)119torch.save(vae.state_dict(), "vae-mnist-mlp.ckpt")120121122