Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/preparation.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/metrics/preparation.py
6
7
from os.path import exists, join
8
import os
9
10
try:
11
from torchvision.models.utils import load_state_dict_from_url
12
except ImportError:
13
from torch.utils.model_zoo import load_url as load_state_dict_from_url
14
from torch.nn import DataParallel
15
from torch.nn.parallel import DistributedDataParallel as DDP
16
from PIL import Image
17
import torch
18
import torch.nn as nn
19
import torch.nn.functional as F
20
import torchvision.transforms as transforms
21
import numpy as np
22
23
from metrics.inception_net import InceptionV3
24
from metrics.swin_transformer import SwinTransformer
25
import metrics.features as features
26
import metrics.vit as vits
27
import metrics.fid as fid
28
import metrics.ins as ins
29
import utils.misc as misc
30
import utils.ops as ops
31
import utils.resize as resize
32
33
model_versions = {"InceptionV3_torch": "pytorch/vision:v0.10.0",
34
"ResNet_torch": "pytorch/vision:v0.10.0",
35
"SwAV_torch": "facebookresearch/swav:main"}
36
model_names = {"InceptionV3_torch": "inception_v3",
37
"ResNet50_torch": "resnet50",
38
"SwAV_torch": "resnet50"}
39
SWAV_CLASSIFIER_URL = "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_eval_linear.pth.tar"
40
SWIN_URL = "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth"
41
42
43
class LoadEvalModel(object):
44
def __init__(self, eval_backbone, post_resizer, world_size, distributed_data_parallel, device):
45
super(LoadEvalModel, self).__init__()
46
self.eval_backbone = eval_backbone
47
self.post_resizer = post_resizer
48
self.device = device
49
self.save_output = misc.SaveOutput()
50
51
if self.eval_backbone == "InceptionV3_tf":
52
self.res, mean, std = 299, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
53
self.model = InceptionV3(resize_input=False, normalize_input=False).to(self.device)
54
elif self.eval_backbone in ["InceptionV3_torch", "ResNet50_torch", "SwAV_torch"]:
55
self.res = 299 if "InceptionV3" in self.eval_backbone else 224
56
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
57
self.model = torch.hub.load(model_versions[self.eval_backbone],
58
model_names[self.eval_backbone],
59
pretrained=True)
60
if self.eval_backbone == "SwAV_torch":
61
linear_state_dict = load_state_dict_from_url(SWAV_CLASSIFIER_URL, progress=True)["state_dict"]
62
linear_state_dict = {k.replace("module.linear.", ""): v for k, v in linear_state_dict.items()}
63
self.model.fc.load_state_dict(linear_state_dict, strict=True)
64
self.model = self.model.to(self.device)
65
hook_handles = []
66
for name, layer in self.model.named_children():
67
if name == "fc":
68
handle = layer.register_forward_pre_hook(self.save_output)
69
hook_handles.append(handle)
70
elif self.eval_backbone == "DINO_torch":
71
self.res, mean, std = 224, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
72
self.model = vits.__dict__["vit_small"](patch_size=8, num_classes=1000, num_last_blocks=4)
73
misc.load_pretrained_weights(self.model, "", "teacher", "vit_small", 8)
74
misc.load_pretrained_linear_weights(self.model.linear, "vit_small", 8)
75
self.model = self.model.to(self.device)
76
elif self.eval_backbone == "Swin-T_torch":
77
self.res, mean, std = 224, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
78
self.model = SwinTransformer()
79
model_state_dict = load_state_dict_from_url(SWIN_URL, progress=True)["model"]
80
self.model.load_state_dict(model_state_dict, strict=True)
81
self.model = self.model.to(self.device)
82
else:
83
raise NotImplementedError
84
85
self.resizer = resize.build_resizer(resizer=self.post_resizer, backbone=self.eval_backbone, size=self.res)
86
self.totensor = transforms.ToTensor()
87
self.mean = torch.Tensor(mean).view(1, 3, 1, 1).to(self.device)
88
self.std = torch.Tensor(std).view(1, 3, 1, 1).to(self.device)
89
90
if world_size > 1 and distributed_data_parallel:
91
misc.make_model_require_grad(self.model)
92
self.model = DDP(self.model,
93
device_ids=[self.device],
94
broadcast_buffers=False if self.eval_backbone=="Swin-T_torch" else True)
95
elif world_size > 1 and distributed_data_parallel is False:
96
self.model = DataParallel(self.model, output_device=self.device)
97
else:
98
pass
99
100
def eval(self):
101
self.model.eval()
102
103
def get_outputs(self, x, quantize=False):
104
if quantize:
105
x = ops.quantize_images(x)
106
else:
107
x = x.detach().cpu().numpy().astype(np.uint8)
108
x = ops.resize_images(x, self.resizer, self.totensor, self.mean, self.std, device=self.device)
109
110
if self.eval_backbone in ["InceptionV3_tf", "DINO_torch", "Swin-T_torch"]:
111
repres, logits = self.model(x)
112
elif self.eval_backbone in ["InceptionV3_torch", "ResNet50_torch", "SwAV_torch"]:
113
logits = self.model(x)
114
if len(self.save_output.outputs) > 1:
115
repres = []
116
for rank in range(len(self.save_output.outputs)):
117
repres.append(self.save_output.outputs[rank][0].detach().cpu())
118
repres = torch.cat(repres, dim=0).to(self.device)
119
else:
120
repres = self.save_output.outputs[0][0].to(self.device)
121
self.save_output.clear()
122
return repres, logits
123
124
125
def prepare_moments(data_loader, eval_model, quantize, cfgs, logger, device):
126
disable_tqdm = device != 0
127
eval_model.eval()
128
moment_dir = join(cfgs.RUN.save_dir, "moments")
129
if not exists(moment_dir):
130
os.makedirs(moment_dir)
131
moment_path = join(moment_dir, cfgs.DATA.name + "_" + str(cfgs.DATA.img_size) + "_"+ cfgs.RUN.pre_resizer + "_" + \
132
cfgs.RUN.ref_dataset + "_" + cfgs.RUN.post_resizer + "_" + cfgs.RUN.eval_backbone + "_moments.npz")
133
134
is_file = os.path.isfile(moment_path)
135
if is_file:
136
mu = np.load(moment_path)["mu"]
137
sigma = np.load(moment_path)["sigma"]
138
else:
139
if device == 0:
140
logger.info("Calculate moments of {ref} dataset using {eval_backbone} model.".\
141
format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))
142
mu, sigma = fid.calculate_moments(data_loader=data_loader,
143
eval_model=eval_model,
144
num_generate="N/A",
145
batch_size=cfgs.OPTIMIZATION.batch_size,
146
quantize=quantize,
147
world_size=cfgs.OPTIMIZATION.world_size,
148
DDP=cfgs.RUN.distributed_data_parallel,
149
disable_tqdm=disable_tqdm,
150
fake_feats=None)
151
152
if device == 0:
153
logger.info("Save calculated means and covariances to disk.")
154
np.savez(moment_path, **{"mu": mu, "sigma": sigma})
155
return mu, sigma
156
157
158
def prepare_real_feats(data_loader, eval_model, num_feats, quantize, cfgs, logger, device):
159
disable_tqdm = device != 0
160
eval_model.eval()
161
feat_dir = join(cfgs.RUN.save_dir, "feats")
162
if not exists(feat_dir):
163
os.makedirs(feat_dir)
164
feat_path = join(feat_dir, cfgs.DATA.name + "_" + str(cfgs.DATA.img_size) + "_"+ cfgs.RUN.pre_resizer + "_" + \
165
cfgs.RUN.ref_dataset + "_" + cfgs.RUN.post_resizer + "_" + cfgs.RUN.eval_backbone + "_feats.npz")
166
167
is_file = os.path.isfile(feat_path)
168
if is_file:
169
real_feats = np.load(feat_path)["real_feats"]
170
else:
171
if device == 0:
172
logger.info("Calculate features of {ref} dataset using {eval_backbone} model.".\
173
format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))
174
real_feats, real_probs, real_labels = features.stack_features(data_loader=data_loader,
175
eval_model=eval_model,
176
num_feats=num_feats,
177
batch_size=cfgs.OPTIMIZATION.batch_size,
178
quantize=quantize,
179
world_size=cfgs.OPTIMIZATION.world_size,
180
DDP=cfgs.RUN.distributed_data_parallel,
181
device=device,
182
disable_tqdm=disable_tqdm)
183
if device == 0:
184
logger.info("Save real_features to disk.")
185
np.savez(feat_path, **{"real_feats": real_feats,
186
"real_probs": real_probs,
187
"real_labels": real_labels})
188
return real_feats
189
190
191
def calculate_ins(data_loader, eval_model, quantize, splits, cfgs, logger, device):
192
disable_tqdm = device != 0
193
is_acc = True if "ImageNet" in cfgs.DATA.name and "Tiny" not in cfgs.DATA.name else False
194
if device == 0:
195
logger.info("Calculate inception score of the {ref} dataset uisng pre-trained {eval_backbone} model.".\
196
format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))
197
is_score, is_std, top1, top5 = ins.eval_dataset(data_loader=data_loader,
198
eval_model=eval_model,
199
quantize=quantize,
200
splits=splits,
201
batch_size=cfgs.OPTIMIZATION.batch_size,
202
world_size=cfgs.OPTIMIZATION.world_size,
203
DDP=cfgs.RUN.distributed_data_parallel,
204
is_acc=is_acc,
205
is_torch_backbone=True if "torch" in cfgs.RUN.eval_backbone else False,
206
disable_tqdm=disable_tqdm)
207
if device == 0:
208
logger.info("Inception score={is_score}-Inception_std={is_std}".format(is_score=is_score, is_std=is_std))
209
if is_acc:
210
logger.info("{eval_model} Top1 acc: ({num} images): {Top1}".format(
211
eval_model=cfgs.RUN.eval_backbone, num=str(len(data_loader.dataset)), Top1=top1))
212
logger.info("{eval_model} Top5 acc: ({num} images): {Top5}".format(
213
eval_model=cfgs.RUN.eval_backbone, num=str(len(data_loader.dataset)), Top5=top5))
214
215