Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/conv2d_resample.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""2D convolution with optional up/downsampling."""
10
11
import torch
12
13
from .. import style_misc as misc
14
from . import conv2d_gradfix
15
from . import upfirdn2d
16
from .upfirdn2d import _parse_padding
17
from .upfirdn2d import _get_filter_size
18
19
#----------------------------------------------------------------------------
20
21
def _get_weight_shape(w):
22
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
shape = [int(sz) for sz in w.shape]
24
misc.assert_shape(w, shape)
25
return shape
26
27
#----------------------------------------------------------------------------
28
29
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
"""
32
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
34
# Flip weight if requested.
35
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
if not flip_weight and (kw > 1 or kh > 1):
37
w = w.flip([2, 3])
38
39
# Execute using conv2d_gradfix.
40
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
return op(x, w, stride=stride, padding=padding, groups=groups)
42
43
#----------------------------------------------------------------------------
44
45
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
46
r"""2D convolution with optional up/downsampling.
47
48
Padding is performed only once at the beginning, not between the operations.
49
50
Args:
51
x: Input tensor of shape
52
`[batch_size, in_channels, in_height, in_width]`.
53
w: Weight tensor of shape
54
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
55
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
56
calling upfirdn2d.setup_filter(). None = identity (default).
57
up: Integer upsampling factor (default: 1).
58
down: Integer downsampling factor (default: 1).
59
padding: Padding with respect to the upsampled image. Can be a single number
60
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
61
(default: 0).
62
groups: Split input channels into N groups (default: 1).
63
flip_weight: False = convolution, True = correlation (default: True).
64
flip_filter: False = convolution, True = correlation (default: False).
65
66
Returns:
67
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
68
"""
69
# Validate arguments.
70
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
71
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
72
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
73
assert isinstance(up, int) and (up >= 1)
74
assert isinstance(down, int) and (down >= 1)
75
assert isinstance(groups, int) and (groups >= 1)
76
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
77
fw, fh = _get_filter_size(f)
78
px0, px1, py0, py1 = _parse_padding(padding)
79
80
# Adjust padding to account for up/downsampling.
81
if up > 1:
82
px0 += (fw + up - 1) // 2
83
px1 += (fw - up) // 2
84
py0 += (fh + up - 1) // 2
85
py1 += (fh - up) // 2
86
if down > 1:
87
px0 += (fw - down + 1) // 2
88
px1 += (fw - down) // 2
89
py0 += (fh - down + 1) // 2
90
py1 += (fh - down) // 2
91
92
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
93
if kw == 1 and kh == 1 and (down > 1 and up == 1):
94
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
95
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
96
return x
97
98
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
99
if kw == 1 and kh == 1 and (up > 1 and down == 1):
100
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
101
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
102
return x
103
104
# Fast path: downsampling only => use strided convolution.
105
if down > 1 and up == 1:
106
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
107
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
108
return x
109
110
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
111
if up > 1:
112
if groups == 1:
113
w = w.transpose(0, 1)
114
else:
115
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
116
w = w.transpose(1, 2)
117
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
118
px0 -= kw - 1
119
px1 -= kw - up
120
py0 -= kh - 1
121
py1 -= kh - up
122
pxt = max(min(-px0, -px1), 0)
123
pyt = max(min(-py0, -py1), 0)
124
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
125
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
126
if down > 1:
127
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
128
return x
129
130
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
131
if up == 1 and down == 1:
132
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
133
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
134
135
# Fallback: Generic reference implementation.
136
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
137
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
138
if down > 1:
139
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
140
return x
141
142
#----------------------------------------------------------------------------
143
144