Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/ae_celeba_lightning.py
1192 views
1
# -*- coding: utf-8 -*-
2
"""
3
Author: Ang Ming Liang
4
5
Please run the following command before running the script
6
7
wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py
8
or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py
9
10
Then, make sure to get your kaggle.json from kaggle.com then run
11
12
mkdir /root/.kaggle
13
cp kaggle.json /root/.kaggle/kaggle.json
14
chmod 600 /root/.kaggle/kaggle.json
15
rm kaggle.json
16
17
to copy kaggle.json into a folder first
18
"""
19
20
import superimport
21
22
import torch
23
import torch.nn as nn
24
import torch.nn.functional as F
25
import torchvision.transforms as transforms
26
from pytorch_lightning import LightningModule, Trainer
27
from data import CelebADataModule
28
29
30
IMAGE_SIZE = 64
31
BATCH_SIZE = 256
32
CROP = 128
33
DATA_PATH = "kaggle"
34
35
trans = []
36
trans.append(transforms.RandomHorizontalFlip())
37
if CROP > 0:
38
trans.append(transforms.CenterCrop(CROP))
39
trans.append(transforms.Resize(IMAGE_SIZE))
40
trans.append(transforms.ToTensor())
41
transform = transforms.Compose(trans)
42
43
class AE(LightningModule):
44
"""
45
Standard VAE with Gaussian Prior and approx posterior.
46
"""
47
48
def __init__(
49
self,
50
input_height: int,
51
enc_type: str = 'resnet18',
52
first_conv: bool = False,
53
maxpool1: bool = False,
54
hidden_dims = None,
55
in_channels = 3,
56
enc_out_dim: int = 512,
57
kl_coeff: float = 0.1,
58
latent_dim: int = 256,
59
lr: float = 1e-4,
60
**kwargs
61
):
62
"""
63
Args:
64
input_height: height of the images
65
enc_type: option between resnet18 or resnet50
66
first_conv: use standard kernel_size 7, stride 2 at start or
67
replace it with kernel_size 3, stride 1 conv
68
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
69
enc_out_dim: set according to the out_channel count of
70
encoder used (512 for resnet18, 2048 for resnet50)
71
kl_coeff: coefficient for kl term of the loss
72
latent_dim: dim of latent space
73
lr: learning rate for Adam
74
"""
75
76
super(AE, self).__init__()
77
78
self.save_hyperparameters()
79
80
self.lr = lr
81
self.kl_coeff = kl_coeff
82
self.enc_out_dim = enc_out_dim
83
self.latent_dim = latent_dim
84
self.input_height = input_height
85
86
modules = []
87
if hidden_dims is None:
88
hidden_dims = [32, 64, 128, 256, 512]
89
90
# Build Encoder
91
for h_dim in hidden_dims:
92
modules.append(
93
nn.Sequential(
94
nn.Conv2d(in_channels, out_channels=h_dim,
95
kernel_size= 3, stride= 2, padding = 1),
96
nn.BatchNorm2d(h_dim),
97
nn.LeakyReLU())
98
)
99
in_channels = h_dim
100
101
self.encoder = nn.Sequential(*modules)
102
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
103
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
104
105
# Build Decoder
106
modules = []
107
108
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
109
110
hidden_dims.reverse()
111
112
for i in range(len(hidden_dims) - 1):
113
modules.append(
114
nn.Sequential(
115
nn.ConvTranspose2d(hidden_dims[i],
116
hidden_dims[i + 1],
117
kernel_size=3,
118
stride = 2,
119
padding=1,
120
output_padding=1),
121
nn.BatchNorm2d(hidden_dims[i + 1]),
122
nn.LeakyReLU())
123
)
124
125
self.decoder = nn.Sequential(*modules)
126
127
self.final_layer = nn.Sequential(
128
nn.ConvTranspose2d(hidden_dims[-1],
129
hidden_dims[-1],
130
kernel_size=3,
131
stride=2,
132
padding=1,
133
output_padding=1),
134
nn.BatchNorm2d(hidden_dims[-1]),
135
nn.LeakyReLU(),
136
nn.Conv2d(hidden_dims[-1], out_channels= 3,
137
kernel_size= 3, padding= 1),
138
nn.Sigmoid())
139
140
@staticmethod
141
def pretrained_weights_available():
142
return list(AE.pretrained_urls.keys())
143
144
def from_pretrained(self, checkpoint_name):
145
if checkpoint_name not in AE.pretrained_urls:
146
raise KeyError(str(checkpoint_name) + ' not present in pretrained weights.')
147
148
return self.load_from_checkpoint(AE.pretrained_urls[checkpoint_name], strict=False)
149
150
def encode(self, x):
151
x = self.encoder(x)
152
x = torch.flatten(x, start_dim=1)
153
mu = self.fc_mu(x)
154
return mu
155
156
def decode(self, z):
157
result = self.decoder_input(z)
158
result = result.view(-1, 512, 2, 2)
159
result = self.decoder(result)
160
result = self.final_layer(result)
161
return result
162
163
def forward(self, x):
164
z = self.encode(x)
165
return self.decode(z)
166
167
def step(self, batch, batch_idx):
168
x, y = batch
169
x_hat= self(x)
170
171
loss = F.mse_loss(x_hat, x, reduction='mean')
172
173
logs = {
174
"loss": loss,
175
}
176
return loss, logs
177
178
def training_step(self, batch, batch_idx):
179
loss, logs = self.step(batch, batch_idx)
180
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
181
return loss
182
183
def validation_step(self, batch, batch_idx):
184
loss, logs = self.step(batch, batch_idx)
185
self.log_dict({f"val_{k}": v for k, v in logs.items()})
186
return loss
187
188
def configure_optimizers(self):
189
return torch.optim.Adam(self.parameters(), lr=self.lr)
190
191
if __name__ == "__main__":
192
m = AE(input_height=IMAGE_SIZE)
193
runner = Trainer(gpus = 2,gradient_clip_val=0.5,
194
max_epochs = 15)
195
dm = CelebADataModule(data_dir=DATA_PATH,
196
target_type='attr',
197
train_transform=transform,
198
val_transform=transform,
199
download=True,
200
batch_size=BATCH_SIZE,
201
num_workers=3)
202
runner.fit(m, datamodule=dm)
203
torch.save(m.state_dict(), "ae-celeba-latent-dim-256.ckpt")
204
205