Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/utils/interpolation.py
1192 views
1
import torch
2
import numpy as np
3
import pandas as pd
4
from einops import rearrange
5
from typing import Callable
6
from torchvision.utils import make_grid
7
8
9
def get_imgs_and_attr(batch):
10
imgs, attr = batch
11
df = pd.DataFrame(
12
attr.numpy(),
13
columns=[
14
"5_o_Clock_Shadow",
15
"Arched_Eyebrows",
16
"Attractive",
17
"Bags_Under_Eyes",
18
"Bald",
19
"Bangs",
20
"Big_Lips",
21
"Big_Nose",
22
"Black_Hair",
23
"Blond_Hair",
24
"Blurry",
25
"Brown_Hair",
26
"Bushy_Eyebrows",
27
"Chubby",
28
"Double_Chin",
29
"Eyeglasses",
30
"Goatee",
31
"Gray_Hair",
32
"Heavy_Makeup",
33
"High_Cheekbones",
34
"Male",
35
"Mouth_Slightly_Open",
36
"Mustache",
37
"Narrow_Eyes",
38
"No_Beard",
39
"Oval_Face",
40
"Pale_Skin",
41
"Pointy_Nose",
42
"Receding_Hairline",
43
"Rosy_Cheeks",
44
"Sideburns",
45
"Smiling",
46
"Straight_Hair",
47
"Wavy_Hair",
48
"Wearing_Earrings",
49
"Wearing_Hat",
50
"Wearing_Lipstick",
51
"Wearing_Necklace",
52
"Wearing_Necktie",
53
"Young",
54
],
55
)
56
return imgs, df
57
58
59
def vector_of_interest(vae, batch, feature_of_interest="Male"):
60
imgs, attr = get_imgs_and_attr(batch)
61
id = np.array(attr.index)
62
get_id_of_all_absent = id[attr[feature_of_interest] == 0]
63
get_id_of_all_present = id[attr[feature_of_interest] == 1]
64
present = imgs[get_id_of_all_present]
65
absent = imgs[get_id_of_all_absent]
66
z_present = vae.det_encode(present).mean(axis=0)
67
z_absent = vae.det_encode(absent).mean(axis=0)
68
label_vector = z_present - z_absent
69
return label_vector, present, absent
70
71
72
def get_interpolation(interpolation):
73
"""
74
interpolation: can accept either string or function
75
"""
76
if interpolation == "spherical":
77
return slerp
78
elif interpolation == "linear":
79
return lerp
80
elif callable(interpolation):
81
return interpolation
82
83
84
def lerp(val, low, high):
85
"""Linear interpolation"""
86
return low + (high - low) * val
87
88
89
def slerp(val, low, high):
90
"""Spherical interpolation. val has a range of 0 to 1."""
91
if val <= 0:
92
return low
93
elif val >= 1:
94
return high
95
elif torch.allclose(low, high):
96
return low
97
omega = torch.arccos(torch.dot(low / torch.norm(low), high / torch.norm(high)))
98
so = torch.sin(omega)
99
return torch.sin((1.0 - val) * omega) / so * low + torch.sin(val * omega) / so * high
100
101
102
def make_imrange(arr: list):
103
interpolation = torch.stack(arr)
104
imgs = rearrange(make_grid(interpolation, 11), "c h w -> h w c")
105
imgs = imgs.cpu().detach().numpy() if torch.cuda.is_available() else imgs.detach().numpy()
106
return imgs
107
108
109
def get_imrange(
110
G: Callable[[torch.tensor], torch.tensor],
111
start: torch.tensor,
112
end: torch.tensor,
113
nums: int = 8,
114
interpolation="spherical",
115
) -> torch.tensor:
116
"""
117
Decoder must produce a 3d vector to be appened togther to form a new grid
118
"""
119
val = 0
120
arr2 = []
121
inter = get_interpolation(interpolation)
122
for val in torch.linspace(0, 1, nums):
123
new_z = torch.unsqueeze(inter(val, start[0], end[0]), 0)
124
arr2.append(G(new_z)[0])
125
return make_imrange(arr2)
126
127