Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/bias_act.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Custom PyTorch ops for efficient bias and activation."""
10
11
import os
12
import numpy as np
13
import torch
14
import utils.style_ops.dnnlib as dnnlib
15
16
from .. import custom_ops
17
18
#----------------------------------------------------------------------------
19
20
activation_funcs = {
21
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
22
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
23
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
24
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
25
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
26
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
27
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
28
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
29
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
30
}
31
32
#----------------------------------------------------------------------------
33
34
_plugin = None
35
_null_tensor = torch.empty([0])
36
37
def _init():
38
global _plugin
39
if _plugin is None:
40
_plugin = custom_ops.get_plugin(
41
module_name='bias_act_plugin',
42
sources=['bias_act.cpp', 'bias_act.cu'],
43
headers=['bias_act.h'],
44
source_dir=os.path.dirname(__file__),
45
extra_cuda_cflags=['--use_fast_math'],
46
)
47
return True
48
49
#----------------------------------------------------------------------------
50
51
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
52
r"""Fused bias and activation function.
53
54
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
55
and scales the result by `gain`. Each of the steps is optional. In most cases,
56
the fused op is considerably more efficient than performing the same calculation
57
using standard PyTorch ops. It supports first and second order gradients,
58
but not third order gradients.
59
60
Args:
61
x: Input activation tensor. Can be of any shape.
62
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
63
as `x`. The shape must be known, and it must match the dimension of `x`
64
corresponding to `dim`.
65
dim: The dimension in `x` corresponding to the elements of `b`.
66
The value of `dim` is ignored if `b` is not specified.
67
act: Name of the activation function to evaluate, or `"linear"` to disable.
68
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
69
See `activation_funcs` for a full list. `None` is not allowed.
70
alpha: Shape parameter for the activation function, or `None` to use the default.
71
gain: Scaling factor for the output tensor, or `None` to use default.
72
See `activation_funcs` for the default scaling of each activation function.
73
If unsure, consider specifying 1.
74
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
75
the clamping (default).
76
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
77
78
Returns:
79
Tensor of the same shape and datatype as `x`.
80
"""
81
assert isinstance(x, torch.Tensor)
82
assert impl in ['ref', 'cuda']
83
if impl == 'cuda' and x.device.type == 'cuda' and _init():
84
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
85
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
86
87
#----------------------------------------------------------------------------
88
89
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
90
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
91
"""
92
assert isinstance(x, torch.Tensor)
93
assert clamp is None or clamp >= 0
94
spec = activation_funcs[act]
95
alpha = float(alpha if alpha is not None else spec.def_alpha)
96
gain = float(gain if gain is not None else spec.def_gain)
97
clamp = float(clamp if clamp is not None else -1)
98
99
# Add bias.
100
if b is not None:
101
assert isinstance(b, torch.Tensor) and b.ndim == 1
102
assert 0 <= dim < x.ndim
103
assert b.shape[0] == x.shape[dim]
104
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
105
106
# Evaluate activation function.
107
alpha = float(alpha)
108
x = spec.func(x, alpha=alpha)
109
110
# Scale by gain.
111
gain = float(gain)
112
if gain != 1:
113
x = x * gain
114
115
# Clamp.
116
if clamp >= 0:
117
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
118
return x
119
120
#----------------------------------------------------------------------------
121
122
_bias_act_cuda_cache = dict()
123
124
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
125
"""Fast CUDA implementation of `bias_act()` using custom ops.
126
"""
127
# Parse arguments.
128
assert clamp is None or clamp >= 0
129
spec = activation_funcs[act]
130
alpha = float(alpha if alpha is not None else spec.def_alpha)
131
gain = float(gain if gain is not None else spec.def_gain)
132
clamp = float(clamp if clamp is not None else -1)
133
134
# Lookup from cache.
135
key = (dim, act, alpha, gain, clamp)
136
if key in _bias_act_cuda_cache:
137
return _bias_act_cuda_cache[key]
138
139
# Forward op.
140
class BiasActCuda(torch.autograd.Function):
141
@staticmethod
142
def forward(ctx, x, b): # pylint: disable=arguments-differ
143
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
144
x = x.contiguous(memory_format=ctx.memory_format)
145
b = b.contiguous() if b is not None else _null_tensor
146
y = x
147
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
148
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
149
ctx.save_for_backward(
150
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
151
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
152
y if 'y' in spec.ref else _null_tensor)
153
return y
154
155
@staticmethod
156
def backward(ctx, dy): # pylint: disable=arguments-differ
157
dy = dy.contiguous(memory_format=ctx.memory_format)
158
x, b, y = ctx.saved_tensors
159
dx = None
160
db = None
161
162
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
163
dx = dy
164
if act != 'linear' or gain != 1 or clamp >= 0:
165
dx = BiasActCudaGrad.apply(dy, x, b, y)
166
167
if ctx.needs_input_grad[1]:
168
db = dx.sum([i for i in range(dx.ndim) if i != dim])
169
170
return dx, db
171
172
# Backward op.
173
class BiasActCudaGrad(torch.autograd.Function):
174
@staticmethod
175
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
176
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
177
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
178
ctx.save_for_backward(
179
dy if spec.has_2nd_grad else _null_tensor,
180
x, b, y)
181
return dx
182
183
@staticmethod
184
def backward(ctx, d_dx): # pylint: disable=arguments-differ
185
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
186
dy, x, b, y = ctx.saved_tensors
187
d_dy = None
188
d_x = None
189
d_b = None
190
d_y = None
191
192
if ctx.needs_input_grad[0]:
193
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
194
195
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
196
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
197
198
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
199
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
200
201
return d_dy, d_x, d_b, d_y
202
203
# Add to cache.
204
_bias_act_cuda_cache[key] = BiasActCuda
205
return BiasActCuda
206
207
#----------------------------------------------------------------------------
208
209