Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/fid.py
809 views
1
#!/usr/bin/env python3
2
"""
3
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead of Tensorflow
4
Copyright 2018 Institute of Bioinformatics, JKU Linz
5
Licensed under the Apache License, Version 2.0 (the "License");
6
you may not use this file except in compliance with the License.
7
8
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
9
10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
"""
16
17
from os.path import dirname, abspath, exists, join
18
import math
19
import os
20
import shutil
21
22
from torch.nn import DataParallel
23
from torch.nn.parallel import DistributedDataParallel
24
from torchvision.utils import save_image
25
from scipy import linalg
26
from tqdm import tqdm
27
import torch
28
import numpy as np
29
30
import utils.sample as sample
31
import utils.losses as losses
32
33
34
def frechet_inception_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
35
mu1 = np.atleast_1d(mu1)
36
mu2 = np.atleast_1d(mu2)
37
38
sigma1 = np.atleast_2d(sigma1)
39
sigma2 = np.atleast_2d(sigma2)
40
41
assert mu1.shape == mu2.shape, \
42
"Training and test mean vectors have different lengths."
43
assert sigma1.shape == sigma2.shape, \
44
"Training and test covariances have different dimensions."
45
46
diff = mu1 - mu2
47
48
# Product might be almost singular
49
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
50
if not np.isfinite(covmean).all():
51
offset = np.eye(sigma1.shape[0]) * eps
52
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
53
54
# Numerical error might give slight imaginary component
55
if np.iscomplexobj(covmean):
56
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
57
m = np.max(np.abs(covmean.imag))
58
raise ValueError("Imaginary component {}".format(m))
59
covmean = covmean.real
60
61
tr_covmean = np.trace(covmean)
62
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
63
64
65
def calculate_moments(data_loader, eval_model, num_generate, batch_size, quantize, world_size,
66
DDP, disable_tqdm, fake_feats=None):
67
if fake_feats is not None:
68
total_instance = num_generate
69
acts = fake_feats.detach().cpu().numpy()[:num_generate]
70
else:
71
eval_model.eval()
72
total_instance = len(data_loader.dataset)
73
data_iter = iter(data_loader)
74
num_batches = math.ceil(float(total_instance) / float(batch_size))
75
if DDP: num_batches = int(math.ceil(float(total_instance) / float(batch_size*world_size)))
76
77
acts = []
78
for i in tqdm(range(0, num_batches), disable=disable_tqdm):
79
start = i * batch_size
80
end = start + batch_size
81
try:
82
images, labels = next(data_iter)
83
except StopIteration:
84
break
85
86
images, labels = images.to("cuda"), labels.to("cuda")
87
88
with torch.no_grad():
89
embeddings, logits = eval_model.get_outputs(images, quantize=quantize)
90
acts.append(embeddings)
91
92
acts = torch.cat(acts, dim=0)
93
if DDP: acts = torch.cat(losses.GatherLayer.apply(acts), dim=0)
94
acts = acts.detach().cpu().numpy()[:total_instance].astype(np.float64)
95
96
mu = np.mean(acts, axis=0)
97
sigma = np.cov(acts, rowvar=False)
98
return mu, sigma
99
100
101
def calculate_fid(data_loader,
102
eval_model,
103
num_generate,
104
cfgs,
105
pre_cal_mean=None,
106
pre_cal_std=None,
107
quantize=True,
108
fake_feats=None,
109
disable_tqdm=False):
110
eval_model.eval()
111
112
if pre_cal_mean is not None and pre_cal_std is not None:
113
m1, s1 = pre_cal_mean, pre_cal_std
114
else:
115
m1, s1 = calculate_moments(data_loader=data_loader,
116
eval_model=eval_model,
117
num_generate="N/A",
118
batch_size=cfgs.OPTIMIZATION.batch_size,
119
quantize=quantize,
120
world_size=cfgs.OPTIMIZATION.world_size,
121
DDP=cfgs.RUN.distributed_data_parallel,
122
disable_tqdm=disable_tqdm,
123
fake_feats=None)
124
125
m2, s2 = calculate_moments(data_loader="N/A",
126
eval_model=eval_model,
127
num_generate=num_generate,
128
batch_size=cfgs.OPTIMIZATION.batch_size,
129
quantize=quantize,
130
world_size=cfgs.OPTIMIZATION.world_size,
131
DDP=cfgs.RUN.distributed_data_parallel,
132
disable_tqdm=disable_tqdm,
133
fake_feats=fake_feats)
134
135
fid_value = frechet_inception_distance(m1, s1, m2, s2)
136
return fid_value, m1, s1
137
138