Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/sigma_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
import numpy as np
7
from typing import Optional
8
9
10
def softclip(tensor, min):
11
"""Clips the tensor values at the minimum value min in a softway. Taken from Handful of Trials"""
12
result_tensor = min + F.softplus(tensor - min)
13
14
return result_tensor
15
16
17
def kl_divergence(mean, logvar):
18
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
19
20
21
def gaussian_nll(mu, log_sigma, x):
22
return 0.5 * torch.pow((x - mu) / log_sigma.exp(), 2) + log_sigma + 0.5 * np.log(2 * np.pi)
23
24
25
def loss(config, x, x_hat, z, mu, logvar):
26
27
log_sigma = torch.tensor(((x - x_hat) ** 2).mean([0, 1, 2, 3], keepdim=True).sqrt().log())
28
log_sigma = softclip(log_sigma, -6)
29
recons_loss = gaussian_nll(x_hat, log_sigma, x).mean()
30
31
kld_loss = kl_divergence(mu, logvar)
32
33
loss = recons_loss + config["kl_coeff"] * kld_loss
34
return loss
35
36
37
class Encoder(nn.Module):
38
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
39
super(Encoder, self).__init__()
40
41
modules = []
42
if hidden_dims is None:
43
hidden_dims = [32, 64, 128, 256, 512]
44
45
# Build Encoder
46
for h_dim in hidden_dims:
47
modules.append(
48
nn.Sequential(
49
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
50
nn.BatchNorm2d(h_dim),
51
nn.LeakyReLU(),
52
)
53
)
54
in_channels = h_dim
55
56
self.encoder = nn.Sequential(*modules)
57
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
58
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
59
60
def forward(self, x):
61
x = self.encoder(x)
62
x = torch.flatten(x, start_dim=1)
63
mu = self.fc_mu(x)
64
log_var = self.fc_var(x)
65
return mu, log_var
66
67
68
class Decoder(nn.Module):
69
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
70
super(Decoder, self).__init__()
71
72
# Build Decoder
73
modules = []
74
75
if hidden_dims is None:
76
hidden_dims = [32, 64, 128, 256, 512]
77
hidden_dims.reverse()
78
79
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
80
81
for i in range(len(hidden_dims) - 1):
82
modules.append(
83
nn.Sequential(
84
nn.ConvTranspose2d(
85
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
86
),
87
nn.BatchNorm2d(hidden_dims[i + 1]),
88
nn.LeakyReLU(),
89
)
90
)
91
92
self.decoder = nn.Sequential(*modules)
93
self.final_layer = nn.Sequential(
94
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
95
nn.BatchNorm2d(hidden_dims[-1]),
96
nn.LeakyReLU(),
97
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
98
nn.Sigmoid(),
99
)
100
101
def forward(self, z):
102
result = self.decoder_input(z)
103
result = result.view(-1, 512, 2, 2)
104
result = self.decoder(result)
105
result = self.final_layer(result)
106
return result
107
108