CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/nms_rotated/src/box_iou_rotated_utils.h
Views: 475
1
// Mortified from
2
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/box_iou_rotated
3
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4
#pragma once
5
6
#include <cassert>
7
#include <cmath>
8
9
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
10
// Designates functions callable from the host (CPU) and the device (GPU)
11
#define HOST_DEVICE __host__ __device__
12
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
13
#else
14
#include <algorithm>
15
#define HOST_DEVICE
16
#define HOST_DEVICE_INLINE HOST_DEVICE inline
17
#endif
18
19
20
template <typename T>
21
struct RotatedBox {
22
T x_ctr, y_ctr, w, h, a;
23
};
24
25
template <typename T>
26
struct Point {
27
T x, y;
28
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
29
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
30
return Point(x + p.x, y + p.y);
31
}
32
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
33
x += p.x;
34
y += p.y;
35
return *this;
36
}
37
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
38
return Point(x - p.x, y - p.y);
39
}
40
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
41
return Point(x * coeff, y * coeff);
42
}
43
};
44
45
template <typename T>
46
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
47
return A.x * B.x + A.y * B.y;
48
}
49
50
// R: result type. can be different from input type
51
template <typename T, typename R = T>
52
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
53
return static_cast<R>(A.x) * static_cast<R>(B.y) -
54
static_cast<R>(B.x) * static_cast<R>(A.y);
55
}
56
57
template <typename T>
58
HOST_DEVICE_INLINE void get_rotated_vertices(
59
const RotatedBox<T>& box,
60
Point<T> (&pts)[4]) {
61
// M_PI / 180. == 0.01745329251
62
//double theta = box.a * 0.01745329251; ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
63
double theta = box.a;
64
T cosTheta2 = (T)cos(theta) * 0.5f;
65
T sinTheta2 = (T)sin(theta) * 0.5f;
66
67
// y: top --> down; x: left --> right
68
pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
69
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
70
pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
71
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
72
pts[2].x = 2 * box.x_ctr - pts[0].x;
73
pts[2].y = 2 * box.y_ctr - pts[0].y;
74
pts[3].x = 2 * box.x_ctr - pts[1].x;
75
pts[3].y = 2 * box.y_ctr - pts[1].y;
76
}
77
78
template <typename T>
79
HOST_DEVICE_INLINE int get_intersection_points(
80
const Point<T> (&pts1)[4],
81
const Point<T> (&pts2)[4],
82
Point<T> (&intersections)[24]) {
83
// Line vector
84
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
85
Point<T> vec1[4], vec2[4];
86
for (int i = 0; i < 4; i++) {
87
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
88
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
89
}
90
91
// Line test - test all line combos for intersection
92
int num = 0; // number of intersections
93
for (int i = 0; i < 4; i++) {
94
for (int j = 0; j < 4; j++) {
95
// Solve for 2x2 Ax=b
96
T det = cross_2d<T>(vec2[j], vec1[i]);
97
98
// This takes care of parallel lines
99
if (fabs(det) <= 1e-14) {
100
continue;
101
}
102
103
auto vec12 = pts2[j] - pts1[i];
104
105
T t1 = cross_2d<T>(vec2[j], vec12) / det;
106
T t2 = cross_2d<T>(vec1[i], vec12) / det;
107
108
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
109
intersections[num++] = pts1[i] + vec1[i] * t1;
110
}
111
}
112
}
113
114
// Check for vertices of rect1 inside rect2
115
{
116
const auto& AB = vec2[0];
117
const auto& DA = vec2[3];
118
auto ABdotAB = dot_2d<T>(AB, AB);
119
auto ADdotAD = dot_2d<T>(DA, DA);
120
for (int i = 0; i < 4; i++) {
121
// assume ABCD is the rectangle, and P is the point to be judged
122
// P is inside ABCD iff. P's projection on AB lies within AB
123
// and P's projection on AD lies within AD
124
125
auto AP = pts1[i] - pts2[0];
126
127
auto APdotAB = dot_2d<T>(AP, AB);
128
auto APdotAD = -dot_2d<T>(AP, DA);
129
130
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
131
(APdotAD <= ADdotAD)) {
132
intersections[num++] = pts1[i];
133
}
134
}
135
}
136
137
// Reverse the check - check for vertices of rect2 inside rect1
138
{
139
const auto& AB = vec1[0];
140
const auto& DA = vec1[3];
141
auto ABdotAB = dot_2d<T>(AB, AB);
142
auto ADdotAD = dot_2d<T>(DA, DA);
143
for (int i = 0; i < 4; i++) {
144
auto AP = pts2[i] - pts1[0];
145
146
auto APdotAB = dot_2d<T>(AP, AB);
147
auto APdotAD = -dot_2d<T>(AP, DA);
148
149
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
150
(APdotAD <= ADdotAD)) {
151
intersections[num++] = pts2[i];
152
}
153
}
154
}
155
156
return num;
157
}
158
159
template <typename T>
160
HOST_DEVICE_INLINE int convex_hull_graham(
161
const Point<T> (&p)[24],
162
const int& num_in,
163
Point<T> (&q)[24],
164
bool shift_to_zero = false) {
165
assert(num_in >= 2);
166
167
// Step 1:
168
// Find point with minimum y
169
// if more than 1 points have the same minimum y,
170
// pick the one with the minimum x.
171
int t = 0;
172
for (int i = 1; i < num_in; i++) {
173
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
174
t = i;
175
}
176
}
177
auto& start = p[t]; // starting point
178
179
// Step 2:
180
// Subtract starting point from every points (for sorting in the next step)
181
for (int i = 0; i < num_in; i++) {
182
q[i] = p[i] - start;
183
}
184
185
// Swap the starting point to position 0
186
auto tmp = q[0];
187
q[0] = q[t];
188
q[t] = tmp;
189
190
// Step 3:
191
// Sort point 1 ~ num_in according to their relative cross-product values
192
// (essentially sorting according to angles)
193
// If the angles are the same, sort according to their distance to origin
194
T dist[24];
195
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
196
// compute distance to origin before sort, and sort them together with the
197
// points
198
for (int i = 0; i < num_in; i++) {
199
dist[i] = dot_2d<T>(q[i], q[i]);
200
}
201
202
// CUDA version
203
// In the future, we can potentially use thrust
204
// for sorting here to improve speed (though not guaranteed)
205
for (int i = 1; i < num_in - 1; i++) {
206
for (int j = i + 1; j < num_in; j++) {
207
T crossProduct = cross_2d<T>(q[i], q[j]);
208
if ((crossProduct < -1e-6) ||
209
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
210
auto q_tmp = q[i];
211
q[i] = q[j];
212
q[j] = q_tmp;
213
auto dist_tmp = dist[i];
214
dist[i] = dist[j];
215
dist[j] = dist_tmp;
216
}
217
}
218
}
219
#else
220
// CPU version
221
std::sort(
222
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
223
T temp = cross_2d<T>(A, B);
224
if (fabs(temp) < 1e-6) {
225
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
226
} else {
227
return temp > 0;
228
}
229
});
230
// compute distance to origin after sort, since the points are now different.
231
for (int i = 0; i < num_in; i++) {
232
dist[i] = dot_2d<T>(q[i], q[i]);
233
}
234
#endif
235
236
// Step 4:
237
// Make sure there are at least 2 points (that don't overlap with each other)
238
// in the stack
239
int k; // index of the non-overlapped second point
240
for (k = 1; k < num_in; k++) {
241
if (dist[k] > 1e-8) {
242
break;
243
}
244
}
245
if (k == num_in) {
246
// We reach the end, which means the convex hull is just one point
247
q[0] = p[t];
248
return 1;
249
}
250
q[1] = q[k];
251
int m = 2; // 2 points in the stack
252
// Step 5:
253
// Finally we can start the scanning process.
254
// When a non-convex relationship between the 3 points is found
255
// (either concave shape or duplicated points),
256
// we pop the previous point from the stack
257
// until the 3-point relationship is convex again, or
258
// until the stack only contains two points
259
for (int i = k + 1; i < num_in; i++) {
260
while (m > 1) {
261
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
262
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
263
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
264
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
265
// round to nearest floating point).
266
if (q1.x * q2.y >= q2.x * q1.y)
267
m--;
268
else
269
break;
270
}
271
// Using double also helps, but float can solve the issue for now.
272
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
273
// >= 0) {
274
// m--;
275
// }
276
q[m++] = q[i];
277
}
278
279
// Step 6 (Optional):
280
// In general sense we need the original coordinates, so we
281
// need to shift the points back (reverting Step 2)
282
// But if we're only interested in getting the area/perimeter of the shape
283
// We can simply return.
284
if (!shift_to_zero) {
285
for (int i = 0; i < m; i++) {
286
q[i] += start;
287
}
288
}
289
290
return m;
291
}
292
293
template <typename T>
294
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
295
if (m <= 2) {
296
return 0;
297
}
298
299
T area = 0;
300
for (int i = 1; i < m - 1; i++) {
301
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
302
}
303
304
return area / 2.0;
305
}
306
307
template <typename T>
308
HOST_DEVICE_INLINE T rotated_boxes_intersection(
309
const RotatedBox<T>& box1,
310
const RotatedBox<T>& box2) {
311
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
312
// from rotated_rect_intersection_pts
313
Point<T> intersectPts[24], orderedPts[24];
314
315
Point<T> pts1[4];
316
Point<T> pts2[4];
317
get_rotated_vertices<T>(box1, pts1);
318
get_rotated_vertices<T>(box2, pts2);
319
320
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
321
322
if (num <= 2) {
323
return 0.0;
324
}
325
326
// Convex Hull to order the intersection points in clockwise order and find
327
// the contour area.
328
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
329
return polygon_area<T>(orderedPts, num_convex);
330
}
331
332
333
template <typename T>
334
HOST_DEVICE_INLINE T
335
single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
336
// shift center to the middle point to achieve higher precision in result
337
RotatedBox<T> box1, box2;
338
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
339
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
340
box1.x_ctr = box1_raw[0] - center_shift_x;
341
box1.y_ctr = box1_raw[1] - center_shift_y;
342
box1.w = box1_raw[2];
343
box1.h = box1_raw[3];
344
box1.a = box1_raw[4];
345
box2.x_ctr = box2_raw[0] - center_shift_x;
346
box2.y_ctr = box2_raw[1] - center_shift_y;
347
box2.w = box2_raw[2];
348
box2.h = box2_raw[3];
349
box2.a = box2_raw[4];
350
351
T area1 = box1.w * box1.h;
352
T area2 = box2.w * box2.h;
353
if (area1 < 1e-14 || area2 < 1e-14) {
354
return 0.f;
355
}
356
357
T intersection = rotated_boxes_intersection<T>(box1, box2);
358
T iou = intersection / (area1 + area2 - intersection);
359
return iou;
360
}
361
362