Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/mmd_vae.py
1192 views
1
# -*- coding: utf-8 -*-
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from typing import Optional
7
8
9
def compute_kernel(x1: torch.Tensor, x2: torch.Tensor, kernel_type: str = "rbf") -> torch.Tensor:
10
# Convert the tensors into row and column vectors
11
D = x1.size(1)
12
N = x1.size(0)
13
14
x1 = x1.unsqueeze(-2) # Make it into a column tensor
15
x2 = x2.unsqueeze(-3) # Make it into a row tensor
16
17
"""
18
Usually the below lines are not required, especially in our case,
19
but this is useful when x1 and x2 have different sizes
20
along the 0th dimension.
21
"""
22
x1 = x1.expand(N, N, D)
23
x2 = x2.expand(N, N, D)
24
25
if kernel_type == "rbf":
26
result = compute_rbf(x1, x2)
27
elif kernel_type == "imq":
28
result = compute_inv_mult_quad(x1, x2)
29
else:
30
raise ValueError("Undefined kernel type.")
31
32
return result
33
34
35
def compute_rbf(x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7) -> torch.Tensor:
36
"""
37
Computes the RBF Kernel between x1 and x2.
38
:param x1: (Tensor)
39
:param x2: (Tensor)
40
:param eps: (Float)
41
:return:
42
"""
43
z_dim = x2.size(-1)
44
sigma = (2.0 / z_dim) * latent_var
45
46
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
47
return result
48
49
50
def compute_inv_mult_quad(
51
x1: torch.Tensor, x2: torch.Tensor, latent_var: float = 2.0, eps: float = 1e-7
52
) -> torch.Tensor:
53
"""
54
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
55
given by
56
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
57
:param x1: (Tensor)
58
:param x2: (Tensor)
59
:param eps: (Float)
60
:return:
61
"""
62
z_dim = x2.size(-1)
63
C = 2 * z_dim * latent_var
64
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))
65
66
# Exclude diagonal elements
67
result = kernel.sum() - kernel.diag().sum()
68
69
return result
70
71
72
def MMD(prior_z: torch.Tensor, z: torch.Tensor):
73
74
prior_z__kernel = compute_kernel(prior_z, prior_z)
75
z__kernel = compute_kernel(z, z)
76
priorz_z__kernel = compute_kernel(prior_z, z)
77
78
mmd = prior_z__kernel.mean() + z__kernel.mean() - 2 * priorz_z__kernel.mean()
79
return mmd
80
81
82
def loss(config, x, x_hat, z, mu, logvar):
83
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
84
mmd = MMD(torch.randn_like(z), z)
85
86
loss = recon_loss + config["beta"] * mmd
87
return loss
88
89
90
class Encoder(nn.Module):
91
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
92
super(Encoder, self).__init__()
93
94
modules = []
95
if hidden_dims is None:
96
hidden_dims = [32, 64, 128, 256, 512]
97
98
# Build Encoder
99
for h_dim in hidden_dims:
100
modules.append(
101
nn.Sequential(
102
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
103
nn.BatchNorm2d(h_dim),
104
nn.LeakyReLU(),
105
)
106
)
107
in_channels = h_dim
108
109
self.encoder = nn.Sequential(*modules)
110
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
111
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
112
113
def forward(self, x):
114
x = self.encoder(x)
115
x = torch.flatten(x, start_dim=1)
116
mu = self.fc_mu(x)
117
return mu, torch.zeros_like(mu)
118
119
120
class Decoder(nn.Module):
121
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
122
super(Decoder, self).__init__()
123
124
# Build Decoder
125
modules = []
126
127
if hidden_dims is None:
128
hidden_dims = [32, 64, 128, 256, 512]
129
hidden_dims.reverse()
130
131
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 4)
132
133
for i in range(len(hidden_dims) - 1):
134
modules.append(
135
nn.Sequential(
136
nn.ConvTranspose2d(
137
hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1
138
),
139
nn.BatchNorm2d(hidden_dims[i + 1]),
140
nn.LeakyReLU(),
141
)
142
)
143
144
self.decoder = nn.Sequential(*modules)
145
self.final_layer = nn.Sequential(
146
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
147
nn.BatchNorm2d(hidden_dims[-1]),
148
nn.LeakyReLU(),
149
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
150
nn.Sigmoid(),
151
)
152
153
def forward(self, z):
154
result = self.decoder_input(z)
155
result = result.view(-1, 512, 2, 2)
156
result = self.decoder(result)
157
result = self.final_layer(result)
158
return result
159
160