Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/guassian_vae.py
1192 views
1
import torch
2
from torch import nn
3
from typing import Callable
4
5
6
class VAE(nn.Module):
7
"""
8
Standard VAE with Gaussian Prior and approx posterior.
9
"""
10
11
def __init__(self, name: str, loss: Callable, encoder: Callable, decoder: Callable, **kwargs) -> None:
12
13
super(VAE, self).__init__()
14
15
self.name = name
16
self.loss = loss
17
self.kwargs = kwargs
18
self.encoder = encoder
19
self.decoder = decoder
20
21
def forward(self, x):
22
mu, log_var = self.encoder(x)
23
z = self.sample(mu, log_var)
24
return self.decoder(z)
25
26
def _run_step(self, x):
27
mu, log_var = self.encoder(x)
28
z = self.sample(mu, log_var)
29
return z, self.decoder(z), mu, log_var
30
31
def sample(self, mu, log_var):
32
std = torch.exp(0.5 * log_var)
33
eps = torch.randn_like(std)
34
return eps * std + mu
35
36
def compute_loss(self, x):
37
z, x_hat, mu, logvar = self._run_step(x)
38
39
loss = self.loss(x, x_hat, z, mu, logvar)
40
41
return loss
42
43