Path: blob/master/deprecated/vae/standalone/vae_logcosh_celeb_lightning.py
1192 views
import torch1import torch.nn as nn2import matplotlib.pyplot as plt3import torch.nn.functional as F4import torchvision.transforms as transforms5import torchvision.utils as vutils6from pytorch_lightning import LightningModule, Trainer7from data import CelebADataModule8from argparse import ArgumentParser9from einops import rearrange1011IMAGE_SIZE = 6412CROP = 12813DATA_PATH = "kaggle"1415trans = []16trans.append(transforms.RandomHorizontalFlip())17if CROP > 0:18trans.append(transforms.CenterCrop(CROP))19trans.append(transforms.Resize(IMAGE_SIZE))20trans.append(transforms.ToTensor())21transform = transforms.Compose(trans)222324def kl_divergence(mean, logvar):25return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))262728class VAE(LightningModule):29"""30Standard VAE with Gaussian Prior and approx posterior.31"""3233def __init__(34self,35input_height: int,36hidden_dims=None,37in_channels=3,38enc_out_dim: int = 512,39kl_coeff: float = 2.0,40alpha: float = 10,41latent_dim: int = 256,42lr: float = 1e-3,43):44"""45Args:46input_height: height of the images47enc_type: option between resnet18 or resnet5048first_conv: use standard kernel_size 7, stride 2 at start or49replace it with kernel_size 3, stride 1 conv50maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 251enc_out_dim: set according to the out_channel count of52encoder used (512 for resnet18, 2048 for resnet50)53kl_coeff: coefficient for kl term of the loss54latent_dim: dim of latent space55lr: learning rate for Adam56"""5758super(VAE, self).__init__()5960self.save_hyperparameters()6162self.lr = lr63self.kl_coeff = kl_coeff64self.alpha = alpha65self.enc_out_dim = enc_out_dim66self.latent_dim = latent_dim67self.input_height = input_height6869modules = []70if hidden_dims is None:71hidden_dims = [32, 64, 128, 256, 512]7273# Build Encoder74for h_dim in hidden_dims:75modules.append(76nn.Sequential(77nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),78nn.BatchNorm2d(h_dim),79nn.LeakyReLU(),80)81)82in_channels = h_dim8384self.encoder = nn.Sequential(*modules)85self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)86self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)8788# Build Decoder89modules = []9091self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)9293hidden_dims.reverse()9495for i in range(len(hidden_dims) - 1):96modules.append(97nn.Sequential(98nn.ConvTranspose2d(99hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1100),101nn.BatchNorm2d(hidden_dims[i + 1]),102nn.LeakyReLU(),103)104)105106self.decoder = nn.Sequential(*modules)107108self.final_layer = nn.Sequential(109nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),110nn.BatchNorm2d(hidden_dims[-1]),111nn.LeakyReLU(),112nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),113nn.Sigmoid(),114)115116@staticmethod117def pretrained_weights_available():118return list(VAE.pretrained_urls.keys())119120def from_pretrained(self, checkpoint_name):121if checkpoint_name not in VAE.pretrained_urls:122raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")123124return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)125126def encode(self, x):127x = self.encoder(x)128x = torch.flatten(x, start_dim=1)129mu = self.fc_mu(x)130log_var = self.fc_var(x)131return mu, log_var132133def forward(self, x):134mu, log_var = self.encode(x)135z = self.sample(mu, log_var)136return self.decode(z)137138def _run_step(self, x):139mu, log_var = self.encode(x)140z = self.sample(mu, log_var)141return z, self.decode(z), mu, log_var142143def sample(self, mu, log_var):144std = torch.exp(0.5 * log_var)145eps = torch.randn_like(std)146return eps * std + mu147148def step(self, batch, batch_idx):149x, y = batch150z, x_hat, mu, logvar = self._run_step(x)151152t = x_hat - x153recons_loss = self.alpha * t + torch.log(1.0 + torch.exp(-2 * self.alpha * t)) - torch.log(torch.tensor(2.0))154recons_loss = (1.0 / self.alpha) * recons_loss.mean()155156kld_loss = kl_divergence(mu, logvar)157158loss = recons_loss + self.kl_coeff * kld_loss159logs = {160"recon_loss": recons_loss,161"loss": loss,162}163return loss, logs164165def step_sample(self, batch, batch_idx):166x, y = batch167z, x_hat = self._run_step(x)168169def decode(self, z):170result = self.decoder_input(z)171result = result.view(-1, 512, 2, 2)172result = self.decoder(result)173result = self.final_layer(result)174return result175176def training_step(self, batch, batch_idx):177loss, logs = self.step(batch, batch_idx)178self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)179return loss180181def validation_step(self, batch, batch_idx):182loss, logs = self.step(batch, batch_idx)183self.log_dict({f"val_{k}": v for k, v in logs.items()})184return loss185186def configure_optimizers(self):187return torch.optim.Adam(self.parameters(), lr=self.lr)188189190if __name__ == "__main__":191parser = ArgumentParser(description="Hyperparameters for our experiments")192parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")193parser.add_argument("--epochs", type=int, default=50, help="num epochs")194parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")195parser.add_argument("--bs", type=int, default=256, help="batch size")196parser.add_argument("--beta", type=int, default=1, help="kl coeff")197parser.add_argument(198"--alpha",199type=int,200default=10,201help="the bigger the value of alpha the closer the reconstruction approaches to l1 loss",202)203parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")204hparams = parser.parse_args()205206m = VAE(207input_height=IMAGE_SIZE,208latent_dim=hparams.latent_dim,209kl_coeff=hparams.beta,210alpha=hparams.alpha,211lr=hparams.lr,212)213dm = CelebADataModule(214data_dir=DATA_PATH,215target_type="attr",216train_transform=transform,217val_transform=transform,218download=True,219batch_size=hparams.bs,220)221trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)222trainer.fit(m, datamodule=dm)223torch.save(m.state_dict(), "logcoshvae-celeba-conv.ckpt")224225226