Path: blob/master/deprecated/vae/standalone/vae_mmd_celeba_lightning.py
1192 views
"""1Author: Ang Ming Liang23Please run the following command before running the script45wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py6or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py78Then, make sure to get your kaggle.json from kaggle.com then run910mkdir /root/.kaggle11cp kaggle.json /root/.kaggle/kaggle.json12chmod 600 /root/.kaggle/kaggle.json13rm kaggle.json1415to copy kaggle.json into a folder first16"""1718import torch19import torch.nn as nn20import torch.nn.functional as F21import torchvision.transforms as transforms22from pytorch_lightning import LightningModule, Trainer23from data import CelebADataModule24from argparse import ArgumentParser25from einops import rearrange2627IMAGE_SIZE = 6428CROP = 12829DATA_PATH = "kaggle"3031trans = []32trans.append(transforms.RandomHorizontalFlip())33if CROP > 0:34trans.append(transforms.CenterCrop(CROP))35trans.append(transforms.Resize(IMAGE_SIZE))36trans.append(transforms.ToTensor())37transform = transforms.Compose(trans)383940def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:41# Convert the tensors into row and column vectors42D = x1.size(1)43N = x1.size(0)4445x1 = x1.unsqueeze(-2) # Make it into a column tensor46x2 = x2.unsqueeze(-3) # Make it into a row tensor4748"""49Usually the below lines are not required, especially in our case,50but this is useful when x1 and x2 have different sizes51along the 0th dimension.52"""53x1 = x1.expand(N, N, D)54x2 = x2.expand(N, N, D)5556if kernel_type == "rbf":57result = compute_rbf(x1, x2)58elif kernel_type == "imq":59result = compute_inv_mult_quad(x1, x2)60else:61raise ValueError("Undefined kernel type.")6263return result646566def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:67"""68Computes the RBF Kernel between x1 and x2.69:param x1: (Tensor)70:param x2: (Tensor)71:param eps: (Float)72:return:73"""74z_dim = x2.size(-1)75sigma = 2.0 * z_dim * latent_var7677result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))78return result798081def compute_inv_mult_quad(82x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-783) -> torch.Tensor:84"""85Computes the Inverse Multi-Quadratics Kernel between x1 and x2,86given by87k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}88:param x1: (Tensor)89:param x2: (Tensor)90:param eps: (Float)91:return:92"""93z_dim = x2.size(-1)94C = 2 * z_dim * latent_var95kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))9697# Exclude diagonal elements98result = kernel.sum() - kernel.diag().sum()99100return result101102103def MMD(prior_z: torch.Tensor, z: torch.Tensor):104105prior_z__kernel = compute_kernel(prior_z, prior_z)106z__kernel = compute_kernel(z, z)107priorz_z__kernel = compute_kernel(prior_z, z)108109mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()110return mmd111112113class VAE(LightningModule):114"""115Standard VAE with Gaussian Prior and approx posterior.116"""117118def __init__(119self,120input_height: int,121hidden_dims=None,122in_channels=3,123enc_out_dim: int = 512,124beta: float = 1,125latent_dim: int = 256,126lr: float = 1e-3,127):128"""129Args:130input_height: height of the images131enc_type: option between resnet18 or resnet50132first_conv: use standard kernel_size 7, stride 2 at start or133replace it with kernel_size 3, stride 1 conv134maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2135enc_out_dim: set according to the out_channel count of136encoder used (512 for resnet18, 2048 for resnet50)137kl_coeff: coefficient for kl term of the loss138latent_dim: dim of latent space139lr: learning rate for Adam140"""141142super(VAE, self).__init__()143144self.save_hyperparameters()145146self.lr = lr147self.beta = beta148self.enc_out_dim = enc_out_dim149self.latent_dim = latent_dim150self.input_height = input_height151152modules = []153if hidden_dims is None:154hidden_dims = [32, 64, 128, 256, 512]155156# Build Encoder157for h_dim in hidden_dims:158modules.append(159nn.Sequential(160nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),161nn.BatchNorm2d(h_dim),162nn.LeakyReLU(),163)164)165in_channels = h_dim166167self.encoder = nn.Sequential(*modules)168self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)169self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)170171# Build Decoder172modules = []173174self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)175176hidden_dims.reverse()177178for i in range(len(hidden_dims) - 1):179modules.append(180nn.Sequential(181nn.ConvTranspose2d(182hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1183),184nn.BatchNorm2d(hidden_dims[i + 1]),185nn.LeakyReLU(),186)187)188189self.decoder = nn.Sequential(*modules)190191self.final_layer = nn.Sequential(192nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),193nn.BatchNorm2d(hidden_dims[-1]),194nn.LeakyReLU(),195nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),196nn.Sigmoid(),197)198199@staticmethod200def pretrained_weights_available():201return list(VAE.pretrained_urls.keys())202203def from_pretrained(self, checkpoint_name):204if checkpoint_name not in VAE.pretrained_urls:205raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")206207return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)208209def encode(self, x):210x = self.encoder(x)211x = torch.flatten(x, start_dim=1)212mu = self.fc_mu(x)213return mu214215def forward(self, x):216z = self.encode(x)217return self.decode(z)218219def _run_step(self, x):220z = self.encode(x)221222return z, self.decode(z)223224def step(self, batch, batch_idx):225x, y = batch226z, x_hat = self._run_step(x)227228recon_loss = F.mse_loss(x_hat, x, reduction="mean")229mmd = MMD(torch.randn_like(z), z)230231loss = recon_loss + self.beta * mmd232233logs = {234"recon_loss": recon_loss,235"mmd": mmd,236}237return loss, logs238239def decode(self, z):240result = self.decoder_input(z)241result = result.view(-1, 512, 2, 2)242result = self.decoder(result)243result = self.final_layer(result)244return result245246def training_step(self, batch, batch_idx):247loss, logs = self.step(batch, batch_idx)248self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)249return loss250251def validation_step(self, batch, batch_idx):252loss, logs = self.step(batch, batch_idx)253self.log_dict({f"val_{k}": v for k, v in logs.items()})254return loss255256def configure_optimizers(self):257return torch.optim.Adam(self.parameters(), lr=self.lr)258259260if __name__ == "__main__":261parser = ArgumentParser(description="Hyperparameters for our experiments")262parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")263parser.add_argument("--epochs", type=int, default=50, help="num epochs")264parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")265parser.add_argument("--bs", type=int, default=256, help="batch size")266parser.add_argument("--beta", type=int, default=1, help="kl coeff aka beta term in the elbo loss function")267parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")268hparams = parser.parse_args()269270m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)271dm = CelebADataModule(272data_dir=DATA_PATH,273target_type="attr",274train_transform=transform,275val_transform=transform,276download=True,277batch_size=hparams.bs,278)279trainer = Trainer(gpus=1, weights_summary="full", max_epochs=10, auto_lr_find=True)280281# Run learning rate finder282lr_finder = trainer.tuner.lr_find(m, dm)283284# Results can be found in285lr_finder.results286287# Plot with288fig = lr_finder.plot(suggest=True)289fig.show()290291# Pick point based on plot, or get suggestion292new_lr = lr_finder.suggestion()293294# update hparams of the model295m.lr = new_lr296297trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)298trainer.fit(m, datamodule=dm)299torch.save(m.state_dict(), "mmd-vae-celeba-conv.ckpt")300301302