Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/sefa.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/sefa.py
6
7
import torch
8
9
import utils.misc as misc
10
11
12
def apply_sefa(generator, backbone, z, fake_label, num_semantic_axis, maximum_variations, num_cols):
13
generator = misc.peel_model(generator)
14
w = generator.linear0.weight
15
if backbone == "big_resnet":
16
zs = z
17
z = torch.split(zs, generator.chunk_size, 0)[0]
18
eigen_vectors = torch.svd(w).V.to(z.device)[:, :num_semantic_axis]
19
20
z_dim = len(z)
21
zs_start = z.repeat(num_semantic_axis).view(-1, 1, z_dim)
22
zs_end = (z.unsqueeze(1) + maximum_variations * eigen_vectors).T.view(-1, 1, z_dim)
23
if backbone == "big_resnet":
24
zs_shard = zs[z_dim:].expand([1, 1, -1]).repeat(num_semantic_axis, 1, 1)
25
zs_start = torch.cat([zs_start, zs_shard], axis=2)
26
zs_end = torch.cat([zs_end, zs_shard], axis=2)
27
zs_canvas = misc.interpolate(x0=zs_start, x1=zs_end, num_midpoints=num_cols - 2).view(-1, zs_start.shape[-1])
28
images_canvas = generator(zs_canvas, fake_label.repeat(len(zs_canvas)), eval=True)
29
return images_canvas
30
31