Path: blob/main/utils/training/image_processing.py
1286 views
import cv21import numpy as np2from PIL import Image34import torch5import torchvision.transforms as transforms6import torch.nn.functional as F789transformer_Arcface = transforms.Compose([10transforms.ToTensor(),11transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])12])131415def torch2image(torch_image: torch.tensor) -> np.ndarray:16batch = False1718if torch_image.dim() == 4:19torch_image = torch_image[:8]20batch = True2122device = torch_image.device23# mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2)24# std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2)25mean = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)26std = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)2728denorm_image = (std * torch_image) + mean2930if batch:31denorm_image = denorm_image.permute(0, 2, 3, 1)32else:33denorm_image = denorm_image.permute(1, 2, 0)3435np_image = denorm_image.detach().cpu().numpy()36np_image = np.clip(np_image*255., 0, 255).astype(np.uint8)3738if batch:39return np.concatenate(np_image, axis=1)40else:41return np_image424344def make_image_list(images) -> np.ndarray:45np_images = []4647for torch_image in images:48np_img = torch2image(torch_image)49np_images.append(np_img)5051return np.concatenate(np_images, axis=0)525354def read_torch_image(path: str) -> torch.tensor:5556image = cv2.imread(path)57image = cv2.resize(image, (256, 256))58image = Image.fromarray(image[:, :, ::-1])59image = transformer_Arcface(image)60image = image.view(-1, image.shape[0], image.shape[1], image.shape[2])6162return image636465def get_faceswap(source_path: str, target_path: str,66G: 'generator model', netArc: 'arcface model',67device: 'torch device') -> np.array:68source = read_torch_image(source_path)69source = source.to(device)7071embeds = netArc(F.interpolate(source, [112, 112], mode='bilinear', align_corners=False))72# embeds = F.normalize(embeds, p=2, dim=1)7374target = read_torch_image(target_path)75target = target.cuda()7677with torch.no_grad():78Yt, _ = G(target, embeds)79Yt = torch2image(Yt)8081source = torch2image(source)82target = torch2image(target)8384return np.concatenate((cv2.resize(source, (256, 256)), target, Yt), axis=1)858687