Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/utils/nms_rotated/src/box_iou_rotated_utils.h
Views: 475
// Mortified from1// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/box_iou_rotated2// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved3#pragma once45#include <cassert>6#include <cmath>78#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 19// Designates functions callable from the host (CPU) and the device (GPU)10#define HOST_DEVICE __host__ __device__11#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__12#else13#include <algorithm>14#define HOST_DEVICE15#define HOST_DEVICE_INLINE HOST_DEVICE inline16#endif171819template <typename T>20struct RotatedBox {21T x_ctr, y_ctr, w, h, a;22};2324template <typename T>25struct Point {26T x, y;27HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}28HOST_DEVICE_INLINE Point operator+(const Point& p) const {29return Point(x + p.x, y + p.y);30}31HOST_DEVICE_INLINE Point& operator+=(const Point& p) {32x += p.x;33y += p.y;34return *this;35}36HOST_DEVICE_INLINE Point operator-(const Point& p) const {37return Point(x - p.x, y - p.y);38}39HOST_DEVICE_INLINE Point operator*(const T coeff) const {40return Point(x * coeff, y * coeff);41}42};4344template <typename T>45HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {46return A.x * B.x + A.y * B.y;47}4849// R: result type. can be different from input type50template <typename T, typename R = T>51HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {52return static_cast<R>(A.x) * static_cast<R>(B.y) -53static_cast<R>(B.x) * static_cast<R>(A.y);54}5556template <typename T>57HOST_DEVICE_INLINE void get_rotated_vertices(58const RotatedBox<T>& box,59Point<T> (&pts)[4]) {60// M_PI / 180. == 0.0174532925161//double theta = box.a * 0.01745329251; ++++++++++++++++++++++++++++++++++++++++++++++++++++++++62double theta = box.a;63T cosTheta2 = (T)cos(theta) * 0.5f;64T sinTheta2 = (T)sin(theta) * 0.5f;6566// y: top --> down; x: left --> right67pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;68pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;69pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;70pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;71pts[2].x = 2 * box.x_ctr - pts[0].x;72pts[2].y = 2 * box.y_ctr - pts[0].y;73pts[3].x = 2 * box.x_ctr - pts[1].x;74pts[3].y = 2 * box.y_ctr - pts[1].y;75}7677template <typename T>78HOST_DEVICE_INLINE int get_intersection_points(79const Point<T> (&pts1)[4],80const Point<T> (&pts2)[4],81Point<T> (&intersections)[24]) {82// Line vector83// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]84Point<T> vec1[4], vec2[4];85for (int i = 0; i < 4; i++) {86vec1[i] = pts1[(i + 1) % 4] - pts1[i];87vec2[i] = pts2[(i + 1) % 4] - pts2[i];88}8990// Line test - test all line combos for intersection91int num = 0; // number of intersections92for (int i = 0; i < 4; i++) {93for (int j = 0; j < 4; j++) {94// Solve for 2x2 Ax=b95T det = cross_2d<T>(vec2[j], vec1[i]);9697// This takes care of parallel lines98if (fabs(det) <= 1e-14) {99continue;100}101102auto vec12 = pts2[j] - pts1[i];103104T t1 = cross_2d<T>(vec2[j], vec12) / det;105T t2 = cross_2d<T>(vec1[i], vec12) / det;106107if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {108intersections[num++] = pts1[i] + vec1[i] * t1;109}110}111}112113// Check for vertices of rect1 inside rect2114{115const auto& AB = vec2[0];116const auto& DA = vec2[3];117auto ABdotAB = dot_2d<T>(AB, AB);118auto ADdotAD = dot_2d<T>(DA, DA);119for (int i = 0; i < 4; i++) {120// assume ABCD is the rectangle, and P is the point to be judged121// P is inside ABCD iff. P's projection on AB lies within AB122// and P's projection on AD lies within AD123124auto AP = pts1[i] - pts2[0];125126auto APdotAB = dot_2d<T>(AP, AB);127auto APdotAD = -dot_2d<T>(AP, DA);128129if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&130(APdotAD <= ADdotAD)) {131intersections[num++] = pts1[i];132}133}134}135136// Reverse the check - check for vertices of rect2 inside rect1137{138const auto& AB = vec1[0];139const auto& DA = vec1[3];140auto ABdotAB = dot_2d<T>(AB, AB);141auto ADdotAD = dot_2d<T>(DA, DA);142for (int i = 0; i < 4; i++) {143auto AP = pts2[i] - pts1[0];144145auto APdotAB = dot_2d<T>(AP, AB);146auto APdotAD = -dot_2d<T>(AP, DA);147148if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&149(APdotAD <= ADdotAD)) {150intersections[num++] = pts2[i];151}152}153}154155return num;156}157158template <typename T>159HOST_DEVICE_INLINE int convex_hull_graham(160const Point<T> (&p)[24],161const int& num_in,162Point<T> (&q)[24],163bool shift_to_zero = false) {164assert(num_in >= 2);165166// Step 1:167// Find point with minimum y168// if more than 1 points have the same minimum y,169// pick the one with the minimum x.170int t = 0;171for (int i = 1; i < num_in; i++) {172if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {173t = i;174}175}176auto& start = p[t]; // starting point177178// Step 2:179// Subtract starting point from every points (for sorting in the next step)180for (int i = 0; i < num_in; i++) {181q[i] = p[i] - start;182}183184// Swap the starting point to position 0185auto tmp = q[0];186q[0] = q[t];187q[t] = tmp;188189// Step 3:190// Sort point 1 ~ num_in according to their relative cross-product values191// (essentially sorting according to angles)192// If the angles are the same, sort according to their distance to origin193T dist[24];194#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1195// compute distance to origin before sort, and sort them together with the196// points197for (int i = 0; i < num_in; i++) {198dist[i] = dot_2d<T>(q[i], q[i]);199}200201// CUDA version202// In the future, we can potentially use thrust203// for sorting here to improve speed (though not guaranteed)204for (int i = 1; i < num_in - 1; i++) {205for (int j = i + 1; j < num_in; j++) {206T crossProduct = cross_2d<T>(q[i], q[j]);207if ((crossProduct < -1e-6) ||208(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {209auto q_tmp = q[i];210q[i] = q[j];211q[j] = q_tmp;212auto dist_tmp = dist[i];213dist[i] = dist[j];214dist[j] = dist_tmp;215}216}217}218#else219// CPU version220std::sort(221q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {222T temp = cross_2d<T>(A, B);223if (fabs(temp) < 1e-6) {224return dot_2d<T>(A, A) < dot_2d<T>(B, B);225} else {226return temp > 0;227}228});229// compute distance to origin after sort, since the points are now different.230for (int i = 0; i < num_in; i++) {231dist[i] = dot_2d<T>(q[i], q[i]);232}233#endif234235// Step 4:236// Make sure there are at least 2 points (that don't overlap with each other)237// in the stack238int k; // index of the non-overlapped second point239for (k = 1; k < num_in; k++) {240if (dist[k] > 1e-8) {241break;242}243}244if (k == num_in) {245// We reach the end, which means the convex hull is just one point246q[0] = p[t];247return 1;248}249q[1] = q[k];250int m = 2; // 2 points in the stack251// Step 5:252// Finally we can start the scanning process.253// When a non-convex relationship between the 3 points is found254// (either concave shape or duplicated points),255// we pop the previous point from the stack256// until the 3-point relationship is convex again, or257// until the stack only contains two points258for (int i = k + 1; i < num_in; i++) {259while (m > 1) {260auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];261// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -262// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we263// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means264// round to nearest floating point).265if (q1.x * q2.y >= q2.x * q1.y)266m--;267else268break;269}270// Using double also helps, but float can solve the issue for now.271// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])272// >= 0) {273// m--;274// }275q[m++] = q[i];276}277278// Step 6 (Optional):279// In general sense we need the original coordinates, so we280// need to shift the points back (reverting Step 2)281// But if we're only interested in getting the area/perimeter of the shape282// We can simply return.283if (!shift_to_zero) {284for (int i = 0; i < m; i++) {285q[i] += start;286}287}288289return m;290}291292template <typename T>293HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {294if (m <= 2) {295return 0;296}297298T area = 0;299for (int i = 1; i < m - 1; i++) {300area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));301}302303return area / 2.0;304}305306template <typename T>307HOST_DEVICE_INLINE T rotated_boxes_intersection(308const RotatedBox<T>& box1,309const RotatedBox<T>& box2) {310// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned311// from rotated_rect_intersection_pts312Point<T> intersectPts[24], orderedPts[24];313314Point<T> pts1[4];315Point<T> pts2[4];316get_rotated_vertices<T>(box1, pts1);317get_rotated_vertices<T>(box2, pts2);318319int num = get_intersection_points<T>(pts1, pts2, intersectPts);320321if (num <= 2) {322return 0.0;323}324325// Convex Hull to order the intersection points in clockwise order and find326// the contour area.327int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);328return polygon_area<T>(orderedPts, num_convex);329}330331332template <typename T>333HOST_DEVICE_INLINE T334single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {335// shift center to the middle point to achieve higher precision in result336RotatedBox<T> box1, box2;337auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;338auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;339box1.x_ctr = box1_raw[0] - center_shift_x;340box1.y_ctr = box1_raw[1] - center_shift_y;341box1.w = box1_raw[2];342box1.h = box1_raw[3];343box1.a = box1_raw[4];344box2.x_ctr = box2_raw[0] - center_shift_x;345box2.y_ctr = box2_raw[1] - center_shift_y;346box2.w = box2_raw[2];347box2.h = box2_raw[3];348box2.a = box2_raw[4];349350T area1 = box1.w * box1.h;351T area2 = box2.w * box2.h;352if (area1 < 1e-14 || area2 < 1e-14) {353return 0.f;354}355356T intersection = rotated_boxes_intersection<T>(box1, box2);357T iou = intersection / (area1 + area2 - intersection);358return iou;359}360361362