Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/grid_sample_gradfix.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Custom replacement for `torch.nn.functional.grid_sample` that
10
supports arbitrarily high order gradients between the input and output.
11
Only works on 2D images and assumes
12
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
14
import torch
15
from pkg_resources import parse_version
16
17
# pylint: disable=redefined-builtin
18
# pylint: disable=arguments-differ
19
# pylint: disable=protected-access
20
21
#----------------------------------------------------------------------------
22
23
enabled = False # Enable the custom op by setting this to true.
24
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25
26
#----------------------------------------------------------------------------
27
28
def grid_sample(input, grid):
29
if _should_use_custom_op():
30
return _GridSample2dForward.apply(input, grid)
31
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32
33
#----------------------------------------------------------------------------
34
35
def _should_use_custom_op():
36
return enabled
37
38
#----------------------------------------------------------------------------
39
40
class _GridSample2dForward(torch.autograd.Function):
41
@staticmethod
42
def forward(ctx, input, grid):
43
assert input.ndim == 4
44
assert grid.ndim == 4
45
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46
ctx.save_for_backward(input, grid)
47
return output
48
49
@staticmethod
50
def backward(ctx, grad_output):
51
input, grid = ctx.saved_tensors
52
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53
return grad_input, grad_grid
54
55
#----------------------------------------------------------------------------
56
57
class _GridSample2dBackward(torch.autograd.Function):
58
@staticmethod
59
def forward(ctx, grad_output, input, grid):
60
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61
if _use_pytorch_1_11_api:
62
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
63
grad_input, grad_grid = op[0](grad_output, input, grid, 0, 0, False, output_mask)
64
else:
65
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
ctx.save_for_backward(grid)
67
return grad_input, grad_grid
68
69
@staticmethod
70
def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
_ = grad2_grad_grid # unused
72
grid, = ctx.saved_tensors
73
grad2_grad_output = None
74
grad2_input = None
75
grad2_grid = None
76
77
if ctx.needs_input_grad[0]:
78
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
80
assert not ctx.needs_input_grad[2]
81
return grad2_grad_output, grad2_input, grad2_grid
82
83
#----------------------------------------------------------------------------
84
85