Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/upfirdn2d.py
809 views
1
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto. Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
"""Custom PyTorch ops for efficient resampling of 2D images."""
10
11
import os
12
import numpy as np
13
import torch
14
15
from .. import custom_ops
16
from .. import style_misc as misc
17
from . import conv2d_gradfix
18
19
#----------------------------------------------------------------------------
20
21
_plugin = None
22
23
def _init():
24
global _plugin
25
if _plugin is None:
26
_plugin = custom_ops.get_plugin(
27
module_name='upfirdn2d_plugin',
28
sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
29
headers=['upfirdn2d.h'],
30
source_dir=os.path.dirname(__file__),
31
extra_cuda_cflags=['--use_fast_math'],
32
)
33
return True
34
35
def _parse_scaling(scaling):
36
if isinstance(scaling, int):
37
scaling = [scaling, scaling]
38
assert isinstance(scaling, (list, tuple))
39
assert all(isinstance(x, int) for x in scaling)
40
sx, sy = scaling
41
assert sx >= 1 and sy >= 1
42
return sx, sy
43
44
def _parse_padding(padding):
45
if isinstance(padding, int):
46
padding = [padding, padding]
47
assert isinstance(padding, (list, tuple))
48
assert all(isinstance(x, int) for x in padding)
49
if len(padding) == 2:
50
padx, pady = padding
51
padding = [padx, padx, pady, pady]
52
padx0, padx1, pady0, pady1 = padding
53
return padx0, padx1, pady0, pady1
54
55
def _get_filter_size(f):
56
if f is None:
57
return 1, 1
58
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
59
fw = f.shape[-1]
60
fh = f.shape[0]
61
with misc.suppress_tracer_warnings():
62
fw = int(fw)
63
fh = int(fh)
64
misc.assert_shape(f, [fh, fw][:f.ndim])
65
assert fw >= 1 and fh >= 1
66
return fw, fh
67
68
#----------------------------------------------------------------------------
69
70
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
71
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
72
73
Args:
74
f: Torch tensor, numpy array, or python list of the shape
75
`[filter_height, filter_width]` (non-separable),
76
`[filter_taps]` (separable),
77
`[]` (impulse), or
78
`None` (identity).
79
device: Result device (default: cpu).
80
normalize: Normalize the filter so that it retains the magnitude
81
for constant input signal (DC)? (default: True).
82
flip_filter: Flip the filter? (default: False).
83
gain: Overall scaling factor for signal magnitude (default: 1).
84
separable: Return a separable filter? (default: select automatically).
85
86
Returns:
87
Float32 tensor of the shape
88
`[filter_height, filter_width]` (non-separable) or
89
`[filter_taps]` (separable).
90
"""
91
# Validate.
92
if f is None:
93
f = 1
94
f = torch.as_tensor(f, dtype=torch.float32)
95
assert f.ndim in [0, 1, 2]
96
assert f.numel() > 0
97
if f.ndim == 0:
98
f = f[np.newaxis]
99
100
# Separable?
101
if separable is None:
102
separable = (f.ndim == 1 and f.numel() >= 8)
103
if f.ndim == 1 and not separable:
104
f = f.ger(f)
105
assert f.ndim == (1 if separable else 2)
106
107
# Apply normalize, flip, gain, and device.
108
if normalize:
109
f /= f.sum()
110
if flip_filter:
111
f = f.flip(list(range(f.ndim)))
112
f = f * (gain ** (f.ndim / 2))
113
f = f.to(device=device)
114
return f
115
116
#----------------------------------------------------------------------------
117
118
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
119
r"""Pad, upsample, filter, and downsample a batch of 2D images.
120
121
Performs the following sequence of operations for each channel:
122
123
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
124
125
2. Pad the image with the specified number of zeros on each side (`padding`).
126
Negative padding corresponds to cropping the image.
127
128
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
129
so that the footprint of all output pixels lies within the input image.
130
131
4. Downsample the image by keeping every Nth pixel (`down`).
132
133
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
134
The fused op is considerably more efficient than performing the same calculation
135
using standard PyTorch ops. It supports gradients of arbitrary order.
136
137
Args:
138
x: Float32/float64/float16 input tensor of the shape
139
`[batch_size, num_channels, in_height, in_width]`.
140
f: Float32 FIR filter of the shape
141
`[filter_height, filter_width]` (non-separable),
142
`[filter_taps]` (separable), or
143
`None` (identity).
144
up: Integer upsampling factor. Can be a single int or a list/tuple
145
`[x, y]` (default: 1).
146
down: Integer downsampling factor. Can be a single int or a list/tuple
147
`[x, y]` (default: 1).
148
padding: Padding with respect to the upsampled image. Can be a single number
149
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
150
(default: 0).
151
flip_filter: False = convolution, True = correlation (default: False).
152
gain: Overall scaling factor for signal magnitude (default: 1).
153
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
154
155
Returns:
156
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
157
"""
158
assert isinstance(x, torch.Tensor)
159
assert impl in ['ref', 'cuda']
160
if impl == 'cuda' and x.device.type == 'cuda' and _init():
161
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
162
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
163
164
#----------------------------------------------------------------------------
165
166
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
167
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
168
"""
169
# Validate arguments.
170
assert isinstance(x, torch.Tensor) and x.ndim == 4
171
if f is None:
172
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
173
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
174
assert f.dtype == torch.float32 and not f.requires_grad
175
batch_size, num_channels, in_height, in_width = x.shape
176
upx, upy = _parse_scaling(up)
177
downx, downy = _parse_scaling(down)
178
padx0, padx1, pady0, pady1 = _parse_padding(padding)
179
180
# Check that upsampled buffer is not smaller than the filter.
181
upW = in_width * upx + padx0 + padx1
182
upH = in_height * upy + pady0 + pady1
183
assert upW >= f.shape[-1] and upH >= f.shape[0]
184
185
# Upsample by inserting zeros.
186
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
187
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
188
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
189
190
# Pad or crop.
191
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
192
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
193
194
# Setup filter.
195
f = f * (gain ** (f.ndim / 2))
196
f = f.to(x.dtype)
197
if not flip_filter:
198
f = f.flip(list(range(f.ndim)))
199
200
# Convolve with the filter.
201
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
202
if f.ndim == 4:
203
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
204
else:
205
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
206
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
207
208
# Downsample by throwing away pixels.
209
x = x[:, :, ::downy, ::downx]
210
return x
211
212
#----------------------------------------------------------------------------
213
214
_upfirdn2d_cuda_cache = dict()
215
216
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
217
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
218
"""
219
# Parse arguments.
220
upx, upy = _parse_scaling(up)
221
downx, downy = _parse_scaling(down)
222
padx0, padx1, pady0, pady1 = _parse_padding(padding)
223
224
# Lookup from cache.
225
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
226
if key in _upfirdn2d_cuda_cache:
227
return _upfirdn2d_cuda_cache[key]
228
229
# Forward op.
230
class Upfirdn2dCuda(torch.autograd.Function):
231
@staticmethod
232
def forward(ctx, x, f): # pylint: disable=arguments-differ
233
assert isinstance(x, torch.Tensor) and x.ndim == 4
234
if f is None:
235
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
236
if f.ndim == 1 and f.shape[0] == 1:
237
f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
238
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
239
y = x
240
if f.ndim == 2:
241
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
242
else:
243
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
244
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
245
ctx.save_for_backward(f)
246
ctx.x_shape = x.shape
247
return y
248
249
@staticmethod
250
def backward(ctx, dy): # pylint: disable=arguments-differ
251
f, = ctx.saved_tensors
252
_, _, ih, iw = ctx.x_shape
253
_, _, oh, ow = dy.shape
254
fw, fh = _get_filter_size(f)
255
p = [
256
fw - padx0 - 1,
257
iw * upx - ow * downx + padx0 - upx + 1,
258
fh - pady0 - 1,
259
ih * upy - oh * downy + pady0 - upy + 1,
260
]
261
dx = None
262
df = None
263
264
if ctx.needs_input_grad[0]:
265
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
266
267
assert not ctx.needs_input_grad[1]
268
return dx, df
269
270
# Add to cache.
271
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
272
return Upfirdn2dCuda
273
274
#----------------------------------------------------------------------------
275
276
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
277
r"""Filter a batch of 2D images using the given 2D FIR filter.
278
279
By default, the result is padded so that its shape matches the input.
280
User-specified padding is applied on top of that, with negative values
281
indicating cropping. Pixels outside the image are assumed to be zero.
282
283
Args:
284
x: Float32/float64/float16 input tensor of the shape
285
`[batch_size, num_channels, in_height, in_width]`.
286
f: Float32 FIR filter of the shape
287
`[filter_height, filter_width]` (non-separable),
288
`[filter_taps]` (separable), or
289
`None` (identity).
290
padding: Padding with respect to the output. Can be a single number or a
291
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
292
(default: 0).
293
flip_filter: False = convolution, True = correlation (default: False).
294
gain: Overall scaling factor for signal magnitude (default: 1).
295
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
296
297
Returns:
298
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
299
"""
300
padx0, padx1, pady0, pady1 = _parse_padding(padding)
301
fw, fh = _get_filter_size(f)
302
p = [
303
padx0 + fw // 2,
304
padx1 + (fw - 1) // 2,
305
pady0 + fh // 2,
306
pady1 + (fh - 1) // 2,
307
]
308
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
309
310
#----------------------------------------------------------------------------
311
312
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
313
r"""Upsample a batch of 2D images using the given 2D FIR filter.
314
315
By default, the result is padded so that its shape is a multiple of the input.
316
User-specified padding is applied on top of that, with negative values
317
indicating cropping. Pixels outside the image are assumed to be zero.
318
319
Args:
320
x: Float32/float64/float16 input tensor of the shape
321
`[batch_size, num_channels, in_height, in_width]`.
322
f: Float32 FIR filter of the shape
323
`[filter_height, filter_width]` (non-separable),
324
`[filter_taps]` (separable), or
325
`None` (identity).
326
up: Integer upsampling factor. Can be a single int or a list/tuple
327
`[x, y]` (default: 1).
328
padding: Padding with respect to the output. Can be a single number or a
329
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
330
(default: 0).
331
flip_filter: False = convolution, True = correlation (default: False).
332
gain: Overall scaling factor for signal magnitude (default: 1).
333
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
334
335
Returns:
336
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
337
"""
338
upx, upy = _parse_scaling(up)
339
padx0, padx1, pady0, pady1 = _parse_padding(padding)
340
fw, fh = _get_filter_size(f)
341
p = [
342
padx0 + (fw + upx - 1) // 2,
343
padx1 + (fw - upx) // 2,
344
pady0 + (fh + upy - 1) // 2,
345
pady1 + (fh - upy) // 2,
346
]
347
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
348
349
#----------------------------------------------------------------------------
350
351
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
352
r"""Downsample a batch of 2D images using the given 2D FIR filter.
353
354
By default, the result is padded so that its shape is a fraction of the input.
355
User-specified padding is applied on top of that, with negative values
356
indicating cropping. Pixels outside the image are assumed to be zero.
357
358
Args:
359
x: Float32/float64/float16 input tensor of the shape
360
`[batch_size, num_channels, in_height, in_width]`.
361
f: Float32 FIR filter of the shape
362
`[filter_height, filter_width]` (non-separable),
363
`[filter_taps]` (separable), or
364
`None` (identity).
365
down: Integer downsampling factor. Can be a single int or a list/tuple
366
`[x, y]` (default: 1).
367
padding: Padding with respect to the input. Can be a single number or a
368
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
369
(default: 0).
370
flip_filter: False = convolution, True = correlation (default: False).
371
gain: Overall scaling factor for signal magnitude (default: 1).
372
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
373
374
Returns:
375
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
376
"""
377
downx, downy = _parse_scaling(down)
378
padx0, padx1, pady0, pady1 = _parse_padding(padding)
379
fw, fh = _get_filter_size(f)
380
p = [
381
padx0 + (fw - downx + 1) // 2,
382
padx1 + (fw - downx) // 2,
383
pady0 + (fh - downy + 1) // 2,
384
pady1 + (fh - downy) // 2,
385
]
386
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
387
388
#----------------------------------------------------------------------------
389
390