Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/logcosh_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def kl_divergence(mean, logvar):
10
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
11
12
13
def loss(config, x, x_hat, z, mu, logvar):
14
t = x_hat - x
15
recons_loss = (
16
config["alpha"] * t + torch.log(1.0 + torch.exp(-2 * config["alpha"] * t)) - torch.log(torch.tensor(2.0))
17
)
18
recons_loss = (1.0 / config["alpha"]) * recons_loss.mean()
19
20
kld_loss = kl_divergence(mu, logvar)
21
22
loss = recons_loss + config["kl_coeff"] * kld_loss
23
return loss
24
25
26
class Encoder(nn.Module):
27
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
28
super(Encoder, self).__init__()
29
30
modules = []
31
if hidden_dims is None:
32
hidden_dims = [32, 64, 128, 256, 512]
33
34
# Build Encoder
35
for h_dim in hidden_dims:
36
modules.append(
37
nn.Sequential(
38
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
39
nn.BatchNorm2d(h_dim),
40
nn.LeakyReLU(),
41
)
42
)
43
in_channels = h_dim
44
45
self.encoder = nn.Sequential(*modules)
46
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
47
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
48
49
def forward(self, x):
50
x = self.encoder(x)
51
x = torch.flatten(x, start_dim=1)
52
mu = self.fc_mu(x)
53
log_var = self.fc_var(x)
54
return mu, log_var
55
56
57
class Decoder(nn.Module):
58
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
59
super(Decoder, self).__init__()
60
61
# Build Decoder
62
modules = []
63
64
if hidden_dims is None:
65
hidden_dims = [32, 64, 128, 256, 512]
66
hidden_dims.reverse()
67
68
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
69
70
for i in range(len(hidden_dims) - 1):
71
modules.append(
72
nn.Sequential(
73
nn.ConvTranspose2d(
74
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
75
),
76
nn.BatchNorm2d(hidden_dims[i + 1]),
77
nn.LeakyReLU(),
78
)
79
)
80
81
self.decoder = nn.Sequential(*modules)
82
self.final_layer = nn.Sequential(
83
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
84
nn.BatchNorm2d(hidden_dims[-1]),
85
nn.LeakyReLU(),
86
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
87
nn.Sigmoid(),
88
)
89
90
def forward(self, z):
91
result = self.decoder_input(z)
92
result = result.view(-1, 512, 2, 2)
93
result = self.decoder(result)
94
result = self.final_layer(result)
95
return result
96
97