Path: blob/master/deprecated/vae/standalone/vae_hinge_celeb_lightning.py
1192 views
import torch1import torch.nn as nn2import torch.nn.functional as F3import torchvision.transforms as transforms4from pytorch_lightning import LightningModule, Trainer5from data import CelebADataModule6from argparse import ArgumentParser78IMAGE_SIZE = 649CROP = 12810DATA_PATH = "kaggle"1112trans = []13trans.append(transforms.RandomHorizontalFlip())14if CROP > 0:15trans.append(transforms.CenterCrop(CROP))16trans.append(transforms.Resize(IMAGE_SIZE))17trans.append(transforms.ToTensor())18transform = transforms.Compose(trans)192021def kl_divergence(mean, logvar):22return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))232425class VAE(LightningModule):26"""27Standard VAE with Gaussian Prior and approx posterior.28"""2930def __init__(31self,32input_height: int,33hidden_dims=None,34in_channels=3,35enc_out_dim: int = 512,36kl_coeff: float = 2.0,37delta: float = 10,38latent_dim: int = 256,39lr: float = 1e-3,40):41"""42Args:43input_height: height of the images44enc_type: option between resnet18 or resnet5045first_conv: use standard kernel_size 7, stride 2 at start or46replace it with kernel_size 3, stride 1 conv47maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 248enc_out_dim: set according to the out_channel count of49encoder used (512 for resnet18, 2048 for resnet50)50kl_coeff: coefficient for kl term of the loss51latent_dim: dim of latent space52lr: learning rate for Adam53"""5455super(VAE, self).__init__()5657self.save_hyperparameters()5859self.lr = lr60self.delta = delta61self.kl_coeff = kl_coeff62self.enc_out_dim = enc_out_dim63self.latent_dim = latent_dim64self.input_height = input_height6566modules = []67if hidden_dims is None:68hidden_dims = [32, 64, 128, 256, 512]6970# Build Encoder71for h_dim in hidden_dims:72modules.append(73nn.Sequential(74nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),75nn.BatchNorm2d(h_dim),76nn.LeakyReLU(),77)78)79in_channels = h_dim8081self.encoder = nn.Sequential(*modules)82self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)83self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)8485# Build Decoder86modules = []8788self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)8990hidden_dims.reverse()9192for i in range(len(hidden_dims) - 1):93modules.append(94nn.Sequential(95nn.ConvTranspose2d(96hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=197),98nn.BatchNorm2d(hidden_dims[i + 1]),99nn.LeakyReLU(),100)101)102103self.decoder = nn.Sequential(*modules)104105self.final_layer = nn.Sequential(106nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),107nn.BatchNorm2d(hidden_dims[-1]),108nn.LeakyReLU(),109nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),110nn.Sigmoid(),111)112113@staticmethod114def pretrained_weights_available():115return list(VAE.pretrained_urls.keys())116117def from_pretrained(self, checkpoint_name):118if checkpoint_name not in VAE.pretrained_urls:119raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")120121return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)122123def encode(self, x):124x = self.encoder(x)125x = torch.flatten(x, start_dim=1)126mu = self.fc_mu(x)127log_var = self.fc_var(x)128return mu, log_var129130def forward(self, x):131mu, log_var = self.encode(x)132z = self.sample(mu, log_var)133return self.decode(z)134135def _run_step(self, x):136mu, log_var = self.encode(x)137z = self.sample(mu, log_var)138return z, self.decode(z), mu, log_var139140def sample(self, mu, log_var):141std = torch.exp(0.5 * log_var)142eps = torch.randn_like(std)143return eps * std + mu144145def step(self, batch, batch_idx):146x, y = batch147z, x_hat, mu, logvar = self._run_step(x)148149recon_loss = F.mse_loss(x_hat, x, reduction="mean")150kld_loss = kl_divergence(mu, logvar)151152loss = recon_loss + self.kl_coeff * torch.max(kld_loss, self.delta * torch.ones_like(kld_loss))153154logs = {155"recon_loss": recon_loss,156"loss": loss,157}158return loss, logs159160def step_sample(self, batch, batch_idx):161x, y = batch162z, x_hat = self._run_step(x)163164def decode(self, z):165result = self.decoder_input(z)166result = result.view(-1, 512, 2, 2)167result = self.decoder(result)168result = self.final_layer(result)169return result170171def training_step(self, batch, batch_idx):172loss, logs = self.step(batch, batch_idx)173self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)174return loss175176def validation_step(self, batch, batch_idx):177loss, logs = self.step(batch, batch_idx)178self.log_dict({f"val_{k}": v for k, v in logs.items()})179return loss180181def configure_optimizers(self):182return torch.optim.Adam(self.parameters(), lr=self.lr)183184185if __name__ == "__main__":186parser = ArgumentParser(description="Hyperparameters for our experiments")187parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")188parser.add_argument("--epochs", type=int, default=50, help="num epochs")189parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")190parser.add_argument("--bs", type=int, default=256, help="batch size")191parser.add_argument("--alpha", type=int, default=1, help="kl coeff")192parser.add_argument("--beta", type=int, default=1, help="mmd coeff")193parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")194hparams = parser.parse_args()195196m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)197dm = CelebADataModule(198data_dir=DATA_PATH,199target_type="attr",200train_transform=transform,201val_transform=transform,202download=True,203batch_size=hparams.bs,204)205trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)206trainer.fit(m, datamodule=dm)207torch.save(m.state_dict(), "hingevae-celeba-conv.ckpt")208209210