Path: blob/master/src/utils/style_ops/conv2d_resample.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""2D convolution with optional up/downsampling."""910import torch1112from .. import style_misc as misc13from . import conv2d_gradfix14from . import upfirdn2d15from .upfirdn2d import _parse_padding16from .upfirdn2d import _get_filter_size1718#----------------------------------------------------------------------------1920def _get_weight_shape(w):21with misc.suppress_tracer_warnings(): # this value will be treated as a constant22shape = [int(sz) for sz in w.shape]23misc.assert_shape(w, shape)24return shape2526#----------------------------------------------------------------------------2728def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):29"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.30"""31_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)3233# Flip weight if requested.34# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).35if not flip_weight and (kw > 1 or kh > 1):36w = w.flip([2, 3])3738# Execute using conv2d_gradfix.39op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d40return op(x, w, stride=stride, padding=padding, groups=groups)4142#----------------------------------------------------------------------------4344def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):45r"""2D convolution with optional up/downsampling.4647Padding is performed only once at the beginning, not between the operations.4849Args:50x: Input tensor of shape51`[batch_size, in_channels, in_height, in_width]`.52w: Weight tensor of shape53`[out_channels, in_channels//groups, kernel_height, kernel_width]`.54f: Low-pass filter for up/downsampling. Must be prepared beforehand by55calling upfirdn2d.setup_filter(). None = identity (default).56up: Integer upsampling factor (default: 1).57down: Integer downsampling factor (default: 1).58padding: Padding with respect to the upsampled image. Can be a single number59or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`60(default: 0).61groups: Split input channels into N groups (default: 1).62flip_weight: False = convolution, True = correlation (default: True).63flip_filter: False = convolution, True = correlation (default: False).6465Returns:66Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.67"""68# Validate arguments.69assert isinstance(x, torch.Tensor) and (x.ndim == 4)70assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)71assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)72assert isinstance(up, int) and (up >= 1)73assert isinstance(down, int) and (down >= 1)74assert isinstance(groups, int) and (groups >= 1)75out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)76fw, fh = _get_filter_size(f)77px0, px1, py0, py1 = _parse_padding(padding)7879# Adjust padding to account for up/downsampling.80if up > 1:81px0 += (fw + up - 1) // 282px1 += (fw - up) // 283py0 += (fh + up - 1) // 284py1 += (fh - up) // 285if down > 1:86px0 += (fw - down + 1) // 287px1 += (fw - down) // 288py0 += (fh - down + 1) // 289py1 += (fh - down) // 29091# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.92if kw == 1 and kh == 1 and (down > 1 and up == 1):93x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)94x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)95return x9697# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.98if kw == 1 and kh == 1 and (up > 1 and down == 1):99x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)100x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)101return x102103# Fast path: downsampling only => use strided convolution.104if down > 1 and up == 1:105x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)106x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)107return x108109# Fast path: upsampling with optional downsampling => use transpose strided convolution.110if up > 1:111if groups == 1:112w = w.transpose(0, 1)113else:114w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)115w = w.transpose(1, 2)116w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)117px0 -= kw - 1118px1 -= kw - up119py0 -= kh - 1120py1 -= kh - up121pxt = max(min(-px0, -px1), 0)122pyt = max(min(-py0, -py1), 0)123x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))124x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)125if down > 1:126x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)127return x128129# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.130if up == 1 and down == 1:131if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:132return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)133134# Fallback: Generic reference implementation.135x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)136x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)137if down > 1:138x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)139return x140141#----------------------------------------------------------------------------142143144