Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/training/detector.py
1285 views
1
import torch
2
import numpy as np
3
import cv2
4
from PIL import Image
5
import torchvision.transforms as transforms
6
from AdaptiveWingLoss.utils.utils import get_preds_fromhm
7
from .image_processing import torch2image
8
9
10
transforms_base = transforms.Compose([
11
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
12
transforms.Resize((256, 256)),
13
transforms.ToTensor(),
14
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
15
])
16
17
18
def detect_landmarks(inputs, model_ft):
19
mean = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(inputs.device)
20
std = torch.tensor([0.5, 0.5, 0.5]).unsqueeze(1).unsqueeze(2).to(inputs.device)
21
inputs = (std * inputs) + mean
22
23
outputs, boundary_channels = model_ft(inputs)
24
pred_heatmap = outputs[-1][:, :-1, :, :].cpu()
25
pred_landmarks, _ = get_preds_fromhm(pred_heatmap)
26
landmarks = pred_landmarks*4.0
27
eyes = torch.cat((landmarks[:,96,:], landmarks[:,97,:]), 1)
28
return eyes, pred_heatmap[:,96,:,:], pred_heatmap[:,97,:,:]
29
30
31
def paint_eyes(images, eyes):
32
list_eyes = []
33
for i in range(len(images)):
34
mask = torch2image(images[i])
35
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
36
37
cv2.circle(mask, (int(eyes[i][0]),int(eyes[i][1])), radius=3, color=(0,255,255), thickness=-1)
38
cv2.circle(mask, (int(eyes[i][2]),int(eyes[i][3])), radius=3, color=(0,255,255), thickness=-1)
39
40
mask = mask[:, :, ::-1]
41
mask = transforms_base(Image.fromarray(mask))
42
list_eyes.append(mask)
43
tensor_eyes = torch.stack(list_eyes)
44
return tensor_eyes
45