Path: blob/master/src/metrics/preparation.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# src/metrics/preparation.py56from os.path import exists, join7import os89try:10from torchvision.models.utils import load_state_dict_from_url11except ImportError:12from torch.utils.model_zoo import load_url as load_state_dict_from_url13from torch.nn import DataParallel14from torch.nn.parallel import DistributedDataParallel as DDP15from PIL import Image16import torch17import torch.nn as nn18import torch.nn.functional as F19import torchvision.transforms as transforms20import numpy as np2122from metrics.inception_net import InceptionV323from metrics.swin_transformer import SwinTransformer24import metrics.features as features25import metrics.vit as vits26import metrics.fid as fid27import metrics.ins as ins28import utils.misc as misc29import utils.ops as ops30import utils.resize as resize3132model_versions = {"InceptionV3_torch": "pytorch/vision:v0.10.0",33"ResNet_torch": "pytorch/vision:v0.10.0",34"SwAV_torch": "facebookresearch/swav:main"}35model_names = {"InceptionV3_torch": "inception_v3",36"ResNet50_torch": "resnet50",37"SwAV_torch": "resnet50"}38SWAV_CLASSIFIER_URL = "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_eval_linear.pth.tar"39SWIN_URL = "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth"404142class LoadEvalModel(object):43def __init__(self, eval_backbone, post_resizer, world_size, distributed_data_parallel, device):44super(LoadEvalModel, self).__init__()45self.eval_backbone = eval_backbone46self.post_resizer = post_resizer47self.device = device48self.save_output = misc.SaveOutput()4950if self.eval_backbone == "InceptionV3_tf":51self.res, mean, std = 299, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]52self.model = InceptionV3(resize_input=False, normalize_input=False).to(self.device)53elif self.eval_backbone in ["InceptionV3_torch", "ResNet50_torch", "SwAV_torch"]:54self.res = 299 if "InceptionV3" in self.eval_backbone else 22455mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]56self.model = torch.hub.load(model_versions[self.eval_backbone],57model_names[self.eval_backbone],58pretrained=True)59if self.eval_backbone == "SwAV_torch":60linear_state_dict = load_state_dict_from_url(SWAV_CLASSIFIER_URL, progress=True)["state_dict"]61linear_state_dict = {k.replace("module.linear.", ""): v for k, v in linear_state_dict.items()}62self.model.fc.load_state_dict(linear_state_dict, strict=True)63self.model = self.model.to(self.device)64hook_handles = []65for name, layer in self.model.named_children():66if name == "fc":67handle = layer.register_forward_pre_hook(self.save_output)68hook_handles.append(handle)69elif self.eval_backbone == "DINO_torch":70self.res, mean, std = 224, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]71self.model = vits.__dict__["vit_small"](patch_size=8, num_classes=1000, num_last_blocks=4)72misc.load_pretrained_weights(self.model, "", "teacher", "vit_small", 8)73misc.load_pretrained_linear_weights(self.model.linear, "vit_small", 8)74self.model = self.model.to(self.device)75elif self.eval_backbone == "Swin-T_torch":76self.res, mean, std = 224, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]77self.model = SwinTransformer()78model_state_dict = load_state_dict_from_url(SWIN_URL, progress=True)["model"]79self.model.load_state_dict(model_state_dict, strict=True)80self.model = self.model.to(self.device)81else:82raise NotImplementedError8384self.resizer = resize.build_resizer(resizer=self.post_resizer, backbone=self.eval_backbone, size=self.res)85self.totensor = transforms.ToTensor()86self.mean = torch.Tensor(mean).view(1, 3, 1, 1).to(self.device)87self.std = torch.Tensor(std).view(1, 3, 1, 1).to(self.device)8889if world_size > 1 and distributed_data_parallel:90misc.make_model_require_grad(self.model)91self.model = DDP(self.model,92device_ids=[self.device],93broadcast_buffers=False if self.eval_backbone=="Swin-T_torch" else True)94elif world_size > 1 and distributed_data_parallel is False:95self.model = DataParallel(self.model, output_device=self.device)96else:97pass9899def eval(self):100self.model.eval()101102def get_outputs(self, x, quantize=False):103if quantize:104x = ops.quantize_images(x)105else:106x = x.detach().cpu().numpy().astype(np.uint8)107x = ops.resize_images(x, self.resizer, self.totensor, self.mean, self.std, device=self.device)108109if self.eval_backbone in ["InceptionV3_tf", "DINO_torch", "Swin-T_torch"]:110repres, logits = self.model(x)111elif self.eval_backbone in ["InceptionV3_torch", "ResNet50_torch", "SwAV_torch"]:112logits = self.model(x)113if len(self.save_output.outputs) > 1:114repres = []115for rank in range(len(self.save_output.outputs)):116repres.append(self.save_output.outputs[rank][0].detach().cpu())117repres = torch.cat(repres, dim=0).to(self.device)118else:119repres = self.save_output.outputs[0][0].to(self.device)120self.save_output.clear()121return repres, logits122123124def prepare_moments(data_loader, eval_model, quantize, cfgs, logger, device):125disable_tqdm = device != 0126eval_model.eval()127moment_dir = join(cfgs.RUN.save_dir, "moments")128if not exists(moment_dir):129os.makedirs(moment_dir)130moment_path = join(moment_dir, cfgs.DATA.name + "_" + str(cfgs.DATA.img_size) + "_"+ cfgs.RUN.pre_resizer + "_" + \131cfgs.RUN.ref_dataset + "_" + cfgs.RUN.post_resizer + "_" + cfgs.RUN.eval_backbone + "_moments.npz")132133is_file = os.path.isfile(moment_path)134if is_file:135mu = np.load(moment_path)["mu"]136sigma = np.load(moment_path)["sigma"]137else:138if device == 0:139logger.info("Calculate moments of {ref} dataset using {eval_backbone} model.".\140format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))141mu, sigma = fid.calculate_moments(data_loader=data_loader,142eval_model=eval_model,143num_generate="N/A",144batch_size=cfgs.OPTIMIZATION.batch_size,145quantize=quantize,146world_size=cfgs.OPTIMIZATION.world_size,147DDP=cfgs.RUN.distributed_data_parallel,148disable_tqdm=disable_tqdm,149fake_feats=None)150151if device == 0:152logger.info("Save calculated means and covariances to disk.")153np.savez(moment_path, **{"mu": mu, "sigma": sigma})154return mu, sigma155156157def prepare_real_feats(data_loader, eval_model, num_feats, quantize, cfgs, logger, device):158disable_tqdm = device != 0159eval_model.eval()160feat_dir = join(cfgs.RUN.save_dir, "feats")161if not exists(feat_dir):162os.makedirs(feat_dir)163feat_path = join(feat_dir, cfgs.DATA.name + "_" + str(cfgs.DATA.img_size) + "_"+ cfgs.RUN.pre_resizer + "_" + \164cfgs.RUN.ref_dataset + "_" + cfgs.RUN.post_resizer + "_" + cfgs.RUN.eval_backbone + "_feats.npz")165166is_file = os.path.isfile(feat_path)167if is_file:168real_feats = np.load(feat_path)["real_feats"]169else:170if device == 0:171logger.info("Calculate features of {ref} dataset using {eval_backbone} model.".\172format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))173real_feats, real_probs, real_labels = features.stack_features(data_loader=data_loader,174eval_model=eval_model,175num_feats=num_feats,176batch_size=cfgs.OPTIMIZATION.batch_size,177quantize=quantize,178world_size=cfgs.OPTIMIZATION.world_size,179DDP=cfgs.RUN.distributed_data_parallel,180device=device,181disable_tqdm=disable_tqdm)182if device == 0:183logger.info("Save real_features to disk.")184np.savez(feat_path, **{"real_feats": real_feats,185"real_probs": real_probs,186"real_labels": real_labels})187return real_feats188189190def calculate_ins(data_loader, eval_model, quantize, splits, cfgs, logger, device):191disable_tqdm = device != 0192is_acc = True if "ImageNet" in cfgs.DATA.name and "Tiny" not in cfgs.DATA.name else False193if device == 0:194logger.info("Calculate inception score of the {ref} dataset uisng pre-trained {eval_backbone} model.".\195format(ref=cfgs.RUN.ref_dataset, eval_backbone=cfgs.RUN.eval_backbone))196is_score, is_std, top1, top5 = ins.eval_dataset(data_loader=data_loader,197eval_model=eval_model,198quantize=quantize,199splits=splits,200batch_size=cfgs.OPTIMIZATION.batch_size,201world_size=cfgs.OPTIMIZATION.world_size,202DDP=cfgs.RUN.distributed_data_parallel,203is_acc=is_acc,204is_torch_backbone=True if "torch" in cfgs.RUN.eval_backbone else False,205disable_tqdm=disable_tqdm)206if device == 0:207logger.info("Inception score={is_score}-Inception_std={is_std}".format(is_score=is_score, is_std=is_std))208if is_acc:209logger.info("{eval_model} Top1 acc: ({num} images): {Top1}".format(210eval_model=cfgs.RUN.eval_backbone, num=str(len(data_loader.dataset)), Top1=top1))211logger.info("{eval_model} Top5 acc: ({num} images): {Top5}".format(212eval_model=cfgs.RUN.eval_backbone, num=str(len(data_loader.dataset)), Top5=top5))213214215