Path: blob/master/src/utils/style_ops/grid_sample_gradfix.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Custom replacement for `torch.nn.functional.grid_sample` that9supports arbitrarily high order gradients between the input and output.10Only works on 2D images and assumes11`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""1213import torch14from pkg_resources import parse_version1516# pylint: disable=redefined-builtin17# pylint: disable=arguments-differ18# pylint: disable=protected-access1920#----------------------------------------------------------------------------2122enabled = False # Enable the custom op by setting this to true.23_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.112425#----------------------------------------------------------------------------2627def grid_sample(input, grid):28if _should_use_custom_op():29return _GridSample2dForward.apply(input, grid)30return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)3132#----------------------------------------------------------------------------3334def _should_use_custom_op():35return enabled3637#----------------------------------------------------------------------------3839class _GridSample2dForward(torch.autograd.Function):40@staticmethod41def forward(ctx, input, grid):42assert input.ndim == 443assert grid.ndim == 444output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)45ctx.save_for_backward(input, grid)46return output4748@staticmethod49def backward(ctx, grad_output):50input, grid = ctx.saved_tensors51grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)52return grad_input, grad_grid5354#----------------------------------------------------------------------------5556class _GridSample2dBackward(torch.autograd.Function):57@staticmethod58def forward(ctx, grad_output, input, grid):59op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')60if _use_pytorch_1_11_api:61output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])62grad_input, grad_grid = op[0](grad_output, input, grid, 0, 0, False, output_mask)63else:64grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)65ctx.save_for_backward(grid)66return grad_input, grad_grid6768@staticmethod69def backward(ctx, grad2_grad_input, grad2_grad_grid):70_ = grad2_grad_grid # unused71grid, = ctx.saved_tensors72grad2_grad_output = None73grad2_input = None74grad2_grid = None7576if ctx.needs_input_grad[0]:77grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)7879assert not ctx.needs_input_grad[2]80return grad2_grad_output, grad2_input, grad2_grid8182#----------------------------------------------------------------------------838485