Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/conv2d_gradfix.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Custom replacement for `torch.nn.functional.conv2d` that supports
10
arbitrarily high order gradients with zero performance penalty."""
11
12
import contextlib
13
import torch
14
from pkg_resources import parse_version
15
16
# pylint: disable=redefined-builtin
17
# pylint: disable=arguments-differ
18
# pylint: disable=protected-access
19
20
#----------------------------------------------------------------------------
21
22
enabled = False # Enable the custom op by setting this to true.
23
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25
26
@contextlib.contextmanager
27
def no_weight_gradients(disable=True):
28
global weight_gradients_disabled
29
old = weight_gradients_disabled
30
if disable:
31
weight_gradients_disabled = True
32
yield
33
weight_gradients_disabled = old
34
35
#----------------------------------------------------------------------------
36
37
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38
if _should_use_custom_op(input):
39
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41
42
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43
if _should_use_custom_op(input):
44
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46
47
#----------------------------------------------------------------------------
48
49
def _should_use_custom_op(input):
50
assert isinstance(input, torch.Tensor)
51
if (not enabled) or (not torch.backends.cudnn.enabled):
52
return False
53
if _use_pytorch_1_11_api:
54
# The work-around code doesn't work on PyTorch 1.11.0 onwards
55
return False
56
if input.device.type != 'cuda':
57
return False
58
return True
59
60
def _tuple_of_ints(xs, ndim):
61
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62
assert len(xs) == ndim
63
assert all(isinstance(x, int) for x in xs)
64
return xs
65
66
#----------------------------------------------------------------------------
67
68
_conv2d_gradfix_cache = dict()
69
_null_tensor = torch.empty([0])
70
71
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
72
# Parse arguments.
73
ndim = 2
74
weight_shape = tuple(weight_shape)
75
stride = _tuple_of_ints(stride, ndim)
76
padding = _tuple_of_ints(padding, ndim)
77
output_padding = _tuple_of_ints(output_padding, ndim)
78
dilation = _tuple_of_ints(dilation, ndim)
79
80
# Lookup from cache.
81
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
82
if key in _conv2d_gradfix_cache:
83
return _conv2d_gradfix_cache[key]
84
85
# Validate arguments.
86
assert groups >= 1
87
assert len(weight_shape) == ndim + 2
88
assert all(stride[i] >= 1 for i in range(ndim))
89
assert all(padding[i] >= 0 for i in range(ndim))
90
assert all(dilation[i] >= 0 for i in range(ndim))
91
if not transpose:
92
assert all(output_padding[i] == 0 for i in range(ndim))
93
else: # transpose
94
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
95
96
# Helpers.
97
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
98
def calc_output_padding(input_shape, output_shape):
99
if transpose:
100
return [0, 0]
101
return [
102
input_shape[i + 2]
103
- (output_shape[i + 2] - 1) * stride[i]
104
- (1 - 2 * padding[i])
105
- dilation[i] * (weight_shape[i + 2] - 1)
106
for i in range(ndim)
107
]
108
109
# Forward & backward.
110
class Conv2d(torch.autograd.Function):
111
@staticmethod
112
def forward(ctx, input, weight, bias):
113
assert weight.shape == weight_shape
114
ctx.save_for_backward(
115
input if weight.requires_grad else _null_tensor,
116
weight if input.requires_grad else _null_tensor,
117
)
118
ctx.input_shape = input.shape
119
120
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
121
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
122
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
123
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
124
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
125
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
126
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
127
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
128
129
# General case => cuDNN.
130
if transpose:
131
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
132
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
133
134
@staticmethod
135
def backward(ctx, grad_output):
136
input, weight = ctx.saved_tensors
137
input_shape = ctx.input_shape
138
grad_input = None
139
grad_weight = None
140
grad_bias = None
141
142
if ctx.needs_input_grad[0]:
143
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
144
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
145
grad_input = op.apply(grad_output, weight, None)
146
assert grad_input.shape == input_shape
147
148
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
149
grad_weight = Conv2dGradWeight.apply(grad_output, input)
150
assert grad_weight.shape == weight_shape
151
152
if ctx.needs_input_grad[2]:
153
grad_bias = grad_output.sum([0, 2, 3])
154
155
return grad_input, grad_weight, grad_bias
156
157
# Gradient with respect to the weights.
158
class Conv2dGradWeight(torch.autograd.Function):
159
@staticmethod
160
def forward(ctx, grad_output, input):
161
ctx.save_for_backward(
162
grad_output if input.requires_grad else _null_tensor,
163
input if grad_output.requires_grad else _null_tensor,
164
)
165
ctx.grad_output_shape = grad_output.shape
166
ctx.input_shape = input.shape
167
168
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
169
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
170
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
171
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
172
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
173
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
174
175
# General case => cuDNN.
176
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
177
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
178
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
179
180
@staticmethod
181
def backward(ctx, grad2_grad_weight):
182
grad_output, input = ctx.saved_tensors
183
grad_output_shape = ctx.grad_output_shape
184
input_shape = ctx.input_shape
185
grad2_grad_output = None
186
grad2_input = None
187
188
if ctx.needs_input_grad[0]:
189
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
190
assert grad2_grad_output.shape == grad_output_shape
191
192
if ctx.needs_input_grad[1]:
193
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
194
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
195
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
196
assert grad2_input.shape == input_shape
197
198
return grad2_grad_output, grad2_input
199
200
_conv2d_gradfix_cache[key] = Conv2d
201
return Conv2d
202
203
#----------------------------------------------------------------------------
204
205