Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_celeba_lightning.py
1192 views
1
"""
2
Author: Ang Ming Liang
3
4
Please run the following command before running the script
5
6
wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py
7
or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py
8
9
Then, make sure to get your kaggle.json from kaggle.com then run
10
11
mkdir /root/.kaggle
12
cp kaggle.json /root/.kaggle/kaggle.json
13
chmod 600 /root/.kaggle/kaggle.json
14
rm kaggle.json
15
16
to copy kaggle.json into a folder first
17
"""
18
19
import torch
20
import torch.nn as nn
21
import torch.nn.functional as F
22
import torchvision.transforms as transforms
23
from pytorch_lightning import LightningModule, Trainer
24
from data import CelebADataModule
25
from argparse import ArgumentParser
26
27
IMAGE_SIZE = 64
28
CROP = 128
29
DATA_PATH = "kaggle"
30
31
trans = []
32
trans.append(transforms.RandomHorizontalFlip())
33
if CROP > 0:
34
trans.append(transforms.CenterCrop(CROP))
35
trans.append(transforms.Resize(IMAGE_SIZE))
36
trans.append(transforms.ToTensor())
37
transform = transforms.Compose(trans)
38
39
40
class VAE(LightningModule):
41
"""
42
Standard VAE with Gaussian Prior and approx posterior.
43
"""
44
45
def __init__(
46
self,
47
input_height: int,
48
hidden_dims=None,
49
in_channels=3,
50
enc_out_dim: int = 512,
51
kl_coeff: float = 0.1,
52
latent_dim: int = 256,
53
lr: float = 1e-4,
54
):
55
"""
56
Args:
57
input_height: height of the images
58
enc_type: option between resnet18 or resnet50
59
first_conv: use standard kernel_size 7, stride 2 at start or
60
replace it with kernel_size 3, stride 1 conv
61
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
62
enc_out_dim: set according to the out_channel count of
63
encoder used (512 for resnet18, 2048 for resnet50)
64
kl_coeff: coefficient for kl term of the loss
65
latent_dim: dim of latent space
66
lr: learning rate for Adam
67
"""
68
69
super(VAE, self).__init__()
70
71
self.save_hyperparameters()
72
73
self.lr = lr
74
self.kl_coeff = kl_coeff
75
self.enc_out_dim = enc_out_dim
76
self.latent_dim = latent_dim
77
self.input_height = input_height
78
79
modules = []
80
if hidden_dims is None:
81
hidden_dims = [32, 64, 128, 256, 512]
82
83
# Build Encoder
84
for h_dim in hidden_dims:
85
modules.append(
86
nn.Sequential(
87
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
88
nn.BatchNorm2d(h_dim),
89
nn.LeakyReLU(),
90
)
91
)
92
in_channels = h_dim
93
94
self.encoder = nn.Sequential(*modules)
95
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
96
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
97
98
# Build Decoder
99
modules = []
100
101
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
102
103
hidden_dims.reverse()
104
105
for i in range(len(hidden_dims) - 1):
106
modules.append(
107
nn.Sequential(
108
nn.ConvTranspose2d(
109
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
110
),
111
nn.BatchNorm2d(hidden_dims[i + 1]),
112
nn.LeakyReLU(),
113
)
114
)
115
116
self.decoder = nn.Sequential(*modules)
117
118
self.final_layer = nn.Sequential(
119
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
120
nn.BatchNorm2d(hidden_dims[-1]),
121
nn.LeakyReLU(),
122
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
123
nn.Sigmoid(),
124
)
125
126
@staticmethod
127
def pretrained_weights_available():
128
return list(VAE.pretrained_urls.keys())
129
130
def from_pretrained(self, checkpoint_name):
131
if checkpoint_name not in VAE.pretrained_urls:
132
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
133
134
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
135
136
def forward(self, x):
137
mu, log_var = self.encode(x)
138
p, q, z = self.sample(mu, log_var)
139
140
return self.decode(z)
141
142
def encode(self, x):
143
x = self.encoder(x)
144
x = torch.flatten(x, start_dim=1)
145
mu = self.fc_mu(x)
146
log_var = self.fc_var(x)
147
return mu, log_var
148
149
def _run_step(self, x):
150
mu, log_var = self.encode(x)
151
p, q, z = self.sample(mu, log_var)
152
153
return z, self.decode(z), p, q
154
155
def decode(self, z):
156
result = self.decoder_input(z)
157
result = result.view(-1, 512, 2, 2)
158
result = self.decoder(result)
159
result = self.final_layer(result)
160
return result
161
162
def sample(self, mu, log_var):
163
std = torch.exp(log_var / 2)
164
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
165
q = torch.distributions.Normal(mu, std)
166
z = q.rsample()
167
return p, q, z
168
169
def step(self, batch, batch_idx):
170
x, y = batch
171
z, x_hat, p, q = self._run_step(x)
172
173
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
174
175
log_qz = q.log_prob(z)
176
log_pz = p.log_prob(z)
177
178
kl = log_qz - log_pz
179
kl = kl.mean()
180
kl *= self.kl_coeff
181
182
loss = kl + recon_loss
183
184
logs = {
185
"recon_loss": recon_loss,
186
"kl": kl,
187
"loss": loss,
188
}
189
return loss, logs
190
191
def training_step(self, batch, batch_idx):
192
loss, logs = self.step(batch, batch_idx)
193
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
194
return loss
195
196
def validation_step(self, batch, batch_idx):
197
loss, logs = self.step(batch, batch_idx)
198
self.log_dict({f"val_{k}": v for k, v in logs.items()})
199
return loss
200
201
def configure_optimizers(self):
202
return torch.optim.Adam(self.parameters(), lr=self.lr)
203
204
205
if __name__ == "__main__":
206
parser = ArgumentParser(description="Hyperparameters for our experiments")
207
parser.add_argument("--latent-dim", type=int, default=256, help="size of latent dim for our vae")
208
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
209
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
210
parser.add_argument("--bs", type=int, default=500, help="batch size")
211
parser.add_argument("--kl-coeff", type=int, default=5, help="kl coeff aka beta term in the elbo loss function")
212
parser.add_argument("--lr", type=int, default=0.01, help="learning rate")
213
hparams = parser.parse_args()
214
215
m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, kl_coeff=hparams.kl_coeff, lr=hparams.lr)
216
runner = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)
217
dm = CelebADataModule(
218
data_dir=DATA_PATH,
219
target_type="attr",
220
train_transform=transform,
221
val_transform=transform,
222
download=True,
223
batch_size=hparams.bs,
224
)
225
runner.fit(m, datamodule=dm)
226
torch.save(m.state_dict(), "vae-celeba-conv.ckpt")
227
228