Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/fma.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
11
import torch
12
13
#----------------------------------------------------------------------------
14
15
def fma(a, b, c): # => a * b + c
16
return _FusedMultiplyAdd.apply(a, b, c)
17
18
#----------------------------------------------------------------------------
19
20
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
@staticmethod
22
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
out = torch.addcmul(c, a, b)
24
ctx.save_for_backward(a, b)
25
ctx.c_shape = c.shape
26
return out
27
28
@staticmethod
29
def backward(ctx, dout): # pylint: disable=arguments-differ
30
a, b = ctx.saved_tensors
31
c_shape = ctx.c_shape
32
da = None
33
db = None
34
dc = None
35
36
if ctx.needs_input_grad[0]:
37
da = _unbroadcast(dout * b, a.shape)
38
39
if ctx.needs_input_grad[1]:
40
db = _unbroadcast(dout * a, b.shape)
41
42
if ctx.needs_input_grad[2]:
43
dc = _unbroadcast(dout, c_shape)
44
45
return da, db, dc
46
47
#----------------------------------------------------------------------------
48
49
def _unbroadcast(x, shape):
50
extra_dims = x.ndim - len(shape)
51
assert extra_dims >= 0
52
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
if len(dim):
54
x = x.sum(dim=dim, keepdim=True)
55
if extra_dims:
56
x = x.reshape(-1, *x.shape[extra_dims+1:])
57
assert x.shape == shape
58
return x
59
60
#----------------------------------------------------------------------------
61
62