Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/experiment.py
1192 views
1
import torch
2
import warnings
3
from models.pixel_cnn import PixelCNN
4
from pytorch_lightning import LightningModule
5
from torch.nn import functional as F
6
7
8
class VAEModule(LightningModule):
9
"""
10
Standard lightning training code.
11
"""
12
13
def __init__(self, model, lr: float = 1e-3, latent_dim: int = 256):
14
15
super(VAEModule, self).__init__()
16
17
self.lr = lr
18
self.model = model
19
self.model_name = model.name
20
self.latent_dim = latent_dim
21
22
def forward(self, x):
23
x = x.to(self.device)
24
return self.model(x)
25
26
def det_encode(self, x):
27
x = x.to(self.device)
28
mu, _ = self.model.encoder(x)
29
return mu
30
31
def stoch_encode(self, x):
32
x = x.to(self.device)
33
mu, log_var = self.model.encoder(x)
34
z = self.model.sample(mu, log_var)
35
return z
36
37
def decode(self, z):
38
return self.model.decoder(z)
39
40
def get_samples(self, num):
41
z = torch.randn(num, self.latent_dim)
42
z = z.to(self.device)
43
return self.model.decoder(z)
44
45
def step(self, batch, batch_idx):
46
x, y = batch
47
48
loss = self.model.compute_loss(x)
49
50
logs = {
51
"loss": loss,
52
}
53
return loss, logs
54
55
def training_step(self, batch, batch_idx):
56
loss, logs = self.step(batch, batch_idx)
57
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
58
return loss
59
60
def validation_step(self, batch, batch_idx):
61
loss, logs = self.step(batch, batch_idx)
62
self.log_dict({f"val_{k}": v for k, v in logs.items()})
63
return loss
64
65
def configure_optimizers(self):
66
return torch.optim.Adam(self.parameters(), lr=self.lr)
67
68
def load_model(self):
69
try:
70
self.load_state_dict(torch.load(f"{self.model.name}_celeba_conv.ckpt"))
71
except FileNotFoundError:
72
print(f"Please train the model using python run.py -c ./configs/{self.model.name}.yaml")
73
74
75
class VAE2stageModule(LightningModule):
76
"""
77
Standard lightning training code.
78
"""
79
80
def __init__(self, stage1, stage2, lr: float = 1e-3, latent_dim: int = 256):
81
82
super(VAE2stageModule, self).__init__()
83
84
self.lr = lr
85
self.stage1 = stage1
86
self.stage2 = stage2
87
self.model_name = stage2.model_name
88
self.latent_dim = latent_dim
89
90
@staticmethod
91
def load_model_from_checkpoint(vae):
92
try:
93
vae.load_state_dict(torch.load(f"{vae.model.name}_celeba_conv.ckpt"))
94
except FileNotFoundError:
95
print(f"Please train the model using python run.py -c ./configs/{vae.model.name}.yaml")
96
97
def load_model(self):
98
self.load_model_from_checkpoint(self.stage1)
99
self.load_model_from_checkpoint(self.stage2)
100
101
def forward(self, x):
102
u = self.stoch_encode(x)
103
return self.decode(u)
104
105
def det_encode(self, x):
106
x = x.to(self.device)
107
u = self.stage2.det_encode(self.stage1.det_encode(x))
108
return u
109
110
def stoch_encode(self, x):
111
x = x.to(self.device)
112
u = self.stage2.stoch_encode(self.stage1.stoch_encode(x))
113
return u
114
115
def decode(self, u):
116
return self.stage1.decode(self.stage2.decode(u))
117
118
def get_samples(self, num):
119
u = torch.randn(num, self.latent_dim)
120
u = u.to(self.device)
121
return self.decode(u)
122
123
124
class VQVAEModule(LightningModule):
125
"""
126
Standard lightning training code.
127
"""
128
129
def __init__(self, model, config):
130
131
super(VQVAEModule, self).__init__()
132
133
self.lr = config["exp_params"]["LR"]
134
self.model = model
135
self.config = config
136
self.model_name = config["exp_params"]["model_name"]
137
self.latent_dim = config["encoder_params"]["latent_dim"]
138
139
def forward(self, x):
140
x = x.to(self.device)
141
return self.model(x)
142
143
def det_encode(self, x):
144
x = x.to(self.device)
145
z = self.model.encoder(x)[0]
146
return z
147
148
def qunatize_encode(self, x):
149
x = x.to(self.device)
150
z = self.model.encoder(x)[0]
151
quantized_inputs, _ = self.model.vq_layer(z)
152
return quantized_inputs
153
154
def decode(self, z):
155
return self.model.decoder(z)
156
157
def set_pixel_cnn(self):
158
num_residual_blocks = self.config["pixel_params"]["num_residual_blocks"]
159
num_pixelcnn_layers = self.config["pixel_params"]["num_pixelcnn_layers"]
160
num_embeddings = self.config["vq_params"]["num_embeddings"]
161
hidden_dim = self.config["pixel_params"]["hidden_dim"]
162
pixel_cnn_raw = PixelCNN(hidden_dim, num_residual_blocks, num_pixelcnn_layers, num_embeddings)
163
pixel_cnn = PixelCNNModule(
164
pixel_cnn_raw,
165
self,
166
self.config["pixel_params"]["height"],
167
self.config["pixel_params"]["width"],
168
self.config["pixel_params"]["LR"],
169
)
170
self.pixel_cnn = pixel_cnn
171
self.pixel_cnn.to(self.device)
172
173
def get_samples(self, num):
174
# Warning these numbers are hardcoded for the default archiecture
175
if self.pixel_cnn is None:
176
raise "Pixel cnn not define please use set_pixel_cnn method first"
177
178
priors = self.pixel_cnn.get_priors(num)
179
generated_samples = self.pixel_cnn.generate_samples_from_priors(priors)
180
return generated_samples
181
182
def step(self, batch, batch_idx):
183
x, y = batch
184
185
loss = self.model.compute_loss(x)
186
187
logs = {
188
"loss": loss,
189
}
190
return loss, logs
191
192
def training_step(self, batch, batch_idx):
193
loss, logs = self.step(batch, batch_idx)
194
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
195
return loss
196
197
def validation_step(self, batch, batch_idx):
198
loss, logs = self.step(batch, batch_idx)
199
self.log_dict({f"val_{k}": v for k, v in logs.items()})
200
return loss
201
202
def configure_optimizers(self):
203
return torch.optim.Adam(self.parameters(), lr=self.lr)
204
205
def load_model_vq_vae(self):
206
try:
207
self.load_state_dict(torch.load(f"{self.model.name}_celeba_conv.ckpt"))
208
except FileNotFoundError:
209
print(f"Please train the model using python run.py -c ./configs/{self.model.name}.yaml")
210
211
def load_model_pixel_cnn(self):
212
try:
213
fpath = self.config["pixel_params"]["save_path"]
214
self.pixel_cnn.load_state_dict(torch.load(f"{fpath}"))
215
except FileNotFoundError:
216
print(f"Please train the model using python run_pixel.py -c ./configs/{self.model.name}.yaml")
217
218
def load_model(self):
219
self.load_model_vq_vae()
220
self.set_pixel_cnn()
221
self.load_model_pixel_cnn()
222
223
224
class PixelCNNModule(LightningModule):
225
def __init__(self, pixel_cnn, vq_vae, height=None, width=None, lr=1e-3):
226
super().__init__()
227
228
self.model = pixel_cnn
229
self.encoder = vq_vae.model.encoder
230
self.decoder = vq_vae.model.decoder
231
self.vector_quantizer = vq_vae.model.vq_layer
232
self.lr = lr
233
self.height, self.width = height, width
234
self.loss_fn = torch.nn.CrossEntropyLoss()
235
236
def forward(self, x):
237
encoding_one_hot = F.one_hot(x, num_classes=self.vector_quantizer.K)
238
encoding_one_hot = (
239
encoding_one_hot.view(-1, self.height, self.width, self.vector_quantizer.K).permute(0, 3, 1, 2).float()
240
)
241
output = self.model(encoding_one_hot) # 256x512x16x16
242
return output
243
244
def training_step(self, batch, batch_idx):
245
x_train, _ = batch # 256x3x16x16
246
with torch.no_grad():
247
encoded_outputs = self.encoder(x_train)[0] # go through encoder
248
codebook_indices = self.vector_quantizer.get_codebook_indices(encoded_outputs) # BHW X 1
249
250
output = self(codebook_indices)
251
loss = self.loss_fn(output, codebook_indices.view(-1, self.height, self.width))
252
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
253
254
return loss
255
256
def sample(self, inputs):
257
device = inputs.device
258
self.model.eval()
259
with torch.no_grad():
260
inputs_ohe = F.one_hot(inputs.long(), num_classes=self.vector_quantizer.K).to(device)
261
inputs_ohe = (
262
inputs_ohe.view(-1, self.height, self.width, self.vector_quantizer.K).permute(0, 3, 1, 2).float()
263
)
264
x = self.model(inputs_ohe)
265
dist = torch.distributions.Categorical(logits=x.permute(0, 2, 3, 1))
266
sampled = dist.sample()
267
268
self.model.train()
269
return sampled
270
271
def get_priors(self, batch):
272
priors = torch.zeros(size=(batch,) + (1, self.height, self.width), device=self.device)
273
# Iterate over the priors because generation has to be done sequentially pixel by pixel.
274
for row in range(self.height):
275
for col in range(self.width):
276
# Feed the whole array and retrieving the pixel value probabilities for the next
277
# pixel.
278
probs = self.sample(priors.view((-1, 1)))
279
# Use the probabilities to pick pixel values and append the values to the priors.
280
priors[:, 0, row, col] = probs[:, row, col]
281
282
priors = priors.squeeze()
283
return priors
284
285
def generate_samples_from_priors(self, priors):
286
priors = priors.to(self.device)
287
priors_ohe = F.one_hot(priors.view(-1, 1).long(), num_classes=self.vector_quantizer.K).squeeze().float()
288
quantized = torch.matmul(priors_ohe, self.vector_quantizer.embedding.weight) # [BHW, D]
289
quantized = quantized.view(-1, self.height, self.width, self.vector_quantizer.D).permute(0, 3, 1, 2)
290
with torch.no_grad():
291
return self.decoder(quantized)
292
293
def configure_optimizers(self):
294
return torch.optim.Adam(self.model.parameters(), lr=self.lr)
295
296
def load_model(self, path):
297
try:
298
self.load_state_dict(torch.load(path))
299
300
except FileNotFoundError:
301
print(f"Please train the model using python run.py -c ./configs/{self.model.name}.yaml")
302
303
def save(self, path="./pixelcnn_model.ckpt"):
304
torch.save(self.state_dict(), path)
305
306