Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/DOTA_devkit/dota_evaluation_task1.py
Views: 475
# --------------------------------------------------------1# dota_evaluation_task12# Licensed under The MIT License [see LICENSE for details]3# Written by Jian Ding, based on code from Bharath Hariharan4# --------------------------------------------------------56"""7To use the code, users should to config detpath, annopath and imagesetfile8detpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"9search for PATH_TO_BE_CONFIGURED to config the paths10Note, the evaluation is on the large scale images11"""12import xml.etree.ElementTree as ET13import os14#import cPickle15import numpy as np16import matplotlib.pyplot as plt17import polyiou18from functools import partial19import argparse2021def parse_gt(filename):22"""23:param filename: ground truth file to parse24:return: all instances in a picture25"""26objects = []27with open(filename, 'r') as f:28while True:29line = f.readline()30if line:31splitlines = line.strip().split(' ')32object_struct = {}33if (len(splitlines) < 9):34continue35object_struct['name'] = splitlines[8]3637if (len(splitlines) == 9):38object_struct['difficult'] = 039elif (len(splitlines) == 10):40object_struct['difficult'] = int(splitlines[9])41object_struct['bbox'] = [float(splitlines[0]),42float(splitlines[1]),43float(splitlines[2]),44float(splitlines[3]),45float(splitlines[4]),46float(splitlines[5]),47float(splitlines[6]),48float(splitlines[7])]49objects.append(object_struct)50else:51break52return objects53def voc_ap(rec, prec, use_07_metric=False):54""" ap = voc_ap(rec, prec, [use_07_metric])55Compute VOC AP given precision and recall.56If use_07_metric is true, uses the57VOC 07 11 point method (default:False).58"""59if use_07_metric:60# 11 point metric61ap = 0.62for t in np.arange(0., 1.1, 0.1):63if np.sum(rec >= t) == 0:64p = 065else:66p = np.max(prec[rec >= t])67ap = ap + p / 11.68else:69# correct AP calculation70# first append sentinel values at the end71mrec = np.concatenate(([0.], rec, [1.]))72mpre = np.concatenate(([0.], prec, [0.]))7374# compute the precision envelope75for i in range(mpre.size - 1, 0, -1):76mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])7778# to calculate area under PR curve, look for points79# where X axis (recall) changes value80i = np.where(mrec[1:] != mrec[:-1])[0]8182# and sum (\Delta recall) * prec83ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])84return ap858687def voc_eval(detpath,88annopath,89imagesetfile,90classname,91# cachedir,92ovthresh=0.5,93use_07_metric=False):94"""rec, prec, ap = voc_eval(detpath,95annopath,96imagesetfile,97classname,98[ovthresh],99[use_07_metric])100Top level function that does the PASCAL VOC evaluation.101detpath: Path to detections102detpath.format(classname) should produce the detection results file.103annopath: Path to annotations104annopath.format(imagename) should be the xml annotations file.105imagesetfile: Text file containing the list of images, one image per line.106classname: Category name (duh)107cachedir: Directory for caching the annotations108[ovthresh]: Overlap threshold (default = 0.5)109[use_07_metric]: Whether to use VOC07's 11 point AP computation110(default False)111"""112# assumes detections are in detpath.format(classname)113# assumes annotations are in annopath.format(imagename)114# assumes imagesetfile is a text file with each line an image name115# cachedir caches the annotations in a pickle file116117# first load gt118#if not os.path.isdir(cachedir):119# os.mkdir(cachedir)120#cachefile = os.path.join(cachedir, 'annots.pkl')121# read list of images122with open(imagesetfile, 'r') as f:123lines = f.readlines()124imagenames = [x.strip() for x in lines]125126recs = {}127for i, imagename in enumerate(imagenames):128#print('parse_files name: ', annopath.format(imagename))129recs[imagename] = parse_gt(annopath.format(imagename))130131# extract gt objects for this class132class_recs = {}133npos = 0134for imagename in imagenames:135R = [obj for obj in recs[imagename] if obj['name'] == classname]136bbox = np.array([x['bbox'] for x in R])137difficult = np.array([x['difficult'] for x in R]).astype(np.bool_)138det = [False] * len(R)139npos = npos + sum(~difficult)140class_recs[imagename] = {'bbox': bbox,141'difficult': difficult,142'det': det}143144# read dets from Task1* files145detfile = detpath.format(classname)146with open(detfile, 'r') as f:147lines = f.readlines()148149splitlines = [x.strip().split(' ') for x in lines]150image_ids = [x[0] for x in splitlines]151confidence = np.array([float(x[1]) for x in splitlines])152153BB = np.array([[float(z) for z in x[2:]] for x in splitlines])154155# sort by confidence156sorted_ind = np.argsort(-confidence)157sorted_scores = np.sort(-confidence)158159## note the usage only in numpy not for list160BB = BB[sorted_ind, :]161image_ids = [image_ids[x] for x in sorted_ind]162# go down dets and mark TPs and FPs163nd = len(image_ids)164tp = np.zeros(nd)165fp = np.zeros(nd)166for d in range(nd):167R = class_recs[image_ids[d]]168bb = BB[d, :].astype(float)169ovmax = -np.inf170BBGT = R['bbox'].astype(float)171172## compute det bb with each BBGT173if BBGT.size > 0:174# compute overlaps175# intersection176177# 1. calculate the overlaps between hbbs, if the iou between hbbs are 0, the iou between obbs are 0, too.178# pdb.set_trace()179BBGT_xmin = np.min(BBGT[:, 0::2], axis=1)180BBGT_ymin = np.min(BBGT[:, 1::2], axis=1)181BBGT_xmax = np.max(BBGT[:, 0::2], axis=1)182BBGT_ymax = np.max(BBGT[:, 1::2], axis=1)183bb_xmin = np.min(bb[0::2])184bb_ymin = np.min(bb[1::2])185bb_xmax = np.max(bb[0::2])186bb_ymax = np.max(bb[1::2])187188ixmin = np.maximum(BBGT_xmin, bb_xmin)189iymin = np.maximum(BBGT_ymin, bb_ymin)190ixmax = np.minimum(BBGT_xmax, bb_xmax)191iymax = np.minimum(BBGT_ymax, bb_ymax)192iw = np.maximum(ixmax - ixmin + 1., 0.)193ih = np.maximum(iymax - iymin + 1., 0.)194inters = iw * ih195196# union197uni = ((bb_xmax - bb_xmin + 1.) * (bb_ymax - bb_ymin + 1.) +198(BBGT_xmax - BBGT_xmin + 1.) *199(BBGT_ymax - BBGT_ymin + 1.) - inters)200201overlaps = inters / uni202203BBGT_keep_mask = overlaps > 0204BBGT_keep = BBGT[BBGT_keep_mask, :]205BBGT_keep_index = np.where(overlaps > 0)[0]206207def calcoverlaps(BBGT_keep, bb):208overlaps = []209for index, GT in enumerate(BBGT_keep):210211overlap = polyiou.iou_poly(polyiou.VectorDouble(BBGT_keep[index]), polyiou.VectorDouble(bb))212overlaps.append(overlap)213return overlaps214if len(BBGT_keep) > 0:215overlaps = calcoverlaps(BBGT_keep, bb)216217ovmax = np.max(overlaps)218jmax = np.argmax(overlaps)219# pdb.set_trace()220jmax = BBGT_keep_index[jmax]221222if ovmax > ovthresh:223if not R['difficult'][jmax]:224if not R['det'][jmax]:225tp[d] = 1.226R['det'][jmax] = 1227else:228fp[d] = 1.229else:230fp[d] = 1.231232# compute precision recall233234print('check fp:', fp)235print('check tp', tp)236237238print('npos num:', npos)239fp = np.cumsum(fp)240tp = np.cumsum(tp)241242rec = tp / float(npos) # recall243# avoid divide by zero in case the first detection matches a difficult244# ground truth245prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) # 准确率246ap = voc_ap(rec, prec, use_07_metric)247248return rec, prec, ap249250def GetFileFromThisRootDir(dir,ext = None):251allfiles = []252needExtFilter = (ext != None)253for root,dirs,files in os.walk(dir):254for filespath in files:255filepath = os.path.join(root, filespath)256extension = os.path.splitext(filepath)[1][1:]257if needExtFilter and extension in ext:258allfiles.append(filepath)259elif not needExtFilter:260allfiles.append(filepath)261return allfiles262263def image2txt(srcpath, dstpath):264"""265将srcpath文件夹下的所有子文件名称打印到namefile.txt中266@param srcpath: imageset267@param dstpath: imgnamefile.txt的存放路径268"""269filelist = GetFileFromThisRootDir(srcpath) # srcpath文件夹下的所有文件相对路径 eg:['example_split/../P0001.txt', ..., '?.txt']270for fullname in filelist: # 'example_split/../P0001.txt'271name = os.path.basename(os.path.splitext(fullname)[0])# 只留下文件名 eg:P0001272dstname = os.path.join(dstpath, 'imgnamefile.txt') # eg: result/imgnamefile.txt273if not os.path.exists(dstpath):274os.makedirs(dstpath)275with open(dstname, 'a') as f:276f.writelines(name + '\n')277278def parse_args():279parser = argparse.ArgumentParser(description='MMDet test (and eval) a model')280parser.add_argument('--detpath', default='runs/val/yolov5t_DroneVehicle_val/splited_obb_prediction_Txt/Task1_{:s}.txt', help='test config file path')281parser.add_argument('--annopath', default='/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/labelTxt/{:s}.txt', help='checkpoint file')282parser.add_argument('--imagesetfile', default='/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/imgnamefile.txt', help='checkpoint file')283args = parser.parse_args()284return args285286def main():287args = parse_args()288# detpath = r'/mnt/SSD/lwt_workdir/data/dota_angle/result_merge_roitran/{:s}.txt'289detpath = args.detpath290annopath = args.annopath291imagesetfile = args.imagesetfile292# For DroneVehicle293classnames=['vehicle']294# For DOTA-v2.0295# classnames = [ 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship',296# 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor',297# 'swimming-pool', 'helicopter', 'container-crane', 'airport', 'helipad']298# For DOTA-v1.5299# classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',300# 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter', 'container-crane']301# For DOTA-v1.0302# classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',303# 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter']304classaps = []305map = 0306skippedClassCount = 0307for classname in classnames:308print('classname:', classname)309detfile = detpath.format(classname)310if not (os.path.exists(detfile)):311skippedClassCount += 1312print('This class is not be detected in your dataset: {:s}'.format(classname))313continue314rec, prec, ap = voc_eval(detpath,315annopath,316imagesetfile,317classname,318ovthresh=0.5,319use_07_metric=True)320map = map + ap321#print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)322print('ap: ', ap)323classaps.append(ap)324325# # umcomment to show p-r curve of each category326# plt.figure(figsize=(8,4))327# plt.xlabel('Recall')328# plt.ylabel('Precision')329# plt.xticks(fontsize=11)330# plt.yticks(fontsize=11)331# plt.xlim(0, 1)332# plt.ylim(0, 1)333# ax = plt.gca()334# ax.spines['top'].set_color('none')335# ax.spines['right'].set_color('none')336# plt.plot(rec, prec)337# # plt.show()338# plt.savefig('pr_curve/{}.png'.format(classname))339map = map/(len(classnames)-skippedClassCount)340print('map:', map)341classaps = 100*np.array(classaps)342print('classaps: ', classaps)343if __name__ == '__main__':344main()345# image2txt('/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/images',346# '/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/')347# image2txt('dataset/dataset_demo_rate1.0_split1024_gap200/images', 'dataset/dataset_demo_rate1.0_split1024_gap200/')348349