Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Face-Recognition-with-ArcFace/embeddings.py
3118 views
1
# Original code
2
# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/util/extract_feature_v1.py
3
4
import os
5
6
import cv2
7
import numpy as np
8
import torch
9
import torch.utils.data as data
10
import torchvision.datasets as datasets
11
import torch.nn.functional as F
12
import torchvision.transforms as transforms
13
from backbone import Backbone
14
from tqdm import tqdm
15
16
17
def get_embeddings(data_root, model_root, input_size=[112, 112], embedding_size=512):
18
19
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
21
# check data and model paths
22
assert os.path.exists(data_root)
23
assert os.path.exists(model_root)
24
print(f"Data root: {data_root}")
25
26
# define image preprocessing
27
transform = transforms.Compose(
28
[
29
transforms.Resize(
30
[int(128 * input_size[0] / 112), int(128 * input_size[0] / 112)],
31
), # smaller side resized
32
transforms.CenterCrop([input_size[0], input_size[1]]),
33
transforms.ToTensor(),
34
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
35
],
36
)
37
38
# define data loader
39
dataset = datasets.ImageFolder(data_root, transform)
40
loader = data.DataLoader(
41
dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0,
42
)
43
print(f"Number of classes: {len(loader.dataset.classes)}")
44
45
# load backbone weigths from a checkpoint
46
backbone = Backbone(input_size)
47
backbone.load_state_dict(torch.load(model_root, map_location=torch.device("cpu")))
48
backbone.to(device)
49
backbone.eval()
50
51
# get embedding for each face
52
embeddings = np.zeros([len(loader.dataset), embedding_size])
53
with torch.no_grad():
54
for idx, (image, _) in enumerate(
55
tqdm(loader, desc="Create embeddings matrix", total=len(loader)),
56
):
57
embeddings[idx, :] = F.normalize(backbone(image.to(device))).cpu()
58
59
# get all original images
60
images = []
61
for img_path, _ in dataset.samples:
62
img = cv2.imread(img_path)
63
images.append(img)
64
65
return images, embeddings
66
67