CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/nms_rotated/src/nms_rotated_ext.cpp
Views: 475
1
// Modified from
2
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated
3
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4
#include <ATen/ATen.h>
5
#include <torch/extension.h>
6
7
8
#ifdef WITH_CUDA
9
at::Tensor nms_rotated_cuda(
10
const at::Tensor& dets,
11
const at::Tensor& scores,
12
const float iou_threshold);
13
14
at::Tensor poly_nms_cuda(
15
const at::Tensor boxes,
16
float nms_overlap_thresh);
17
#endif
18
19
at::Tensor nms_rotated_cpu(
20
const at::Tensor& dets,
21
const at::Tensor& scores,
22
const float iou_threshold);
23
24
25
inline at::Tensor nms_rotated(
26
const at::Tensor& dets,
27
const at::Tensor& scores,
28
const float iou_threshold) {
29
assert(dets.device().is_cuda() == scores.device().is_cuda());
30
if (dets.device().is_cuda()) {
31
#ifdef WITH_CUDA
32
return nms_rotated_cuda(
33
dets.contiguous(), scores.contiguous(), iou_threshold);
34
#else
35
AT_ERROR("Not compiled with GPU support");
36
#endif
37
}
38
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold);
39
}
40
41
42
inline at::Tensor nms_poly(
43
const at::Tensor& dets,
44
const float iou_threshold) {
45
if (dets.device().is_cuda()) {
46
#ifdef WITH_CUDA
47
if (dets.numel() == 0)
48
return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
49
return poly_nms_cuda(dets, iou_threshold);
50
#else
51
AT_ERROR("POLY_NMS is not compiled with GPU support");
52
#endif
53
}
54
AT_ERROR("POLY_NMS is not implemented on CPU");
55
}
56
57
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
58
m.def("nms_rotated", &nms_rotated, "nms for rotated bboxes");
59
m.def("nms_poly", &nms_poly, "nms for poly bboxes");
60
}
61
62