Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/utils/nms_rotated/src/nms_rotated_cpu.cpp
Views: 475
// Modified from1// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/nms_rotated2// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved3#include <torch/types.h>4#include "box_iou_rotated_utils.h"567template <typename scalar_t>8at::Tensor nms_rotated_cpu_kernel(9const at::Tensor& dets,10const at::Tensor& scores,11const float iou_threshold) {12// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,13// however, the code in this function is much shorter because14// we delegate the IoU computation for rotated boxes to15// the single_box_iou_rotated function in box_iou_rotated_utils.h16AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor");17AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor");18AT_ASSERTM(19dets.scalar_type() == scores.scalar_type(),20"dets should have the same type as scores");2122if (dets.numel() == 0) {23return at::empty({0}, dets.options().dtype(at::kLong));24}2526auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));2728auto ndets = dets.size(0);29at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));30at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));3132auto suppressed = suppressed_t.data_ptr<uint8_t>();33auto keep = keep_t.data_ptr<int64_t>();34auto order = order_t.data_ptr<int64_t>();3536int64_t num_to_keep = 0;3738for (int64_t _i = 0; _i < ndets; _i++) {39auto i = order[_i];40if (suppressed[i] == 1) {41continue;42}4344keep[num_to_keep++] = i;4546for (int64_t _j = _i + 1; _j < ndets; _j++) {47auto j = order[_j];48if (suppressed[j] == 1) {49continue;50}5152auto ovr = single_box_iou_rotated<scalar_t>(53dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());54if (ovr >= iou_threshold) {55suppressed[j] = 1;56}57}58}59return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);60}6162at::Tensor nms_rotated_cpu(63// input must be contiguous64const at::Tensor& dets,65const at::Tensor& scores,66const float iou_threshold) {67auto result = at::empty({0}, dets.options());6869AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {70result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);71});72return result;73}747576