Path: blob/master/deprecated/vae/standalone/vae_info_celeba_lightning.py
1192 views
# -*- coding: utf-8 -*-1"""2Author: Ang Ming Liang34Please run the following command before running the script56wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py7or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py89Then, make sure to get your kaggle.json from kaggle.com then run1011mkdir /root/.kaggle12cp kaggle.json /root/.kaggle/kaggle.json13chmod 600 /root/.kaggle/kaggle.json14rm kaggle.json1516to copy kaggle.json into a folder first17"""1819import torch20import torch.nn as nn21import matplotlib.pyplot as plt22import torch.nn.functional as F23import torchvision.transforms as transforms24import torchvision.utils as vutils25from pytorch_lightning import LightningModule, Trainer26from data import CelebADataModule27from argparse import ArgumentParser28from einops import rearrange2930IMAGE_SIZE = 6431CROP = 12832DATA_PATH = "kaggle"3334trans = []35trans.append(transforms.RandomHorizontalFlip())36if CROP > 0:37trans.append(transforms.CenterCrop(CROP))38trans.append(transforms.Resize(IMAGE_SIZE))39trans.append(transforms.ToTensor())40transform = transforms.Compose(trans)414243def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:44# Convert the tensors into row and column vectors45D = x1.size(1)46N = x1.size(0)4748x1 = x1.unsqueeze(-2) # Make it into a column tensor49x2 = x2.unsqueeze(-3) # Make it into a row tensor5051"""52Usually the below lines are not required, especially in our case,53but this is useful when x1 and x2 have different sizes54along the 0th dimension.55"""56x1 = x1.expand(N, N, D)57x2 = x2.expand(N, N, D)5859if kernel_type == "rbf":60result = compute_rbf(x1, x2)61elif kernel_type == "imq":62result = compute_inv_mult_quad(x1, x2)63else:64raise ValueError("Undefined kernel type.")6566return result676869def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:70"""71Computes the RBF Kernel between x1 and x2.72:param x1: (Tensor)73:param x2: (Tensor)74:param eps: (Float)75:return:76"""77z_dim = x2.size(-1)78sigma = 2.0 * z_dim * latent_var7980result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))81return result828384def compute_inv_mult_quad(85x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-786) -> torch.Tensor:87"""88Computes the Inverse Multi-Quadratics Kernel between x1 and x2,89given by90k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}91:param x1: (Tensor)92:param x2: (Tensor)93:param eps: (Float)94:return:95"""96z_dim = x2.size(-1)97C = 2 * z_dim * latent_var98kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))99100# Exclude diagonal elements101result = kernel.sum() - kernel.diag().sum()102103return result104105106def MMD(prior_z: torch.Tensor, z: torch.Tensor):107108prior_z__kernel = compute_kernel(prior_z, prior_z)109z__kernel = compute_kernel(z, z)110priorz_z__kernel = compute_kernel(prior_z, z)111112mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()113return mmd114115116def kl_divergence(mean, logvar):117return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))118119120class VAE(LightningModule):121"""122Standard VAE with Gaussian Prior and approx posterior.123"""124125def __init__(126self,127input_height: int,128hidden_dims=None,129in_channels=3,130enc_out_dim: int = 512,131alpha: float = 0.99,132beta: float = 2,133latent_dim: int = 256,134lr: float = 1e-3,135):136"""137Args:138input_height: height of the images139enc_type: option between resnet18 or resnet50140first_conv: use standard kernel_size 7, stride 2 at start or141replace it with kernel_size 3, stride 1 conv142maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2143enc_out_dim: set according to the out_channel count of144encoder used (512 for resnet18, 2048 for resnet50)145kl_coeff: coefficient for kl term of the loss146latent_dim: dim of latent space147lr: learning rate for Adam148"""149150super(VAE, self).__init__()151152self.save_hyperparameters()153154self.lr = lr155self.alpha = alpha156self.beta = beta157self.enc_out_dim = enc_out_dim158self.latent_dim = latent_dim159self.input_height = input_height160161modules = []162if hidden_dims is None:163hidden_dims = [32, 64, 128, 256, 512]164165# Build Encoder166for h_dim in hidden_dims:167modules.append(168nn.Sequential(169nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),170nn.BatchNorm2d(h_dim),171nn.LeakyReLU(),172)173)174in_channels = h_dim175176self.encoder = nn.Sequential(*modules)177self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)178self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)179180# Build Decoder181modules = []182183self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)184185hidden_dims.reverse()186187for i in range(len(hidden_dims) - 1):188modules.append(189nn.Sequential(190nn.ConvTranspose2d(191hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1192),193nn.BatchNorm2d(hidden_dims[i + 1]),194nn.LeakyReLU(),195)196)197198self.decoder = nn.Sequential(*modules)199200self.final_layer = nn.Sequential(201nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),202nn.BatchNorm2d(hidden_dims[-1]),203nn.LeakyReLU(),204nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),205nn.Sigmoid(),206)207208@staticmethod209def pretrained_weights_available():210return list(VAE.pretrained_urls.keys())211212def from_pretrained(self, checkpoint_name):213if checkpoint_name not in VAE.pretrained_urls:214raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")215216return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)217218def encode(self, x):219x = self.encoder(x)220x = torch.flatten(x, start_dim=1)221mu = self.fc_mu(x)222log_var = self.fc_var(x)223return mu, log_var224225def forward(self, x):226mu, log_var = self.encode(x)227z = self.sample(mu, log_var)228return self.decode(z)229230def _run_step(self, x):231mu, log_var = self.encode(x)232z = self.sample(mu, log_var)233return z, self.decode(z), mu, log_var234235def sample(self, mu, log_var):236std = torch.exp(0.5 * log_var)237eps = torch.randn_like(std)238return eps * std + mu239240def step(self, batch, batch_idx):241x, y = batch242z, x_hat, mu, logvar = self._run_step(x)243244recon_loss = F.mse_loss(x_hat, x, reduction="mean")245kld_loss = kl_divergence(mu, logvar)246247mmd = MMD(torch.randn_like(z), z)248loss = recon_loss + (1 - self.alpha) * kld_loss + (self.alpha + self.beta - 1) * mmd249250logs = {251"recon_loss": recon_loss,252"mmd": mmd,253"loss": loss,254}255return loss, logs256257def step_sample(self, batch, batch_idx):258x, y = batch259z, x_hat = self._run_step(x)260261def decode(self, z):262result = self.decoder_input(z)263result = result.view(-1, 512, 2, 2)264result = self.decoder(result)265result = self.final_layer(result)266return result267268def training_step(self, batch, batch_idx):269loss, logs = self.step(batch, batch_idx)270self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)271return loss272273def validation_step(self, batch, batch_idx):274loss, logs = self.step(batch, batch_idx)275self.log_dict({f"val_{k}": v for k, v in logs.items()})276return loss277278def configure_optimizers(self):279return torch.optim.Adam(self.parameters(), lr=self.lr)280281282if __name__ == "__main__":283parser = ArgumentParser(description="Hyperparameters for our experiments")284parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")285parser.add_argument("--epochs", type=int, default=50, help="num epochs")286parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")287parser.add_argument("--bs", type=int, default=256, help="batch size")288parser.add_argument("--alpha", type=int, default=1, help="kl coeff")289parser.add_argument("--beta", type=int, default=1, help="mmd coeff")290parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")291hparams = parser.parse_args()292293m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)294dm = CelebADataModule(295data_dir=DATA_PATH,296target_type="attr",297train_transform=transform,298val_transform=transform,299download=True,300batch_size=hparams.bs,301)302trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)303trainer.fit(m, datamodule=dm)304torch.save(m.state_dict(), "infovae-celeba-conv.ckpt")305306307