Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/sampling/dgflow.py
1192 views
1
from typing import Callable
2
import torch
3
import numpy as np
4
from torch import Tensor
5
from typing import Callable
6
7
8
def sampling(config: dict, G: Callable, D: Callable, z_img: Tensor):
9
eta = config["eta"]
10
noise_factor = config["noise_factor"]
11
num_steps = config["num_steps"]
12
13
def _velocity(z_img, D, G):
14
z_img_t = z_img.clone()
15
z_img_t.requires_grad_(True)
16
if z_img_t.grad is not None:
17
z_img_t.grad.zero_()
18
d_score = D(G(z_img_t))
19
d_score.backward(torch.ones_like(d_score).to(z_img.device))
20
return z_img_t.grad.data
21
22
def refine_samples(z_img, G, D):
23
for _ in range(1, num_steps):
24
v = _velocity(z_img, D, G)
25
z_img = z_img.data + eta * v + np.sqrt(2 * eta) * noise_factor * torch.randn_like(z_img)
26
return z_img
27
28
return refine_samples(z_img, G, D)
29
30