Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/standalone/lvm_plots_utils.py
1192 views
1
import umap
2
from typing import Callable, Tuple
3
import torch
4
import numpy as np
5
import matplotlib
6
import matplotlib.pyplot as plt
7
from einops import rearrange
8
from torchvision.utils import make_grid
9
from scipy.stats import truncnorm
10
from scipy.stats import norm
11
from sklearn.manifold import TSNE
12
13
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
15
16
def get_interpolation(interpolation):
17
"""
18
interpolation: can accept either string or function
19
"""
20
if interpolation == "spherical":
21
return slerp
22
elif interpolation == "linear":
23
return lerp
24
elif callable(interpolation):
25
return interpolation
26
27
28
def get_embedder(encoder, X_data, y_data=None, use_embedder="TSNE"):
29
X_data_2D = encoder(X_data)
30
if X_data_2D.shape[-1] == 2:
31
return X_data_2D
32
if use_embedder == "UMAP":
33
umap_fn = umap.UMAP()
34
X_data_2D = umap_fn.fit_transform(X_data_2D.cpu().detach().numpy(), y_data)
35
elif use_embedder == "TSNE":
36
tsne = TSNE()
37
X_data_2D = tsne.fit_transform(X_data_2D.cpu().detach().numpy())
38
return X_data_2D
39
40
41
def lerp(val, low, high):
42
"""Linear interpolation"""
43
return low + (high - low) * val
44
45
46
def slerp(val, low, high):
47
"""Spherical interpolation. val has a range of 0 to 1."""
48
if val <= 0:
49
return low
50
elif val >= 1:
51
return high
52
elif torch.allclose(low, high):
53
return low
54
omega = torch.arccos(torch.dot(low / torch.norm(low), high / torch.norm(high)))
55
so = torch.sin(omega)
56
return torch.sin((1.0 - val) * omega) / so * low + torch.sin(val * omega) / so * high
57
58
59
def make_imrange(arr: list):
60
interpolation = torch.stack(arr)
61
imgs = rearrange(make_grid(interpolation, 11), "c h w -> h w c")
62
imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available() else imgs.detach().numpy()
63
return imgs
64
65
66
def get_imrange(
67
G: Callable[[torch.tensor], torch.tensor],
68
start: torch.tensor,
69
end: torch.tensor,
70
nums: int = 8,
71
interpolation="spherical",
72
) -> torch.tensor:
73
"""
74
Decoder must produce a 3d vector to be appened togther to form a new grid
75
"""
76
val = 0
77
arr2 = []
78
inter = get_interpolation(interpolation)
79
for val in torch.linspace(0, 1, nums):
80
new_z = torch.unsqueeze(inter(val, start, end), 0)
81
arr2.append(G(new_z))
82
return make_imrange(arr2)
83
84
85
def get_random_samples(
86
decoder: Callable[[torch.tensor], torch.tensor],
87
truncation_threshold=1,
88
latent_dim=20,
89
num_images=64,
90
num_images_per_row=8,
91
) -> torch.tensor:
92
"""
93
Decoder must produce a 4d vector to be feed into make_grid
94
"""
95
values = truncnorm.rvs(-truncation_threshold, truncation_threshold, size=(num_images, latent_dim))
96
z = torch.from_numpy(values).float()
97
z = z.to(device)
98
imgs = rearrange(make_grid(decoder(z), num_images_per_row), "c h w -> h w c").cpu().detach().numpy()
99
return imgs
100
101
102
def get_grid_samples(
103
decoder: Callable[[torch.tensor], torch.tensor], latent_size: int = 2, size: int = 10, max_z: float = 3.1
104
) -> torch.tensor:
105
"""
106
Decoder must produce a 3d vector to be appened togther to form a new grid
107
"""
108
arr = []
109
for i in range(0, size):
110
z1 = (((i / (size - 1)) * max_z) * 2) - max_z
111
for j in range(0, size):
112
z2 = (((j / (size - 1)) * max_z) * 2) - max_z
113
z_ = torch.tensor([[z1, z2] + (latent_size - 2) * [0]], device=device)
114
decoded = decoder(z_)
115
arr.append(decoded)
116
return torch.stack(arr)
117
118
119
def plot_scatter_plot(batch, encoder, use_embedder="TSNE", min_distance=0.03):
120
"""
121
Plots scatter plot of embeddings
122
"""
123
X_data, y_data = batch
124
X_data = X_data.to(device)
125
np.random.seed(42)
126
X_data_2D = get_embedder(encoder, X_data, y_data, use_embedder)
127
X_data_2D = (X_data_2D - X_data_2D.min()) / (X_data_2D.max() - X_data_2D.min())
128
129
# adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html
130
fig = plt.figure(figsize=(10, 8))
131
cmap = plt.cm.tab10
132
plt.scatter(X_data_2D[:, 0], X_data_2D[:, 1], c=y_data, s=10, cmap=cmap)
133
image_positions = np.array([[1.0, 1.0]])
134
for index, position in enumerate(X_data_2D):
135
dist = np.sum((position - image_positions) ** 2, axis=1)
136
if np.min(dist) > 0.04: # if far enough from other images
137
image_positions = np.r_[image_positions, [position]]
138
if X_data[index].shape[0] == 3:
139
imagebox = matplotlib.offsetbox.AnnotationBbox(
140
matplotlib.offsetbox.OffsetImage(rearrange(X_data[index].cpu(), "c h w -> h w c"), cmap="binary"),
141
position,
142
bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2},
143
)
144
elif X_data[index].shape[0] == 1:
145
imagebox = matplotlib.offsetbox.AnnotationBbox(
146
matplotlib.offsetbox.OffsetImage(rearrange(X_data[index].cpu(), "c h w -> (c h) w"), cmap="binary"),
147
position,
148
bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2},
149
)
150
plt.gca().add_artist(imagebox)
151
plt.axis("off")
152
return fig
153
154
155
def plot_grid_plot(batch, encoder, use_cdf=False, use_embedder="TSNE", model_name="VAE mnist"):
156
"""
157
This takes in images in batch, so G should produce a 3D tensor output example
158
for a model that outputs images with a channel dim along with a batch dim we need
159
to rearrange the tensor as such to produce the correct shape
160
def decoder(z):
161
return rearrange(m.decode(z), "b c h w -> b (c h) w")
162
"""
163
figsize = 8
164
example_images, example_labels = batch
165
example_images = example_images.to(device=device)
166
167
z_points = get_embedder(encoder, example_images, use_embedder=use_embedder)
168
p_points = norm.cdf(z_points)
169
170
fig = plt.figure(figsize=(figsize, figsize))
171
if use_cdf:
172
plt.scatter(p_points[:, 0], p_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=5)
173
else:
174
plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2)
175
plt.colorbar()
176
plt.title(f"{model_name} embedding")
177
return fig
178
179
180
def plot_grid_plot_with_sample(batch, encoder, decoder, use_embedder="TSNE", model_name="VAE mnist"):
181
"""
182
This takes in images in batch, so G should produce a 3D tensor output example
183
for a model that outputs images with a channel dim along with a batch dim we need
184
to rearrange the tensor as such to produce the correct shape
185
def decoder(z):
186
return rearrange(m.decode(z), "b c h w -> b (c h) w")
187
"""
188
figsize = 8
189
example_images, example_labels = batch
190
example_images = example_images.to(device=device)
191
192
z_points = get_embedder(encoder, example_images, use_embedder=use_embedder)
193
plt.figure(figsize=(figsize, figsize))
194
# plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)
195
plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2)
196
plt.colorbar()
197
198
grid_size = 15
199
grid_depth = 2
200
np.random.seed(42)
201
x_min = np.min(z_points[:, 0])
202
x_max = np.max(z_points[:, 0])
203
y_min = np.min(z_points[:, 1])
204
y_max = np.max(z_points[:, 1])
205
x = np.random.uniform(low=x_min, high=x_max, size=grid_size * grid_depth)
206
y = np.random.uniform(low=y_min, high=y_max, size=grid_size * grid_depth)
207
208
z_grid = np.array(list(zip(x, y)))
209
t_z_grid = torch.FloatTensor(z_grid).to(device)
210
reconst = decoder(t_z_grid)
211
reconst = reconst.cpu().detach() if torch.cuda.is_available() else reconst.detach()
212
plt.scatter(z_grid[:, 0], z_grid[:, 1], c="red", alpha=1, s=20)
213
n = np.shape(z_grid)[0]
214
for i in range(n):
215
x = z_grid[i, 0]
216
y = z_grid[i, 1]
217
plt.text(x, y, i)
218
plt.title(f"{model_name} embedding with samples")
219
220
fig = plt.figure(figsize=(figsize, grid_depth))
221
fig.subplots_adjust(hspace=0.4, wspace=0.4)
222
for i in range(grid_size * grid_depth):
223
ax = fig.add_subplot(grid_depth, grid_size, i + 1)
224
ax.axis("off")
225
# ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=8, ha='center', transform=ax.transAxes)
226
ax.text(0.5, -0.35, str(i))
227
ax.imshow(reconst[i, :], cmap="Greys")
228
229