Path: blob/master/FaceMaskOverlay/lib/datasets/cofw.py
3443 views
# ------------------------------------------------------------------------------1# Copyright (c) Microsoft2# Licensed under the MIT License.3# Created by Tianheng Cheng([email protected]), Yang Zhao4# ------------------------------------------------------------------------------56import math7import random89import torch10import torch.utils.data as data11import numpy as np1213from hdf5storage import loadmat14from ..utils.transforms import fliplr_joints, crop, generate_target, transform_pixel151617class COFW(data.Dataset):1819def __init__(self, cfg, is_train=True, transform=None):20# specify annotation file for dataset21if is_train:22self.mat_file = cfg.DATASET.TRAINSET23else:24self.mat_file = cfg.DATASET.TESTSET2526self.is_train = is_train27self.transform = transform28self.data_root = cfg.DATASET.ROOT29self.input_size = cfg.MODEL.IMAGE_SIZE30self.output_size = cfg.MODEL.HEATMAP_SIZE31self.sigma = cfg.MODEL.SIGMA32self.scale_factor = cfg.DATASET.SCALE_FACTOR33self.rot_factor = cfg.DATASET.ROT_FACTOR34self.label_type = cfg.MODEL.TARGET_TYPE35self.flip = cfg.DATASET.FLIP3637# load annotations38self.mat = loadmat(self.mat_file)39if is_train:40self.images = self.mat['IsTr']41self.pts = self.mat['phisTr']42else:43self.images = self.mat['IsT']44self.pts = self.mat['phisT']4546self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)47self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)4849def __len__(self):50return len(self.images)5152def __getitem__(self, idx):5354img = self.images[idx][0]5556if len(img.shape) == 2:57img = img.reshape(img.shape[0], img.shape[1], 1)58img = np.repeat(img, 3, axis=2)5960pts = self.pts[idx][0:58].reshape(2, -1).transpose()6162xmin = np.min(pts[:, 0])63xmax = np.max(pts[:, 0])64ymin = np.min(pts[:, 1])65ymax = np.max(pts[:, 1])6667center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.068center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.06970scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.071center = torch.Tensor([center_w, center_h])7273scale *= 1.2574nparts = pts.shape[0]7576r = 077if self.is_train:78scale = scale * (random.uniform(1 - self.scale_factor,791 + self.scale_factor))80r = random.uniform(-self.rot_factor, self.rot_factor) \81if random.random() <= 0.6 else 08283if random.random() <= 0.5 and self.flip:84img = np.fliplr(img)85pts = fliplr_joints(pts, width=img.shape[1], dataset='COFW')86center[0] = img.shape[1] - center[0]8788img = crop(img, center, scale, self.input_size, rot=r)8990target = np.zeros((nparts, self.output_size[0], self.output_size[1]))91tpts = pts.copy()9293for i in range(nparts):94if tpts[i, 1] > 0:95tpts[i, 0:2] = transform_pixel(tpts[i, 0:2]+1, center,96scale, self.output_size, rot=r)97target[i] = generate_target(target[i], tpts[i]-1, self.sigma,98label_type=self.label_type)99img = img.astype(np.float32)100img = (img/255 - self.mean) / self.std101img = img.transpose([2, 0, 1])102target = torch.Tensor(target)103tpts = torch.Tensor(tpts)104center = torch.Tensor(center)105106meta = {'index': idx, 'center': center, 'scale': scale,107'pts': torch.Tensor(pts), 'tpts': tpts}108109return img, target, meta110111112if __name__ == '__main__':113114pass115116117