Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/utils/plot.py
1192 views
1
import matplotlib.pyplot as plt
2
import torchvision.utils as vutils
3
from einops import rearrange
4
5
6
def plot(model_samples, title, figsize=(10, 30), num_of_images_per_row=5, filename=None):
7
plt.figure(figsize=figsize)
8
img1 = vutils.make_grid(model_samples, nrow=num_of_images_per_row).cpu().detach().numpy()
9
plt.title(title)
10
plt.imshow(rearrange(img1, "c h w -> h w c"))
11
plt.axis("off")
12
if filename is not None:
13
# plt.tight_layout()
14
plt.savefig(filename, bbox_inches="tight")
15
plt.show()
16
17
18
def plot_samples(vaes, num=25, figsize=(10, 30), num_of_images_per_row=5, figdir=None):
19
filename = None
20
if hasattr(vaes, "__iter__"): # list of models
21
for vae in vaes:
22
if figdir is not None:
23
filename = f"{figdir}/vae-samples-{vae.model_name}.png"
24
plot_samples(vae, num, figsize, num_of_images_per_row, filename)
25
else:
26
vae = vaes # single model
27
model_samples = vae.get_samples(num)
28
title = f"Samples from {vae.model_name}"
29
if figdir is not None:
30
filename = f"{figdir}/vae-samples-{vae.model_name}.png"
31
plot(model_samples, title, figsize, num_of_images_per_row, figdir)
32
33
34
def plot_reconstruction(vaes, batch, num_of_samples=5, num_of_images_per_row=5, figsize=(10, 30), figdir=None):
35
x, y = batch
36
img = x[:num_of_samples, :, :, :]
37
filename = None
38
if figdir is not None:
39
filename = f"{figdir}/vae-recon-original.png"
40
plot(img, "Original", figsize, num_of_images_per_row, filename)
41
42
if hasattr(vaes, "__iter__"):
43
for vae in vaes:
44
title = f"Reconstruction from {vae.model_name}"
45
if figdir is not None:
46
filename = f"{figdir}/vae-recon-{vae.model_name}.png"
47
plot(vae(img), title, figsize, num_of_images_per_row, filename)
48
else:
49
vae = vaes
50
title = f"Reconstruction from {vae.model_name}"
51
if figdir is not None:
52
filename = f"{figdir}/vae-recon-{vae.model_name}.png"
53
plot(vae(img), title, figsize, num_of_images_per_row, filename)
54
55