Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/training/image_processing.py
1286 views
1
import cv2
2
import numpy as np
3
from PIL import Image
4
5
import torch
6
import torchvision.transforms as transforms
7
import torch.nn.functional as F
8
9
10
transformer_Arcface = transforms.Compose([
11
transforms.ToTensor(),
12
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
13
])
14
15
16
def torch2image(torch_image: torch.tensor) -> np.ndarray:
17
batch = False
18
19
if torch_image.dim() == 4:
20
torch_image = torch_image[:8]
21
batch = True
22
23
device = torch_image.device
24
# mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2)
25
# std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2)
26
mean = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)
27
std = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(device)
28
29
denorm_image = (std * torch_image) + mean
30
31
if batch:
32
denorm_image = denorm_image.permute(0, 2, 3, 1)
33
else:
34
denorm_image = denorm_image.permute(1, 2, 0)
35
36
np_image = denorm_image.detach().cpu().numpy()
37
np_image = np.clip(np_image*255., 0, 255).astype(np.uint8)
38
39
if batch:
40
return np.concatenate(np_image, axis=1)
41
else:
42
return np_image
43
44
45
def make_image_list(images) -> np.ndarray:
46
np_images = []
47
48
for torch_image in images:
49
np_img = torch2image(torch_image)
50
np_images.append(np_img)
51
52
return np.concatenate(np_images, axis=0)
53
54
55
def read_torch_image(path: str) -> torch.tensor:
56
57
image = cv2.imread(path)
58
image = cv2.resize(image, (256, 256))
59
image = Image.fromarray(image[:, :, ::-1])
60
image = transformer_Arcface(image)
61
image = image.view(-1, image.shape[0], image.shape[1], image.shape[2])
62
63
return image
64
65
66
def get_faceswap(source_path: str, target_path: str,
67
G: 'generator model', netArc: 'arcface model',
68
device: 'torch device') -> np.array:
69
source = read_torch_image(source_path)
70
source = source.to(device)
71
72
embeds = netArc(F.interpolate(source, [112, 112], mode='bilinear', align_corners=False))
73
# embeds = F.normalize(embeds, p=2, dim=1)
74
75
target = read_torch_image(target_path)
76
target = target.cuda()
77
78
with torch.no_grad():
79
Yt, _ = G(target, embeds)
80
Yt = torch2image(Yt)
81
82
source = torch2image(source)
83
target = torch2image(target)
84
85
return np.concatenate((cv2.resize(source, (256, 256)), target, Yt), axis=1)
86
87