CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/nms_rotated/src/nms_rotated_cpu.cpp
Views: 475
1
// Modified from
2
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated
3
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4
#include <torch/types.h>
5
#include "box_iou_rotated_utils.h"
6
7
8
template <typename scalar_t>
9
at::Tensor nms_rotated_cpu_kernel(
10
const at::Tensor& dets,
11
const at::Tensor& scores,
12
const float iou_threshold) {
13
// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,
14
// however, the code in this function is much shorter because
15
// we delegate the IoU computation for rotated boxes to
16
// the single_box_iou_rotated function in box_iou_rotated_utils.h
17
AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor");
18
AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor");
19
AT_ASSERTM(
20
dets.scalar_type() == scores.scalar_type(),
21
"dets should have the same type as scores");
22
23
if (dets.numel() == 0) {
24
return at::empty({0}, dets.options().dtype(at::kLong));
25
}
26
27
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
28
29
auto ndets = dets.size(0);
30
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
31
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
32
33
auto suppressed = suppressed_t.data_ptr<uint8_t>();
34
auto keep = keep_t.data_ptr<int64_t>();
35
auto order = order_t.data_ptr<int64_t>();
36
37
int64_t num_to_keep = 0;
38
39
for (int64_t _i = 0; _i < ndets; _i++) {
40
auto i = order[_i];
41
if (suppressed[i] == 1) {
42
continue;
43
}
44
45
keep[num_to_keep++] = i;
46
47
for (int64_t _j = _i + 1; _j < ndets; _j++) {
48
auto j = order[_j];
49
if (suppressed[j] == 1) {
50
continue;
51
}
52
53
auto ovr = single_box_iou_rotated<scalar_t>(
54
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());
55
if (ovr >= iou_threshold) {
56
suppressed[j] = 1;
57
}
58
}
59
}
60
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
61
}
62
63
at::Tensor nms_rotated_cpu(
64
// input must be contiguous
65
const at::Tensor& dets,
66
const at::Tensor& scores,
67
const float iou_threshold) {
68
auto result = at::empty({0}, dets.options());
69
70
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {
71
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
72
});
73
return result;
74
}
75
76