Path: blob/master/src/metrics/features.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# src/metrics/generate.py56import math78from tqdm import tqdm9import torch10import numpy as np1112import utils.sample as sample13import utils.losses as losses141516def generate_images_and_stack_features(generator, discriminator, eval_model, num_generate, y_sampler, batch_size, z_prior,17truncation_factor, z_dim, num_classes, LOSS, RUN, MODEL, is_stylegan, generator_mapping,18generator_synthesis, quantize, world_size, DDP, device, logger, disable_tqdm):19eval_model.eval()20feature_holder, prob_holder, fake_label_holder = [], [], []2122if device == 0 and not disable_tqdm:23logger.info("generate images and stack features ({} images).".format(num_generate))24num_batches = int(math.ceil(float(num_generate) / float(batch_size)))25if DDP: num_batches = num_batches//world_size + 126for i in tqdm(range(num_batches), disable=disable_tqdm):27fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=z_prior,28truncation_factor=truncation_factor,29batch_size=batch_size,30z_dim=z_dim,31num_classes=num_classes,32y_sampler=y_sampler,33radius="N/A",34generator=generator,35discriminator=discriminator,36is_train=False,37LOSS=LOSS,38RUN=RUN,39MODEL=MODEL,40is_stylegan=is_stylegan,41generator_mapping=generator_mapping,42generator_synthesis=generator_synthesis,43style_mixing_p=0.0,44device=device,45stylegan_update_emas=False,46cal_trsp_cost=False)4748with torch.no_grad():49features, logits = eval_model.get_outputs(fake_images, quantize=quantize)50probs = torch.nn.functional.softmax(logits, dim=1)5152feature_holder.append(features)53prob_holder.append(probs)54fake_label_holder.append(fake_labels)5556feature_holder = torch.cat(feature_holder, 0)57prob_holder = torch.cat(prob_holder, 0)58fake_label_holder = torch.cat(fake_label_holder, 0)5960if DDP:61feature_holder = torch.cat(losses.GatherLayer.apply(feature_holder), dim=0)62prob_holder = torch.cat(losses.GatherLayer.apply(prob_holder), dim=0)63fake_label_holder = torch.cat(losses.GatherLayer.apply(fake_label_holder), dim=0)64return feature_holder, prob_holder, list(fake_label_holder.detach().cpu().numpy())656667def sample_images_from_loader_and_stack_features(dataloader, eval_model, batch_size, quantize,68world_size, DDP, device, disable_tqdm):69eval_model.eval()70total_instance = len(dataloader.dataset)71num_batches = math.ceil(float(total_instance) / float(batch_size))72if DDP: num_batches = int(math.ceil(float(total_instance) / float(batch_size*world_size)))73data_iter = iter(dataloader)7475if device == 0 and not disable_tqdm:76print("Sample images and stack features ({} images).".format(total_instance))7778feature_holder, prob_holder, label_holder = [], [], []79for i in tqdm(range(0, num_batches), disable=disable_tqdm):80try:81images, labels = next(data_iter)82except StopIteration:83break8485images, labels = images.to(device), labels.to(device)8687with torch.no_grad():88features, logits = eval_model.get_outputs(images, quantize=quantize)89probs = torch.nn.functional.softmax(logits, dim=1)9091feature_holder.append(features)92prob_holder.append(probs)93label_holder.append(labels.to("cuda"))9495feature_holder = torch.cat(feature_holder, 0)96prob_holder = torch.cat(prob_holder, 0)97label_holder = torch.cat(label_holder, 0)9899if DDP:100feature_holder = torch.cat(losses.GatherLayer.apply(feature_holder), dim=0)101prob_holder = torch.cat(losses.GatherLayer.apply(prob_holder), dim=0)102label_holder = torch.cat(losses.GatherLayer.apply(label_holder), dim=0)103return feature_holder, prob_holder, list(label_holder.detach().cpu().numpy())104105106def stack_features(data_loader, eval_model, num_feats, batch_size, quantize, world_size, DDP, device, disable_tqdm):107eval_model.eval()108data_iter = iter(data_loader)109num_batches = math.ceil(float(num_feats) / float(batch_size))110if DDP: num_batches = num_batches//world_size + 1111112real_feats, real_probs, real_labels = [], [], []113for i in tqdm(range(0, num_batches), disable=disable_tqdm):114start = i * batch_size115end = start + batch_size116try:117images, labels = next(data_iter)118except StopIteration:119break120121images, labels = images.to(device), labels.to(device)122123with torch.no_grad():124embeddings, logits = eval_model.get_outputs(images, quantize=quantize)125probs = torch.nn.functional.softmax(logits, dim=1)126real_feats.append(embeddings)127real_probs.append(probs)128real_labels.append(labels)129130real_feats = torch.cat(real_feats, dim=0)131real_probs = torch.cat(real_probs, dim=0)132real_labels = torch.cat(real_labels, dim=0)133if DDP:134real_feats = torch.cat(losses.GatherLayer.apply(real_feats), dim=0)135real_probs = torch.cat(losses.GatherLayer.apply(real_probs), dim=0)136real_labels = torch.cat(losses.GatherLayer.apply(real_labels), dim=0)137138real_feats = real_feats.detach().cpu().numpy().astype(np.float64)139real_probs = real_probs.detach().cpu().numpy().astype(np.float64)140real_labels = real_labels.detach().cpu().numpy()141return real_feats, real_probs, real_labels142143144