Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/info_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:
10
# Convert the tensors into row and column vectors
11
D = x1.size(1)
12
N = x1.size(0)
13
14
x1 = x1.unsqueeze(-2) # Make it into a column tensor
15
x2 = x2.unsqueeze(-3) # Make it into a row tensor
16
17
"""
18
Usually the below lines are not required, especially in our case,
19
but this is useful when x1 and x2 have different sizes
20
along the 0th dimension.
21
"""
22
x1 = x1.expand(N, N, D)
23
x2 = x2.expand(N, N, D)
24
25
if kernel_type == "rbf":
26
result = compute_rbf(x1, x2)
27
elif kernel_type == "imq":
28
result = compute_inv_mult_quad(x1, x2)
29
else:
30
raise ValueError("Undefined kernel type.")
31
32
return result
33
34
35
def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:
36
"""
37
Computes the RBF Kernel between x1 and x2.
38
:param x1: (Tensor)
39
:param x2: (Tensor)
40
:param eps: (Float)
41
:return:
42
"""
43
z_dim = x2.size(-1)
44
sigma = 2.0 * z_dim * latent_var
45
46
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
47
return result
48
49
50
def compute_inv_mult_quad(
51
x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7
52
) -> torch.Tensor:
53
"""
54
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
55
given by
56
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
57
:param x1: (Tensor)
58
:param x2: (Tensor)
59
:param eps: (Float)
60
:return:
61
"""
62
z_dim = x2.size(-1)
63
C = (2 / z_dim) * latent_var
64
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))
65
66
# Exclude diagonal elements
67
result = kernel.sum() - kernel.diag().sum()
68
69
return result
70
71
72
def MMD(prior_z: torch.Tensor, z: torch.Tensor):
73
74
prior_z__kernel = compute_kernel(prior_z, prior_z)
75
z__kernel = compute_kernel(z, z)
76
priorz_z__kernel = compute_kernel(prior_z, z)
77
78
mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()
79
return mmd
80
81
82
def kl_divergence(mean, logvar):
83
return -0.5 * torch.mean(1 + logvar - torch.square(mean) - torch.exp(logvar))
84
85
86
def loss(config, x, x_hat, z, mu, logvar):
87
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
88
kld_loss = kl_divergence(mu, logvar)
89
90
mmd = MMD(torch.randn_like(z), z)
91
loss = recon_loss + (1 - config["alpha"]) * kld_loss + (config["alpha"] + config["beta"] - 1) * mmd
92
return loss
93
94
95
class Encoder(nn.Module):
96
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
97
super(Encoder, self).__init__()
98
99
modules = []
100
if hidden_dims is None:
101
hidden_dims = [32, 64, 128, 256, 512]
102
103
# Build Encoder
104
for h_dim in hidden_dims:
105
modules.append(
106
nn.Sequential(
107
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
108
nn.BatchNorm2d(h_dim),
109
nn.LeakyReLU(),
110
)
111
)
112
in_channels = h_dim
113
114
self.encoder = nn.Sequential(*modules)
115
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
116
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
117
118
def forward(self, x):
119
x = self.encoder(x)
120
x = torch.flatten(x, start_dim=1)
121
mu = self.fc_mu(x)
122
log_var = self.fc_var(x)
123
return mu, log_var
124
125
126
class Decoder(nn.Module):
127
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
128
super(Decoder, self).__init__()
129
130
# Build Decoder
131
modules = []
132
133
if hidden_dims is None:
134
hidden_dims = [32, 64, 128, 256, 512]
135
hidden_dims.reverse()
136
137
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
138
139
for i in range(len(hidden_dims) - 1):
140
modules.append(
141
nn.Sequential(
142
nn.ConvTranspose2d(
143
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
144
),
145
nn.BatchNorm2d(hidden_dims[i + 1]),
146
nn.LeakyReLU(),
147
)
148
)
149
150
self.decoder = nn.Sequential(*modules)
151
self.final_layer = nn.Sequential(
152
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
153
nn.BatchNorm2d(hidden_dims[-1]),
154
nn.LeakyReLU(),
155
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
156
nn.Sigmoid(),
157
)
158
159
def forward(self, z):
160
result = self.decoder_input(z)
161
result = result.view(-1, 512, 2, 2)
162
result = self.decoder(result)
163
result = self.final_layer(result)
164
return result
165
166