Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/vq_vae.py
1192 views
1
import torch
2
from torch import nn
3
from torch.nn import functional as F
4
from torch import Tensor
5
from typing import Optional, Callable
6
7
8
class VectorQuantizer(nn.Module):
9
"""
10
Reference:
11
[1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
12
"""
13
14
def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25):
15
super(VectorQuantizer, self).__init__()
16
self.K = num_embeddings
17
self.D = embedding_dim
18
self.beta = beta
19
20
self.embedding = nn.Embedding(self.K, self.D)
21
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
22
23
def get_codebook_indices(self, latents: Tensor) -> Tensor:
24
flat_latents = latents.view(-1, self.D) # [BHW x D]
25
26
# Compute L2 distance between latents and embedding weights
27
dist = (
28
torch.sum(flat_latents**2, dim=1, keepdim=True)
29
+ torch.sum(self.embedding.weight**2, dim=1)
30
- 2 * torch.matmul(flat_latents, self.embedding.weight.t())
31
) # [BHW x K]
32
33
# Get the encoding that has the min distance
34
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1]
35
return encoding_inds
36
37
def forward(self, latents: Tensor) -> Tensor:
38
latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D]
39
latents_shape = latents.shape
40
encoding_inds = self.get_codebook_indices(latents)
41
42
# Convert to one-hot encodings
43
device = latents.device
44
encoding_one_hot = torch.nn.functional.one_hot(encoding_inds, num_classes=self.K).float().to(device)
45
46
# Quantize the latents
47
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
48
quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
49
50
# Compute the VQ Losses
51
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
52
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
53
54
vq_loss = commitment_loss * self.beta + embedding_loss
55
56
# Add the residue back to the latents
57
quantized_latents = latents + (quantized_latents - latents).detach()
58
59
return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]
60
61
62
class ResidualLayer(nn.Module):
63
def __init__(self, in_channels: int, out_channels: int):
64
super(ResidualLayer, self).__init__()
65
self.resblock = nn.Sequential(
66
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
67
nn.ReLU(True),
68
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
69
)
70
71
def forward(self, input: Tensor) -> Tensor:
72
return input + self.resblock(input)
73
74
75
class Encoder(nn.Module):
76
def __init__(self, in_channels: int = 3, hidden_dims: Optional[list] = None, latent_dim: int = 256):
77
super(Encoder, self).__init__()
78
79
modules = []
80
if hidden_dims is None:
81
hidden_dims = [128, 256]
82
83
# Build Encoder
84
for h_dim in hidden_dims:
85
modules.append(
86
nn.Sequential(
87
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1), nn.LeakyReLU()
88
)
89
)
90
in_channels = h_dim
91
92
modules.append(
93
nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), nn.LeakyReLU())
94
)
95
96
for _ in range(6):
97
modules.append(ResidualLayer(in_channels, in_channels))
98
modules.append(nn.LeakyReLU())
99
100
modules.append(nn.Sequential(nn.Conv2d(in_channels, latent_dim, kernel_size=1, stride=1), nn.LeakyReLU()))
101
102
self.encoder = nn.Sequential(*modules)
103
104
def forward(self, x):
105
result = self.encoder(x)
106
return [result]
107
108
109
class Decoder(nn.Module):
110
def __init__(self, hidden_dims: Optional[list] = None, latent_dim: int = 256):
111
super(Decoder, self).__init__()
112
113
modules = []
114
115
if hidden_dims is None:
116
hidden_dims = [128, 256]
117
118
modules.append(
119
nn.Sequential(nn.Conv2d(latent_dim, hidden_dims[-1], kernel_size=3, stride=1, padding=1), nn.LeakyReLU())
120
)
121
122
for _ in range(6):
123
modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))
124
125
modules.append(nn.LeakyReLU())
126
127
hidden_dims.reverse()
128
129
for i in range(len(hidden_dims) - 1):
130
modules.append(
131
nn.Sequential(
132
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=4, stride=2, padding=1),
133
nn.LeakyReLU(),
134
)
135
)
136
137
modules.append(
138
nn.Sequential(
139
nn.ConvTranspose2d(hidden_dims[-1], out_channels=3, kernel_size=4, stride=2, padding=1), nn.Tanh()
140
)
141
)
142
143
self.decoder = nn.Sequential(*modules)
144
145
def forward(self, z):
146
result = self.decoder(z)
147
return result
148
149
150
def loss(config, x_hat, x, vq_loss):
151
152
recons_loss = F.mse_loss(x_hat, x)
153
154
loss = recons_loss + vq_loss
155
return loss
156
157
158
class VQVAE(nn.Module):
159
def __init__(self, name: str, loss: Callable, encoder: Callable, decoder: Callable, config: dict) -> None:
160
super(VQVAE, self).__init__()
161
162
self.name = name
163
self.loss = loss
164
self.encoder = encoder
165
self.decoder = decoder
166
self.vq_layer = VectorQuantizer(config["num_embeddings"], config["embedding_dim"], config["beta"])
167
168
def forward(self, x: Tensor):
169
encoding = self.encoder(x)[0]
170
quantized_inputs, vq_loss = self.vq_layer(encoding)
171
return self.decoder(quantized_inputs)
172
173
def _run_step(self, x: Tensor):
174
encoding = self.encoder(x)[0]
175
quantized_inputs, vq_loss = self.vq_layer(encoding)
176
return self.decoder(quantized_inputs), x, vq_loss
177
178
def compute_loss(self, x):
179
x_hat, x, vq_loss = self._run_step(x)
180
181
loss = self.loss(x_hat, x, vq_loss)
182
183
return loss
184
185