Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_logcosh_celeb_lightning.py
1192 views
1
import torch
2
import torch.nn as nn
3
import matplotlib.pyplot as plt
4
import torch.nn.functional as F
5
import torchvision.transforms as transforms
6
import torchvision.utils as vutils
7
from pytorch_lightning import LightningModule, Trainer
8
from data import CelebADataModule
9
from argparse import ArgumentParser
10
from einops import rearrange
11
12
IMAGE_SIZE = 64
13
CROP = 128
14
DATA_PATH = "kaggle"
15
16
trans = []
17
trans.append(transforms.RandomHorizontalFlip())
18
if CROP > 0:
19
trans.append(transforms.CenterCrop(CROP))
20
trans.append(transforms.Resize(IMAGE_SIZE))
21
trans.append(transforms.ToTensor())
22
transform = transforms.Compose(trans)
23
24
25
def kl_divergence(mean, logvar):
26
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
27
28
29
class VAE(LightningModule):
30
"""
31
Standard VAE with Gaussian Prior and approx posterior.
32
"""
33
34
def __init__(
35
self,
36
input_height: int,
37
hidden_dims=None,
38
in_channels=3,
39
enc_out_dim: int = 512,
40
kl_coeff: float = 2.0,
41
alpha: float = 10,
42
latent_dim: int = 256,
43
lr: float = 1e-3,
44
):
45
"""
46
Args:
47
input_height: height of the images
48
enc_type: option between resnet18 or resnet50
49
first_conv: use standard kernel_size 7, stride 2 at start or
50
replace it with kernel_size 3, stride 1 conv
51
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
52
enc_out_dim: set according to the out_channel count of
53
encoder used (512 for resnet18, 2048 for resnet50)
54
kl_coeff: coefficient for kl term of the loss
55
latent_dim: dim of latent space
56
lr: learning rate for Adam
57
"""
58
59
super(VAE, self).__init__()
60
61
self.save_hyperparameters()
62
63
self.lr = lr
64
self.kl_coeff = kl_coeff
65
self.alpha = alpha
66
self.enc_out_dim = enc_out_dim
67
self.latent_dim = latent_dim
68
self.input_height = input_height
69
70
modules = []
71
if hidden_dims is None:
72
hidden_dims = [32, 64, 128, 256, 512]
73
74
# Build Encoder
75
for h_dim in hidden_dims:
76
modules.append(
77
nn.Sequential(
78
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
79
nn.BatchNorm2d(h_dim),
80
nn.LeakyReLU(),
81
)
82
)
83
in_channels = h_dim
84
85
self.encoder = nn.Sequential(*modules)
86
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
87
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
88
89
# Build Decoder
90
modules = []
91
92
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
93
94
hidden_dims.reverse()
95
96
for i in range(len(hidden_dims) - 1):
97
modules.append(
98
nn.Sequential(
99
nn.ConvTranspose2d(
100
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
101
),
102
nn.BatchNorm2d(hidden_dims[i + 1]),
103
nn.LeakyReLU(),
104
)
105
)
106
107
self.decoder = nn.Sequential(*modules)
108
109
self.final_layer = nn.Sequential(
110
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
111
nn.BatchNorm2d(hidden_dims[-1]),
112
nn.LeakyReLU(),
113
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
114
nn.Sigmoid(),
115
)
116
117
@staticmethod
118
def pretrained_weights_available():
119
return list(VAE.pretrained_urls.keys())
120
121
def from_pretrained(self, checkpoint_name):
122
if checkpoint_name not in VAE.pretrained_urls:
123
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
124
125
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
126
127
def encode(self, x):
128
x = self.encoder(x)
129
x = torch.flatten(x, start_dim=1)
130
mu = self.fc_mu(x)
131
log_var = self.fc_var(x)
132
return mu, log_var
133
134
def forward(self, x):
135
mu, log_var = self.encode(x)
136
z = self.sample(mu, log_var)
137
return self.decode(z)
138
139
def _run_step(self, x):
140
mu, log_var = self.encode(x)
141
z = self.sample(mu, log_var)
142
return z, self.decode(z), mu, log_var
143
144
def sample(self, mu, log_var):
145
std = torch.exp(0.5 * log_var)
146
eps = torch.randn_like(std)
147
return eps * std + mu
148
149
def step(self, batch, batch_idx):
150
x, y = batch
151
z, x_hat, mu, logvar = self._run_step(x)
152
153
t = x_hat - x
154
recons_loss = self.alpha * t + torch.log(1.0 + torch.exp(-2 * self.alpha * t)) - torch.log(torch.tensor(2.0))
155
recons_loss = (1.0 / self.alpha) * recons_loss.mean()
156
157
kld_loss = kl_divergence(mu, logvar)
158
159
loss = recons_loss + self.kl_coeff * kld_loss
160
logs = {
161
"recon_loss": recons_loss,
162
"loss": loss,
163
}
164
return loss, logs
165
166
def step_sample(self, batch, batch_idx):
167
x, y = batch
168
z, x_hat = self._run_step(x)
169
170
def decode(self, z):
171
result = self.decoder_input(z)
172
result = result.view(-1, 512, 2, 2)
173
result = self.decoder(result)
174
result = self.final_layer(result)
175
return result
176
177
def training_step(self, batch, batch_idx):
178
loss, logs = self.step(batch, batch_idx)
179
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
180
return loss
181
182
def validation_step(self, batch, batch_idx):
183
loss, logs = self.step(batch, batch_idx)
184
self.log_dict({f"val_{k}": v for k, v in logs.items()})
185
return loss
186
187
def configure_optimizers(self):
188
return torch.optim.Adam(self.parameters(), lr=self.lr)
189
190
191
if __name__ == "__main__":
192
parser = ArgumentParser(description="Hyperparameters for our experiments")
193
parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")
194
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
195
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
196
parser.add_argument("--bs", type=int, default=256, help="batch size")
197
parser.add_argument("--beta", type=int, default=1, help="kl coeff")
198
parser.add_argument(
199
"--alpha",
200
type=int,
201
default=10,
202
help="the bigger the value of alpha the closer the reconstruction approaches to l1 loss",
203
)
204
parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")
205
hparams = parser.parse_args()
206
207
m = VAE(
208
input_height=IMAGE_SIZE,
209
latent_dim=hparams.latent_dim,
210
kl_coeff=hparams.beta,
211
alpha=hparams.alpha,
212
lr=hparams.lr,
213
)
214
dm = CelebADataModule(
215
data_dir=DATA_PATH,
216
target_type="attr",
217
train_transform=transform,
218
val_transform=transform,
219
download=True,
220
batch_size=hparams.bs,
221
)
222
trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)
223
trainer.fit(m, datamodule=dm)
224
torch.save(m.state_dict(), "logcoshvae-celeba-conv.ckpt")
225
226