CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/DOTA_devkit/dota_evaluation_task1.py
Views: 475
1
# --------------------------------------------------------
2
# dota_evaluation_task1
3
# Licensed under The MIT License [see LICENSE for details]
4
# Written by Jian Ding, based on code from Bharath Hariharan
5
# --------------------------------------------------------
6
7
"""
8
To use the code, users should to config detpath, annopath and imagesetfile
9
detpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"
10
search for PATH_TO_BE_CONFIGURED to config the paths
11
Note, the evaluation is on the large scale images
12
"""
13
import xml.etree.ElementTree as ET
14
import os
15
#import cPickle
16
import numpy as np
17
import matplotlib.pyplot as plt
18
import polyiou
19
from functools import partial
20
import argparse
21
22
def parse_gt(filename):
23
"""
24
:param filename: ground truth file to parse
25
:return: all instances in a picture
26
"""
27
objects = []
28
with open(filename, 'r') as f:
29
while True:
30
line = f.readline()
31
if line:
32
splitlines = line.strip().split(' ')
33
object_struct = {}
34
if (len(splitlines) < 9):
35
continue
36
object_struct['name'] = splitlines[8]
37
38
if (len(splitlines) == 9):
39
object_struct['difficult'] = 0
40
elif (len(splitlines) == 10):
41
object_struct['difficult'] = int(splitlines[9])
42
object_struct['bbox'] = [float(splitlines[0]),
43
float(splitlines[1]),
44
float(splitlines[2]),
45
float(splitlines[3]),
46
float(splitlines[4]),
47
float(splitlines[5]),
48
float(splitlines[6]),
49
float(splitlines[7])]
50
objects.append(object_struct)
51
else:
52
break
53
return objects
54
def voc_ap(rec, prec, use_07_metric=False):
55
""" ap = voc_ap(rec, prec, [use_07_metric])
56
Compute VOC AP given precision and recall.
57
If use_07_metric is true, uses the
58
VOC 07 11 point method (default:False).
59
"""
60
if use_07_metric:
61
# 11 point metric
62
ap = 0.
63
for t in np.arange(0., 1.1, 0.1):
64
if np.sum(rec >= t) == 0:
65
p = 0
66
else:
67
p = np.max(prec[rec >= t])
68
ap = ap + p / 11.
69
else:
70
# correct AP calculation
71
# first append sentinel values at the end
72
mrec = np.concatenate(([0.], rec, [1.]))
73
mpre = np.concatenate(([0.], prec, [0.]))
74
75
# compute the precision envelope
76
for i in range(mpre.size - 1, 0, -1):
77
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
78
79
# to calculate area under PR curve, look for points
80
# where X axis (recall) changes value
81
i = np.where(mrec[1:] != mrec[:-1])[0]
82
83
# and sum (\Delta recall) * prec
84
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
85
return ap
86
87
88
def voc_eval(detpath,
89
annopath,
90
imagesetfile,
91
classname,
92
# cachedir,
93
ovthresh=0.5,
94
use_07_metric=False):
95
"""rec, prec, ap = voc_eval(detpath,
96
annopath,
97
imagesetfile,
98
classname,
99
[ovthresh],
100
[use_07_metric])
101
Top level function that does the PASCAL VOC evaluation.
102
detpath: Path to detections
103
detpath.format(classname) should produce the detection results file.
104
annopath: Path to annotations
105
annopath.format(imagename) should be the xml annotations file.
106
imagesetfile: Text file containing the list of images, one image per line.
107
classname: Category name (duh)
108
cachedir: Directory for caching the annotations
109
[ovthresh]: Overlap threshold (default = 0.5)
110
[use_07_metric]: Whether to use VOC07's 11 point AP computation
111
(default False)
112
"""
113
# assumes detections are in detpath.format(classname)
114
# assumes annotations are in annopath.format(imagename)
115
# assumes imagesetfile is a text file with each line an image name
116
# cachedir caches the annotations in a pickle file
117
118
# first load gt
119
#if not os.path.isdir(cachedir):
120
# os.mkdir(cachedir)
121
#cachefile = os.path.join(cachedir, 'annots.pkl')
122
# read list of images
123
with open(imagesetfile, 'r') as f:
124
lines = f.readlines()
125
imagenames = [x.strip() for x in lines]
126
127
recs = {}
128
for i, imagename in enumerate(imagenames):
129
#print('parse_files name: ', annopath.format(imagename))
130
recs[imagename] = parse_gt(annopath.format(imagename))
131
132
# extract gt objects for this class
133
class_recs = {}
134
npos = 0
135
for imagename in imagenames:
136
R = [obj for obj in recs[imagename] if obj['name'] == classname]
137
bbox = np.array([x['bbox'] for x in R])
138
difficult = np.array([x['difficult'] for x in R]).astype(np.bool_)
139
det = [False] * len(R)
140
npos = npos + sum(~difficult)
141
class_recs[imagename] = {'bbox': bbox,
142
'difficult': difficult,
143
'det': det}
144
145
# read dets from Task1* files
146
detfile = detpath.format(classname)
147
with open(detfile, 'r') as f:
148
lines = f.readlines()
149
150
splitlines = [x.strip().split(' ') for x in lines]
151
image_ids = [x[0] for x in splitlines]
152
confidence = np.array([float(x[1]) for x in splitlines])
153
154
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
155
156
# sort by confidence
157
sorted_ind = np.argsort(-confidence)
158
sorted_scores = np.sort(-confidence)
159
160
## note the usage only in numpy not for list
161
BB = BB[sorted_ind, :]
162
image_ids = [image_ids[x] for x in sorted_ind]
163
# go down dets and mark TPs and FPs
164
nd = len(image_ids)
165
tp = np.zeros(nd)
166
fp = np.zeros(nd)
167
for d in range(nd):
168
R = class_recs[image_ids[d]]
169
bb = BB[d, :].astype(float)
170
ovmax = -np.inf
171
BBGT = R['bbox'].astype(float)
172
173
## compute det bb with each BBGT
174
if BBGT.size > 0:
175
# compute overlaps
176
# intersection
177
178
# 1. calculate the overlaps between hbbs, if the iou between hbbs are 0, the iou between obbs are 0, too.
179
# pdb.set_trace()
180
BBGT_xmin = np.min(BBGT[:, 0::2], axis=1)
181
BBGT_ymin = np.min(BBGT[:, 1::2], axis=1)
182
BBGT_xmax = np.max(BBGT[:, 0::2], axis=1)
183
BBGT_ymax = np.max(BBGT[:, 1::2], axis=1)
184
bb_xmin = np.min(bb[0::2])
185
bb_ymin = np.min(bb[1::2])
186
bb_xmax = np.max(bb[0::2])
187
bb_ymax = np.max(bb[1::2])
188
189
ixmin = np.maximum(BBGT_xmin, bb_xmin)
190
iymin = np.maximum(BBGT_ymin, bb_ymin)
191
ixmax = np.minimum(BBGT_xmax, bb_xmax)
192
iymax = np.minimum(BBGT_ymax, bb_ymax)
193
iw = np.maximum(ixmax - ixmin + 1., 0.)
194
ih = np.maximum(iymax - iymin + 1., 0.)
195
inters = iw * ih
196
197
# union
198
uni = ((bb_xmax - bb_xmin + 1.) * (bb_ymax - bb_ymin + 1.) +
199
(BBGT_xmax - BBGT_xmin + 1.) *
200
(BBGT_ymax - BBGT_ymin + 1.) - inters)
201
202
overlaps = inters / uni
203
204
BBGT_keep_mask = overlaps > 0
205
BBGT_keep = BBGT[BBGT_keep_mask, :]
206
BBGT_keep_index = np.where(overlaps > 0)[0]
207
208
def calcoverlaps(BBGT_keep, bb):
209
overlaps = []
210
for index, GT in enumerate(BBGT_keep):
211
212
overlap = polyiou.iou_poly(polyiou.VectorDouble(BBGT_keep[index]), polyiou.VectorDouble(bb))
213
overlaps.append(overlap)
214
return overlaps
215
if len(BBGT_keep) > 0:
216
overlaps = calcoverlaps(BBGT_keep, bb)
217
218
ovmax = np.max(overlaps)
219
jmax = np.argmax(overlaps)
220
# pdb.set_trace()
221
jmax = BBGT_keep_index[jmax]
222
223
if ovmax > ovthresh:
224
if not R['difficult'][jmax]:
225
if not R['det'][jmax]:
226
tp[d] = 1.
227
R['det'][jmax] = 1
228
else:
229
fp[d] = 1.
230
else:
231
fp[d] = 1.
232
233
# compute precision recall
234
235
print('check fp:', fp)
236
print('check tp', tp)
237
238
239
print('npos num:', npos)
240
fp = np.cumsum(fp)
241
tp = np.cumsum(tp)
242
243
rec = tp / float(npos) # recall
244
# avoid divide by zero in case the first detection matches a difficult
245
# ground truth
246
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) # 准确率
247
ap = voc_ap(rec, prec, use_07_metric)
248
249
return rec, prec, ap
250
251
def GetFileFromThisRootDir(dir,ext = None):
252
allfiles = []
253
needExtFilter = (ext != None)
254
for root,dirs,files in os.walk(dir):
255
for filespath in files:
256
filepath = os.path.join(root, filespath)
257
extension = os.path.splitext(filepath)[1][1:]
258
if needExtFilter and extension in ext:
259
allfiles.append(filepath)
260
elif not needExtFilter:
261
allfiles.append(filepath)
262
return allfiles
263
264
def image2txt(srcpath, dstpath):
265
"""
266
将srcpath文件夹下的所有子文件名称打印到namefile.txt中
267
@param srcpath: imageset
268
@param dstpath: imgnamefile.txt的存放路径
269
"""
270
filelist = GetFileFromThisRootDir(srcpath) # srcpath文件夹下的所有文件相对路径 eg:['example_split/../P0001.txt', ..., '?.txt']
271
for fullname in filelist: # 'example_split/../P0001.txt'
272
name = os.path.basename(os.path.splitext(fullname)[0])# 只留下文件名 eg:P0001
273
dstname = os.path.join(dstpath, 'imgnamefile.txt') # eg: result/imgnamefile.txt
274
if not os.path.exists(dstpath):
275
os.makedirs(dstpath)
276
with open(dstname, 'a') as f:
277
f.writelines(name + '\n')
278
279
def parse_args():
280
parser = argparse.ArgumentParser(description='MMDet test (and eval) a model')
281
parser.add_argument('--detpath', default='runs/val/yolov5t_DroneVehicle_val/splited_obb_prediction_Txt/Task1_{:s}.txt', help='test config file path')
282
parser.add_argument('--annopath', default='/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/labelTxt/{:s}.txt', help='checkpoint file')
283
parser.add_argument('--imagesetfile', default='/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/imgnamefile.txt', help='checkpoint file')
284
args = parser.parse_args()
285
return args
286
287
def main():
288
args = parse_args()
289
# detpath = r'/mnt/SSD/lwt_workdir/data/dota_angle/result_merge_roitran/{:s}.txt'
290
detpath = args.detpath
291
annopath = args.annopath
292
imagesetfile = args.imagesetfile
293
# For DroneVehicle
294
classnames=['vehicle']
295
# For DOTA-v2.0
296
# classnames = [ 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship',
297
# 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor',
298
# 'swimming-pool', 'helicopter', 'container-crane', 'airport', 'helipad']
299
# For DOTA-v1.5
300
# classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
301
# 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter', 'container-crane']
302
# For DOTA-v1.0
303
# classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
304
# 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter']
305
classaps = []
306
map = 0
307
skippedClassCount = 0
308
for classname in classnames:
309
print('classname:', classname)
310
detfile = detpath.format(classname)
311
if not (os.path.exists(detfile)):
312
skippedClassCount += 1
313
print('This class is not be detected in your dataset: {:s}'.format(classname))
314
continue
315
rec, prec, ap = voc_eval(detpath,
316
annopath,
317
imagesetfile,
318
classname,
319
ovthresh=0.5,
320
use_07_metric=True)
321
map = map + ap
322
#print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)
323
print('ap: ', ap)
324
classaps.append(ap)
325
326
# # umcomment to show p-r curve of each category
327
# plt.figure(figsize=(8,4))
328
# plt.xlabel('Recall')
329
# plt.ylabel('Precision')
330
# plt.xticks(fontsize=11)
331
# plt.yticks(fontsize=11)
332
# plt.xlim(0, 1)
333
# plt.ylim(0, 1)
334
# ax = plt.gca()
335
# ax.spines['top'].set_color('none')
336
# ax.spines['right'].set_color('none')
337
# plt.plot(rec, prec)
338
# # plt.show()
339
# plt.savefig('pr_curve/{}.png'.format(classname))
340
map = map/(len(classnames)-skippedClassCount)
341
print('map:', map)
342
classaps = 100*np.array(classaps)
343
print('classaps: ', classaps)
344
if __name__ == '__main__':
345
main()
346
# image2txt('/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/images',
347
# '/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/')
348
# image2txt('dataset/dataset_demo_rate1.0_split1024_gap200/images', 'dataset/dataset_demo_rate1.0_split1024_gap200/')
349