Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/vanilla_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def kl_divergence(mean, logvar):
10
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
11
12
13
def loss(config, x, x_hat, z, mu, logvar):
14
recons_loss = F.mse_loss(x_hat, x, reduction="mean")
15
16
kld_loss = kl_divergence(mu, logvar)
17
18
loss = recons_loss + config["kl_coeff"] * kld_loss
19
return loss
20
21
22
class Encoder(nn.Module):
23
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
24
super(Encoder, self).__init__()
25
26
modules = []
27
if hidden_dims is None:
28
hidden_dims = [32, 64, 128, 256, 512]
29
30
# Build Encoder
31
for h_dim in hidden_dims:
32
modules.append(
33
nn.Sequential(
34
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
35
nn.BatchNorm2d(h_dim),
36
nn.LeakyReLU(),
37
)
38
)
39
in_channels = h_dim
40
41
self.encoder = nn.Sequential(*modules)
42
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
43
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
44
45
def forward(self, x):
46
x = self.encoder(x)
47
x = torch.flatten(x, start_dim=1)
48
mu = self.fc_mu(x)
49
log_var = self.fc_var(x)
50
return mu, log_var
51
52
53
class Decoder(nn.Module):
54
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
55
super(Decoder, self).__init__()
56
57
# Build Decoder
58
modules = []
59
60
if hidden_dims is None:
61
hidden_dims = [32, 64, 128, 256, 512]
62
hidden_dims.reverse()
63
64
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
65
66
for i in range(len(hidden_dims) - 1):
67
modules.append(
68
nn.Sequential(
69
nn.ConvTranspose2d(
70
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
71
),
72
nn.BatchNorm2d(hidden_dims[i + 1]),
73
nn.LeakyReLU(),
74
)
75
)
76
77
self.decoder = nn.Sequential(*modules)
78
self.final_layer = nn.Sequential(
79
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
80
nn.BatchNorm2d(hidden_dims[-1]),
81
nn.LeakyReLU(),
82
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
83
nn.Sigmoid(),
84
)
85
86
def forward(self, z):
87
result = self.decoder_input(z)
88
result = result.view(-1, 512, 2, 2)
89
result = self.decoder(result)
90
result = self.final_layer(result)
91
return result
92
93