Path: blob/master/src/utils/style_ops/fma.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""910import torch1112#----------------------------------------------------------------------------1314def fma(a, b, c): # => a * b + c15return _FusedMultiplyAdd.apply(a, b, c)1617#----------------------------------------------------------------------------1819class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c20@staticmethod21def forward(ctx, a, b, c): # pylint: disable=arguments-differ22out = torch.addcmul(c, a, b)23ctx.save_for_backward(a, b)24ctx.c_shape = c.shape25return out2627@staticmethod28def backward(ctx, dout): # pylint: disable=arguments-differ29a, b = ctx.saved_tensors30c_shape = ctx.c_shape31da = None32db = None33dc = None3435if ctx.needs_input_grad[0]:36da = _unbroadcast(dout * b, a.shape)3738if ctx.needs_input_grad[1]:39db = _unbroadcast(dout * a, b.shape)4041if ctx.needs_input_grad[2]:42dc = _unbroadcast(dout, c_shape)4344return da, db, dc4546#----------------------------------------------------------------------------4748def _unbroadcast(x, shape):49extra_dims = x.ndim - len(shape)50assert extra_dims >= 051dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]52if len(dim):53x = x.sum(dim=dim, keepdim=True)54if extra_dims:55x = x.reshape(-1, *x.shape[extra_dims+1:])56assert x.shape == shape57return x5859#----------------------------------------------------------------------------606162