from torch.utils.data import TensorDataset
import torchvision.transforms as transforms
from PIL import Image
import glob
import pickle
import random
import os
import cv2
import tqdm
import sys
sys.path.append('..')
class FaceEmbed(TensorDataset):
def __init__(self, data_path_list, same_prob=0.8):
datasets = []
self.N = []
self.same_prob = same_prob
for data_path in data_path_list:
image_list = glob.glob(f'{data_path}/*.*g')
datasets.append(image_list)
self.N.append(len(image_list))
self.datasets = datasets
self.transforms_arcface = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_base = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __getitem__(self, item):
idx = 0
while item >= self.N[idx]:
item -= self.N[idx]
idx += 1
image_path = self.datasets[idx][item]
Xs = cv2.imread(image_path)[:, :, ::-1]
Xs = Image.fromarray(Xs)
if random.random() > self.same_prob:
image_path = random.choice(self.datasets[random.randint(0, len(self.datasets)-1)])
Xt = cv2.imread(image_path)[:, :, ::-1]
Xt = Image.fromarray(Xt)
same_person = 0
else:
Xt = Xs.copy()
same_person = 1
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person
def __len__(self):
return sum(self.N)
class FaceEmbedVGG2(TensorDataset):
def __init__(self, data_path, same_prob=0.8, same_identity=False):
self.same_prob = same_prob
self.same_identity = same_identity
self.images_list = glob.glob(f'{data_path}/*/*.*g')
self.folders_list = glob.glob(f'{data_path}/*')
self.folder2imgs = {}
for folder in tqdm.tqdm(self.folders_list):
folder_imgs = glob.glob(f'{folder}/*')
self.folder2imgs[folder] = folder_imgs
self.N = len(self.images_list)
self.transforms_arcface = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_base = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __getitem__(self, item):
image_path = self.images_list[item]
Xs = cv2.imread(image_path)[:, :, ::-1]
Xs = Image.fromarray(Xs)
if self.same_identity:
folder_name = '/'.join(image_path.split('/')[:-1])
if random.random() > self.same_prob:
image_path = random.choice(self.images_list)
Xt = cv2.imread(image_path)[:, :, ::-1]
Xt = Image.fromarray(Xt)
same_person = 0
else:
if self.same_identity:
image_path = random.choice(self.folder2imgs[folder_name])
Xt = cv2.imread(image_path)[:, :, ::-1]
Xt = Image.fromarray(Xt)
else:
Xt = Xs.copy()
same_person = 1
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person
def __len__(self):
return self.N