CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/rboxs_utils.py
Views: 475
1
"""
2
Oriented Bounding Boxes utils
3
"""
4
import numpy as np
5
pi = 3.141592
6
import cv2
7
import torch
8
9
def gaussian_label_cpu(label, num_class, u=0, sig=4.0):
10
"""
11
转换成CSL Labels:
12
用高斯窗口函数根据角度θ的周期性赋予gt labels同样的周期性,使得损失函数在计算边界处时可以做到“差值很大但loss很小”;
13
并且使得其labels具有环形特征,能够反映各个θ之间的角度距离
14
Args:
15
label (float32):[1], theta class
16
num_theta_class (int): [1], theta class num
17
u (float32):[1], μ in gaussian function
18
sig (float32):[1], σ in gaussian function, which is window radius for Circular Smooth Label
19
Returns:
20
csl_label (array): [num_theta_class], gaussian function smooth label
21
"""
22
x = np.arange(-num_class/2, num_class/2)
23
y_sig = np.exp(-(x - u) ** 2 / (2 * sig ** 2))
24
index = int(num_class/2 - label)
25
return np.concatenate([y_sig[index:],
26
y_sig[:index]], axis=0)
27
28
def regular_theta(theta, mode='180', start=-pi/2):
29
"""
30
limit theta ∈ [-pi/2, pi/2)
31
"""
32
assert mode in ['360', '180']
33
cycle = 2 * pi if mode == '360' else pi
34
35
theta = theta - start
36
theta = theta % cycle
37
return theta + start
38
39
def poly2rbox(polys, num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False):
40
"""
41
Trans poly format to rbox format.
42
Args:
43
polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
44
num_cls_thata (int): [1], theta class num
45
radius (float32): [1], window radius for Circular Smooth Label
46
use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180)
47
48
Returns:
49
use_gaussian True:
50
rboxes (array):
51
csl_labels (array): (num_gts, num_cls_thata)
52
elif
53
rboxes (array): (num_gts, [cx cy l s θ])
54
"""
55
assert polys.shape[-1] == 8
56
if use_gaussian:
57
csl_labels = []
58
rboxes = []
59
for poly in polys:
60
poly = np.float32(poly.reshape(4, 2))
61
(x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90]
62
angle = -angle # θ ∈ [-90, 0]
63
theta = angle / 180 * pi # 转为pi制
64
65
# trans opencv format to longedge format θ ∈ [-pi/2, pi/2]
66
if w != max(w, h):
67
w, h = h, w
68
theta += pi/2
69
theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2)
70
angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180)
71
72
if not use_pi: # 采用angle弧度制 θ ∈ [0, 180)
73
rboxes.append([x, y, w, h, angle])
74
else: # 采用pi制
75
rboxes.append([x, y, w, h, theta])
76
if use_gaussian:
77
csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius)
78
csl_labels.append(csl_label)
79
if use_gaussian:
80
return np.array(rboxes), np.array(csl_labels)
81
return np.array(rboxes)
82
83
# def rbox2poly(rboxes):
84
# """
85
# Trans rbox format to poly format.
86
# Args:
87
# rboxes (array): (num_gts, [cx cy l s θ]) θ∈(0, 180]
88
89
# Returns:
90
# polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
91
# """
92
# assert rboxes.shape[-1] == 5
93
# polys = []
94
# for rbox in rboxes:
95
# x, y, w, h, theta = rbox
96
# if theta > 90 and theta <= 180: # longedge format -> opencv format
97
# w, h = h, w
98
# theta -= 90
99
# if theta <= 0 or theta > 90:
100
# print("cv2.minAreaRect occurs some error. θ isn't in range(0, 90]. The longedge format is: ", rbox)
101
102
# poly = cv2.boxPoints(((x, y), (w, h), theta)).reshape(-1)
103
# polys.append(poly)
104
# return np.array(polys)
105
106
def rbox2poly(obboxes):
107
"""
108
Trans rbox format to poly format.
109
Args:
110
rboxes (array/tensor): (num_gts, [cx cy l s θ]) θ∈[-pi/2, pi/2)
111
112
Returns:
113
polys (array/tensor): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
114
"""
115
if isinstance(obboxes, torch.Tensor):
116
center, w, h, theta = obboxes[:, :2], obboxes[:, 2:3], obboxes[:, 3:4], obboxes[:, 4:5]
117
Cos, Sin = torch.cos(theta), torch.sin(theta)
118
119
vector1 = torch.cat(
120
(w/2 * Cos, -w/2 * Sin), dim=-1)
121
vector2 = torch.cat(
122
(-h/2 * Sin, -h/2 * Cos), dim=-1)
123
point1 = center + vector1 + vector2
124
point2 = center + vector1 - vector2
125
point3 = center - vector1 - vector2
126
point4 = center - vector1 + vector2
127
order = obboxes.shape[:-1]
128
return torch.cat(
129
(point1, point2, point3, point4), dim=-1).reshape(*order, 8)
130
else:
131
center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1)
132
Cos, Sin = np.cos(theta), np.sin(theta)
133
134
vector1 = np.concatenate(
135
[w/2 * Cos, -w/2 * Sin], axis=-1)
136
vector2 = np.concatenate(
137
[-h/2 * Sin, -h/2 * Cos], axis=-1)
138
139
point1 = center + vector1 + vector2
140
point2 = center + vector1 - vector2
141
point3 = center - vector1 - vector2
142
point4 = center - vector1 + vector2
143
order = obboxes.shape[:-1]
144
return np.concatenate(
145
[point1, point2, point3, point4], axis=-1).reshape(*order, 8)
146
147
def poly2hbb(polys):
148
"""
149
Trans poly format to hbb format
150
Args:
151
rboxes (array/tensor): (num_gts, poly)
152
153
Returns:
154
hbboxes (array/tensor): (num_gts, [xc yc w h])
155
"""
156
assert polys.shape[-1] == 8
157
if isinstance(polys, torch.Tensor):
158
x = polys[:, 0::2] # (num, 4)
159
y = polys[:, 1::2]
160
x_max = torch.amax(x, dim=1) # (num)
161
x_min = torch.amin(x, dim=1)
162
y_max = torch.amax(y, dim=1)
163
y_min = torch.amin(y, dim=1)
164
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
165
h = y_max - y_min # (num)
166
w = x_max - x_min
167
x_ctr, y_ctr, w, h = x_ctr.reshape(-1, 1), y_ctr.reshape(-1, 1), w.reshape(-1, 1), h.reshape(-1, 1) # (num, 1)
168
hbboxes = torch.cat((x_ctr, y_ctr, w, h), dim=1)
169
else:
170
x = polys[:, 0::2] # (num, 4)
171
y = polys[:, 1::2]
172
x_max = np.amax(x, axis=1) # (num)
173
x_min = np.amin(x, axis=1)
174
y_max = np.amax(y, axis=1)
175
y_min = np.amin(y, axis=1)
176
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
177
h = y_max - y_min # (num)
178
w = x_max - x_min
179
x_ctr, y_ctr, w, h = x_ctr.reshape(-1, 1), y_ctr.reshape(-1, 1), w.reshape(-1, 1), h.reshape(-1, 1) # (num, 1)
180
hbboxes = np.concatenate((x_ctr, y_ctr, w, h), axis=1)
181
return hbboxes
182
183
def poly_filter(polys, h, w):
184
"""
185
Filter the poly labels which is out of the image.
186
Args:
187
polys (array): (num, 8)
188
189
Return:
190
keep_masks (array): (num)
191
"""
192
x = polys[:, 0::2] # (num, 4)
193
y = polys[:, 1::2]
194
x_max = np.amax(x, axis=1) # (num)
195
x_min = np.amin(x, axis=1)
196
y_max = np.amax(y, axis=1)
197
y_min = np.amin(y, axis=1)
198
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
199
keep_masks = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h)
200
return keep_masks
201