Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/evaluate.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/evaluate.py
6
7
from argparse import ArgumentParser
8
import os
9
import random
10
11
from torch.utils.data import Dataset
12
from torch.utils.data import DataLoader
13
from torch.utils.data.distributed import DistributedSampler
14
from torchvision.transforms import InterpolationMode
15
from torchvision.datasets import ImageFolder
16
from torch.backends import cudnn
17
from PIL import Image
18
import torch
19
import torch.multiprocessing as mp
20
import torchvision.transforms as transforms
21
import numpy as np
22
import pickle
23
24
import utils.misc as misc
25
import metrics.preparation as pp
26
import metrics.features as features
27
import metrics.ins as ins
28
import metrics.fid as fid
29
import metrics.prdc as prdc
30
31
32
resizer_collection = {"nearest": InterpolationMode.NEAREST,
33
"box": InterpolationMode.BOX,
34
"bilinear": InterpolationMode.BILINEAR,
35
"hamming": InterpolationMode.HAMMING,
36
"bicubic": InterpolationMode.BICUBIC,
37
"lanczos": InterpolationMode.LANCZOS}
38
39
40
class CenterCropLongEdge(object):
41
"""
42
this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
43
MIT License
44
Copyright (c) 2019 Andy Brock
45
"""
46
def __call__(self, img):
47
return transforms.functional.center_crop(img, min(img.size))
48
49
def __repr__(self):
50
return self.__class__.__name__
51
52
53
class Dataset_(Dataset):
54
def __init__(self, data_dir):
55
super(Dataset_, self).__init__()
56
self.data_dir = data_dir
57
self.trsf_list = [transforms.PILToTensor()]
58
self.trsf = transforms.Compose(self.trsf_list)
59
60
self.load_dataset()
61
62
def load_dataset(self):
63
self.data = ImageFolder(root=self.data_dir)
64
65
def __len__(self):
66
num_dataset = len(self.data)
67
return num_dataset
68
69
def __getitem__(self, index):
70
img, label = self.data[index]
71
return self.trsf(img), int(label)
72
73
74
def prepare_evaluation():
75
parser = ArgumentParser(add_help=True)
76
parser.add_argument("-metrics", "--eval_metrics", nargs='+', default=['fid'],
77
help="evaluation metrics to use during training, a subset list of ['fid', 'is', 'prdc'] or none")
78
parser.add_argument("--post_resizer", type=str, default="legacy", help="which resizer will you use to evaluate GANs\
79
in ['legacy', 'clean', 'friendly']")
80
parser.add_argument('--eval_backbone', type=str, default='InceptionV3_tf',\
81
help="[InceptionV3_tf, InceptionV3_torch, ResNet50_torch, SwAV_torch, DINO_torch, Swin-T_torch]")
82
parser.add_argument("--dset1", type=str, default=None, help="specify the directory of the folder that contains dset1 images (real).")
83
parser.add_argument("--dset1_feats", type=str, default=None, help="specify the path of *.npy that contains features of dset1 (real). \
84
If not specified, StudioGAN will automatically extract feat1 using the whole dset1.")
85
parser.add_argument("--dset1_moments", type=str, default=None, help="specify the path of *.npy that contains moments (mu, sigma) of dset1 (real). \
86
If not specified, StudioGAN will automatically extract moments using the whole dset1.")
87
parser.add_argument("--dset2", type=str, default=None, help="specify the directory of the folder that contains dset2 images (fake).")
88
parser.add_argument("--batch_size", default=256, type=int, help="batch_size for evaluation")
89
90
parser.add_argument("--seed", type=int, default=-1, help="seed for generating random numbers")
91
parser.add_argument("-DDP", "--distributed_data_parallel", action="store_true")
92
parser.add_argument("--backend", type=str, default="nccl", help="cuda backend for DDP training \in ['nccl', 'gloo']")
93
parser.add_argument("-tn", "--total_nodes", default=1, type=int, help="total number of nodes for training")
94
parser.add_argument("-cn", "--current_node", default=0, type=int, help="rank of the current node")
95
parser.add_argument("--num_workers", type=int, default=8)
96
args = parser.parse_args()
97
98
if args.dset1_feats == None and args.dset1_moments == None:
99
assert args.dset1 != None, "dset1 should be specified!"
100
if "fid" in args.eval_metrics:
101
assert args.dset1 != None or args.dset1_moments != None, "Either dset1 or dset1_moments should be given to compute FID."
102
if "prdc" in args.eval_metrics:
103
assert args.dset1 != None or args.dset1_feats != None, "Either dset1 or dset1_feats should be given to compute PRDC."
104
105
gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device()
106
world_size = gpus_per_node * args.total_nodes
107
if args.seed == -1: args.seed = random.randint(1, 4096)
108
if world_size == 1: print("You have chosen a specific GPU. This will completely disable data parallelism.")
109
return args, world_size, gpus_per_node, rank
110
111
112
def evaluate(local_rank, args, world_size, gpus_per_node):
113
# -----------------------------------------------------------------------------
114
# determine cuda, cudnn, and backends settings.
115
# -----------------------------------------------------------------------------
116
cudnn.benchmark, cudnn.deterministic = False, True
117
118
# -----------------------------------------------------------------------------
119
# initialize all processes and fix seed of each process
120
# -----------------------------------------------------------------------------
121
if args.distributed_data_parallel:
122
global_rank = args.current_node * (gpus_per_node) + local_rank
123
print("Use GPU: {global_rank} for training.".format(global_rank=global_rank))
124
misc.setup(global_rank, world_size, args.backend)
125
torch.cuda.set_device(local_rank)
126
else:
127
global_rank = local_rank
128
129
misc.fix_seed(args.seed + global_rank)
130
131
# -----------------------------------------------------------------------------
132
# load dset1 and dset1.
133
# -----------------------------------------------------------------------------
134
load_dset1 = ("fid" in args.eval_metrics and args.dset1_moments == None) or \
135
("prdc" in args.eval_metrics and args.dset1_feats == None)
136
if load_dset1:
137
dset1 = Dataset_(data_dir=args.dset1)
138
if local_rank == 0:
139
print("Size of dset1: {dataset_size}".format(dataset_size=len(dset1)))
140
141
dset2 = Dataset_(data_dir=args.dset2)
142
if local_rank == 0:
143
print("Size of dset2: {dataset_size}".format(dataset_size=len(dset2)))
144
145
# -----------------------------------------------------------------------------
146
# define a distributed sampler for DDP evaluation.
147
# -----------------------------------------------------------------------------
148
if args.distributed_data_parallel:
149
batch_size = args.batch_size//world_size
150
if load_dset1:
151
dset1_sampler = DistributedSampler(dset1,
152
num_replicas=world_size,
153
rank=local_rank,
154
shuffle=False,
155
drop_last=False)
156
157
dset2_sampler = DistributedSampler(dset2,
158
num_replicas=world_size,
159
rank=local_rank,
160
shuffle=False,
161
drop_last=False)
162
else:
163
batch_size = args.batch_size
164
dset1_sampler, dset2_sampler = None, None
165
166
# -----------------------------------------------------------------------------
167
# define dataloaders for dset1 and dset2.
168
# -----------------------------------------------------------------------------
169
if load_dset1:
170
dset1_dataloader = DataLoader(dataset=dset1,
171
batch_size=batch_size,
172
shuffle=False,
173
pin_memory=True,
174
num_workers=args.num_workers,
175
sampler=dset1_sampler,
176
drop_last=False)
177
178
dset2_dataloader = DataLoader(dataset=dset2,
179
batch_size=batch_size,
180
shuffle=False,
181
pin_memory=True,
182
num_workers=args.num_workers,
183
sampler=dset2_sampler,
184
drop_last=False)
185
186
# -----------------------------------------------------------------------------
187
# load a pre-trained network (InceptionV3 or ResNet50 trained using SwAV).
188
# -----------------------------------------------------------------------------
189
eval_model = pp.LoadEvalModel(eval_backbone=args.eval_backbone,
190
post_resizer=args.post_resizer,
191
world_size=world_size,
192
distributed_data_parallel=args.distributed_data_parallel,
193
device=local_rank)
194
195
# -----------------------------------------------------------------------------
196
# extract features, probabilities, and labels to calculate metrics.
197
# -----------------------------------------------------------------------------
198
if load_dset1:
199
dset1_feats, dset1_probs, dset1_labels = features.sample_images_from_loader_and_stack_features(
200
dataloader=dset1_dataloader,
201
eval_model=eval_model,
202
batch_size=batch_size,
203
quantize=False,
204
world_size=world_size,
205
DDP=args.distributed_data_parallel,
206
device=local_rank,
207
disable_tqdm=local_rank != 0)
208
209
dset2_feats, dset2_probs, dset2_labels = features.sample_images_from_loader_and_stack_features(
210
dataloader=dset2_dataloader,
211
eval_model=eval_model,
212
batch_size=batch_size,
213
quantize=False,
214
world_size=world_size,
215
DDP=args.distributed_data_parallel,
216
device=local_rank,
217
disable_tqdm=local_rank != 0)
218
219
# -----------------------------------------------------------------------------
220
# calculate metrics.
221
# -----------------------------------------------------------------------------
222
metric_dict = {}
223
if "is" in args.eval_metrics:
224
num_splits = 1
225
if load_dset1:
226
dset1_kl_score, dset1_kl_std, dset1_top1, dset1_top5 = ins.eval_features(probs=dset1_probs,
227
labels=dset1_labels,
228
data_loader=dset1_dataloader,
229
num_features=len(dset1),
230
split=num_splits,
231
is_acc=False,
232
is_torch_backbone=True if "torch" in args.eval_backbone else False)
233
234
dset2_kl_score, dset2_kl_std, dset2_top1, dset2_top5 = ins.eval_features(
235
probs=dset2_probs,
236
labels=dset2_labels,
237
data_loader=dset2_dataloader,
238
num_features=len(dset2),
239
split=num_splits,
240
is_acc=False,
241
is_torch_backbone=True if "torch" in args.eval_backbone else False)
242
if local_rank == 0:
243
metric_dict.update({"IS": dset2_kl_score, "Top1_acc": dset2_top1, "Top5_acc": dset2_top5})
244
if load_dset1:
245
print("Inception score of dset1 ({num} images): {IS}".format(num=str(len(dset1)), IS=dset1_kl_score))
246
print("Inception score of dset2 ({num} images): {IS}".format(num=str(len(dset2)), IS=dset2_kl_score))
247
248
if "fid" in args.eval_metrics:
249
if args.dset1_moments is None:
250
mu1 = np.mean(dset1_feats.detach().cpu().numpy().astype(np.float64)[:len(dset1)], axis=0)
251
sigma1 = np.cov(dset1_feats.detach().cpu().numpy().astype(np.float64)[:len(dset1)], rowvar=False)
252
else:
253
mu1, sigma1 = np.load(args.dset1_moments)["mu"], np.load(args.dset1_moments)["sigma"]
254
255
mu2 = np.mean(dset2_feats.detach().cpu().numpy().astype(np.float64)[:len(dset2)], axis=0)
256
sigma2 = np.cov(dset2_feats.detach().cpu().numpy().astype(np.float64)[:len(dset2)], rowvar=False)
257
258
fid_score = fid.frechet_inception_distance(mu1, sigma1, mu2, sigma2)
259
if local_rank == 0:
260
metric_dict.update({"FID": fid_score})
261
if args.dset1_moments is None:
262
print("FID between dset1 and dset2 (dset1: {num1} images, dset2: {num2} images): {fid}".\
263
format(num1=str(len(dset1)), num2=str(len(dset2)), fid=fid_score))
264
else:
265
print("FID between pre-calculated dset1 moments and dset2 (dset2: {num2} images): {fid}".\
266
format(num2=str(len(dset2)), fid=fid_score))
267
268
if "prdc" in args.eval_metrics:
269
nearest_k = 5
270
if args.dset1_feats is None:
271
dset1_feats_np = np.array(dset1_feats.detach().cpu().numpy(), dtype=np.float64)[:len(dset1)]
272
dset1_mode = "dset1"
273
else:
274
dset1_feats_np = np.load(args.dset1_feats, mmap_mode='r')["real_feats"]
275
dset1_mode = "pre-calculated dset1_feats"
276
dset2_feats_np = np.array(dset2_feats.detach().cpu().numpy(), dtype=np.float64)[:len(dset2)]
277
metrics = prdc.compute_prdc(real_features=dset1_feats_np, fake_features=dset2_feats_np, nearest_k=nearest_k)
278
prc, rec, dns, cvg = metrics["precision"], metrics["recall"], metrics["density"], metrics["coverage"]
279
if local_rank == 0:
280
metric_dict.update({"Improved_Precision": prc, "Improved_Recall": rec, "Density": dns, "Coverage": cvg})
281
print("Improved Precision between {dset1_mode} (ref) and dset2 (target) ({dset1_mode}: {num1} images, dset2: {num2} images): {prc}".\
282
format(dset1_mode=str(dset1_mode), num1=str(len(dset1_feats_np)), num2=str(len(dset2_feats_np)), prc=prc))
283
print("Improved Recall between {dset1_mode} (ref) and dset2 (target) ({dset1_mode}: {num1} images, dset2: {num2} images): {rec}".\
284
format(dset1_mode=str(dset1_mode), num1=str(len(dset1_feats_np)), num2=str(len(dset2_feats_np)), rec=rec))
285
print("Density between {dset1_mode} (ref) and dset2 (target) ({dset1_mode}: {num1} images, dset2: {num2} images): {dns}".\
286
format(dset1_mode=str(dset1_mode), num1=str(len(dset1_feats_np)), num2=str(len(dset2_feats_np)), dns=dns))
287
print("Coverage between {dset1_mode} (ref) and dset2 (target) ({dset1_mode}: {num1} images, dset2: {num2} images): {cvg}".\
288
format(dset1_mode=str(dset1_mode), num1=str(len(dset1_feats_np)), num2=str(len(dset2_feats_np)), cvg=cvg))
289
290
291
if __name__ == "__main__":
292
args, world_size, gpus_per_node, rank = prepare_evaluation()
293
294
if args.distributed_data_parallel and world_size > 1:
295
mp.set_start_method("spawn", force=True)
296
print("Train the models through DistributedDataParallel (DDP) mode.")
297
try:
298
torch.multiprocessing.spawn(fn=evaluate,
299
args=(args,
300
world_size,
301
gpus_per_node),
302
nprocs=gpus_per_node)
303
except KeyboardInterrupt:
304
misc.cleanup()
305
else:
306
evaluate(local_rank=rank,
307
args=args,
308
world_size=world_size,
309
gpus_per_node=gpus_per_node)
310
311