Path: blob/master/deprecated/vae/standalone/vae_celeba_lightning.py
1192 views
"""1Author: Ang Ming Liang23Please run the following command before running the script45wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py6or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py78Then, make sure to get your kaggle.json from kaggle.com then run910mkdir /root/.kaggle11cp kaggle.json /root/.kaggle/kaggle.json12chmod 600 /root/.kaggle/kaggle.json13rm kaggle.json1415to copy kaggle.json into a folder first16"""1718import torch19import torch.nn as nn20import torch.nn.functional as F21import torchvision.transforms as transforms22from pytorch_lightning import LightningModule, Trainer23from data import CelebADataModule24from argparse import ArgumentParser2526IMAGE_SIZE = 6427CROP = 12828DATA_PATH = "kaggle"2930trans = []31trans.append(transforms.RandomHorizontalFlip())32if CROP > 0:33trans.append(transforms.CenterCrop(CROP))34trans.append(transforms.Resize(IMAGE_SIZE))35trans.append(transforms.ToTensor())36transform = transforms.Compose(trans)373839class VAE(LightningModule):40"""41Standard VAE with Gaussian Prior and approx posterior.42"""4344def __init__(45self,46input_height: int,47hidden_dims=None,48in_channels=3,49enc_out_dim: int = 512,50kl_coeff: float = 0.1,51latent_dim: int = 256,52lr: float = 1e-4,53):54"""55Args:56input_height: height of the images57enc_type: option between resnet18 or resnet5058first_conv: use standard kernel_size 7, stride 2 at start or59replace it with kernel_size 3, stride 1 conv60maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 261enc_out_dim: set according to the out_channel count of62encoder used (512 for resnet18, 2048 for resnet50)63kl_coeff: coefficient for kl term of the loss64latent_dim: dim of latent space65lr: learning rate for Adam66"""6768super(VAE, self).__init__()6970self.save_hyperparameters()7172self.lr = lr73self.kl_coeff = kl_coeff74self.enc_out_dim = enc_out_dim75self.latent_dim = latent_dim76self.input_height = input_height7778modules = []79if hidden_dims is None:80hidden_dims = [32, 64, 128, 256, 512]8182# Build Encoder83for h_dim in hidden_dims:84modules.append(85nn.Sequential(86nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),87nn.BatchNorm2d(h_dim),88nn.LeakyReLU(),89)90)91in_channels = h_dim9293self.encoder = nn.Sequential(*modules)94self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)95self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)9697# Build Decoder98modules = []99100self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)101102hidden_dims.reverse()103104for i in range(len(hidden_dims) - 1):105modules.append(106nn.Sequential(107nn.ConvTranspose2d(108hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1109),110nn.BatchNorm2d(hidden_dims[i + 1]),111nn.LeakyReLU(),112)113)114115self.decoder = nn.Sequential(*modules)116117self.final_layer = nn.Sequential(118nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),119nn.BatchNorm2d(hidden_dims[-1]),120nn.LeakyReLU(),121nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),122nn.Sigmoid(),123)124125@staticmethod126def pretrained_weights_available():127return list(VAE.pretrained_urls.keys())128129def from_pretrained(self, checkpoint_name):130if checkpoint_name not in VAE.pretrained_urls:131raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")132133return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)134135def forward(self, x):136mu, log_var = self.encode(x)137p, q, z = self.sample(mu, log_var)138139return self.decode(z)140141def encode(self, x):142x = self.encoder(x)143x = torch.flatten(x, start_dim=1)144mu = self.fc_mu(x)145log_var = self.fc_var(x)146return mu, log_var147148def _run_step(self, x):149mu, log_var = self.encode(x)150p, q, z = self.sample(mu, log_var)151152return z, self.decode(z), p, q153154def decode(self, z):155result = self.decoder_input(z)156result = result.view(-1, 512, 2, 2)157result = self.decoder(result)158result = self.final_layer(result)159return result160161def sample(self, mu, log_var):162std = torch.exp(log_var / 2)163p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))164q = torch.distributions.Normal(mu, std)165z = q.rsample()166return p, q, z167168def step(self, batch, batch_idx):169x, y = batch170z, x_hat, p, q = self._run_step(x)171172recon_loss = F.mse_loss(x_hat, x, reduction="mean")173174log_qz = q.log_prob(z)175log_pz = p.log_prob(z)176177kl = log_qz - log_pz178kl = kl.mean()179kl *= self.kl_coeff180181loss = kl + recon_loss182183logs = {184"recon_loss": recon_loss,185"kl": kl,186"loss": loss,187}188return loss, logs189190def training_step(self, batch, batch_idx):191loss, logs = self.step(batch, batch_idx)192self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)193return loss194195def validation_step(self, batch, batch_idx):196loss, logs = self.step(batch, batch_idx)197self.log_dict({f"val_{k}": v for k, v in logs.items()})198return loss199200def configure_optimizers(self):201return torch.optim.Adam(self.parameters(), lr=self.lr)202203204if __name__ == "__main__":205parser = ArgumentParser(description="Hyperparameters for our experiments")206parser.add_argument("--latent-dim", type=int, default=256, help="size of latent dim for our vae")207parser.add_argument("--epochs", type=int, default=50, help="num epochs")208parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")209parser.add_argument("--bs", type=int, default=500, help="batch size")210parser.add_argument("--kl-coeff", type=int, default=5, help="kl coeff aka beta term in the elbo loss function")211parser.add_argument("--lr", type=int, default=0.01, help="learning rate")212hparams = parser.parse_args()213214m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, kl_coeff=hparams.kl_coeff, lr=hparams.lr)215runner = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)216dm = CelebADataModule(217data_dir=DATA_PATH,218target_type="attr",219train_transform=transform,220val_transform=transform,221download=True,222batch_size=hparams.bs,223)224runner.fit(m, datamodule=dm)225torch.save(m.state_dict(), "vae-celeba-conv.ckpt")226227228