Path: blob/master/samples/dnn/mobilenet_ssd_accuracy.py
16337 views
from __future__ import print_function1# Script to evaluate MobileNet-SSD object detection model trained in TensorFlow2# using both TensorFlow and OpenCV. Example:3#4# python mobilenet_ssd_accuracy.py \5# --weights=frozen_inference_graph.pb \6# --prototxt=ssd_mobilenet_v1_coco.pbtxt \7# --images=val2017 \8# --annotations=annotations/instances_val2017.json9#10# Tested on COCO 2017 object detection dataset, http://cocodataset.org/#download11import os12import cv2 as cv13import json14import argparse1516parser = argparse.ArgumentParser(17description='Evaluate MobileNet-SSD model using both TensorFlow and OpenCV. '18'COCO evaluation framework is required: http://cocodataset.org')19parser.add_argument('--weights', required=True,20help='Path to frozen_inference_graph.pb of MobileNet-SSD model. '21'Download it from http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz')22parser.add_argument('--prototxt', help='Path to ssd_mobilenet_v1_coco.pbtxt from opencv_extra.', required=True)23parser.add_argument('--images', help='Path to COCO validation images directory.', required=True)24parser.add_argument('--annotations', help='Path to COCO annotations file.', required=True)25args = parser.parse_args()2627### Get OpenCV predictions #####################################################28net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)29net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV);3031detections = []32for imgName in os.listdir(args.images):33inp = cv.imread(os.path.join(args.images, imgName))34rows = inp.shape[0]35cols = inp.shape[1]36inp = cv.resize(inp, (300, 300))3738net.setInput(cv.dnn.blobFromImage(inp, 1.0/127.5, (300, 300), (127.5, 127.5, 127.5), True))39out = net.forward()4041for i in range(out.shape[2]):42score = float(out[0, 0, i, 2])43# Confidence threshold is in prototxt.44classId = int(out[0, 0, i, 1])4546x = out[0, 0, i, 3] * cols47y = out[0, 0, i, 4] * rows48w = out[0, 0, i, 5] * cols - x49h = out[0, 0, i, 6] * rows - y50detections.append({51"image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),52"category_id": classId,53"bbox": [x, y, w, h],54"score": score55})5657with open('cv_result.json', 'wt') as f:58json.dump(detections, f)5960### Get TensorFlow predictions #################################################61import tensorflow as tf6263with tf.gfile.FastGFile(args.weights) as f:64# Load the model65graph_def = tf.GraphDef()66graph_def.ParseFromString(f.read())6768with tf.Session() as sess:69# Restore session70sess.graph.as_default()71tf.import_graph_def(graph_def, name='')7273detections = []74for imgName in os.listdir(args.images):75inp = cv.imread(os.path.join(args.images, imgName))76rows = inp.shape[0]77cols = inp.shape[1]78inp = cv.resize(inp, (300, 300))79inp = inp[:, :, [2, 1, 0]] # BGR2RGB80out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),81sess.graph.get_tensor_by_name('detection_scores:0'),82sess.graph.get_tensor_by_name('detection_boxes:0'),83sess.graph.get_tensor_by_name('detection_classes:0')],84feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})85num_detections = int(out[0][0])86for i in range(num_detections):87classId = int(out[3][0][i])88score = float(out[1][0][i])89bbox = [float(v) for v in out[2][0][i]]90if score > 0.01:91x = bbox[1] * cols92y = bbox[0] * rows93w = bbox[3] * cols - x94h = bbox[2] * rows - y95detections.append({96"image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),97"category_id": classId,98"bbox": [x, y, w, h],99"score": score100})101102with open('tf_result.json', 'wt') as f:103json.dump(detections, f)104105### Evaluation part ############################################################106107# %matplotlib inline108import matplotlib.pyplot as plt109from pycocotools.coco import COCO110from pycocotools.cocoeval import COCOeval111import numpy as np112import skimage.io as io113import pylab114pylab.rcParams['figure.figsize'] = (10.0, 8.0)115116annType = ['segm','bbox','keypoints']117annType = annType[1] #specify type here118prefix = 'person_keypoints' if annType=='keypoints' else 'instances'119print('Running demo for *%s* results.'%(annType))120121#initialize COCO ground truth api122cocoGt=COCO(args.annotations)123124#initialize COCO detections api125for resFile in ['tf_result.json', 'cv_result.json']:126print(resFile)127cocoDt=cocoGt.loadRes(resFile)128129cocoEval = COCOeval(cocoGt,cocoDt,annType)130cocoEval.evaluate()131cocoEval.accumulate()132cocoEval.summarize()133134135