Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/simclr_aug.py
809 views
1
"""
2
this code is borrowed from https://github.com/jh-jeong/ContraD with few modifications
3
4
MIT License
5
6
Copyright (c) 2021 Jongheon Jeong
7
8
Permission is hereby granted, free of charge, to any person obtaining a copy
9
of this software and associated documentation files (the "Software"), to deal
10
in the Software without restriction, including without limitation the rights
11
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
copies of the Software, and to permit persons to whom the Software is
13
furnished to do so, subject to the following conditions:
14
15
The above copyright notice and this permission notice shall be included in all
16
copies or substantial portions of the Software.
17
18
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
SOFTWARE.
25
"""
26
27
import numpy as np
28
import math
29
import torch
30
import torch.nn as nn
31
from torch.nn.functional import affine_grid, grid_sample
32
from torch.nn import functional as F
33
from torch.autograd import Function
34
from kornia.filters import get_gaussian_kernel2d, filter2d
35
import numbers
36
37
38
def rgb2hsv(rgb):
39
"""Convert a 4-d RGB tensor to the HSV counterpart.
40
Here, we compute hue using atan2() based on the definition in [1],
41
instead of using the common lookup table approach as in [2, 3].
42
Those values agree when the angle is a multiple of 30°,
43
otherwise they may differ at most ~1.2°.
44
>>> %timeit rgb2hsv_lookup(rgb)
45
1.07 ms ± 2.96 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
46
>>> %timeit rgb2hsv(rgb)
47
380 µs ± 555 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
48
>>> (rgb2hsv_lookup(rgb) - rgb2hsv(rgb)).abs().max()
49
tensor(0.0031, device='cuda:0')
50
References
51
[1] https://en.wikipedia.org/wiki/Hue
52
[2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html
53
[3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212
54
"""
55
56
r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
57
58
Cmax = rgb.max(1)[0]
59
Cmin = rgb.min(1)[0]
60
61
hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
62
hue = (hue % (2 * math.pi)) / (2 * math.pi)
63
saturate = 1 - Cmin / (Cmax + 1e-8)
64
value = Cmax
65
hsv = torch.stack([hue, saturate, value], dim=1)
66
hsv[~torch.isfinite(hsv)] = 0.
67
return hsv
68
69
70
def hsv2rgb(hsv):
71
"""Convert a 4-d HSV tensor to the RGB counterpart.
72
>>> %timeit hsv2rgb_lookup(hsv)
73
2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
74
>>> %timeit hsv2rgb(rgb)
75
298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
76
>>> torch.allclose(hsv2rgb(hsv), hsv2rgb_lookup(hsv), atol=1e-6)
77
True
78
References
79
[1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative
80
"""
81
82
h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
83
c = v * s
84
85
n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
86
k = (n + h * 6) % 6
87
t = torch.min(k, 4. - k)
88
t = torch.clamp(t, 0, 1)
89
return v - c * t
90
91
92
class RandomApply(nn.Module):
93
def __init__(self, fn, p):
94
super().__init__()
95
self.fn = fn
96
self.p = p
97
98
def forward(self, inputs):
99
_prob = inputs.new_full((inputs.size(0), ), self.p)
100
_mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
101
return inputs * (1 - _mask) + self.fn(inputs) * _mask
102
103
104
class RandomResizeCropLayer(nn.Module):
105
def __init__(self, scale, ratio=(3. / 4., 4. / 3.)):
106
'''
107
Inception Crop
108
scale (tuple): range of size of the origin size cropped
109
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
110
'''
111
super(RandomResizeCropLayer, self).__init__()
112
113
_eye = torch.eye(2, 3)
114
self.register_buffer('_eye', _eye)
115
self.scale = scale
116
self.ratio = ratio
117
118
def forward(self, inputs):
119
_device = inputs.device
120
N, _, width, height = inputs.shape
121
122
_theta = self._eye.repeat(N, 1, 1)
123
124
# N * 10 trial
125
area = height * width
126
target_area = np.random.uniform(*self.scale, N * 10) * area
127
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
128
aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10))
129
130
# If doesn't satisfy ratio condition, then do central crop
131
w = np.round(np.sqrt(target_area * aspect_ratio))
132
h = np.round(np.sqrt(target_area / aspect_ratio))
133
cond = (0 < w) * (w <= width) * (0 < h) * (h <= height)
134
w = w[cond]
135
h = h[cond]
136
if len(w) > N:
137
inds = np.random.choice(len(w), N, replace=False)
138
w = w[inds]
139
h = h[inds]
140
transform_len = len(w)
141
142
r_w_bias = np.random.randint(w - width, width - w + 1) / width
143
r_h_bias = np.random.randint(h - height, height - h + 1) / height
144
w = w / width
145
h = h / height
146
147
_theta[:transform_len, 0, 0] = torch.tensor(w, device=_device)
148
_theta[:transform_len, 1, 1] = torch.tensor(h, device=_device)
149
_theta[:transform_len, 0, 2] = torch.tensor(r_w_bias, device=_device)
150
_theta[:transform_len, 1, 2] = torch.tensor(r_h_bias, device=_device)
151
152
grid = affine_grid(_theta, inputs.size(), align_corners=False)
153
output = grid_sample(inputs, grid, padding_mode='reflection', align_corners=False)
154
return output
155
156
157
class HorizontalFlipLayer(nn.Module):
158
def __init__(self):
159
"""
160
img_size : (int, int, int)
161
Height and width must be powers of 2. E.g. (32, 32, 1) or
162
(64, 128, 3). Last number indicates number of channels, e.g. 1 for
163
grayscale or 3 for RGB
164
"""
165
super(HorizontalFlipLayer, self).__init__()
166
167
_eye = torch.eye(2, 3)
168
self.register_buffer('_eye', _eye)
169
170
def forward(self, inputs):
171
_device = inputs.device
172
173
N = inputs.size(0)
174
_theta = self._eye.repeat(N, 1, 1)
175
r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1
176
_theta[:, 0, 0] = r_sign
177
grid = affine_grid(_theta, inputs.size(), align_corners=False)
178
output = grid_sample(inputs, grid, padding_mode='reflection', align_corners=False)
179
return output
180
181
182
class RandomHSVFunction(Function):
183
@staticmethod
184
def forward(ctx, x, f_h, f_s, f_v):
185
# ctx is a context object that can be used to stash information
186
# for backward computation
187
x = rgb2hsv(x)
188
h = x[:, 0, :, :]
189
h += (f_h * 255. / 360.)
190
h = (h % 1)
191
x[:, 0, :, :] = h
192
x[:, 1, :, :] = x[:, 1, :, :] * f_s
193
x[:, 2, :, :] = x[:, 2, :, :] * f_v
194
x = torch.clamp(x, 0, 1)
195
x = hsv2rgb(x)
196
return x
197
198
@staticmethod
199
def backward(ctx, grad_output):
200
# We return as many input gradients as there were arguments.
201
# Gradients of non-Tensor arguments to forward must be None.
202
grad_input = None
203
if ctx.needs_input_grad[0]:
204
grad_input = grad_output.clone()
205
return grad_input, None, None, None
206
207
208
class ColorJitterLayer(nn.Module):
209
def __init__(self, brightness, contrast, saturation, hue):
210
super(ColorJitterLayer, self).__init__()
211
self.brightness = self._check_input(brightness, 'brightness')
212
self.contrast = self._check_input(contrast, 'contrast')
213
self.saturation = self._check_input(saturation, 'saturation')
214
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
215
216
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
217
if isinstance(value, numbers.Number):
218
if value < 0:
219
raise ValueError("If {} is a single number, it must be non negative.".format(name))
220
value = [center - value, center + value]
221
if clip_first_on_zero:
222
value[0] = max(value[0], 0)
223
elif isinstance(value, (tuple, list)) and len(value) == 2:
224
if not bound[0] <= value[0] <= value[1] <= bound[1]:
225
raise ValueError("{} values should be between {}".format(name, bound))
226
else:
227
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
228
229
# if value is 0 or (1., 1.) for brightness/contrast/saturation
230
# or (0., 0.) for hue, do nothing
231
if value[0] == value[1] == center:
232
value = None
233
return value
234
235
def adjust_contrast(self, x):
236
if self.contrast:
237
factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
238
means = torch.mean(x, dim=[2, 3], keepdim=True)
239
x = (x - means) * factor + means
240
return torch.clamp(x, 0, 1)
241
242
def adjust_hsv(self, x):
243
f_h = x.new_zeros(x.size(0), 1, 1)
244
f_s = x.new_ones(x.size(0), 1, 1)
245
f_v = x.new_ones(x.size(0), 1, 1)
246
247
if self.hue:
248
f_h.uniform_(*self.hue)
249
if self.saturation:
250
f_s = f_s.uniform_(*self.saturation)
251
if self.brightness:
252
f_v = f_v.uniform_(*self.brightness)
253
254
return RandomHSVFunction.apply(x, f_h, f_s, f_v)
255
256
def transform(self, inputs):
257
# Shuffle transform
258
if np.random.rand() > 0.5:
259
transforms = [self.adjust_contrast, self.adjust_hsv]
260
else:
261
transforms = [self.adjust_hsv, self.adjust_contrast]
262
263
for t in transforms:
264
inputs = t(inputs)
265
return inputs
266
267
def forward(self, inputs):
268
return self.transform(inputs)
269
270
271
class RandomColorGrayLayer(nn.Module):
272
def __init__(self):
273
super(RandomColorGrayLayer, self).__init__()
274
_weight = torch.tensor([[0.299, 0.587, 0.114]])
275
self.register_buffer('_weight', _weight.view(1, 3, 1, 1))
276
277
def forward(self, inputs):
278
l = F.conv2d(inputs, self._weight)
279
gray = torch.cat([l, l, l], dim=1)
280
return gray
281
282
283
class GaussianBlur(nn.Module):
284
def __init__(self, sigma_range):
285
"""Blurs the given image with separable convolution.
286
Args:
287
sigma_range: Range of sigma for being used in each gaussian kernel.
288
"""
289
super(GaussianBlur, self).__init__()
290
self.sigma_range = sigma_range
291
292
def forward(self, inputs):
293
_device = inputs.device
294
295
batch_size, num_channels, height, width = inputs.size()
296
297
kernel_size = height // 10
298
radius = int(kernel_size / 2)
299
kernel_size = radius * 2 + 1
300
301
sigma = np.random.uniform(*self.sigma_range)
302
kernel = torch.unsqueeze(get_gaussian_kernel2d((kernel_size, kernel_size), (sigma, sigma)), dim=0)
303
blurred = filter2d(inputs, kernel, "reflect")
304
return blurred
305
306
307
class CutOut(nn.Module):
308
def __init__(self, length):
309
super().__init__()
310
if length % 2 == 0:
311
raise ValueError("Currently CutOut only accepts odd lengths: length % 2 == 1")
312
self.length = length
313
314
_weight = torch.ones(1, 1, self.length)
315
self.register_buffer('_weight', _weight)
316
self._padding = (length - 1) // 2
317
318
def forward(self, inputs):
319
_device = inputs.device
320
N, _, h, w = inputs.shape
321
322
mask_h = inputs.new_zeros(N, h)
323
mask_w = inputs.new_zeros(N, w)
324
325
h_center = torch.randint(h, (N, 1), device=_device)
326
w_center = torch.randint(w, (N, 1), device=_device)
327
328
mask_h.scatter_(1, h_center, 1).unsqueeze_(1)
329
mask_w.scatter_(1, w_center, 1).unsqueeze_(1)
330
331
mask_h = F.conv1d(mask_h, self._weight, padding=self._padding)
332
mask_w = F.conv1d(mask_w, self._weight, padding=self._padding)
333
334
mask = 1. - torch.einsum('bci,bcj->bcij', mask_h, mask_w)
335
outputs = inputs * mask
336
return outputs
337
338
339
class SimclrAugment(nn.Module):
340
def __init__(self, aug_type):
341
super().__init__()
342
if aug_type == "simclr_basic":
343
self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),
344
RandomApply(ColorJitterLayer(ColorJitterLayer(0.4, 0.4, 0.4, 0.1)), p=0.8),
345
RandomApply(RandomColorGrayLayer(), p=0.2))
346
elif aug_type == "simclr_hq":
347
self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),
348
RandomApply(ColorJitterLayer(0.4, 0.4, 0.4, 0.1), p=0.8),
349
RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5))
350
elif aug_type == "simclr_hq_cutout":
351
self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),
352
RandomApply(ColorJitterLayer(0.4, 0.4, 0.4, 0.1), p=0.8),
353
RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5),
354
RandomApply(CutOut(15), p=0.5))
355
elif aug_type == "byol":
356
self.pipeline = nn.Sequential(RandomResizeCropLayer(scale=(0.2, 1.0)), HorizontalFlipLayer(),
357
RandomApply(ColorJitterLayer(0.4, 0.4, 0.2, 0.1), p=0.8),
358
RandomApply(RandomColorGrayLayer(), p=0.2), RandomApply(GaussianBlur((0.1, 2.0)), p=0.5))
359
360
def forward(self, images):
361
return self.pipeline(images)
362
363