Path: blob/master/deprecated/vae/standalone/vae_conv_mnist.py
1192 views
"""1Install pytorch lightning and einops23pip install pytorch_lightning einops4"""5import torch6import torch.nn as nn7import numpy as np8from torch.nn import functional as F9from torchvision.datasets import MNIST10from torch.utils.data import DataLoader11import torchvision.transforms as transforms12from pytorch_lightning import LightningModule, Trainer13from argparse import ArgumentParser141516class ConvVAEModule(nn.Module):17def __init__(self, input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim, deterministic=False):18super(ConvVAEModule, self).__init__()19self.input_shape = input_shape2021self.latent_dim = latent_dim22self.deterministic = deterministic2324all_channels = [self.input_shape[0]] + encoder_conv_filters2526self.enc_convs = nn.ModuleList([])2728# encoder_conv_layers29for i in range(len(encoder_conv_filters)):30self.enc_convs.append(nn.Conv2d(all_channels[i], all_channels[i + 1], kernel_size=3, stride=2, padding=1))31if not self.latent_dim == 2:32self.enc_convs.append(nn.BatchNorm2d(all_channels[i + 1]))33self.enc_convs.append(nn.LeakyReLU())3435self.flatten_out_size = self.flatten_enc_out_shape(input_shape)3637if self.latent_dim == 2:38self.mu_linear = nn.Linear(self.flatten_out_size, self.latent_dim)39else:40self.mu_linear = nn.Sequential(41nn.Linear(self.flatten_out_size, self.latent_dim), nn.LeakyReLU(), nn.Dropout(0.2)42)4344if self.latent_dim == 2:45self.log_var_linear = nn.Linear(self.flatten_out_size, self.latent_dim)46else:47self.log_var_linear = nn.Sequential(48nn.Linear(self.flatten_out_size, self.latent_dim), nn.LeakyReLU(), nn.Dropout(0.2)49)5051if self.latent_dim == 2:52self.decoder_linear = nn.Linear(self.latent_dim, self.flatten_out_size)53else:54self.decoder_linear = nn.Sequential(55nn.Linear(self.latent_dim, self.flatten_out_size), nn.LeakyReLU(), nn.Dropout(0.2)56)5758all_t_channels = [encoder_conv_filters[-1]] + decoder_conv_t_filters5960self.dec_t_convs = nn.ModuleList([])6162num = len(decoder_conv_t_filters)6364# decoder_trans_conv_layers65for i in range(num - 1):66self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))67self.dec_t_convs.append(68nn.ConvTranspose2d(all_t_channels[i], all_t_channels[i + 1], 3, stride=1, padding=1)69)70if not self.latent_dim == 2:71self.dec_t_convs.append(nn.BatchNorm2d(all_t_channels[i + 1]))72self.dec_t_convs.append(nn.LeakyReLU())7374self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2))75self.dec_t_convs.append(76nn.ConvTranspose2d(all_t_channels[num - 1], all_t_channels[num], 3, stride=1, padding=1)77)78self.dec_t_convs.append(nn.Sigmoid())7980def reparameterize(self, mu, log_var):81std = torch.exp(0.5 * log_var) # standard deviation82eps = torch.randn_like(std) # `randn_like` as we need the same size83sample = mu + (eps * std) # sampling84return sample8586def _run_step(self, x):87mu, log_var = self.encode(x)88std = torch.exp(0.5 * log_var)89p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))90q = torch.distributions.Normal(mu, std)91z = self.reparameterize(mu, log_var)92recon = self.decode(z)93return z, recon, p, q9495def flatten_enc_out_shape(self, input_shape):96x = torch.zeros(1, *input_shape)97for l in self.enc_convs:98x = l(x)99self.shape_before_flattening = x.shape100return int(np.prod(self.shape_before_flattening))101102def encode(self, x):103for l in self.enc_convs:104x = l(x)105x = x.view(x.size()[0], -1) # flatten106mu = self.mu_linear(x)107log_var = self.log_var_linear(x)108return mu, log_var109110def decode(self, z):111z = self.decoder_linear(z)112recon = z.view(z.size()[0], *self.shape_before_flattening[1:])113for l in self.dec_t_convs:114recon = l(recon)115return recon116117def forward(self, x):118mu, log_var = self.encode(x)119if self.deterministic:120return self.decode(mu), mu, None121else:122z = self.reparameterize(mu, log_var)123recon = self.decode(z)124return recon, mu, log_var125126127class ConvVAE(LightningModule):128def __init__(self, input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim, kl_coeff=0.1, lr=0.001):129super(ConvVAE, self).__init__()130self.kl_coeff = kl_coeff131self.lr = lr132self.vae = ConvVAEModule(input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim)133134def step(self, batch, batch_idx):135x, y = batch136z, x_hat, p, q = self.vae._run_step(x)137138recon_loss = F.binary_cross_entropy(x_hat, x, reduction="sum")139140log_qz = q.log_prob(z)141log_pz = p.log_prob(z)142143kl = log_qz - log_pz144kl = kl.sum() # I tried sum, here145kl *= self.kl_coeff146147loss = kl + recon_loss148149logs = {150"recon_loss": recon_loss,151"kl": kl,152"loss": loss,153}154return loss, logs155156def training_step(self, batch, batch_idx):157loss, logs = self.step(batch, batch_idx)158self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)159return loss160161def validation_step(self, batch, batch_idx):162loss, logs = self.step(batch, batch_idx)163self.log_dict({f"val_{k}": v for k, v in logs.items()})164return loss165166def configure_optimizers(self):167return torch.optim.Adam(self.parameters(), lr=self.lr)168169170if __name__ == "__main__":171parser = ArgumentParser(description="Hyperparameters for our experiments")172parser.add_argument("--bs", type=int, default=500, help="batch size")173parser.add_argument("--epochs", type=int, default=50, help="num epochs")174parser.add_argument("--latent-dim", type=int, default=2, help="size of latent dim for our vae")175parser.add_argument("--lr", type=float, default=0.001, help="learning rate")176parser.add_argument("--kl-coeff", type=int, default=5, help="kl coeff aka beta term in the elbo loss function")177hparams = parser.parse_args()178179m = ConvVAE(180(1, 28, 28),181encoder_conv_filters=[28, 64, 64],182decoder_conv_t_filters=[64, 28, 1],183latent_dim=hparams.latent_dim,184kl_coeff=hparams.kl_coeff,185lr=hparams.lr,186)187188mnist_full = MNIST(189".",190train=True,191download=True,192transform=transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))]),193)194dm = DataLoader(mnist_full, batch_size=hparams.bs)195trainer = Trainer(gpus=1, weights_summary="full", max_epochs=hparams.epochs)196trainer.fit(m, dm)197torch.save(m.state_dict(), "vae-mnist-conv.ckpt")198199200