Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/hinge_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def kl_divergence(mean, logvar):
10
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
11
12
13
def loss(config, x, x_hat, z, mu, logvar):
14
recons_loss = F.mse_loss(x_hat, x, reduction="mean")
15
16
kld_loss = kl_divergence(mu, logvar)
17
18
loss = recons_loss + config["kl_coeff"] * torch.max(kld_loss, config["delta"] * torch.ones_like(kld_loss))
19
20
return loss
21
22
23
class Encoder(nn.Module):
24
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
25
super(Encoder, self).__init__()
26
27
modules = []
28
if hidden_dims is None:
29
hidden_dims = [32, 64, 128, 256, 512]
30
31
# Build Encoder
32
for h_dim in hidden_dims:
33
modules.append(
34
nn.Sequential(
35
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
36
nn.BatchNorm2d(h_dim),
37
nn.LeakyReLU(),
38
)
39
)
40
in_channels = h_dim
41
42
self.encoder = nn.Sequential(*modules)
43
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
44
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
45
46
def forward(self, x):
47
x = self.encoder(x)
48
x = torch.flatten(x, start_dim=1)
49
mu = self.fc_mu(x)
50
log_var = self.fc_var(x)
51
return mu, log_var
52
53
54
class Decoder(nn.Module):
55
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
56
super(Decoder, self).__init__()
57
58
# Build Decoder
59
modules = []
60
61
if hidden_dims is None:
62
hidden_dims = [32, 64, 128, 256, 512]
63
hidden_dims.reverse()
64
65
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
66
67
for i in range(len(hidden_dims) - 1):
68
modules.append(
69
nn.Sequential(
70
nn.ConvTranspose2d(
71
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
72
),
73
nn.BatchNorm2d(hidden_dims[i + 1]),
74
nn.LeakyReLU(),
75
)
76
)
77
78
self.decoder = nn.Sequential(*modules)
79
self.final_layer = nn.Sequential(
80
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
81
nn.BatchNorm2d(hidden_dims[-1]),
82
nn.LeakyReLU(),
83
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
84
nn.Sigmoid(),
85
)
86
87
def forward(self, z):
88
result = self.decoder_input(z)
89
result = result.view(-1, 512, 2, 2)
90
result = self.decoder(result)
91
result = self.final_layer(result)
92
return result
93
94