CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/DOTA_devkit/ImgSplit.py
Views: 475
1
import os
2
import codecs
3
import numpy as np
4
import math
5
from dota_utils import GetFileFromThisRootDir
6
import cv2
7
import shapely.geometry as shgeo
8
import dota_utils as util
9
import copy
10
11
def choose_best_pointorder_fit_another(poly1, poly2):
12
"""
13
To make the two polygons best fit with each point
14
"""
15
x1 = poly1[0]
16
y1 = poly1[1]
17
x2 = poly1[2]
18
y2 = poly1[3]
19
x3 = poly1[4]
20
y3 = poly1[5]
21
x4 = poly1[6]
22
y4 = poly1[7]
23
combinate = [np.array([x1, y1, x2, y2, x3, y3, x4, y4]), np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
24
np.array([x3, y3, x4, y4, x1, y1, x2, y2]), np.array([x4, y4, x1, y1, x2, y2, x3, y3])]
25
dst_coordinate = np.array(poly2)
26
distances = np.array([np.sum((coord - dst_coordinate)**2) for coord in combinate])
27
sorted = distances.argsort()
28
return combinate[sorted[0]]
29
30
def cal_line_length(point1, point2):
31
return math.sqrt( math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
32
33
34
class splitbase():
35
def __init__(self,
36
basepath,
37
outpath,
38
code = 'utf-8',
39
gap=100,
40
subsize=1024,
41
thresh=0.7,
42
choosebestpoint=True,
43
ext = '.png'
44
):
45
"""
46
:param basepath: base path for dota data
47
:param outpath: output base path for dota data,
48
the basepath and outputpath have the similar subdirectory, 'images' and 'labelTxt'
49
:param code: encodeing format of txt file
50
:param gap: overlap between two patches
51
:param subsize: subsize of patch
52
:param thresh: the thresh determine whether to keep the instance if the instance is cut down in the process of split
53
:param choosebestpoint: used to choose the first point for the
54
:param ext: ext for the image format
55
"""
56
self.basepath = basepath
57
self.outpath = outpath
58
self.code = code
59
self.gap = gap
60
self.subsize = subsize
61
self.slide = self.subsize - self.gap
62
self.thresh = thresh
63
self.imagepath = os.path.join(self.basepath, 'images')
64
self.labelpath = os.path.join(self.basepath, 'labelTxt')
65
self.outimagepath = os.path.join(self.outpath, 'images')
66
self.outlabelpath = os.path.join(self.outpath, 'labelTxt')
67
self.choosebestpoint = choosebestpoint
68
self.ext = ext
69
if not os.path.exists(self.outimagepath):
70
os.makedirs(self.outimagepath)
71
if not os.path.exists(self.outlabelpath):
72
os.makedirs(self.outlabelpath)
73
74
## point: (x, y), rec: (xmin, ymin, xmax, ymax)
75
# def __del__(self):
76
# self.f_sub.close()
77
## grid --> (x, y) position of grids
78
def polyorig2sub(self, left, up, poly):
79
polyInsub = np.zeros(len(poly))
80
for i in range(int(len(poly)/2)):
81
polyInsub[i * 2] = int(poly[i * 2] - left)
82
polyInsub[i * 2 + 1] = int(poly[i * 2 + 1] - up)
83
return polyInsub
84
85
def calchalf_iou(self, poly1, poly2):
86
"""
87
It is not the iou on usual, the iou is the value of intersection over poly1
88
"""
89
inter_poly = poly1.intersection(poly2)
90
inter_area = inter_poly.area
91
poly1_area = poly1.area
92
half_iou = inter_area / poly1_area
93
return inter_poly, half_iou
94
95
def saveimagepatches(self, img, subimgname, left, up):
96
subimg = copy.deepcopy(img[up: (up + self.subsize), left: (left + self.subsize)])
97
outdir = os.path.join(self.outimagepath, subimgname + self.ext)
98
cv2.imwrite(outdir, subimg)
99
100
def GetPoly4FromPoly5(self, poly):
101
distances = [cal_line_length((poly[i * 2], poly[i * 2 + 1] ), (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1])) for i in range(int(len(poly)/2 - 1))]
102
distances.append(cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
103
pos = np.array(distances).argsort()[0]
104
count = 0
105
outpoly = []
106
while count < 5:
107
#print('count:', count)
108
if (count == pos):
109
outpoly.append((poly[count * 2] + poly[(count * 2 + 2)%10])/2)
110
outpoly.append((poly[(count * 2 + 1)%10] + poly[(count * 2 + 3)%10])/2)
111
count = count + 1
112
elif (count == (pos + 1)%5):
113
count = count + 1
114
continue
115
116
else:
117
outpoly.append(poly[count * 2])
118
outpoly.append(poly[count * 2 + 1])
119
count = count + 1
120
return outpoly
121
122
def savepatches(self, resizeimg, objects, subimgname, left, up, right, down):
123
outdir = os.path.join(self.outlabelpath, subimgname + '.txt')
124
mask_poly = []
125
imgpoly = shgeo.Polygon([(left, up), (right, up), (right, down),
126
(left, down)])
127
with codecs.open(outdir, 'w', self.code) as f_out:
128
for obj in objects:
129
gtpoly = shgeo.Polygon([(obj['poly'][0], obj['poly'][1]),
130
(obj['poly'][2], obj['poly'][3]),
131
(obj['poly'][4], obj['poly'][5]),
132
(obj['poly'][6], obj['poly'][7])])
133
if (gtpoly.area <= 0):
134
continue
135
inter_poly, half_iou = self.calchalf_iou(gtpoly, imgpoly)
136
137
# print('writing...')
138
if (half_iou == 1):
139
polyInsub = self.polyorig2sub(left, up, obj['poly'])
140
outline = ' '.join(list(map(str, polyInsub)))
141
outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
142
f_out.write(outline + '\n')
143
elif (half_iou > 0):
144
#elif (half_iou > self.thresh):
145
## print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
146
inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
147
out_poly = list(inter_poly.exterior.coords)[0: -1]
148
if len(out_poly) < 4:
149
continue
150
151
out_poly2 = []
152
for i in range(len(out_poly)):
153
out_poly2.append(out_poly[i][0])
154
out_poly2.append(out_poly[i][1])
155
156
if (len(out_poly) == 5):
157
#print('==========================')
158
out_poly2 = self.GetPoly4FromPoly5(out_poly2)
159
elif (len(out_poly) > 5):
160
"""
161
if the cut instance is a polygon with points more than 5, we do not handle it currently
162
"""
163
continue
164
if (self.choosebestpoint):
165
out_poly2 = choose_best_pointorder_fit_another(out_poly2, obj['poly'])
166
167
polyInsub = self.polyorig2sub(left, up, out_poly2)
168
169
for index, item in enumerate(polyInsub):
170
if (item <= 1):
171
polyInsub[index] = 1
172
elif (item >= self.subsize):
173
polyInsub[index] = self.subsize
174
outline = ' '.join(list(map(str, polyInsub)))
175
if (half_iou > self.thresh):
176
outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
177
else:
178
## if the left part is too small, label as '2'
179
outline = outline + ' ' + obj['name'] + ' ' + '2'
180
f_out.write(outline + '\n')
181
#else:
182
# mask_poly.append(inter_poly)
183
self.saveimagepatches(resizeimg, subimgname, left, up)
184
185
def SplitSingle(self, name, rate, extent):
186
"""
187
split a single image and ground truth
188
:param name: image name
189
:param rate: the resize scale for the image
190
:param extent: the image format
191
:return:
192
"""
193
img = cv2.imread(os.path.join(self.imagepath, name + extent))
194
if np.shape(img) == ():
195
return
196
fullname = os.path.join(self.labelpath, name + '.txt')
197
objects = util.parse_dota_poly2(fullname)
198
for obj in objects:
199
obj['poly'] = list(map(lambda x:rate*x, obj['poly']))
200
#obj['poly'] = list(map(lambda x: ([2 * y for y in x]), obj['poly']))
201
202
if (rate != 1):
203
resizeimg = cv2.resize(img, None, fx=rate, fy=rate, interpolation = cv2.INTER_CUBIC)
204
else:
205
resizeimg = img
206
outbasename = name + '__' + str(rate) + '__'
207
weight = np.shape(resizeimg)[1]
208
height = np.shape(resizeimg)[0]
209
210
left, up = 0, 0
211
while (left < weight):
212
if (left + self.subsize >= weight):
213
left = max(weight - self.subsize, 0)
214
up = 0
215
while (up < height):
216
if (up + self.subsize >= height):
217
up = max(height - self.subsize, 0)
218
right = min(left + self.subsize, weight - 1)
219
down = min(up + self.subsize, height - 1)
220
subimgname = outbasename + str(left) + '___' + str(up)
221
# self.f_sub.write(name + ' ' + subimgname + ' ' + str(left) + ' ' + str(up) + '\n')
222
self.savepatches(resizeimg, objects, subimgname, left, up, right, down)
223
if (up + self.subsize >= height):
224
break
225
else:
226
up = up + self.slide
227
if (left + self.subsize >= weight):
228
break
229
else:
230
left = left + self.slide
231
232
def splitdata(self, rate):
233
"""
234
:param rate: resize rate before cut
235
"""
236
imagelist = GetFileFromThisRootDir(self.imagepath)
237
imagenames = [util.custombasename(x) for x in imagelist if (util.custombasename(x) != 'Thumbs')]
238
for name in imagenames:
239
self.SplitSingle(name, rate, self.ext)
240
241
if __name__ == '__main__':
242
# example usage of ImgSplit
243
split = splitbase(r'example',
244
r'examplesplit')
245
split.splitdata(1)
246