Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/prdc.py
809 views
1
"""
2
Copyright (c) 2020-present NAVER Corp.
3
4
Permission is hereby granted, free of charge, to any person obtaining a copy
5
of this software and associated documentation files (the "Software"), to deal
6
in the Software without restriction, including without limitation the rights
7
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
copies of the Software, and to permit persons to whom the Software is
9
furnished to do so, subject to the following conditions:
10
11
The above copyright notice and this permission notice shall be included in
12
all copies or substantial portions of the Software.
13
14
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
THE SOFTWARE.
21
"""
22
23
### Reliable Fidelity and Diversity Metrics for Generative Models (https://arxiv.org/abs/2002.09797)
24
### Muhammad Ferjad Naeem, Seong Joon Oh, Youngjung Uh, Yunjey Choi, Jaejun Yoo
25
### https://github.com/clovaai/generative-evaluation-prdc
26
27
28
from tqdm import tqdm
29
import math
30
31
import torch
32
import numpy as np
33
import sklearn.metrics
34
35
import utils.sample as sample
36
import utils.losses as losses
37
38
__all__ = ["compute_prdc"]
39
40
41
def compute_real_embeddings(data_loader, batch_size, eval_model, quantize, world_size, DDP, disable_tqdm):
42
data_iter = iter(data_loader)
43
num_batches = int(math.ceil(float(len(data_loader.dataset)) / float(batch_size)))
44
if DDP: num_batches = num_batches = int(math.ceil(float(len(data_loader.dataset)) / float(batch_size*world_size)))
45
46
real_embeds = []
47
for i in tqdm(range(num_batches), disable=disable_tqdm):
48
try:
49
real_images, real_labels = next(data_iter)
50
except StopIteration:
51
break
52
53
real_images, real_labels = real_images.to("cuda"), real_labels.to("cuda")
54
55
with torch.no_grad():
56
real_embeddings, _ = eval_model.get_outputs(real_images, quantize=quantize)
57
real_embeds.append(real_embeddings)
58
59
real_embeds = torch.cat(real_embeds, dim=0)
60
if DDP: real_embeds = torch.cat(losses.GatherLayer.apply(real_embeds), dim=0)
61
real_embeds = np.array(real_embeds.detach().cpu().numpy(), dtype=np.float64)
62
return real_embeds[:len(data_loader.dataset)]
63
64
65
def calculate_pr_dc(real_feats, fake_feats, data_loader, eval_model, num_generate, cfgs, quantize, nearest_k,
66
world_size, DDP, disable_tqdm):
67
eval_model.eval()
68
69
if real_feats is None:
70
real_embeds = compute_real_embeddings(data_loader=data_loader,
71
batch_size=cfgs.OPTIMIZATION.batch_size,
72
eval_model=eval_model,
73
quantize=quantize,
74
world_size=world_size,
75
DDP=DDP,
76
disable_tqdm=disable_tqdm)
77
78
real_embeds = real_feats
79
fake_embeds = np.array(fake_feats.detach().cpu().numpy(), dtype=np.float64)[:num_generate]
80
81
metrics = compute_prdc(real_features=real_embeds, fake_features=fake_embeds, nearest_k=nearest_k)
82
83
prc, rec, dns, cvg = metrics["precision"], metrics["recall"], metrics["density"], metrics["coverage"]
84
return prc, rec, dns, cvg
85
86
87
def compute_pairwise_distance(data_x, data_y=None):
88
"""
89
Args:
90
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
91
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
92
Returns:
93
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
94
"""
95
if data_y is None:
96
data_y = data_x
97
dists = sklearn.metrics.pairwise_distances(
98
data_x, data_y, metric='euclidean', n_jobs=8)
99
return dists
100
101
102
def get_kth_value(unsorted, k, axis=-1):
103
"""
104
Args:
105
unsorted: numpy.ndarray of any dimensionality.
106
k: int
107
Returns:
108
kth values along the designated axis.
109
"""
110
indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
111
k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
112
kth_values = k_smallests.max(axis=axis)
113
return kth_values
114
115
116
def compute_nearest_neighbour_distances(input_features, nearest_k):
117
"""
118
Args:
119
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
120
nearest_k: int
121
Returns:
122
Distances to kth nearest neighbours.
123
"""
124
distances = compute_pairwise_distance(input_features)
125
radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
126
return radii
127
128
129
def compute_prdc(real_features, fake_features, nearest_k):
130
"""
131
Computes precision, recall, density, and coverage given two manifolds.
132
Args:
133
real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
134
fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
135
nearest_k: int.
136
Returns:
137
dict of precision, recall, density, and coverage.
138
"""
139
140
real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
141
real_features, nearest_k)
142
fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
143
fake_features, nearest_k)
144
distance_real_fake = compute_pairwise_distance(
145
real_features, fake_features)
146
147
precision = (
148
distance_real_fake <
149
np.expand_dims(real_nearest_neighbour_distances, axis=1)
150
).any(axis=0).mean()
151
152
recall = (
153
distance_real_fake <
154
np.expand_dims(fake_nearest_neighbour_distances, axis=0)
155
).any(axis=1).mean()
156
157
density = (1. / float(nearest_k)) * (
158
distance_real_fake <
159
np.expand_dims(real_nearest_neighbour_distances, axis=1)
160
).sum(axis=0).mean()
161
162
coverage = (
163
distance_real_fake.min(axis=1) <
164
real_nearest_neighbour_distances
165
).mean()
166
167
return dict(precision=precision, recall=recall,
168
density=density, coverage=coverage)
169
170