Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/DOTA_devkit/hrsc2016_evaluation.py
Views: 475
# --------------------------------------------------------1# dota_evaluation_task12# Licensed under The MIT License [see LICENSE for details]3# Written by Jian Ding, based on code from Bharath Hariharan4# --------------------------------------------------------56"""7To use the code, users should to config detpath, annopath and imagesetfile8detpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"9search for PATH_TO_BE_CONFIGURED to config the paths10Note, the evaluation is on the large scale images11"""12import xml.etree.ElementTree as ET13import os14#import cPickle15import numpy as np16import matplotlib.pyplot as plt17import polyiou18from functools import partial1920def parse_gt(filename):21"""22:param filename: ground truth file to parse23:return: all instances in a picture24"""25objects = []26with open(filename, 'r') as f:27while True:28line = f.readline()29if line:30splitlines = line.strip().split(' ')31object_struct = {}32if (len(splitlines) < 9):33continue34object_struct['name'] = splitlines[8]3536if (len(splitlines) == 9):37object_struct['difficult'] = 038elif (len(splitlines) == 10):39object_struct['difficult'] = int(splitlines[9])40object_struct['bbox'] = [float(splitlines[0]),41float(splitlines[1]),42float(splitlines[2]),43float(splitlines[3]),44float(splitlines[4]),45float(splitlines[5]),46float(splitlines[6]),47float(splitlines[7])]48objects.append(object_struct)49else:50break51return objects52def voc_ap(rec, prec, use_07_metric=False):53""" ap = voc_ap(rec, prec, [use_07_metric])54Compute VOC AP given precision and recall.55If use_07_metric is true, uses the56VOC 07 11 point method (default:False).57"""58if use_07_metric:59# 11 point metric60ap = 0.61for t in np.arange(0., 1.1, 0.1):62if np.sum(rec >= t) == 0:63p = 064else:65p = np.max(prec[rec >= t])66ap = ap + p / 11.67else:68# correct AP calculation69# first append sentinel values at the end70mrec = np.concatenate(([0.], rec, [1.]))71mpre = np.concatenate(([0.], prec, [0.]))7273# compute the precision envelope74for i in range(mpre.size - 1, 0, -1):75mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])7677# to calculate area under PR curve, look for points78# where X axis (recall) changes value79i = np.where(mrec[1:] != mrec[:-1])[0]8081# and sum (\Delta recall) * prec82ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])83return ap848586def voc_eval(detpath,87annopath,88imagesetfile,89classname,90# cachedir,91ovthresh=0.5,92use_07_metric=False):93"""rec, prec, ap = voc_eval(detpath,94annopath,95imagesetfile,96classname,97[ovthresh],98[use_07_metric])99Top level function that does the PASCAL VOC evaluation.100detpath: Path to detections101detpath.format(classname) should produce the detection results file.102annopath: Path to annotations103annopath.format(imagename) should be the xml annotations file.104imagesetfile: Text file containing the list of images, one image per line.105classname: Category name (duh)106cachedir: Directory for caching the annotations107[ovthresh]: Overlap threshold (default = 0.5)108[use_07_metric]: Whether to use VOC07's 11 point AP computation109(default False)110"""111# assumes detections are in detpath.format(classname)112# assumes annotations are in annopath.format(imagename)113# assumes imagesetfile is a text file with each line an image name114# cachedir caches the annotations in a pickle file115116# first load gt117#if not os.path.isdir(cachedir):118# os.mkdir(cachedir)119#cachefile = os.path.join(cachedir, 'annots.pkl')120# read list of images121with open(imagesetfile, 'r') as f:122lines = f.readlines()123imagenames = [x.strip() for x in lines]124#print('imagenames: ', imagenames)125#if not os.path.isfile(cachefile):126# load annots127recs = {}128for i, imagename in enumerate(imagenames):129#print('parse_files name: ', annopath.format(imagename))130recs[imagename] = parse_gt(annopath.format(imagename))131#if i % 100 == 0:132# print ('Reading annotation for {:d}/{:d}'.format(133# i + 1, len(imagenames)) )134# save135#print ('Saving cached annotations to {:s}'.format(cachefile))136#with open(cachefile, 'w') as f:137# cPickle.dump(recs, f)138#else:139# load140#with open(cachefile, 'r') as f:141# recs = cPickle.load(f)142143# extract gt objects for this class144class_recs = {}145npos = 0146for imagename in imagenames:147R = [obj for obj in recs[imagename] if obj['name'] == classname]148bbox = np.array([x['bbox'] for x in R])149difficult = np.array([x['difficult'] for x in R]).astype(np.bool)150det = [False] * len(R)151npos = npos + sum(~difficult)152class_recs[imagename] = {'bbox': bbox,153'difficult': difficult,154'det': det}155156# read dets from Task1* files157detfile = detpath.format(classname)158with open(detfile, 'r') as f:159lines = f.readlines()160161splitlines = [x.strip().split(' ') for x in lines]162image_ids = [x[0] for x in splitlines]163confidence = np.array([float(x[1]) for x in splitlines])164165#print('check confidence: ', confidence)166167BB = np.array([[float(z) for z in x[2:]] for x in splitlines])168169# sort by confidence170sorted_ind = np.argsort(-confidence)171sorted_scores = np.sort(-confidence)172173#print('check sorted_scores: ', sorted_scores)174#print('check sorted_ind: ', sorted_ind)175176## note the usage only in numpy not for list177BB = BB[sorted_ind, :]178image_ids = [image_ids[x] for x in sorted_ind]179#print('check imge_ids: ', image_ids)180#print('imge_ids len:', len(image_ids))181# go down dets and mark TPs and FPs182nd = len(image_ids)183tp = np.zeros(nd)184fp = np.zeros(nd)185for d in range(nd):186R = class_recs[image_ids[d]]187bb = BB[d, :].astype(float)188ovmax = -np.inf189BBGT = R['bbox'].astype(float)190191## compute det bb with each BBGT192193if BBGT.size > 0:194# compute overlaps195# intersection196197# 1. calculate the overlaps between hbbs, if the iou between hbbs are 0, the iou between obbs are 0, too.198# pdb.set_trace()199BBGT_xmin = np.min(BBGT[:, 0::2], axis=1)200BBGT_ymin = np.min(BBGT[:, 1::2], axis=1)201BBGT_xmax = np.max(BBGT[:, 0::2], axis=1)202BBGT_ymax = np.max(BBGT[:, 1::2], axis=1)203bb_xmin = np.min(bb[0::2])204bb_ymin = np.min(bb[1::2])205bb_xmax = np.max(bb[0::2])206bb_ymax = np.max(bb[1::2])207208ixmin = np.maximum(BBGT_xmin, bb_xmin)209iymin = np.maximum(BBGT_ymin, bb_ymin)210ixmax = np.minimum(BBGT_xmax, bb_xmax)211iymax = np.minimum(BBGT_ymax, bb_ymax)212iw = np.maximum(ixmax - ixmin + 1., 0.)213ih = np.maximum(iymax - iymin + 1., 0.)214inters = iw * ih215216# union217uni = ((bb_xmax - bb_xmin + 1.) * (bb_ymax - bb_ymin + 1.) +218(BBGT_xmax - BBGT_xmin + 1.) *219(BBGT_ymax - BBGT_ymin + 1.) - inters)220221overlaps = inters / uni222223BBGT_keep_mask = overlaps > 0224BBGT_keep = BBGT[BBGT_keep_mask, :]225BBGT_keep_index = np.where(overlaps > 0)[0]226# pdb.set_trace()227def calcoverlaps(BBGT_keep, bb):228overlaps = []229for index, GT in enumerate(BBGT_keep):230231overlap = polyiou.iou_poly(polyiou.VectorDouble(BBGT_keep[index]), polyiou.VectorDouble(bb))232overlaps.append(overlap)233return overlaps234if len(BBGT_keep) > 0:235overlaps = calcoverlaps(BBGT_keep, bb)236237ovmax = np.max(overlaps)238jmax = np.argmax(overlaps)239# pdb.set_trace()240jmax = BBGT_keep_index[jmax]241242if ovmax > ovthresh:243if not R['difficult'][jmax]:244if not R['det'][jmax]:245tp[d] = 1.246R['det'][jmax] = 1247else:248fp[d] = 1.249else:250fp[d] = 1.251252# compute precision recall253254print('check fp:', fp)255print('check tp', tp)256257258print('npos num:', npos)259fp = np.cumsum(fp)260tp = np.cumsum(tp)261262rec = tp / float(npos)263# avoid divide by zero in case the first detection matches a difficult264# ground truth265prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)266ap = voc_ap(rec, prec, use_07_metric)267268return rec, prec, ap269270def main():271272# detpath = r'/mnt/SSD/lwt_workdir/BeyondBoundingBox/hrsc_pkl/result_raw/Task1_{:s}.txt'273detpath = r'/mnt/SSD/lwt_workdir/BeyondBoundingBox/hrsc_pkl/s2a/result_raw/Task1_{:s}.txt'274annopath = r'/mnt/SSD/lwt_workdir/BeyondBoundingBox/data/HRSC2016/Test/labelTxt/{:s}.txt' # change the directory to the path of val/labelTxt, if you want to do evaluation on the valset275imagesetfile = r'/mnt/SSD/lwt_workdir/BeyondBoundingBox/data/HRSC2016/Test/test.txt'276277278# For HRSC2016279classnames = ['ship']280classaps = []281map = 0282for classname in classnames:283print('classname:', classname)284rec, prec, ap = voc_eval(detpath,285annopath,286imagesetfile,287classname,288ovthresh=0.5,289use_07_metric=True)290map = map + ap291#print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)292print('ap: ', ap)293classaps.append(ap)294295# umcomment to show p-r curve of each category296#plt.figure(figsize=(8,4))297#plt.xlabel('recall')298#plt.ylabel('precision')299#plt.plot(rec, prec)300#plt.savefig('PRcurve.png')301# plt.show()302map = map/len(classnames)303print('map:', map)304classaps = 100*np.array(classaps)305print('classaps: ', classaps)306if __name__ == '__main__':307main()308309