Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/vanilla_ae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def kl_divergence(mean, logvar):
10
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
11
12
13
def loss(config, x, x_hat, z, mu, logvar):
14
loss = F.mse_loss(x_hat, x, reduction="mean")
15
16
return loss
17
18
19
class Encoder(nn.Module):
20
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
21
super(Encoder, self).__init__()
22
23
modules = []
24
if hidden_dims is None:
25
hidden_dims = [32, 64, 128, 256, 512]
26
27
# Build Encoder
28
for h_dim in hidden_dims:
29
modules.append(
30
nn.Sequential(
31
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
32
nn.BatchNorm2d(h_dim),
33
nn.LeakyReLU(),
34
)
35
)
36
in_channels = h_dim
37
38
self.encoder = nn.Sequential(*modules)
39
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
40
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
41
42
def forward(self, x):
43
x = self.encoder(x)
44
x = torch.flatten(x, start_dim=1)
45
mu = self.fc_mu(x)
46
return mu, torch.zeros_like(mu)
47
48
49
class Decoder(nn.Module):
50
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
51
super(Decoder, self).__init__()
52
53
# Build Decoder
54
modules = []
55
56
if hidden_dims is None:
57
hidden_dims = [32, 64, 128, 256, 512]
58
hidden_dims.reverse()
59
60
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
61
62
for i in range(len(hidden_dims) - 1):
63
modules.append(
64
nn.Sequential(
65
nn.ConvTranspose2d(
66
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
67
),
68
nn.BatchNorm2d(hidden_dims[i + 1]),
69
nn.LeakyReLU(),
70
)
71
)
72
73
self.decoder = nn.Sequential(*modules)
74
self.final_layer = nn.Sequential(
75
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
76
nn.BatchNorm2d(hidden_dims[-1]),
77
nn.LeakyReLU(),
78
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
79
nn.Sigmoid(),
80
)
81
82
def forward(self, z):
83
result = self.decoder_input(z)
84
result = result.view(-1, 512, 2, 2)
85
result = self.decoder(result)
86
result = self.final_layer(result)
87
return result
88
89