Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_mmd_celeba_lightning.py
1192 views
1
"""
2
Author: Ang Ming Liang
3
4
Please run the following command before running the script
5
6
wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py
7
or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py
8
9
Then, make sure to get your kaggle.json from kaggle.com then run
10
11
mkdir /root/.kaggle
12
cp kaggle.json /root/.kaggle/kaggle.json
13
chmod 600 /root/.kaggle/kaggle.json
14
rm kaggle.json
15
16
to copy kaggle.json into a folder first
17
"""
18
19
import torch
20
import torch.nn as nn
21
import torch.nn.functional as F
22
import torchvision.transforms as transforms
23
from pytorch_lightning import LightningModule, Trainer
24
from data import CelebADataModule
25
from argparse import ArgumentParser
26
from einops import rearrange
27
28
IMAGE_SIZE = 64
29
CROP = 128
30
DATA_PATH = "kaggle"
31
32
trans = []
33
trans.append(transforms.RandomHorizontalFlip())
34
if CROP > 0:
35
trans.append(transforms.CenterCrop(CROP))
36
trans.append(transforms.Resize(IMAGE_SIZE))
37
trans.append(transforms.ToTensor())
38
transform = transforms.Compose(trans)
39
40
41
def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:
42
# Convert the tensors into row and column vectors
43
D = x1.size(1)
44
N = x1.size(0)
45
46
x1 = x1.unsqueeze(-2) # Make it into a column tensor
47
x2 = x2.unsqueeze(-3) # Make it into a row tensor
48
49
"""
50
Usually the below lines are not required, especially in our case,
51
but this is useful when x1 and x2 have different sizes
52
along the 0th dimension.
53
"""
54
x1 = x1.expand(N, N, D)
55
x2 = x2.expand(N, N, D)
56
57
if kernel_type == "rbf":
58
result = compute_rbf(x1, x2)
59
elif kernel_type == "imq":
60
result = compute_inv_mult_quad(x1, x2)
61
else:
62
raise ValueError("Undefined kernel type.")
63
64
return result
65
66
67
def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:
68
"""
69
Computes the RBF Kernel between x1 and x2.
70
:param x1: (Tensor)
71
:param x2: (Tensor)
72
:param eps: (Float)
73
:return:
74
"""
75
z_dim = x2.size(-1)
76
sigma = 2.0 * z_dim * latent_var
77
78
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
79
return result
80
81
82
def compute_inv_mult_quad(
83
x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7
84
) -> torch.Tensor:
85
"""
86
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
87
given by
88
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
89
:param x1: (Tensor)
90
:param x2: (Tensor)
91
:param eps: (Float)
92
:return:
93
"""
94
z_dim = x2.size(-1)
95
C = 2 * z_dim * latent_var
96
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))
97
98
# Exclude diagonal elements
99
result = kernel.sum() - kernel.diag().sum()
100
101
return result
102
103
104
def MMD(prior_z: torch.Tensor, z: torch.Tensor):
105
106
prior_z__kernel = compute_kernel(prior_z, prior_z)
107
z__kernel = compute_kernel(z, z)
108
priorz_z__kernel = compute_kernel(prior_z, z)
109
110
mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()
111
return mmd
112
113
114
class VAE(LightningModule):
115
"""
116
Standard VAE with Gaussian Prior and approx posterior.
117
"""
118
119
def __init__(
120
self,
121
input_height: int,
122
hidden_dims=None,
123
in_channels=3,
124
enc_out_dim: int = 512,
125
beta: float = 1,
126
latent_dim: int = 256,
127
lr: float = 1e-3,
128
):
129
"""
130
Args:
131
input_height: height of the images
132
enc_type: option between resnet18 or resnet50
133
first_conv: use standard kernel_size 7, stride 2 at start or
134
replace it with kernel_size 3, stride 1 conv
135
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
136
enc_out_dim: set according to the out_channel count of
137
encoder used (512 for resnet18, 2048 for resnet50)
138
kl_coeff: coefficient for kl term of the loss
139
latent_dim: dim of latent space
140
lr: learning rate for Adam
141
"""
142
143
super(VAE, self).__init__()
144
145
self.save_hyperparameters()
146
147
self.lr = lr
148
self.beta = beta
149
self.enc_out_dim = enc_out_dim
150
self.latent_dim = latent_dim
151
self.input_height = input_height
152
153
modules = []
154
if hidden_dims is None:
155
hidden_dims = [32, 64, 128, 256, 512]
156
157
# Build Encoder
158
for h_dim in hidden_dims:
159
modules.append(
160
nn.Sequential(
161
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
162
nn.BatchNorm2d(h_dim),
163
nn.LeakyReLU(),
164
)
165
)
166
in_channels = h_dim
167
168
self.encoder = nn.Sequential(*modules)
169
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
170
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
171
172
# Build Decoder
173
modules = []
174
175
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
176
177
hidden_dims.reverse()
178
179
for i in range(len(hidden_dims) - 1):
180
modules.append(
181
nn.Sequential(
182
nn.ConvTranspose2d(
183
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
184
),
185
nn.BatchNorm2d(hidden_dims[i + 1]),
186
nn.LeakyReLU(),
187
)
188
)
189
190
self.decoder = nn.Sequential(*modules)
191
192
self.final_layer = nn.Sequential(
193
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
194
nn.BatchNorm2d(hidden_dims[-1]),
195
nn.LeakyReLU(),
196
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
197
nn.Sigmoid(),
198
)
199
200
@staticmethod
201
def pretrained_weights_available():
202
return list(VAE.pretrained_urls.keys())
203
204
def from_pretrained(self, checkpoint_name):
205
if checkpoint_name not in VAE.pretrained_urls:
206
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
207
208
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
209
210
def encode(self, x):
211
x = self.encoder(x)
212
x = torch.flatten(x, start_dim=1)
213
mu = self.fc_mu(x)
214
return mu
215
216
def forward(self, x):
217
z = self.encode(x)
218
return self.decode(z)
219
220
def _run_step(self, x):
221
z = self.encode(x)
222
223
return z, self.decode(z)
224
225
def step(self, batch, batch_idx):
226
x, y = batch
227
z, x_hat = self._run_step(x)
228
229
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
230
mmd = MMD(torch.randn_like(z), z)
231
232
loss = recon_loss + self.beta * mmd
233
234
logs = {
235
"recon_loss": recon_loss,
236
"mmd": mmd,
237
}
238
return loss, logs
239
240
def decode(self, z):
241
result = self.decoder_input(z)
242
result = result.view(-1, 512, 2, 2)
243
result = self.decoder(result)
244
result = self.final_layer(result)
245
return result
246
247
def training_step(self, batch, batch_idx):
248
loss, logs = self.step(batch, batch_idx)
249
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
250
return loss
251
252
def validation_step(self, batch, batch_idx):
253
loss, logs = self.step(batch, batch_idx)
254
self.log_dict({f"val_{k}": v for k, v in logs.items()})
255
return loss
256
257
def configure_optimizers(self):
258
return torch.optim.Adam(self.parameters(), lr=self.lr)
259
260
261
if __name__ == "__main__":
262
parser = ArgumentParser(description="Hyperparameters for our experiments")
263
parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")
264
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
265
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
266
parser.add_argument("--bs", type=int, default=256, help="batch size")
267
parser.add_argument("--beta", type=int, default=1, help="kl coeff aka beta term in the elbo loss function")
268
parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")
269
hparams = parser.parse_args()
270
271
m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)
272
dm = CelebADataModule(
273
data_dir=DATA_PATH,
274
target_type="attr",
275
train_transform=transform,
276
val_transform=transform,
277
download=True,
278
batch_size=hparams.bs,
279
)
280
trainer = Trainer(gpus=1, weights_summary="full", max_epochs=10, auto_lr_find=True)
281
282
# Run learning rate finder
283
lr_finder = trainer.tuner.lr_find(m, dm)
284
285
# Results can be found in
286
lr_finder.results
287
288
# Plot with
289
fig = lr_finder.plot(suggest=True)
290
fig.show()
291
292
# Pick point based on plot, or get suggestion
293
new_lr = lr_finder.suggestion()
294
295
# update hparams of the model
296
m.lr = new_lr
297
298
trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)
299
trainer.fit(m, datamodule=dm)
300
torch.save(m.state_dict(), "mmd-vae-celeba-conv.ckpt")
301
302