Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/utils/plotting.py
1192 views
1
import torch
2
import matplotlib.pyplot as plt
3
from einops import rearrange
4
from scipy.stats import truncnorm
5
from torchvision.utils import make_grid
6
7
8
def sample_from_truncated_normal(gans, num, threshold=1, num_images_per_row=8, figsize=(10, 10)):
9
values = truncnorm.rvs(-threshold, threshold, size=(num, gans[0].generator.latent_dim))
10
z = torch.from_numpy(values).float()
11
for gan in gans:
12
plotting(gan, z, num_images_per_row, figsize=figsize)
13
14
15
def plotting(gan, z, num_row=8, figsize=(10, 10)):
16
imgs = rearrange(make_grid(gan(z), num_row), "c h w -> h w c").cpu().detach().numpy()
17
plt.figure(figsize=figsize)
18
plt.imshow(imgs)
19
plt.title(f"{gan.name}")
20
plt.savefig(f"{gan.name}.png")
21
plt.show()
22
23