Path: blob/master/FaceMaskOverlay/lib/datasets/aflw.py
3443 views
# ------------------------------------------------------------------------------1# Copyright (c) Microsoft2# Licensed under the MIT License.3# Created by Tianheng Cheng([email protected]), Yang Zhao4# ------------------------------------------------------------------------------56import os7import random89import torch10import torch.utils.data as data11import pandas as pd12from PIL import Image, ImageFile13import numpy as np1415from ..utils.transforms import fliplr_joints, crop, generate_target, transform_pixel1617ImageFile.LOAD_TRUNCATED_IMAGES = True181920class AFLW(data.Dataset):21"""AFLW22"""23def __init__(self, cfg, is_train=True, transform=None):24# specify annotation file for dataset25if is_train:26self.csv_file = cfg.DATASET.TRAINSET27else:28self.csv_file = cfg.DATASET.TESTSET2930self.is_train = is_train31self.transform = transform32self.data_root = cfg.DATASET.ROOT33self.input_size = cfg.MODEL.IMAGE_SIZE34self.output_size = cfg.MODEL.HEATMAP_SIZE35self.sigma = cfg.MODEL.SIGMA36self.scale_factor = cfg.DATASET.SCALE_FACTOR37self.rot_factor = cfg.DATASET.ROT_FACTOR38self.label_type = cfg.MODEL.TARGET_TYPE39self.flip = cfg.DATASET.FLIP40self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)41self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)42# load annotations43self.landmarks_frame = pd.read_csv(self.csv_file)4445def __len__(self):46return len(self.landmarks_frame)4748def __getitem__(self, idx):4950image_path = os.path.join(self.data_root,51self.landmarks_frame.iloc[idx, 0])52scale = self.landmarks_frame.iloc[idx, 1]53box_size = self.landmarks_frame.iloc[idx, 2]5455center_w = self.landmarks_frame.iloc[idx, 3]56center_h = self.landmarks_frame.iloc[idx, 4]57center = torch.Tensor([center_w, center_h])5859pts = self.landmarks_frame.iloc[idx, 5:].values60pts = pts.astype('float').reshape(-1, 2)6162scale *= 1.2563nparts = pts.shape[0]64img = np.array(Image.open(image_path).convert('RGB'), dtype=np.float32)6566r = 067if self.is_train:68scale = scale * (random.uniform(1 - self.scale_factor,691 + self.scale_factor))70r = random.uniform(-self.rot_factor, self.rot_factor) \71if random.random() <= 0.6 else 072if random.random() <= 0.5 and self.flip:73img = np.fliplr(img)74pts = fliplr_joints(pts, width=img.shape[1], dataset='AFLW')75center[0] = img.shape[1] - center[0]7677img = crop(img, center, scale, self.input_size, rot=r)7879target = np.zeros((nparts, self.output_size[0], self.output_size[1]))80tpts = pts.copy()8182for i in range(nparts):83if tpts[i, 1] > 0:84tpts[i, 0:2] = transform_pixel(tpts[i, 0:2]+1, center,85scale, self.output_size, rot=r)86target[i] = generate_target(target[i], tpts[i]-1, self.sigma,87label_type=self.label_type)88img = img.astype(np.float32)89img = (img/255.0 - self.mean) / self.std90img = img.transpose([2, 0, 1])91target = torch.Tensor(target)92tpts = torch.Tensor(tpts)93center = torch.Tensor(center)9495meta = {'index': idx, 'center': center, 'scale': scale,96'pts': torch.Tensor(pts), 'tpts': tpts, 'box_size': box_size}9798return img, target, meta99100101if __name__ == '__main__':102103pass104105106