Path: blob/master/deprecated/vae/utils/interpolation.py
1192 views
import torch1import numpy as np2import pandas as pd3from einops import rearrange4from typing import Callable5from torchvision.utils import make_grid678def get_imgs_and_attr(batch):9imgs, attr = batch10df = pd.DataFrame(11attr.numpy(),12columns=[13"5_o_Clock_Shadow",14"Arched_Eyebrows",15"Attractive",16"Bags_Under_Eyes",17"Bald",18"Bangs",19"Big_Lips",20"Big_Nose",21"Black_Hair",22"Blond_Hair",23"Blurry",24"Brown_Hair",25"Bushy_Eyebrows",26"Chubby",27"Double_Chin",28"Eyeglasses",29"Goatee",30"Gray_Hair",31"Heavy_Makeup",32"High_Cheekbones",33"Male",34"Mouth_Slightly_Open",35"Mustache",36"Narrow_Eyes",37"No_Beard",38"Oval_Face",39"Pale_Skin",40"Pointy_Nose",41"Receding_Hairline",42"Rosy_Cheeks",43"Sideburns",44"Smiling",45"Straight_Hair",46"Wavy_Hair",47"Wearing_Earrings",48"Wearing_Hat",49"Wearing_Lipstick",50"Wearing_Necklace",51"Wearing_Necktie",52"Young",53],54)55return imgs, df565758def vector_of_interest(vae, batch, feature_of_interest="Male"):59imgs, attr = get_imgs_and_attr(batch)60id = np.array(attr.index)61get_id_of_all_absent = id[attr[feature_of_interest] == 0]62get_id_of_all_present = id[attr[feature_of_interest] == 1]63present = imgs[get_id_of_all_present]64absent = imgs[get_id_of_all_absent]65z_present = vae.det_encode(present).mean(axis=0)66z_absent = vae.det_encode(absent).mean(axis=0)67label_vector = z_present - z_absent68return label_vector, present, absent697071def get_interpolation(interpolation):72"""73interpolation: can accept either string or function74"""75if interpolation == "spherical":76return slerp77elif interpolation == "linear":78return lerp79elif callable(interpolation):80return interpolation818283def lerp(val, low, high):84"""Linear interpolation"""85return low + (high - low) * val868788def slerp(val, low, high):89"""Spherical interpolation. val has a range of 0 to 1."""90if val <= 0:91return low92elif val >= 1:93return high94elif torch.allclose(low, high):95return low96omega = torch.arccos(torch.dot(low / torch.norm(low), high / torch.norm(high)))97so = torch.sin(omega)98return torch.sin((1.0 - val) * omega) / so * low + torch.sin(val * omega) / so * high99100101def make_imrange(arr: list):102interpolation = torch.stack(arr)103imgs = rearrange(make_grid(interpolation, 11), "c h w -> h w c")104imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available() else imgs.detach().numpy()105return imgs106107108def get_imrange(109G: Callable[[torch.tensor], torch.tensor],110start: torch.tensor,111end: torch.tensor,112nums: int = 8,113interpolation="spherical",114) -> torch.tensor:115"""116Decoder must produce a 3d vector to be appened togther to form a new grid117"""118val = 0119arr2 = []120inter = get_interpolation(interpolation)121for val in torch.linspace(0, 1, nums):122new_z = torch.unsqueeze(inter(val, start[0], end[0]), 0)123arr2.append(G(new_z)[0])124return make_imrange(arr2)125126127