Path: blob/master/src/utils/style_ops/conv2d_gradfix.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Custom replacement for `torch.nn.functional.conv2d` that supports9arbitrarily high order gradients with zero performance penalty."""1011import contextlib12import torch13from pkg_resources import parse_version1415# pylint: disable=redefined-builtin16# pylint: disable=arguments-differ17# pylint: disable=protected-access1819#----------------------------------------------------------------------------2021enabled = False # Enable the custom op by setting this to true.22weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.23_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.112425@contextlib.contextmanager26def no_weight_gradients(disable=True):27global weight_gradients_disabled28old = weight_gradients_disabled29if disable:30weight_gradients_disabled = True31yield32weight_gradients_disabled = old3334#----------------------------------------------------------------------------3536def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):37if _should_use_custom_op(input):38return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)39return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)4041def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):42if _should_use_custom_op(input):43return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)44return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)4546#----------------------------------------------------------------------------4748def _should_use_custom_op(input):49assert isinstance(input, torch.Tensor)50if (not enabled) or (not torch.backends.cudnn.enabled):51return False52if _use_pytorch_1_11_api:53# The work-around code doesn't work on PyTorch 1.11.0 onwards54return False55if input.device.type != 'cuda':56return False57return True5859def _tuple_of_ints(xs, ndim):60xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim61assert len(xs) == ndim62assert all(isinstance(x, int) for x in xs)63return xs6465#----------------------------------------------------------------------------6667_conv2d_gradfix_cache = dict()68_null_tensor = torch.empty([0])6970def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):71# Parse arguments.72ndim = 273weight_shape = tuple(weight_shape)74stride = _tuple_of_ints(stride, ndim)75padding = _tuple_of_ints(padding, ndim)76output_padding = _tuple_of_ints(output_padding, ndim)77dilation = _tuple_of_ints(dilation, ndim)7879# Lookup from cache.80key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)81if key in _conv2d_gradfix_cache:82return _conv2d_gradfix_cache[key]8384# Validate arguments.85assert groups >= 186assert len(weight_shape) == ndim + 287assert all(stride[i] >= 1 for i in range(ndim))88assert all(padding[i] >= 0 for i in range(ndim))89assert all(dilation[i] >= 0 for i in range(ndim))90if not transpose:91assert all(output_padding[i] == 0 for i in range(ndim))92else: # transpose93assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))9495# Helpers.96common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)97def calc_output_padding(input_shape, output_shape):98if transpose:99return [0, 0]100return [101input_shape[i + 2]102- (output_shape[i + 2] - 1) * stride[i]103- (1 - 2 * padding[i])104- dilation[i] * (weight_shape[i + 2] - 1)105for i in range(ndim)106]107108# Forward & backward.109class Conv2d(torch.autograd.Function):110@staticmethod111def forward(ctx, input, weight, bias):112assert weight.shape == weight_shape113ctx.save_for_backward(114input if weight.requires_grad else _null_tensor,115weight if input.requires_grad else _null_tensor,116)117ctx.input_shape = input.shape118119# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).120if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):121a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])122b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)123c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)124c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)125c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)126return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))127128# General case => cuDNN.129if transpose:130return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)131return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)132133@staticmethod134def backward(ctx, grad_output):135input, weight = ctx.saved_tensors136input_shape = ctx.input_shape137grad_input = None138grad_weight = None139grad_bias = None140141if ctx.needs_input_grad[0]:142p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)143op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)144grad_input = op.apply(grad_output, weight, None)145assert grad_input.shape == input_shape146147if ctx.needs_input_grad[1] and not weight_gradients_disabled:148grad_weight = Conv2dGradWeight.apply(grad_output, input)149assert grad_weight.shape == weight_shape150151if ctx.needs_input_grad[2]:152grad_bias = grad_output.sum([0, 2, 3])153154return grad_input, grad_weight, grad_bias155156# Gradient with respect to the weights.157class Conv2dGradWeight(torch.autograd.Function):158@staticmethod159def forward(ctx, grad_output, input):160ctx.save_for_backward(161grad_output if input.requires_grad else _null_tensor,162input if grad_output.requires_grad else _null_tensor,163)164ctx.grad_output_shape = grad_output.shape165ctx.input_shape = input.shape166167# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).168if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):169a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)170b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)171c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)172return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))173174# General case => cuDNN.175name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'176flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]177return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)178179@staticmethod180def backward(ctx, grad2_grad_weight):181grad_output, input = ctx.saved_tensors182grad_output_shape = ctx.grad_output_shape183input_shape = ctx.input_shape184grad2_grad_output = None185grad2_input = None186187if ctx.needs_input_grad[0]:188grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)189assert grad2_grad_output.shape == grad_output_shape190191if ctx.needs_input_grad[1]:192p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)193op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)194grad2_input = op.apply(grad_output, grad2_grad_weight, None)195assert grad2_input.shape == input_shape196197return grad2_grad_output, grad2_input198199_conv2d_gradfix_cache[key] = Conv2d200return Conv2d201202#----------------------------------------------------------------------------203204205