Path: blob/master/deprecated/vae/standalone/lvm_plots_utils.py
1192 views
import umap1from typing import Callable, Tuple2import torch3import numpy as np4import matplotlib5import matplotlib.pyplot as plt6from einops import rearrange7from torchvision.utils import make_grid8from scipy.stats import truncnorm9from scipy.stats import norm10from sklearn.manifold import TSNE1112device = torch.device("cuda" if torch.cuda.is_available() else "cpu")131415def get_interpolation(interpolation):16"""17interpolation: can accept either string or function18"""19if interpolation == "spherical":20return slerp21elif interpolation == "linear":22return lerp23elif callable(interpolation):24return interpolation252627def get_embedder(encoder, X_data, y_data=None, use_embedder="TSNE"):28X_data_2D = encoder(X_data)29if X_data_2D.shape[-1] == 2:30return X_data_2D31if use_embedder == "UMAP":32umap_fn = umap.UMAP()33X_data_2D = umap_fn.fit_transform(X_data_2D.cpu().detach().numpy(), y_data)34elif use_embedder == "TSNE":35tsne = TSNE()36X_data_2D = tsne.fit_transform(X_data_2D.cpu().detach().numpy())37return X_data_2D383940def lerp(val, low, high):41"""Linear interpolation"""42return low + (high - low) * val434445def slerp(val, low, high):46"""Spherical interpolation. val has a range of 0 to 1."""47if val <= 0:48return low49elif val >= 1:50return high51elif torch.allclose(low, high):52return low53omega = torch.arccos(torch.dot(low / torch.norm(low), high / torch.norm(high)))54so = torch.sin(omega)55return torch.sin((1.0 - val) * omega) / so * low + torch.sin(val * omega) / so * high565758def make_imrange(arr: list):59interpolation = torch.stack(arr)60imgs = rearrange(make_grid(interpolation, 11), "c h w -> h w c")61imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available() else imgs.detach().numpy()62return imgs636465def get_imrange(66G: Callable[[torch.tensor], torch.tensor],67start: torch.tensor,68end: torch.tensor,69nums: int = 8,70interpolation="spherical",71) -> torch.tensor:72"""73Decoder must produce a 3d vector to be appened togther to form a new grid74"""75val = 076arr2 = []77inter = get_interpolation(interpolation)78for val in torch.linspace(0, 1, nums):79new_z = torch.unsqueeze(inter(val, start, end), 0)80arr2.append(G(new_z))81return make_imrange(arr2)828384def get_random_samples(85decoder: Callable[[torch.tensor], torch.tensor],86truncation_threshold=1,87latent_dim=20,88num_images=64,89num_images_per_row=8,90) -> torch.tensor:91"""92Decoder must produce a 4d vector to be feed into make_grid93"""94values = truncnorm.rvs(-truncation_threshold, truncation_threshold, size=(num_images, latent_dim))95z = torch.from_numpy(values).float()96z = z.to(device)97imgs = rearrange(make_grid(decoder(z), num_images_per_row), "c h w -> h w c").cpu().detach().numpy()98return imgs99100101def get_grid_samples(102decoder: Callable[[torch.tensor], torch.tensor], latent_size: int = 2, size: int = 10, max_z: float = 3.1103) -> torch.tensor:104"""105Decoder must produce a 3d vector to be appened togther to form a new grid106"""107arr = []108for i in range(0, size):109z1 = (((i / (size - 1)) * max_z) * 2) - max_z110for j in range(0, size):111z2 = (((j / (size - 1)) * max_z) * 2) - max_z112z_ = torch.tensor([[z1, z2] + (latent_size - 2) * [0]], device=device)113decoded = decoder(z_)114arr.append(decoded)115return torch.stack(arr)116117118def plot_scatter_plot(batch, encoder, use_embedder="TSNE", min_distance=0.03):119"""120Plots scatter plot of embeddings121"""122X_data, y_data = batch123X_data = X_data.to(device)124np.random.seed(42)125X_data_2D = get_embedder(encoder, X_data, y_data, use_embedder)126X_data_2D = (X_data_2D - X_data_2D.min()) / (X_data_2D.max() - X_data_2D.min())127128# adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html129fig = plt.figure(figsize=(10, 8))130cmap = plt.cm.tab10131plt.scatter(X_data_2D[:, 0], X_data_2D[:, 1], c=y_data, s=10, cmap=cmap)132image_positions = np.array([[1.0, 1.0]])133for index, position in enumerate(X_data_2D):134dist = np.sum((position - image_positions) ** 2, axis=1)135if np.min(dist) > 0.04: # if far enough from other images136image_positions = np.r_[image_positions, [position]]137if X_data[index].shape[0] == 3:138imagebox = matplotlib.offsetbox.AnnotationBbox(139matplotlib.offsetbox.OffsetImage(rearrange(X_data[index].cpu(), "c h w -> h w c"), cmap="binary"),140position,141bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2},142)143elif X_data[index].shape[0] == 1:144imagebox = matplotlib.offsetbox.AnnotationBbox(145matplotlib.offsetbox.OffsetImage(rearrange(X_data[index].cpu(), "c h w -> (c h) w"), cmap="binary"),146position,147bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2},148)149plt.gca().add_artist(imagebox)150plt.axis("off")151return fig152153154def plot_grid_plot(batch, encoder, use_cdf=False, use_embedder="TSNE", model_name="VAE mnist"):155"""156This takes in images in batch, so G should produce a 3D tensor output example157for a model that outputs images with a channel dim along with a batch dim we need158to rearrange the tensor as such to produce the correct shape159def decoder(z):160return rearrange(m.decode(z), "b c h w -> b (c h) w")161"""162figsize = 8163example_images, example_labels = batch164example_images = example_images.to(device=device)165166z_points = get_embedder(encoder, example_images, use_embedder=use_embedder)167p_points = norm.cdf(z_points)168169fig = plt.figure(figsize=(figsize, figsize))170if use_cdf:171plt.scatter(p_points[:, 0], p_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=5)172else:173plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2)174plt.colorbar()175plt.title(f"{model_name} embedding")176return fig177178179def plot_grid_plot_with_sample(batch, encoder, decoder, use_embedder="TSNE", model_name="VAE mnist"):180"""181This takes in images in batch, so G should produce a 3D tensor output example182for a model that outputs images with a channel dim along with a batch dim we need183to rearrange the tensor as such to produce the correct shape184def decoder(z):185return rearrange(m.decode(z), "b c h w -> b (c h) w")186"""187figsize = 8188example_images, example_labels = batch189example_images = example_images.to(device=device)190191z_points = get_embedder(encoder, example_images, use_embedder=use_embedder)192plt.figure(figsize=(figsize, figsize))193# plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)194plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2)195plt.colorbar()196197grid_size = 15198grid_depth = 2199np.random.seed(42)200x_min = np.min(z_points[:, 0])201x_max = np.max(z_points[:, 0])202y_min = np.min(z_points[:, 1])203y_max = np.max(z_points[:, 1])204x = np.random.uniform(low=x_min, high=x_max, size=grid_size * grid_depth)205y = np.random.uniform(low=y_min, high=y_max, size=grid_size * grid_depth)206207z_grid = np.array(list(zip(x, y)))208t_z_grid = torch.FloatTensor(z_grid).to(device)209reconst = decoder(t_z_grid)210reconst = reconst.cpu().detach() if torch.cuda.is_available() else reconst.detach()211plt.scatter(z_grid[:, 0], z_grid[:, 1], c="red", alpha=1, s=20)212n = np.shape(z_grid)[0]213for i in range(n):214x = z_grid[i, 0]215y = z_grid[i, 1]216plt.text(x, y, i)217plt.title(f"{model_name} embedding with samples")218219fig = plt.figure(figsize=(figsize, grid_depth))220fig.subplots_adjust(hspace=0.4, wspace=0.4)221for i in range(grid_size * grid_depth):222ax = fig.add_subplot(grid_depth, grid_size, i + 1)223ax.axis("off")224# ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=8, ha='center', transform=ax.transAxes)225ax.text(0.5, -0.35, str(i))226ax.imshow(reconst[i, :], cmap="Greys")227228229