Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/training/Dataset.py
1285 views
1
from torch.utils.data import TensorDataset
2
import torchvision.transforms as transforms
3
from PIL import Image
4
import glob
5
import pickle
6
import random
7
import os
8
import cv2
9
import tqdm
10
import sys
11
sys.path.append('..')
12
# from utils.cap_aug import CAP_AUG
13
14
15
class FaceEmbed(TensorDataset):
16
def __init__(self, data_path_list, same_prob=0.8):
17
datasets = []
18
# embeds = []
19
self.N = []
20
self.same_prob = same_prob
21
for data_path in data_path_list:
22
image_list = glob.glob(f'{data_path}/*.*g')
23
datasets.append(image_list)
24
self.N.append(len(image_list))
25
# with open(f'{data_path}/embed.pkl', 'rb') as f:
26
# embed = pickle.load(f)
27
# embeds.append(embed)
28
self.datasets = datasets
29
# self.embeds = embeds
30
self.transforms_arcface = transforms.Compose([
31
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
32
transforms.Resize((224, 224)),
33
transforms.ToTensor(),
34
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35
])
36
37
self.transforms_base = transforms.Compose([
38
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
39
transforms.Resize((256, 256)),
40
transforms.ToTensor(),
41
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
42
])
43
44
def __getitem__(self, item):
45
idx = 0
46
while item >= self.N[idx]:
47
item -= self.N[idx]
48
idx += 1
49
50
image_path = self.datasets[idx][item]
51
# name = os.path.split(image_path)[1]
52
# embed = self.embeds[idx][name]
53
Xs = cv2.imread(image_path)[:, :, ::-1]
54
Xs = Image.fromarray(Xs)
55
56
if random.random() > self.same_prob:
57
image_path = random.choice(self.datasets[random.randint(0, len(self.datasets)-1)])
58
Xt = cv2.imread(image_path)[:, :, ::-1]
59
Xt = Image.fromarray(Xt)
60
same_person = 0
61
else:
62
Xt = Xs.copy()
63
same_person = 1
64
65
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person
66
67
def __len__(self):
68
return sum(self.N)
69
70
71
class FaceEmbedVGG2(TensorDataset):
72
def __init__(self, data_path, same_prob=0.8, same_identity=False):
73
74
self.same_prob = same_prob
75
self.same_identity = same_identity
76
77
self.images_list = glob.glob(f'{data_path}/*/*.*g')
78
self.folders_list = glob.glob(f'{data_path}/*')
79
80
self.folder2imgs = {}
81
82
for folder in tqdm.tqdm(self.folders_list):
83
folder_imgs = glob.glob(f'{folder}/*')
84
self.folder2imgs[folder] = folder_imgs
85
86
self.N = len(self.images_list)
87
88
self.transforms_arcface = transforms.Compose([
89
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
90
transforms.Resize((224, 224)),
91
transforms.ToTensor(),
92
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
93
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
94
])
95
96
self.transforms_base = transforms.Compose([
97
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
98
transforms.Resize((256, 256)),
99
transforms.ToTensor(),
100
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
101
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
102
])
103
104
def __getitem__(self, item):
105
106
image_path = self.images_list[item]
107
108
Xs = cv2.imread(image_path)[:, :, ::-1]
109
Xs = Image.fromarray(Xs)
110
111
if self.same_identity:
112
folder_name = '/'.join(image_path.split('/')[:-1])
113
114
if random.random() > self.same_prob:
115
image_path = random.choice(self.images_list)
116
Xt = cv2.imread(image_path)[:, :, ::-1]
117
Xt = Image.fromarray(Xt)
118
same_person = 0
119
else:
120
if self.same_identity:
121
image_path = random.choice(self.folder2imgs[folder_name])
122
Xt = cv2.imread(image_path)[:, :, ::-1]
123
Xt = Image.fromarray(Xt)
124
else:
125
Xt = Xs.copy()
126
same_person = 1
127
128
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person
129
130
def __len__(self):
131
return self.N
132