Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/upfirdn2d.cu
809 views
1
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
//
3
// NVIDIA CORPORATION and its licensors retain all intellectual property
4
// and proprietary rights in and to this software, related documentation
5
// and any modifications thereto. Any use, reproduction, disclosure or
6
// distribution of this software and related documentation without an express
7
// license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
#include <c10/util/Half.h>
10
#include "upfirdn2d.h"
11
12
//------------------------------------------------------------------------
13
// Helpers.
14
15
template <class T> struct InternalType;
16
template <> struct InternalType<double> { typedef double scalar_t; };
17
template <> struct InternalType<float> { typedef float scalar_t; };
18
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
20
static __device__ __forceinline__ int floor_div(int a, int b)
21
{
22
int t = 1 - a / b;
23
return (a + t * b) / b - t;
24
}
25
26
//------------------------------------------------------------------------
27
// Generic CUDA implementation for large filters.
28
29
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
{
31
typedef typename InternalType<T>::scalar_t scalar_t;
32
33
// Calculate thread index.
34
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
int outY = minorBase / p.launchMinor;
36
minorBase -= outY * p.launchMinor;
37
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
int majorBase = blockIdx.z * p.loopMajor;
39
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
return;
41
42
// Setup Y receptive field.
43
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
if (p.flip)
48
filterY = p.filterSize.y - 1 - filterY;
49
50
// Loop over major, minor, and X.
51
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
{
54
int nc = major * p.sizeMinor + minor;
55
int n = nc / p.inSize.z;
56
int c = nc - n * p.inSize.z;
57
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
{
59
// Setup X receptive field.
60
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
if (p.flip)
65
filterX = p.filterSize.x - 1 - filterX;
66
67
// Initialize pointers.
68
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
73
// Inner loop.
74
scalar_t v = 0;
75
for (int y = 0; y < h; y++)
76
{
77
for (int x = 0; x < w; x++)
78
{
79
v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
xp += p.inStride.x;
81
fp += filterStepX;
82
}
83
xp += p.inStride.y - w * p.inStride.x;
84
fp += filterStepY - w * filterStepX;
85
}
86
87
// Store result.
88
v *= p.gain;
89
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
}
91
}
92
}
93
94
//------------------------------------------------------------------------
95
// Specialized CUDA implementation for small filters.
96
97
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
{
100
typedef typename InternalType<T>::scalar_t scalar_t;
101
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
__shared__ volatile scalar_t sf[filterH][filterW];
104
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
106
// Calculate tile index.
107
int minorBase = blockIdx.x;
108
int tileOutY = minorBase / p.launchMinor;
109
minorBase -= tileOutY * p.launchMinor;
110
minorBase *= loopMinor;
111
tileOutY *= tileOutH;
112
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
int majorBase = blockIdx.z * p.loopMajor;
114
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
return;
116
117
// Load filter (flipped).
118
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
{
120
int fy = tapIdx / filterW;
121
int fx = tapIdx - fy * filterW;
122
scalar_t v = 0;
123
if (fx < p.filterSize.x & fy < p.filterSize.y)
124
{
125
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
}
129
sf[fy][fx] = v;
130
}
131
132
// Loop over major and X.
133
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
{
135
int baseNC = major * p.sizeMinor + minorBase;
136
int n = baseNC / p.inSize.z;
137
int baseC = baseNC - n * p.inSize.z;
138
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
{
140
// Load input pixels.
141
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
int tileInX = floor_div(tileMidX, upx);
144
int tileInY = floor_div(tileMidY, upy);
145
__syncthreads();
146
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
{
148
int relC = inIdx;
149
int relInX = relC / loopMinor;
150
int relInY = relInX / tileInW;
151
relC -= relInX * loopMinor;
152
relInX -= relInY * tileInW;
153
int c = baseC + relC;
154
int inX = tileInX + relInX;
155
int inY = tileInY + relInY;
156
scalar_t v = 0;
157
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
sx[relInY][relInX][relC] = v;
160
}
161
162
// Loop over output pixels.
163
__syncthreads();
164
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
{
166
int relC = outIdx;
167
int relOutX = relC / loopMinor;
168
int relOutY = relOutX / tileOutW;
169
relC -= relOutX * loopMinor;
170
relOutX -= relOutY * tileOutW;
171
int c = baseC + relC;
172
int outX = tileOutX + relOutX;
173
int outY = tileOutY + relOutY;
174
175
// Setup receptive field.
176
int midX = tileMidX + relOutX * downx;
177
int midY = tileMidY + relOutY * downy;
178
int inX = floor_div(midX, upx);
179
int inY = floor_div(midY, upy);
180
int relInX = inX - tileInX;
181
int relInY = inY - tileInY;
182
int filterX = (inX + 1) * upx - midX - 1; // flipped
183
int filterY = (inY + 1) * upy - midY - 1; // flipped
184
185
// Inner loop.
186
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
{
188
scalar_t v = 0;
189
#pragma unroll
190
for (int y = 0; y < filterH / upy; y++)
191
#pragma unroll
192
for (int x = 0; x < filterW / upx; x++)
193
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
v *= p.gain;
195
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
}
197
}
198
}
199
}
200
}
201
202
//------------------------------------------------------------------------
203
// CUDA kernel selection.
204
205
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
{
207
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
211
// No up/downsampling.
212
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
{
214
// contiguous
215
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
// channels_last
229
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
}
243
244
// 2x upsampling.
245
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
{
247
// contiguous
248
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
// channels_last
255
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
}
262
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
{
264
// contiguous
265
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
// channels_last
269
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
}
273
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
{
275
// contiguous
276
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
// channels_last
280
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
}
284
285
// 2x downsampling.
286
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
{
288
// contiguous
289
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
// channels_last
296
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
}
303
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
{
305
// contiguous
306
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
// channels_last
310
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
}
314
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
{
316
// contiguous
317
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
// channels_last
321
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
}
325
326
// 4x upsampling.
327
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
{
329
// contiguous
330
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
// channels_last
333
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
}
336
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
{
338
// contiguous
339
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
// channels_last
342
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
}
345
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
{
347
// contiguous
348
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
// channels_last
351
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
}
354
355
// 4x downsampling (inefficient).
356
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
{
358
// contiguous
359
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
// channels_last
362
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
}
365
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
{
367
// contiguous
368
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
// channels_last
371
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
}
374
return spec;
375
}
376
377
//------------------------------------------------------------------------
378
// Template specializations.
379
380
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
384
//------------------------------------------------------------------------
385
386