Path: blob/master/src/utils/simclr_aug.py
809 views
"""1this code is borrowed from https://github.com/jh-jeong/ContraD with few modifications23MIT License45Copyright (c) 2021 Jongheon Jeong67Permission is hereby granted, free of charge, to any person obtaining a copy8of this software and associated documentation files (the "Software"), to deal9in the Software without restriction, including without limitation the rights10to use, copy, modify, merge, publish, distribute, sublicense, and/or sell11copies of the Software, and to permit persons to whom the Software is12furnished to do so, subject to the following conditions:1314The above copyright notice and this permission notice shall be included in all15copies or substantial portions of the Software.1617THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE23SOFTWARE.24"""2526import numpy as np27import math28import torch29import torch.nn as nn30from torch.nn.functional import affine_grid, grid_sample31from torch.nn import functional as F32from torch.autograd import Function33from kornia.filters import get_gaussian_kernel2d, filter2d34import numbers353637def rgb2hsv(rgb):38"""Convert a 4-d RGB tensor to the HSV counterpart.39Here, we compute hue using atan2() based on the definition in [1],40instead of using the common lookup table approach as in [2, 3].41Those values agree when the angle is a multiple of 30°,42otherwise they may differ at most ~1.2°.43>>> %timeit rgb2hsv_lookup(rgb)441.07 ms ± 2.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)45>>> %timeit rgb2hsv(rgb)46380 µs ± 555 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)47>>> (rgb2hsv_lookup(rgb) - rgb2hsv(rgb)).abs().max()48tensor(0.0031, device='cuda:0')49References50[1] https://en.wikipedia.org/wiki/Hue51[2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html52[3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L21253"""5455r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]5657Cmax = rgb.max(1)[0]58Cmin = rgb.min(1)[0]5960hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)61hue = (hue % (2 * math.pi)) / (2 * math.pi)62saturate = 1 - Cmin / (Cmax + 1e-8)63value = Cmax64hsv = torch.stack([hue, saturate, value], dim=1)65hsv[~torch.isfinite(hsv)] = 0.66return hsv676869def hsv2rgb(hsv):70"""Convert a 4-d HSV tensor to the RGB counterpart.71>>> %timeit hsv2rgb_lookup(hsv)722.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)73>>> %timeit hsv2rgb(rgb)74298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)75>>> torch.allclose(hsv2rgb(hsv), hsv2rgb_lookup(hsv), atol=1e-6)76True77References78[1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative79"""8081h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]82c = v * s8384n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)85k = (n + h * 6) % 686t = torch.min(k, 4. - k)87t = torch.clamp(t, 0, 1)88return v - c * t899091class RandomApply(nn.Module):92def __init__(self, fn, p):93super().__init__()94self.fn = fn95self.p = p9697def forward(self, inputs):98_prob = inputs.new_full((inputs.size(0), ), self.p)99_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)100return inputs * (1 - _mask) + self.fn(inputs) * _mask101102103class RandomResizeCropLayer(nn.Module):104def __init__(self, scale, ratio=(3. / 4., 4. / 3.)):105'''106Inception Crop107scale (tuple): range of size of the origin size cropped108ratio (tuple): range of aspect ratio of the origin aspect ratio cropped109'''110super(RandomResizeCropLayer, self).__init__()111112_eye = torch.eye(2, 3)113self.register_buffer('_eye', _eye)114self.scale = scale115self.ratio = ratio116117def forward(self, inputs):118_device = inputs.device119N, _, width, height = inputs.shape120121_theta = self._eye.repeat(N, 1, 1)122123# N * 10 trial124area = height * width125target_area = np.random.uniform(*self.scale, N * 10) * area126log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))127aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))128129# If doesn't satisfy ratio condition, then do central crop130w = np.round(np.sqrt(target_area * aspect_ratio))131h = np.round(np.sqrt(target_area / aspect_ratio))132cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)133w = w[cond]134h = h[cond]135if len(w) > N:136inds = np.random.choice(len(w), N, replace=False)137w = w[inds]138h = h[inds]139transform_len = len(w)140141r_w_bias = np.random.randint(w - width, width - w + 1) / width142r_h_bias = np.random.randint(h - height, height - h + 1) / height143w = w / width144h = h / height145146_theta[:transform_len, 0, 0] = torch.tensor(w, device=_device)147_theta[:transform_len, 1, 1] = torch.tensor(h, device=_device)148_theta[:transform_len, 0, 2] = torch.tensor(r_w_bias, device=_device)149_theta[:transform_len, 1, 2] = torch.tensor(r_h_bias, device=_device)150151grid = affine_grid(_theta, inputs.size(), align_corners=False)152output = grid_sample(inputs, grid, padding_mode='reflection', align_corners=False)153return output154155156class HorizontalFlipLayer(nn.Module):157def __init__(self):158"""159img_size : (int, int, int)160Height and width must be powers of 2. E.g. (32, 32, 1) or161(64, 128, 3). Last number indicates number of channels, e.g. 1 for162grayscale or 3 for RGB163"""164super(HorizontalFlipLayer, self).__init__()165166_eye = torch.eye(2, 3)167self.register_buffer('_eye', _eye)168169def forward(self, inputs):170_device = inputs.device171172N = inputs.size(0)173_theta = self._eye.repeat(N, 1, 1)174r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1175_theta[:, 0, 0] = r_sign176grid = affine_grid(_theta, inputs.size(), align_corners=False)177output = grid_sample(inputs, grid, padding_mode='reflection', align_corners=False)178return output179180181class RandomHSVFunction(Function):182@staticmethod183def forward(ctx, x, f_h, f_s, f_v):184# ctx is a context object that can be used to stash information185# for backward computation186x = rgb2hsv(x)187h = x[:, 0, :, :]188h += (f_h * 255. / 360.)189h = (h % 1)190x[:, 0, :, :] = h191x[:, 1, :, :] = x[:, 1, :, :] * f_s192x[:, 2, :, :] = x[:, 2, :, :] * f_v193x = torch.clamp(x, 0, 1)194x = hsv2rgb(x)195return x196197@staticmethod198def backward(ctx, grad_output):199# We return as many input gradients as there were arguments.200# Gradients of non-Tensor arguments to forward must be None.201grad_input = None202if ctx.needs_input_grad[0]:203grad_input = grad_output.clone()204return grad_input, None, None, None205206207class ColorJitterLayer(nn.Module):208def __init__(self, brightness, contrast, saturation, hue):209super(ColorJitterLayer, self).__init__()210self.brightness = self._check_input(brightness, 'brightness')211self.contrast = self._check_input(contrast, 'contrast')212self.saturation = self._check_input(saturation, 'saturation')213self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)214215def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):216if isinstance(value, numbers.Number):217if value < 0:218raise ValueError("If {} is a single number, it must be non negative.".format(name))219value = [center - value, center + value]220if clip_first_on_zero:221value[0] = max(value[0], 0)222elif isinstance(value, (tuple, list)) and len(value) == 2:223if not bound[0] <= value[0] <= value[1] <= bound[1]:224raise ValueError("{} values should be between {}".format(name, bound))225else:226raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))227228# if value is 0 or (1., 1.) for brightness/contrast/saturation229# or (0., 0.) for hue, do nothing230if value[0] == value[1] == center:231value = None232return value233234def adjust_contrast(self, x):235if self.contrast:236factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)237means = torch.mean(x, dim=[2, 3], keepdim=True)238x = (x - means) * factor + means239return torch.clamp(x, 0, 1)240241def adjust_hsv(self, x):242f_h = x.new_zeros(x.size(0), 1, 1)243f_s = x.new_ones(x.size(0), 1, 1)244f_v = x.new_ones(x.size(0), 1, 1)245246if self.hue:247f_h.uniform_(*self.hue)248if self.saturation:249f_s = f_s.uniform_(*self.saturation)250if self.brightness:251f_v = f_v.uniform_(*self.brightness)252253return RandomHSVFunction.apply(x, f_h, f_s, f_v)254255def transform(self, inputs):256# Shuffle transform257if np.random.rand() > 0.5:258transforms = [self.adjust_contrast, self.adjust_hsv]259else:260transforms = [self.adjust_hsv, self.adjust_contrast]261262for t in transforms:263inputs = t(inputs)264return inputs265266def forward(self, inputs):267return self.transform(inputs)268269270class RandomColorGrayLayer(nn.Module):271def __init__(self):272super(RandomColorGrayLayer, self).__init__()273_weight = torch.tensor([[0.299, 0.587, 0.114]])274self.register_buffer('_weight', _weight.view(1, 3, 1, 1))275276def forward(self, inputs):277l = F.conv2d(inputs, self._weight)278gray = torch.cat([l, l, l], dim=1)279return gray280281282class GaussianBlur(nn.Module):283def __init__(self, sigma_range):284"""Blurs the given image with separable convolution.285Args:286sigma_range: Range of sigma for being used in each gaussian kernel.287"""288super(GaussianBlur, self).__init__()289self.sigma_range = sigma_range290291def forward(self, inputs):292_device = inputs.device293294batch_size, num_channels, height, width = inputs.size()295296kernel_size = height // 10297radius = int(kernel_size / 2)298kernel_size = radius * 2 + 1299300sigma = np.random.uniform(*self.sigma_range)301kernel = torch.unsqueeze(get_gaussian_kernel2d((kernel_size, kernel_size), (sigma, sigma)), dim=0)302blurred = filter2d(inputs, kernel, "reflect")303return blurred304305306class CutOut(nn.Module):307def __init__(self, length):308super().__init__()309if length % 2 == 0:310raise ValueError("Currently CutOut only accepts odd lengths: length % 2 == 1")311self.length = length312313_weight = torch.ones(1, 1, self.length)314self.register_buffer('_weight', _weight)315self._padding = (length - 1) // 2316317def forward(self, inputs):318_device = inputs.device319N, _, h, w = inputs.shape320321mask_h = inputs.new_zeros(N, h)322mask_w = inputs.new_zeros(N, w)323324h_center = torch.randint(h, (N, 1), device=_device)325w_center = torch.randint(w, (N, 1), device=_device)326327mask_h.scatter_(1, h_center, 1).unsqueeze_(1)328mask_w.scatter_(1, w_center, 1).unsqueeze_(1)329330mask_h = F.conv1d(mask_h, self._weight, padding=self._padding)331mask_w = F.conv1d(mask_w, self._weight, padding=self._padding)332333mask = 1. - torch.einsum('bci,bcj->bcij', mask_h, mask_w)334outputs = inputs * mask335return outputs336337338class SimclrAugment(nn.Module):339def __init__(self, aug_type):340super().__init__()341if aug_type == "simclr_basic":342self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),343RandomApply(ColorJitterLayer(ColorJitterLayer(0.4, 0.4, 0.4, 0.1)), p=0.8),344RandomApply(RandomColorGrayLayer(), p=0.2))345elif aug_type == "simclr_hq":346self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),347RandomApply(ColorJitterLayer(0.4, 0.4, 0.4, 0.1), p=0.8),348RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5))349elif aug_type == "simclr_hq_cutout":350self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),351RandomApply(ColorJitterLayer(0.4, 0.4, 0.4, 0.1), p=0.8),352RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5),353RandomApply(CutOut(15), p=0.5))354elif aug_type == "byol":355self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),356RandomApply(ColorJitterLayer(0.4, 0.4, 0.2, 0.1), p=0.8),357RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5))358359def forward(self, images):360return self.pipeline(images)361362363