Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/vae_hinge_celeb_lightning.py
1192 views
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import torchvision.transforms as transforms
5
from pytorch_lightning import LightningModule, Trainer
6
from data import CelebADataModule
7
from argparse import ArgumentParser
8
9
IMAGE_SIZE = 64
10
CROP = 128
11
DATA_PATH = "kaggle"
12
13
trans = []
14
trans.append(transforms.RandomHorizontalFlip())
15
if CROP > 0:
16
trans.append(transforms.CenterCrop(CROP))
17
trans.append(transforms.Resize(IMAGE_SIZE))
18
trans.append(transforms.ToTensor())
19
transform = transforms.Compose(trans)
20
21
22
def kl_divergence(mean, logvar):
23
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
24
25
26
class VAE(LightningModule):
27
"""
28
Standard VAE with Gaussian Prior and approx posterior.
29
"""
30
31
def __init__(
32
self,
33
input_height: int,
34
hidden_dims=None,
35
in_channels=3,
36
enc_out_dim: int = 512,
37
kl_coeff: float = 2.0,
38
delta: float = 10,
39
latent_dim: int = 256,
40
lr: float = 1e-3,
41
):
42
"""
43
Args:
44
input_height: height of the images
45
enc_type: option between resnet18 or resnet50
46
first_conv: use standard kernel_size 7, stride 2 at start or
47
replace it with kernel_size 3, stride 1 conv
48
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
49
enc_out_dim: set according to the out_channel count of
50
encoder used (512 for resnet18, 2048 for resnet50)
51
kl_coeff: coefficient for kl term of the loss
52
latent_dim: dim of latent space
53
lr: learning rate for Adam
54
"""
55
56
super(VAE, self).__init__()
57
58
self.save_hyperparameters()
59
60
self.lr = lr
61
self.delta = delta
62
self.kl_coeff = kl_coeff
63
self.enc_out_dim = enc_out_dim
64
self.latent_dim = latent_dim
65
self.input_height = input_height
66
67
modules = []
68
if hidden_dims is None:
69
hidden_dims = [32, 64, 128, 256, 512]
70
71
# Build Encoder
72
for h_dim in hidden_dims:
73
modules.append(
74
nn.Sequential(
75
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
76
nn.BatchNorm2d(h_dim),
77
nn.LeakyReLU(),
78
)
79
)
80
in_channels = h_dim
81
82
self.encoder = nn.Sequential(*modules)
83
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
84
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
85
86
# Build Decoder
87
modules = []
88
89
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
90
91
hidden_dims.reverse()
92
93
for i in range(len(hidden_dims) - 1):
94
modules.append(
95
nn.Sequential(
96
nn.ConvTranspose2d(
97
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
98
),
99
nn.BatchNorm2d(hidden_dims[i + 1]),
100
nn.LeakyReLU(),
101
)
102
)
103
104
self.decoder = nn.Sequential(*modules)
105
106
self.final_layer = nn.Sequential(
107
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
108
nn.BatchNorm2d(hidden_dims[-1]),
109
nn.LeakyReLU(),
110
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
111
nn.Sigmoid(),
112
)
113
114
@staticmethod
115
def pretrained_weights_available():
116
return list(VAE.pretrained_urls.keys())
117
118
def from_pretrained(self, checkpoint_name):
119
if checkpoint_name not in VAE.pretrained_urls:
120
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
121
122
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
123
124
def encode(self, x):
125
x = self.encoder(x)
126
x = torch.flatten(x, start_dim=1)
127
mu = self.fc_mu(x)
128
log_var = self.fc_var(x)
129
return mu, log_var
130
131
def forward(self, x):
132
mu, log_var = self.encode(x)
133
z = self.sample(mu, log_var)
134
return self.decode(z)
135
136
def _run_step(self, x):
137
mu, log_var = self.encode(x)
138
z = self.sample(mu, log_var)
139
return z, self.decode(z), mu, log_var
140
141
def sample(self, mu, log_var):
142
std = torch.exp(0.5 * log_var)
143
eps = torch.randn_like(std)
144
return eps * std + mu
145
146
def step(self, batch, batch_idx):
147
x, y = batch
148
z, x_hat, mu, logvar = self._run_step(x)
149
150
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
151
kld_loss = kl_divergence(mu, logvar)
152
153
loss = recon_loss + self.kl_coeff * torch.max(kld_loss, self.delta * torch.ones_like(kld_loss))
154
155
logs = {
156
"recon_loss": recon_loss,
157
"loss": loss,
158
}
159
return loss, logs
160
161
def step_sample(self, batch, batch_idx):
162
x, y = batch
163
z, x_hat = self._run_step(x)
164
165
def decode(self, z):
166
result = self.decoder_input(z)
167
result = result.view(-1, 512, 2, 2)
168
result = self.decoder(result)
169
result = self.final_layer(result)
170
return result
171
172
def training_step(self, batch, batch_idx):
173
loss, logs = self.step(batch, batch_idx)
174
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
175
return loss
176
177
def validation_step(self, batch, batch_idx):
178
loss, logs = self.step(batch, batch_idx)
179
self.log_dict({f"val_{k}": v for k, v in logs.items()})
180
return loss
181
182
def configure_optimizers(self):
183
return torch.optim.Adam(self.parameters(), lr=self.lr)
184
185
186
if __name__ == "__main__":
187
parser = ArgumentParser(description="Hyperparameters for our experiments")
188
parser.add_argument("--latent-dim", type=int, default=128, help="size of latent dim for our vae")
189
parser.add_argument("--epochs", type=int, default=50, help="num epochs")
190
parser.add_argument("--gpus", type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1")
191
parser.add_argument("--bs", type=int, default=256, help="batch size")
192
parser.add_argument("--alpha", type=int, default=1, help="kl coeff")
193
parser.add_argument("--beta", type=int, default=1, help="mmd coeff")
194
parser.add_argument("--lr", type=int, default=1e-3, help="learning rate")
195
hparams = parser.parse_args()
196
197
m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr)
198
dm = CelebADataModule(
199
data_dir=DATA_PATH,
200
target_type="attr",
201
train_transform=transform,
202
val_transform=transform,
203
download=True,
204
batch_size=hparams.bs,
205
)
206
trainer = Trainer(gpus=hparams.gpus, max_epochs=hparams.epochs)
207
trainer.fit(m, datamodule=dm)
208
torch.save(m.state_dict(), "hingevae-celeba-conv.ckpt")
209
210