Path: blob/master/src/utils/style_ops/bias_act.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Custom PyTorch ops for efficient bias and activation."""910import os11import numpy as np12import torch13import utils.style_ops.dnnlib as dnnlib1415from .. import custom_ops1617#----------------------------------------------------------------------------1819activation_funcs = {20'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),21'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),22'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),23'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),24'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),25'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),26'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),27'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),28'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),29}3031#----------------------------------------------------------------------------3233_plugin = None34_null_tensor = torch.empty([0])3536def _init():37global _plugin38if _plugin is None:39_plugin = custom_ops.get_plugin(40module_name='bias_act_plugin',41sources=['bias_act.cpp', 'bias_act.cu'],42headers=['bias_act.h'],43source_dir=os.path.dirname(__file__),44extra_cuda_cflags=['--use_fast_math'],45)46return True4748#----------------------------------------------------------------------------4950def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):51r"""Fused bias and activation function.5253Adds bias `b` to activation tensor `x`, evaluates activation function `act`,54and scales the result by `gain`. Each of the steps is optional. In most cases,55the fused op is considerably more efficient than performing the same calculation56using standard PyTorch ops. It supports first and second order gradients,57but not third order gradients.5859Args:60x: Input activation tensor. Can be of any shape.61b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type62as `x`. The shape must be known, and it must match the dimension of `x`63corresponding to `dim`.64dim: The dimension in `x` corresponding to the elements of `b`.65The value of `dim` is ignored if `b` is not specified.66act: Name of the activation function to evaluate, or `"linear"` to disable.67Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.68See `activation_funcs` for a full list. `None` is not allowed.69alpha: Shape parameter for the activation function, or `None` to use the default.70gain: Scaling factor for the output tensor, or `None` to use default.71See `activation_funcs` for the default scaling of each activation function.72If unsure, consider specifying 1.73clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable74the clamping (default).75impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).7677Returns:78Tensor of the same shape and datatype as `x`.79"""80assert isinstance(x, torch.Tensor)81assert impl in ['ref', 'cuda']82if impl == 'cuda' and x.device.type == 'cuda' and _init():83return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)84return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)8586#----------------------------------------------------------------------------8788def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):89"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.90"""91assert isinstance(x, torch.Tensor)92assert clamp is None or clamp >= 093spec = activation_funcs[act]94alpha = float(alpha if alpha is not None else spec.def_alpha)95gain = float(gain if gain is not None else spec.def_gain)96clamp = float(clamp if clamp is not None else -1)9798# Add bias.99if b is not None:100assert isinstance(b, torch.Tensor) and b.ndim == 1101assert 0 <= dim < x.ndim102assert b.shape[0] == x.shape[dim]103x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])104105# Evaluate activation function.106alpha = float(alpha)107x = spec.func(x, alpha=alpha)108109# Scale by gain.110gain = float(gain)111if gain != 1:112x = x * gain113114# Clamp.115if clamp >= 0:116x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type117return x118119#----------------------------------------------------------------------------120121_bias_act_cuda_cache = dict()122123def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):124"""Fast CUDA implementation of `bias_act()` using custom ops.125"""126# Parse arguments.127assert clamp is None or clamp >= 0128spec = activation_funcs[act]129alpha = float(alpha if alpha is not None else spec.def_alpha)130gain = float(gain if gain is not None else spec.def_gain)131clamp = float(clamp if clamp is not None else -1)132133# Lookup from cache.134key = (dim, act, alpha, gain, clamp)135if key in _bias_act_cuda_cache:136return _bias_act_cuda_cache[key]137138# Forward op.139class BiasActCuda(torch.autograd.Function):140@staticmethod141def forward(ctx, x, b): # pylint: disable=arguments-differ142ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format143x = x.contiguous(memory_format=ctx.memory_format)144b = b.contiguous() if b is not None else _null_tensor145y = x146if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:147y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)148ctx.save_for_backward(149x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,150b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,151y if 'y' in spec.ref else _null_tensor)152return y153154@staticmethod155def backward(ctx, dy): # pylint: disable=arguments-differ156dy = dy.contiguous(memory_format=ctx.memory_format)157x, b, y = ctx.saved_tensors158dx = None159db = None160161if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:162dx = dy163if act != 'linear' or gain != 1 or clamp >= 0:164dx = BiasActCudaGrad.apply(dy, x, b, y)165166if ctx.needs_input_grad[1]:167db = dx.sum([i for i in range(dx.ndim) if i != dim])168169return dx, db170171# Backward op.172class BiasActCudaGrad(torch.autograd.Function):173@staticmethod174def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ175ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format176dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)177ctx.save_for_backward(178dy if spec.has_2nd_grad else _null_tensor,179x, b, y)180return dx181182@staticmethod183def backward(ctx, d_dx): # pylint: disable=arguments-differ184d_dx = d_dx.contiguous(memory_format=ctx.memory_format)185dy, x, b, y = ctx.saved_tensors186d_dy = None187d_x = None188d_b = None189d_y = None190191if ctx.needs_input_grad[0]:192d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)193194if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):195d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)196197if spec.has_2nd_grad and ctx.needs_input_grad[2]:198d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])199200return d_dy, d_x, d_b, d_y201202# Add to cache.203_bias_act_cuda_cache[key] = BiasActCuda204return BiasActCuda205206#----------------------------------------------------------------------------207208209