CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/nms_rotated/nms_rotated_wrapper.py
Views: 475
1
import numpy as np
2
import torch
3
4
from . import nms_rotated_ext
5
6
def obb_nms(dets, scores, iou_thr, device_id=None):
7
"""
8
RIoU NMS - iou_thr.
9
Args:
10
dets (tensor/array): (num, [cx cy w h θ]) θ∈[-pi/2, pi/2)
11
scores (tensor/array): (num)
12
iou_thr (float): (1)
13
Returns:
14
dets (tensor): (n_nms, [cx cy w h θ])
15
inds (tensor): (n_nms), nms index of dets
16
"""
17
if isinstance(dets, torch.Tensor):
18
is_numpy = False
19
dets_th = dets
20
elif isinstance(dets, np.ndarray):
21
is_numpy = True
22
device = 'cpu' if device_id is None else f'cuda:{device_id}'
23
dets_th = torch.from_numpy(dets).to(device)
24
else:
25
raise TypeError('dets must be eithr a Tensor or numpy array, '
26
f'but got {type(dets)}')
27
28
if dets_th.numel() == 0: # len(dets)
29
inds = dets_th.new_zeros(0, dtype=torch.int64)
30
else:
31
# same bug will happen when bboxes is too small
32
too_small = dets_th[:, [2, 3]].min(1)[0] < 0.001 # [n]
33
if too_small.all(): # all the bboxes is too small
34
inds = dets_th.new_zeros(0, dtype=torch.int64)
35
else:
36
ori_inds = torch.arange(dets_th.size(0)) # 0 ~ n-1
37
ori_inds = ori_inds[~too_small]
38
dets_th = dets_th[~too_small] # (n_filter, 5)
39
scores = scores[~too_small]
40
41
inds = nms_rotated_ext.nms_rotated(dets_th, scores, iou_thr)
42
inds = ori_inds[inds]
43
44
if is_numpy:
45
inds = inds.cpu().numpy()
46
return dets[inds, :], inds
47
48
49
def poly_nms(dets, iou_thr, device_id=None):
50
if isinstance(dets, torch.Tensor):
51
is_numpy = False
52
dets_th = dets
53
elif isinstance(dets, np.ndarray):
54
is_numpy = True
55
device = 'cpu' if device_id is None else f'cuda:{device_id}'
56
dets_th = torch.from_numpy(dets).to(device)
57
else:
58
raise TypeError('dets must be eithr a Tensor or numpy array, '
59
f'but got {type(dets)}')
60
61
if dets_th.device == torch.device('cpu'):
62
raise NotImplementedError
63
inds = nms_rotated_ext.nms_poly(dets_th.float(), iou_thr)
64
65
if is_numpy:
66
inds = inds.cpu().numpy()
67
return dets[inds, :], inds
68
69
if __name__ == '__main__':
70
rboxes_opencv = torch.tensor(([136.6, 111.6, 200, 100, -60],
71
[136.6, 111.6, 100, 200, -30],
72
[100, 100, 141.4, 141.4, -45],
73
[100, 100, 141.4, 141.4, -45]))
74
rboxes_longedge = torch.tensor(([136.6, 111.6, 200, 100, -60],
75
[136.6, 111.6, 200, 100, 120],
76
[100, 100, 141.4, 141.4, 45],
77
[100, 100, 141.4, 141.4, 135]))
78
79