Path: blob/master/src/utils/style_ops/upfirdn2d.cu
809 views
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1//2// NVIDIA CORPORATION and its licensors retain all intellectual property3// and proprietary rights in and to this software, related documentation4// and any modifications thereto. Any use, reproduction, disclosure or5// distribution of this software and related documentation without an express6// license agreement from NVIDIA CORPORATION is strictly prohibited.78#include <c10/util/Half.h>9#include "upfirdn2d.h"1011//------------------------------------------------------------------------12// Helpers.1314template <class T> struct InternalType;15template <> struct InternalType<double> { typedef double scalar_t; };16template <> struct InternalType<float> { typedef float scalar_t; };17template <> struct InternalType<c10::Half> { typedef float scalar_t; };1819static __device__ __forceinline__ int floor_div(int a, int b)20{21int t = 1 - a / b;22return (a + t * b) / b - t;23}2425//------------------------------------------------------------------------26// Generic CUDA implementation for large filters.2728template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)29{30typedef typename InternalType<T>::scalar_t scalar_t;3132// Calculate thread index.33int minorBase = blockIdx.x * blockDim.x + threadIdx.x;34int outY = minorBase / p.launchMinor;35minorBase -= outY * p.launchMinor;36int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;37int majorBase = blockIdx.z * p.loopMajor;38if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)39return;4041// Setup Y receptive field.42int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;43int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);44int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;45int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;46if (p.flip)47filterY = p.filterSize.y - 1 - filterY;4849// Loop over major, minor, and X.50for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)51for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)52{53int nc = major * p.sizeMinor + minor;54int n = nc / p.inSize.z;55int c = nc - n * p.inSize.z;56for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)57{58// Setup X receptive field.59int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;60int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);61int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;62int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;63if (p.flip)64filterX = p.filterSize.x - 1 - filterX;6566// Initialize pointers.67const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];68const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];69int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;70int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;7172// Inner loop.73scalar_t v = 0;74for (int y = 0; y < h; y++)75{76for (int x = 0; x < w; x++)77{78v += (scalar_t)(*xp) * (scalar_t)(*fp);79xp += p.inStride.x;80fp += filterStepX;81}82xp += p.inStride.y - w * p.inStride.x;83fp += filterStepY - w * filterStepX;84}8586// Store result.87v *= p.gain;88((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;89}90}91}9293//------------------------------------------------------------------------94// Specialized CUDA implementation for small filters.9596template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>97static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)98{99typedef typename InternalType<T>::scalar_t scalar_t;100const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;101const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;102__shared__ volatile scalar_t sf[filterH][filterW];103__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];104105// Calculate tile index.106int minorBase = blockIdx.x;107int tileOutY = minorBase / p.launchMinor;108minorBase -= tileOutY * p.launchMinor;109minorBase *= loopMinor;110tileOutY *= tileOutH;111int tileOutXBase = blockIdx.y * p.loopX * tileOutW;112int majorBase = blockIdx.z * p.loopMajor;113if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)114return;115116// Load filter (flipped).117for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)118{119int fy = tapIdx / filterW;120int fx = tapIdx - fy * filterW;121scalar_t v = 0;122if (fx < p.filterSize.x & fy < p.filterSize.y)123{124int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;125int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;126v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];127}128sf[fy][fx] = v;129}130131// Loop over major and X.132for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)133{134int baseNC = major * p.sizeMinor + minorBase;135int n = baseNC / p.inSize.z;136int baseC = baseNC - n * p.inSize.z;137for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)138{139// Load input pixels.140int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;141int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;142int tileInX = floor_div(tileMidX, upx);143int tileInY = floor_div(tileMidY, upy);144__syncthreads();145for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)146{147int relC = inIdx;148int relInX = relC / loopMinor;149int relInY = relInX / tileInW;150relC -= relInX * loopMinor;151relInX -= relInY * tileInW;152int c = baseC + relC;153int inX = tileInX + relInX;154int inY = tileInY + relInY;155scalar_t v = 0;156if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)157v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];158sx[relInY][relInX][relC] = v;159}160161// Loop over output pixels.162__syncthreads();163for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)164{165int relC = outIdx;166int relOutX = relC / loopMinor;167int relOutY = relOutX / tileOutW;168relC -= relOutX * loopMinor;169relOutX -= relOutY * tileOutW;170int c = baseC + relC;171int outX = tileOutX + relOutX;172int outY = tileOutY + relOutY;173174// Setup receptive field.175int midX = tileMidX + relOutX * downx;176int midY = tileMidY + relOutY * downy;177int inX = floor_div(midX, upx);178int inY = floor_div(midY, upy);179int relInX = inX - tileInX;180int relInY = inY - tileInY;181int filterX = (inX + 1) * upx - midX - 1; // flipped182int filterY = (inY + 1) * upy - midY - 1; // flipped183184// Inner loop.185if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)186{187scalar_t v = 0;188#pragma unroll189for (int y = 0; y < filterH / upy; y++)190#pragma unroll191for (int x = 0; x < filterW / upx; x++)192v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];193v *= p.gain;194((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;195}196}197}198}199}200201//------------------------------------------------------------------------202// CUDA kernel selection.203204template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)205{206int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;207upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous208if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last209210// No up/downsampling.211if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)212{213// contiguous214if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};215if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};216if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};217if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};218if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};219if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};220if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};221if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};222if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};223if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};224if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};225if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};226if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};227// channels_last228if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};229if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};230if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};231if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};232if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};233if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};234if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};235if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};236if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};237if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};238if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};239if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};240if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};241}242243// 2x upsampling.244if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)245{246// contiguous247if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};248if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};249if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};250if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};251if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};252if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};253// channels_last254if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};255if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};256if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};257if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};258if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};259if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};260}261if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)262{263// contiguous264if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};265if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};266if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};267// channels_last268if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};269if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};270if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};271}272if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)273{274// contiguous275if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};276if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};277if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};278// channels_last279if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};280if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};281if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};282}283284// 2x downsampling.285if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)286{287// contiguous288if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};289if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};290if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};291if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};292if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};293if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};294// channels_last295if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};296if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};297if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};298if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};299if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};300if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};301}302if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)303{304// contiguous305if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};306if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};307if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};308// channels_last309if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};310if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};311if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};312}313if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)314{315// contiguous316if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};317if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};318if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};319// channels_last320if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};321if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};322if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};323}324325// 4x upsampling.326if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)327{328// contiguous329if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};330if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};331// channels_last332if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};333if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};334}335if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)336{337// contiguous338if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};339if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};340// channels_last341if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};342if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};343}344if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)345{346// contiguous347if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};348if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};349// channels_last350if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};351if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};352}353354// 4x downsampling (inefficient).355if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)356{357// contiguous358if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};359if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};360// channels_last361if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};362if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};363}364if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)365{366// contiguous367if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};368if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};369// channels_last370if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};371if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};372}373return spec;374}375376//------------------------------------------------------------------------377// Template specializations.378379template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);380template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);381template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);382383//------------------------------------------------------------------------384385386