Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/utils/nms_rotated/nms_rotated_wrapper.py
Views: 475
import numpy as np1import torch23from . import nms_rotated_ext45def obb_nms(dets, scores, iou_thr, device_id=None):6"""7RIoU NMS - iou_thr.8Args:9dets (tensor/array): (num, [cx cy w h θ]) θ∈[-pi/2, pi/2)10scores (tensor/array): (num)11iou_thr (float): (1)12Returns:13dets (tensor): (n_nms, [cx cy w h θ])14inds (tensor): (n_nms), nms index of dets15"""16if isinstance(dets, torch.Tensor):17is_numpy = False18dets_th = dets19elif isinstance(dets, np.ndarray):20is_numpy = True21device = 'cpu' if device_id is None else f'cuda:{device_id}'22dets_th = torch.from_numpy(dets).to(device)23else:24raise TypeError('dets must be eithr a Tensor or numpy array, '25f'but got {type(dets)}')2627if dets_th.numel() == 0: # len(dets)28inds = dets_th.new_zeros(0, dtype=torch.int64)29else:30# same bug will happen when bboxes is too small31too_small = dets_th[:, [2, 3]].min(1)[0] < 0.001 # [n]32if too_small.all(): # all the bboxes is too small33inds = dets_th.new_zeros(0, dtype=torch.int64)34else:35ori_inds = torch.arange(dets_th.size(0)) # 0 ~ n-136ori_inds = ori_inds[~too_small]37dets_th = dets_th[~too_small] # (n_filter, 5)38scores = scores[~too_small]3940inds = nms_rotated_ext.nms_rotated(dets_th, scores, iou_thr)41inds = ori_inds[inds]4243if is_numpy:44inds = inds.cpu().numpy()45return dets[inds, :], inds464748def poly_nms(dets, iou_thr, device_id=None):49if isinstance(dets, torch.Tensor):50is_numpy = False51dets_th = dets52elif isinstance(dets, np.ndarray):53is_numpy = True54device = 'cpu' if device_id is None else f'cuda:{device_id}'55dets_th = torch.from_numpy(dets).to(device)56else:57raise TypeError('dets must be eithr a Tensor or numpy array, '58f'but got {type(dets)}')5960if dets_th.device == torch.device('cpu'):61raise NotImplementedError62inds = nms_rotated_ext.nms_poly(dets_th.float(), iou_thr)6364if is_numpy:65inds = inds.cpu().numpy()66return dets[inds, :], inds6768if __name__ == '__main__':69rboxes_opencv = torch.tensor(([136.6, 111.6, 200, 100, -60],70[136.6, 111.6, 100, 200, -30],71[100, 100, 141.4, 141.4, -45],72[100, 100, 141.4, 141.4, -45]))73rboxes_longedge = torch.tensor(([136.6, 111.6, 200, 100, -60],74[136.6, 111.6, 200, 100, 120],75[100, 100, 141.4, 141.4, 45],76[100, 100, 141.4, 141.4, 135]))777879