Path: blob/master/FBAMatting/networks/transforms.py
3119 views
import cv21import numpy as np2import torch345def dt(a):6return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)789def trimap_transform(trimap):10h, w = trimap.shape[0], trimap.shape[1]1112clicks = np.zeros((h, w, 6))13for k in range(2):14if np.count_nonzero(trimap[:, :, k]) > 0:15dt_mask = -dt(1 - trimap[:, :, k]) ** 216L = 32017clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))18clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))19clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))2021return clicks222324# For RGB !25group_norm_std = [0.229, 0.224, 0.225]26group_norm_mean = [0.485, 0.456, 0.406]272829def groupnorm_normalise_image(img, format="nhwc"):30"""31Accept rgb in range 0,132"""33if format == "nhwc":34for i in range(3):35img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]36else:37for i in range(3):38img[..., i, :, :] = (39img[..., i, :, :] - group_norm_mean[i]40) / group_norm_std[i]4142return img434445def groupnorm_denormalise_image(img, format="nhwc"):46"""47Accept rgb, normalised, return in range 0,148"""49if format == "nhwc":50for i in range(3):51img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i]52else:53img1 = torch.zeros_like(img).cuda()54for i in range(3):55img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i]56return img157return img585960