Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_info_celeba_lightning.py
1192 views
1
# -*- coding: utf-8 -*-
2
"""
3
Author: Ang Ming Liang
4
5
Please run the following command before running the script
6
7
wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py
8
or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py
9
10
Then, make sure to get your kaggle.json from kaggle.com then run
11
12
mkdir /root/.kaggle
13
cp kaggle.json /root/.kaggle/kaggle.json
14
chmod 600 /root/.kaggle/kaggle.json
15
rm kaggle.json
16
17
to copy kaggle.json into a folder first
18
"""
19
20
import torch
21
import torch.nn as nn
22
import matplotlib.pyplot as plt
23
import torch.nn.functional as F
24
import torchvision.transforms as transforms
25
import torchvision.utils as vutils
26
from pytorch_lightning import LightningModule, Trainer
27
from data import CelebADataModule
28
from argparse import ArgumentParser
29
from einops import rearrange
30
31
IMAGE_SIZE = 64
32
CROP = 128
33
DATA_PATH = "kaggle"
34
35
trans = []
36
trans.append(transforms.RandomHorizontalFlip())
37
if CROP > 0:
38
trans.append(transforms.CenterCrop(CROP))
39
trans.append(transforms.Resize(IMAGE_SIZE))
40
trans.append(transforms.ToTensor())
41
transform = transforms.Compose(trans)
42
43
44
def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:
45
# Convert the tensors into row and column vectors
46
D = x1.size(1)
47
N = x1.size(0)
48
49
x1 = x1.unsqueeze(-2) # Make it into a column tensor
50
x2 = x2.unsqueeze(-3) # Make it into a row tensor
51
52
"""
53
Usually the below lines are not required, especially in our case,
54
but this is useful when x1 and x2 have different sizes
55
along the 0th dimension.
56
"""
57
x1 = x1.expand(N, N, D)
58
x2 = x2.expand(N, N, D)
59
60
if kernel_type == "rbf":
61
result = compute_rbf(x1, x2)
62
elif kernel_type == "imq":
63
result = compute_inv_mult_quad(x1, x2)
64
else:
65
raise ValueError("Undefined kernel type.")
66
67
return result
68
69
70
def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:
71
"""
72
Computes the RBF Kernel between x1 and x2.
73
:param x1: (Tensor)
74
:param x2: (Tensor)
75
:param eps: (Float)
76
:return:
77
"""
78
z_dim = x2.size(-1)
79
sigma = 2.0 * z_dim * latent_var
80
81
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
82
return result
83
84
85
def compute_inv_mult_quad(
86
x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7
87
) -> torch.Tensor:
88
"""
89
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
90
given by
91
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
92
:param x1: (Tensor)
93
:param x2: (Tensor)
94
:param eps: (Float)
95
:return:
96
"""
97
z_dim = x2.size(-1)
98
C = 2 * z_dim * latent_var
99
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))
100
101
# Exclude diagonal elements
102
result = kernel.sum() - kernel.diag().sum()
103
104
return result
105
106
107
def MMD(prior_z: torch.Tensor, z: torch.Tensor):
108
109
prior_z__kernel = compute_kernel(prior_z, prior_z)
110
z__kernel = compute_kernel(z, z)
111
priorz_z__kernel = compute_kernel(prior_z, z)
112
113
mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()
114
return mmd
115
116
117
def kl_divergence(mean, logvar):
118
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
119
120
121
class VAE(LightningModule):
122
"""
123
Standard VAE with Gaussian Prior and approx posterior.
124
"""
125
126
def __init__(
127
self,
128
input_height: int,
129
hidden_dims=None,
130
in_channels=3,
131
enc_out_dim: int = 512,
132
alpha: float = 0.99,
133
beta: float = 2,
134
latent_dim: int = 256,
135
lr: float = 1e-3,
136
):
137
"""
138
Args:
139
input_height: height of the images
140
enc_type: option between resnet18 or resnet50
141
first_conv: use standard kernel_size 7, stride 2 at start or
142
replace it with kernel_size 3, stride 1 conv
143
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
144
enc_out_dim: set according to the out_channel count of
145
encoder used (512 for resnet18, 2048 for resnet50)
146
kl_coeff: coefficient for kl term of the loss
147
latent_dim: dim of latent space
148
lr: learning rate for Adam
149
"""
150
151
super(VAE, self).__init__()
152
153
self.save_hyperparameters()
154
155
self.lr = lr
156
self.alpha = alpha
157
self.beta = beta
158
self.enc_out_dim = enc_out_dim
159
self.latent_dim = latent_dim
160
self.input_height = input_height
161
162
modules = []
163
if hidden_dims is None:
164
hidden_dims = [32, 64, 128, 256, 512]
165
166
# Build Encoder
167
for h_dim in hidden_dims:
168
modules.append(
169
nn.Sequential(
170
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
171
nn.BatchNorm2d(h_dim),
172
nn.LeakyReLU(),
173
)
174
)
175
in_channels = h_dim
176
177
self.encoder = nn.Sequential(*modules)
178
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
179
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
180
181
# Build Decoder
182
modules = []
183
184
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
185
186
hidden_dims.reverse()
187
188
for i in range(len(hidden_dims) - 1):
189
modules.append(
190
nn.Sequential(
191
nn.ConvTranspose2d(
192
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
193
),
194
nn.BatchNorm2d(hidden_dims[i + 1]),
195
nn.LeakyReLU(),
196
)
197
)
198
199
self.decoder = nn.Sequential(*modules)
200
201
self.final_layer = nn.Sequential(
202
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
203
nn.BatchNorm2d(hidden_dims[-1]),
204
nn.LeakyReLU(),
205
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
206
nn.Sigmoid(),
207
)
208
209
@staticmethod
210
def pretrained_weights_available():
211
return list(VAE.pretrained_urls.keys())
212
213
def from_pretrained(self, checkpoint_name):
214
if checkpoint_name not in VAE.pretrained_urls:
215
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
216
217
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
218
219
def encode(self, x):
220
x = self.encoder(x)
221
x = torch.flatten(x, start_dim=1)
222
mu = self.fc_mu(x)
223
log_var = self.fc_var(x)
224
return mu, log_var
225
226
def forward(self, x):
227
mu, log_var = self.encode(x)
228
z = self.sample(mu, log_var)
229
return self.decode(z)
230
231
def _run_step(self, x):
232
mu, log_var = self.encode(x)
233
z = self.sample(mu, log_var)
234
return z, self.decode(z), mu, log_var
235
236
def sample(self, mu, log_var):
237
std = torch.exp(0.5 * log_var)
238
eps = torch.randn_like(std)
239
return eps * std + mu
240
241
def step(self, batch, batch_idx):
242
x, y = batch
243
z, x_hat, mu, logvar = self._run_step(x)
244
245
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
246
kld_loss = kl_divergence(mu, logvar)
247
248
mmd = MMD(torch.randn_like(z), z)
249
loss = recon_loss + (1 - self.alpha) * kld_loss + (self.alpha + self.beta - 1) * mmd
250
251
logs = {
252
"recon_loss": recon_loss,
253
"mmd": mmd,
254
"loss": loss,
255
}
256
return loss, logs
257
258
def step_sample(self, batch, batch_idx):
259
x, y = batch
260
z, x_hat = self._run_step(x)
261
262
def decode(self, z):
263
result = self.decoder_input(z)
264
result = result.view(-1, 512, 2, 2)
265
result = self.decoder(result)
266
result = self.final_layer(result)
267
return result
268
269
def training_step(self, batch, batch_idx):
270
loss, logs = self.step(batch, batch_idx)
271
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
272
return loss
273
274
def validation_step(self, batch, batch_idx):
275
loss, logs = self.step(batch, batch_idx)
276
self.log_dict({f"val_{k}": v for k, v in logs.items()})
277
return loss
278
279
def configure_optimizers(self):
280
return torch.optim.Adam(self.parameters(), lr=self.lr)
281
282
283
if __name__ == "__main__":
284
parser = ArgumentParser(description="Hyperparameters for our experiments")
285
parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")
286
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
287
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
288
parser.add_argument("--bs", type=int, default=256, help="batch size")
289
parser.add_argument("--alpha", type=int, default=1, help="kl coeff")
290
parser.add_argument("--beta", type=int, default=1, help="mmd coeff")
291
parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")
292
hparams = parser.parse_args()
293
294
m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)
295
dm = CelebADataModule(
296
data_dir=DATA_PATH,
297
target_type="attr",
298
train_transform=transform,
299
val_transform=transform,
300
download=True,
301
batch_size=hparams.bs,
302
)
303
trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)
304
trainer.fit(m, datamodule=dm)
305
torch.save(m.state_dict(), "infovae-celeba-conv.ckpt")
306
307