Path: blob/master/deprecated/vae/models/guassian_vae.py
1192 views
import torch1from torch import nn2from typing import Callable345class VAE(nn.Module):6"""7Standard VAE with Gaussian Prior and approx posterior.8"""910def __init__(self, name: str, loss: Callable, encoder: Callable, decoder: Callable, **kwargs) -> None:1112super(VAE, self).__init__()1314self.name = name15self.loss = loss16self.kwargs = kwargs17self.encoder = encoder18self.decoder = decoder1920def forward(self, x):21mu, log_var = self.encoder(x)22z = self.sample(mu, log_var)23return self.decoder(z)2425def _run_step(self, x):26mu, log_var = self.encoder(x)27z = self.sample(mu, log_var)28return z, self.decoder(z), mu, log_var2930def sample(self, mu, log_var):31std = torch.exp(0.5 * log_var)32eps = torch.randn_like(std)33return eps * std + mu3435def compute_loss(self, x):36z, x_hat, mu, logvar = self._run_step(x)3738loss = self.loss(x, x_hat, z, mu, logvar)3940return loss414243