Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aamini
GitHub Repository: aamini/introtodeeplearning
Path: blob/master/mitdeeplearning/util.py
547 views
1
import matplotlib.pyplot as plt
2
import time
3
import numpy as np
4
5
from IPython import display as ipythondisplay
6
from string import Formatter
7
8
9
def display_model(model):
10
import tensorflow as tf
11
tf.keras.utils.plot_model(model, to_file="tmp.png", show_shapes=True)
12
return ipythondisplay.Image("tmp.png")
13
14
15
def plot_sample(x, y, vae, backend='tf'):
16
"""Plot original and reconstructed images side by side.
17
18
Args:
19
x: Input images array of shape [B, H, W, C] (TF) or [B, C, H, W] (PT)
20
y: Labels array of shape [B] where 1 indicates a face
21
vae: VAE model (TensorFlow or PyTorch)
22
framework: 'tf' or 'pt' indicating which framework to use
23
"""
24
plt.figure(figsize=(2, 1))
25
26
if backend == 'tf':
27
idx = np.where(y == 1)[0][0]
28
_, _, _, recon = vae(x)
29
recon = np.clip(recon, 0, 1)
30
31
elif backend == 'pt':
32
import torch
33
y = y.detach().cpu().numpy()
34
face_indices = np.where(y == 1)[0]
35
idx = face_indices[0] if len(face_indices) > 0 else 0
36
37
with torch.inference_mode():
38
_, _, _, recon = vae(x)
39
recon = torch.clamp(recon, 0, 1)
40
recon = recon.permute(0, 2, 3, 1).detach().cpu().numpy()
41
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
42
43
else:
44
raise ValueError("framework must be 'tf' or 'pt'")
45
46
plt.subplot(1, 2, 1)
47
plt.imshow(x[idx])
48
plt.grid(False)
49
50
plt.subplot(1, 2, 2)
51
plt.imshow(recon[idx])
52
plt.grid(False)
53
54
if backend == 'pt':
55
plt.show()
56
57
58
class LossHistory:
59
def __init__(self, smoothing_factor=0.0):
60
self.alpha = smoothing_factor
61
self.loss = []
62
63
def append(self, value):
64
self.loss.append(
65
self.alpha * self.loss[-1] + (1 - self.alpha) * value
66
if len(self.loss) > 0
67
else value
68
)
69
70
def get(self):
71
return self.loss
72
73
74
class PeriodicPlotter:
75
def __init__(self, sec, xlabel="", ylabel="", scale=None):
76
self.xlabel = xlabel
77
self.ylabel = ylabel
78
self.sec = sec
79
self.scale = scale
80
81
self.tic = time.time()
82
83
def plot(self, data):
84
if time.time() - self.tic > self.sec:
85
plt.cla()
86
87
if self.scale is None:
88
plt.plot(data)
89
elif self.scale == "semilogx":
90
plt.semilogx(data)
91
elif self.scale == "semilogy":
92
plt.semilogy(data)
93
elif self.scale == "loglog":
94
plt.loglog(data)
95
else:
96
raise ValueError("unrecognized parameter scale {}".format(self.scale))
97
98
plt.xlabel(self.xlabel)
99
plt.ylabel(self.ylabel)
100
ipythondisplay.clear_output(wait=True)
101
ipythondisplay.display(plt.gcf())
102
103
self.tic = time.time()
104
105
106
def create_grid_of_images(xs, size=(5, 5)):
107
"""Combine a list of images into a single image grid by stacking them into an array of shape `size`"""
108
109
grid = []
110
counter = 0
111
for i in range(size[0]):
112
row = []
113
for j in range(size[1]):
114
row.append(xs[counter])
115
counter += 1
116
row = np.hstack(row)
117
grid.append(row)
118
grid = np.vstack(grid)
119
return grid
120
121