Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/features.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/metrics/generate.py
6
7
import math
8
9
from tqdm import tqdm
10
import torch
11
import numpy as np
12
13
import utils.sample as sample
14
import utils.losses as losses
15
16
17
def generate_images_and_stack_features(generator, discriminator, eval_model, num_generate, y_sampler, batch_size, z_prior,
18
truncation_factor, z_dim, num_classes, LOSS, RUN, MODEL, is_stylegan, generator_mapping,
19
generator_synthesis, quantize, world_size, DDP, device, logger, disable_tqdm):
20
eval_model.eval()
21
feature_holder, prob_holder, fake_label_holder = [], [], []
22
23
if device == 0 and not disable_tqdm:
24
logger.info("generate images and stack features ({} images).".format(num_generate))
25
num_batches = int(math.ceil(float(num_generate) / float(batch_size)))
26
if DDP: num_batches = num_batches//world_size + 1
27
for i in tqdm(range(num_batches), disable=disable_tqdm):
28
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=z_prior,
29
truncation_factor=truncation_factor,
30
batch_size=batch_size,
31
z_dim=z_dim,
32
num_classes=num_classes,
33
y_sampler=y_sampler,
34
radius="N/A",
35
generator=generator,
36
discriminator=discriminator,
37
is_train=False,
38
LOSS=LOSS,
39
RUN=RUN,
40
MODEL=MODEL,
41
is_stylegan=is_stylegan,
42
generator_mapping=generator_mapping,
43
generator_synthesis=generator_synthesis,
44
style_mixing_p=0.0,
45
device=device,
46
stylegan_update_emas=False,
47
cal_trsp_cost=False)
48
49
with torch.no_grad():
50
features, logits = eval_model.get_outputs(fake_images, quantize=quantize)
51
probs = torch.nn.functional.softmax(logits, dim=1)
52
53
feature_holder.append(features)
54
prob_holder.append(probs)
55
fake_label_holder.append(fake_labels)
56
57
feature_holder = torch.cat(feature_holder, 0)
58
prob_holder = torch.cat(prob_holder, 0)
59
fake_label_holder = torch.cat(fake_label_holder, 0)
60
61
if DDP:
62
feature_holder = torch.cat(losses.GatherLayer.apply(feature_holder), dim=0)
63
prob_holder = torch.cat(losses.GatherLayer.apply(prob_holder), dim=0)
64
fake_label_holder = torch.cat(losses.GatherLayer.apply(fake_label_holder), dim=0)
65
return feature_holder, prob_holder, list(fake_label_holder.detach().cpu().numpy())
66
67
68
def sample_images_from_loader_and_stack_features(dataloader, eval_model, batch_size, quantize,
69
world_size, DDP, device, disable_tqdm):
70
eval_model.eval()
71
total_instance = len(dataloader.dataset)
72
num_batches = math.ceil(float(total_instance) / float(batch_size))
73
if DDP: num_batches = int(math.ceil(float(total_instance) / float(batch_size*world_size)))
74
data_iter = iter(dataloader)
75
76
if device == 0 and not disable_tqdm:
77
print("Sample images and stack features ({} images).".format(total_instance))
78
79
feature_holder, prob_holder, label_holder = [], [], []
80
for i in tqdm(range(0, num_batches), disable=disable_tqdm):
81
try:
82
images, labels = next(data_iter)
83
except StopIteration:
84
break
85
86
images, labels = images.to(device), labels.to(device)
87
88
with torch.no_grad():
89
features, logits = eval_model.get_outputs(images, quantize=quantize)
90
probs = torch.nn.functional.softmax(logits, dim=1)
91
92
feature_holder.append(features)
93
prob_holder.append(probs)
94
label_holder.append(labels.to("cuda"))
95
96
feature_holder = torch.cat(feature_holder, 0)
97
prob_holder = torch.cat(prob_holder, 0)
98
label_holder = torch.cat(label_holder, 0)
99
100
if DDP:
101
feature_holder = torch.cat(losses.GatherLayer.apply(feature_holder), dim=0)
102
prob_holder = torch.cat(losses.GatherLayer.apply(prob_holder), dim=0)
103
label_holder = torch.cat(losses.GatherLayer.apply(label_holder), dim=0)
104
return feature_holder, prob_holder, list(label_holder.detach().cpu().numpy())
105
106
107
def stack_features(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm):
108
eval_model.eval()
109
data_iter = iter(data_loader)
110
num_batches = math.ceil(float(num_feats) / float(batch_size))
111
if DDP: num_batches = num_batches//world_size + 1
112
113
real_feats, real_probs, real_labels = [], [], []
114
for i in tqdm(range(0, num_batches), disable=disable_tqdm):
115
start = i * batch_size
116
end = start + batch_size
117
try:
118
images, labels = next(data_iter)
119
except StopIteration:
120
break
121
122
images, labels = images.to(device), labels.to(device)
123
124
with torch.no_grad():
125
embeddings, logits = eval_model.get_outputs(images, quantize=quantize)
126
probs = torch.nn.functional.softmax(logits, dim=1)
127
real_feats.append(embeddings)
128
real_probs.append(probs)
129
real_labels.append(labels)
130
131
real_feats = torch.cat(real_feats, dim=0)
132
real_probs = torch.cat(real_probs, dim=0)
133
real_labels = torch.cat(real_labels, dim=0)
134
if DDP:
135
real_feats = torch.cat(losses.GatherLayer.apply(real_feats), dim=0)
136
real_probs = torch.cat(losses.GatherLayer.apply(real_probs), dim=0)
137
real_labels = torch.cat(losses.GatherLayer.apply(real_labels), dim=0)
138
139
real_feats = real_feats.detach().cpu().numpy().astype(np.float64)
140
real_probs = real_probs.detach().cpu().numpy().astype(np.float64)
141
real_labels = real_labels.detach().cpu().numpy()
142
return real_feats, real_probs, real_labels
143
144