Path: blob/master/deprecated/scripts/ae_celeba_lightning.py
1192 views
# -*- coding: utf-8 -*-1"""2Author: Ang Ming Liang34Please run the following command before running the script56wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py7or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py89Then, make sure to get your kaggle.json from kaggle.com then run1011mkdir /root/.kaggle12cp kaggle.json /root/.kaggle/kaggle.json13chmod 600 /root/.kaggle/kaggle.json14rm kaggle.json1516to copy kaggle.json into a folder first17"""1819import superimport2021import torch22import torch.nn as nn23import torch.nn.functional as F24import torchvision.transforms as transforms25from pytorch_lightning import LightningModule, Trainer26from data import CelebADataModule272829IMAGE_SIZE = 6430BATCH_SIZE = 25631CROP = 12832DATA_PATH = "kaggle"3334trans = []35trans.append(transforms.RandomHorizontalFlip())36if CROP > 0:37trans.append(transforms.CenterCrop(CROP))38trans.append(transforms.Resize(IMAGE_SIZE))39trans.append(transforms.ToTensor())40transform = transforms.Compose(trans)4142class AE(LightningModule):43"""44Standard VAE with Gaussian Prior and approx posterior.45"""4647def __init__(48self,49input_height: int,50enc_type: str = 'resnet18',51first_conv: bool = False,52maxpool1: bool = False,53hidden_dims = None,54in_channels = 3,55enc_out_dim: int = 512,56kl_coeff: float = 0.1,57latent_dim: int = 256,58lr: float = 1e-4,59**kwargs60):61"""62Args:63input_height: height of the images64enc_type: option between resnet18 or resnet5065first_conv: use standard kernel_size 7, stride 2 at start or66replace it with kernel_size 3, stride 1 conv67maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 268enc_out_dim: set according to the out_channel count of69encoder used (512 for resnet18, 2048 for resnet50)70kl_coeff: coefficient for kl term of the loss71latent_dim: dim of latent space72lr: learning rate for Adam73"""7475super(AE, self).__init__()7677self.save_hyperparameters()7879self.lr = lr80self.kl_coeff = kl_coeff81self.enc_out_dim = enc_out_dim82self.latent_dim = latent_dim83self.input_height = input_height8485modules = []86if hidden_dims is None:87hidden_dims = [32, 64, 128, 256, 512]8889# Build Encoder90for h_dim in hidden_dims:91modules.append(92nn.Sequential(93nn.Conv2d(in_channels, out_channels=h_dim,94kernel_size= 3, stride= 2, padding = 1),95nn.BatchNorm2d(h_dim),96nn.LeakyReLU())97)98in_channels = h_dim99100self.encoder = nn.Sequential(*modules)101self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)102self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)103104# Build Decoder105modules = []106107self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)108109hidden_dims.reverse()110111for i in range(len(hidden_dims) - 1):112modules.append(113nn.Sequential(114nn.ConvTranspose2d(hidden_dims[i],115hidden_dims[i + 1],116kernel_size=3,117stride = 2,118padding=1,119output_padding=1),120nn.BatchNorm2d(hidden_dims[i + 1]),121nn.LeakyReLU())122)123124self.decoder = nn.Sequential(*modules)125126self.final_layer = nn.Sequential(127nn.ConvTranspose2d(hidden_dims[-1],128hidden_dims[-1],129kernel_size=3,130stride=2,131padding=1,132output_padding=1),133nn.BatchNorm2d(hidden_dims[-1]),134nn.LeakyReLU(),135nn.Conv2d(hidden_dims[-1], out_channels= 3,136kernel_size= 3, padding= 1),137nn.Sigmoid())138139@staticmethod140def pretrained_weights_available():141return list(AE.pretrained_urls.keys())142143def from_pretrained(self, checkpoint_name):144if checkpoint_name not in AE.pretrained_urls:145raise KeyError(str(checkpoint_name) + ' not present in pretrained weights.')146147return self.load_from_checkpoint(AE.pretrained_urls[checkpoint_name], strict=False)148149def encode(self, x):150x = self.encoder(x)151x = torch.flatten(x, start_dim=1)152mu = self.fc_mu(x)153return mu154155def decode(self, z):156result = self.decoder_input(z)157result = result.view(-1, 512, 2, 2)158result = self.decoder(result)159result = self.final_layer(result)160return result161162def forward(self, x):163z = self.encode(x)164return self.decode(z)165166def step(self, batch, batch_idx):167x, y = batch168x_hat= self(x)169170loss = F.mse_loss(x_hat, x, reduction='mean')171172logs = {173"loss": loss,174}175return loss, logs176177def training_step(self, batch, batch_idx):178loss, logs = self.step(batch, batch_idx)179self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)180return loss181182def validation_step(self, batch, batch_idx):183loss, logs = self.step(batch, batch_idx)184self.log_dict({f"val_{k}": v for k, v in logs.items()})185return loss186187def configure_optimizers(self):188return torch.optim.Adam(self.parameters(), lr=self.lr)189190if __name__ == "__main__":191m = AE(input_height=IMAGE_SIZE)192runner = Trainer(gpus = 2,gradient_clip_val=0.5,193max_epochs = 15)194dm = CelebADataModule(data_dir=DATA_PATH,195target_type='attr',196train_transform=transform,197val_transform=transform,198download=True,199batch_size=BATCH_SIZE,200num_workers=3)201runner.fit(m, datamodule=dm)202torch.save(m.state_dict(), "ae-celeba-latent-dim-256.ckpt")203204205