Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/cr.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/cr.py
6
7
import random
8
9
import torch
10
import torch.nn.functional as F
11
12
13
def apply_cr_aug(x, flip=True, translation=True):
14
if flip:
15
x = random_flip(x, 0.5)
16
if translation:
17
x = random_translation(x, 1 / 8)
18
if flip or translation:
19
x = x.contiguous()
20
return x
21
22
23
def random_flip(x, p):
24
x_out = x.clone()
25
n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
26
flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0).to(x.device)
27
flip_mask = flip_prob < p
28
flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w)
29
x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1)
30
return x_out
31
32
33
def random_translation(x, ratio):
34
max_t_x, max_t_y = int(x.shape[2] * ratio), int(x.shape[3] * ratio)
35
t_x = torch.randint(-max_t_x, max_t_x + 1, size=[x.shape[0], 1, 1], device=x.device)
36
t_y = torch.randint(-max_t_y, max_t_y + 1, size=[x.shape[0], 1, 1], device=x.device)
37
38
grid_batch, grid_x, grid_y = torch.meshgrid(
39
torch.arange(x.shape[0], dtype=torch.long, device=x.device),
40
torch.arange(x.shape[2], dtype=torch.long, device=x.device),
41
torch.arange(x.shape[3], dtype=torch.long, device=x.device),
42
)
43
44
grid_x = (grid_x + t_x) + max_t_x
45
grid_y = (grid_y + t_y) + max_t_y
46
x_pad = F.pad(input=x, pad=[max_t_y, max_t_y, max_t_x, max_t_x], mode='reflect')
47
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
48
return x
49
50