Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/style_ops/upfirdn2d.cpp
809 views
1
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
//
3
// NVIDIA CORPORATION and its licensors retain all intellectual property
4
// and proprietary rights in and to this software, related documentation
5
// and any modifications thereto. Any use, reproduction, disclosure or
6
// distribution of this software and related documentation without an express
7
// license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9
#include <torch/extension.h>
10
#include <ATen/cuda/CUDAContext.h>
11
#include <c10/cuda/CUDAGuard.h>
12
#include "upfirdn2d.h"
13
14
//------------------------------------------------------------------------
15
16
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
{
18
// Validate arguments.
19
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
TORCH_CHECK(x.numel() > 0, "x has zero size");
25
TORCH_CHECK(f.numel() > 0, "f has zero size");
26
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
33
// Create output tensor.
34
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
42
// Initialize CUDA kernel parameters.
43
upfirdn2d_kernel_params p;
44
p.x = x.data_ptr();
45
p.f = f.data_ptr<float>();
46
p.y = y.data_ptr();
47
p.up = make_int2(upx, upy);
48
p.down = make_int2(downx, downy);
49
p.pad0 = make_int2(padx0, pady0);
50
p.flip = (flip) ? 1 : 0;
51
p.gain = gain;
52
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
61
// Choose CUDA kernel.
62
upfirdn2d_kernel_spec spec;
63
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
{
65
spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
});
67
68
// Set looping options.
69
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
p.loopMinor = spec.loopMinor;
71
p.loopX = spec.loopX;
72
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
75
// Compute grid size.
76
dim3 blockSize, gridSize;
77
if (spec.tileOutW < 0) // large
78
{
79
blockSize = dim3(4, 32, 1);
80
gridSize = dim3(
81
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
p.launchMajor);
84
}
85
else // small
86
{
87
blockSize = dim3(256, 1, 1);
88
gridSize = dim3(
89
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
p.launchMajor);
92
}
93
94
// Launch CUDA kernel.
95
void* args[] = {&p};
96
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
return y;
98
}
99
100
//------------------------------------------------------------------------
101
102
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
{
104
m.def("upfirdn2d", &upfirdn2d);
105
}
106
107
//------------------------------------------------------------------------
108
109