CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/nms_rotated/src/poly_nms_cuda.cu
Views: 475
1
#include <ATen/ATen.h>
2
#include <ATen/cuda/CUDAContext.h>
3
4
#include <THC/THC.h>
5
#include <THC/THCDeviceUtils.cuh>
6
7
#include <vector>
8
#include <iostream>
9
10
#define CUDA_CHECK(condition) \
11
/* Code block avoids redefinition of cudaError_t error */ \
12
do { \
13
cudaError_t error = condition; \
14
if (error != cudaSuccess) { \
15
std::cout << cudaGetErrorString(error) << std::endl; \
16
} \
17
} while (0)
18
19
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
20
int const threadsPerBlock = sizeof(unsigned long long) * 8;
21
22
23
#define maxn 10
24
const double eps=1E-8;
25
26
__device__ inline int sig(float d){
27
return(d>eps)-(d<-eps);
28
}
29
30
__device__ inline int point_eq(const float2 a, const float2 b) {
31
return sig(a.x - b.x) == 0 && sig(a.y - b.y)==0;
32
}
33
34
__device__ inline void point_swap(float2 *a, float2 *b) {
35
float2 temp = *a;
36
*a = *b;
37
*b = temp;
38
}
39
40
__device__ inline void point_reverse(float2 *first, float2* last)
41
{
42
while ((first!=last)&&(first!=--last)) {
43
point_swap (first,last);
44
++first;
45
}
46
}
47
48
__device__ inline float cross(float2 o,float2 a,float2 b){ //叉积
49
return(a.x-o.x)*(b.y-o.y)-(b.x-o.x)*(a.y-o.y);
50
}
51
__device__ inline float area(float2* ps,int n){
52
ps[n]=ps[0];
53
float res=0;
54
for(int i=0;i<n;i++){
55
res+=ps[i].x*ps[i+1].y-ps[i].y*ps[i+1].x;
56
}
57
return res/2.0;
58
}
59
__device__ inline int lineCross(float2 a,float2 b,float2 c,float2 d,float2&p){
60
float s1,s2;
61
s1=cross(a,b,c);
62
s2=cross(a,b,d);
63
if(sig(s1)==0&&sig(s2)==0) return 2;
64
if(sig(s2-s1)==0) return 0;
65
p.x=(c.x*s2-d.x*s1)/(s2-s1);
66
p.y=(c.y*s2-d.y*s1)/(s2-s1);
67
return 1;
68
}
69
70
__device__ inline void polygon_cut(float2*p,int&n,float2 a,float2 b, float2* pp){
71
72
int m=0;p[n]=p[0];
73
for(int i=0;i<n;i++){
74
if(sig(cross(a,b,p[i]))>0) pp[m++]=p[i];
75
if(sig(cross(a,b,p[i]))!=sig(cross(a,b,p[i+1])))
76
lineCross(a,b,p[i],p[i+1],pp[m++]);
77
}
78
n=0;
79
for(int i=0;i<m;i++)
80
if(!i||!(point_eq(pp[i], pp[i-1])))
81
p[n++]=pp[i];
82
// while(n>1&&p[n-1]==p[0])n--;
83
while(n>1&&point_eq(p[n-1], p[0]))n--;
84
}
85
86
//---------------华丽的分隔线-----------------//
87
//返回三角形oab和三角形ocd的有向交面积,o是原点//
88
__device__ inline float intersectArea(float2 a,float2 b,float2 c,float2 d){
89
float2 o = make_float2(0,0);
90
int s1=sig(cross(o,a,b));
91
int s2=sig(cross(o,c,d));
92
if(s1==0||s2==0)return 0.0;//退化,面积为0
93
// if(s1==-1) swap(a,b);
94
// if(s2==-1) swap(c,d);
95
if (s1 == -1) point_swap(&a, &b);
96
if (s2 == -1) point_swap(&c, &d);
97
float2 p[10]={o,a,b};
98
int n=3;
99
float2 pp[maxn];
100
polygon_cut(p,n,o,c,pp);
101
polygon_cut(p,n,c,d,pp);
102
polygon_cut(p,n,d,o,pp);
103
float res=fabs(area(p,n));
104
if(s1*s2==-1) res=-res;return res;
105
}
106
//求两多边形的交面积
107
__device__ inline float intersectArea(float2*ps1,int n1,float2*ps2,int n2){
108
if(area(ps1,n1)<0) point_reverse(ps1,ps1+n1);
109
if(area(ps2,n2)<0) point_reverse(ps2,ps2+n2);
110
ps1[n1]=ps1[0];
111
ps2[n2]=ps2[0];
112
float res=0;
113
for(int i=0;i<n1;i++){
114
for(int j=0;j<n2;j++){
115
res+=intersectArea(ps1[i],ps1[i+1],ps2[j],ps2[j+1]);
116
}
117
}
118
return res;//assumeresispositive!
119
}
120
121
// TODO: optimal if by first calculate the iou between two hbbs
122
__device__ inline float devPolyIoU(float const * const p, float const * const q) {
123
float2 ps1[maxn], ps2[maxn];
124
int n1 = 4;
125
int n2 = 4;
126
for (int i = 0; i < 4; i++) {
127
ps1[i].x = p[i * 2];
128
ps1[i].y = p[i * 2 + 1];
129
130
ps2[i].x = q[i * 2];
131
ps2[i].y = q[i * 2 + 1];
132
}
133
float inter_area = intersectArea(ps1, n1, ps2, n2);
134
float union_area = fabs(area(ps1, n1)) + fabs(area(ps2, n2)) - inter_area;
135
float iou = 0;
136
if (union_area == 0) {
137
iou = (inter_area + 1) / (union_area + 1);
138
} else {
139
iou = inter_area / union_area;
140
}
141
return iou;
142
}
143
144
__global__ void poly_nms_kernel(const int n_polys, const float nms_overlap_thresh,
145
const float *dev_polys, unsigned long long *dev_mask) {
146
const int row_start = blockIdx.y;
147
const int col_start = blockIdx.x;
148
149
const int row_size =
150
min(n_polys - row_start * threadsPerBlock, threadsPerBlock);
151
const int cols_size =
152
min(n_polys - col_start * threadsPerBlock, threadsPerBlock);
153
154
__shared__ float block_polys[threadsPerBlock * 9];
155
if (threadIdx.x < cols_size) {
156
block_polys[threadIdx.x * 9 + 0] =
157
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
158
block_polys[threadIdx.x * 9 + 1] =
159
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
160
block_polys[threadIdx.x * 9 + 2] =
161
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
162
block_polys[threadIdx.x * 9 + 3] =
163
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
164
block_polys[threadIdx.x * 9 + 4] =
165
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
166
block_polys[threadIdx.x * 9 + 5] =
167
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
168
block_polys[threadIdx.x * 9 + 6] =
169
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
170
block_polys[threadIdx.x * 9 + 7] =
171
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
172
block_polys[threadIdx.x * 9 + 8] =
173
dev_polys[(threadsPerBlock * col_start + threadIdx.x) * 9 + 8];
174
}
175
__syncthreads();
176
177
if (threadIdx.x < row_size) {
178
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
179
const float *cur_box = dev_polys + cur_box_idx * 9;
180
int i = 0;
181
unsigned long long t = 0;
182
int start = 0;
183
if (row_start == col_start) {
184
start = threadIdx.x + 1;
185
}
186
for (i = start; i < cols_size; i++) {
187
if (devPolyIoU(cur_box, block_polys + i * 9) > nms_overlap_thresh) {
188
t |= 1ULL << i;
189
}
190
}
191
const int col_blocks = THCCeilDiv(n_polys, threadsPerBlock);
192
dev_mask[cur_box_idx * col_blocks + col_start] = t;
193
}
194
}
195
196
// boxes is a N x 9 tensor
197
at::Tensor poly_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
198
199
at::DeviceGuard guard(boxes.device());
200
201
using scalar_t = float;
202
AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor");
203
auto scores = boxes.select(1, 8);
204
auto order_t = std::get<1>(scores.sort(0, /*descending=*/true));
205
auto boxes_sorted = boxes.index_select(0, order_t);
206
207
int boxes_num = boxes.size(0);
208
209
const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
210
211
scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();
212
213
THCState *state = at::globalContext().lazyInitCUDA();
214
215
unsigned long long* mask_dev = NULL;
216
217
mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
218
219
dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
220
THCCeilDiv(boxes_num, threadsPerBlock));
221
dim3 threads(threadsPerBlock);
222
poly_nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
223
nms_overlap_thresh,
224
boxes_dev,
225
mask_dev);
226
227
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
228
THCudaCheck(cudaMemcpyAsync(
229
&mask_host[0],
230
mask_dev,
231
sizeof(unsigned long long) * boxes_num * col_blocks,
232
cudaMemcpyDeviceToHost,
233
at::cuda::getCurrentCUDAStream()
234
));
235
236
std::vector<unsigned long long> remv(col_blocks);
237
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
238
239
at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
240
int64_t* keep_out = keep.data_ptr<int64_t>();
241
242
int num_to_keep = 0;
243
for (int i = 0; i < boxes_num; i++) {
244
int nblock = i / threadsPerBlock;
245
int inblock = i % threadsPerBlock;
246
247
if (!(remv[nblock] & (1ULL << inblock))) {
248
keep_out[num_to_keep++] = i;
249
unsigned long long *p = &mask_host[0] + i * col_blocks;
250
for (int j = nblock; j < col_blocks; j++) {
251
remv[j] |= p[j];
252
}
253
}
254
}
255
256
THCudaFree(state, mask_dev);
257
258
return order_t.index({
259
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
260
order_t.device(), keep.scalar_type())});
261
}
262
263
264