Path: blob/master/modules/dnn/test/imagenet_cls_test_alexnet.py
16354 views
from __future__ import print_function1from abc import ABCMeta, abstractmethod2import numpy as np3import sys4import os5import argparse6import time78try:9import caffe10except ImportError:11raise ImportError('Can\'t find Caffe Python module. If you\'ve built it from sources without installation, '12'configure environment variable PYTHONPATH to "git/caffe/python" directory')13try:14import cv2 as cv15except ImportError:16raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '17'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')1819try:20xrange # Python 221except NameError:22xrange = range # Python 3232425class DataFetch(object):26imgs_dir = ''27frame_size = 028bgr_to_rgb = False29__metaclass__ = ABCMeta3031@abstractmethod32def preprocess(self, img):33pass3435def get_batch(self, imgs_names):36assert type(imgs_names) is list37batch = np.zeros((len(imgs_names), 3, self.frame_size, self.frame_size)).astype(np.float32)38for i in range(len(imgs_names)):39img_name = imgs_names[i]40img_file = self.imgs_dir + img_name41assert os.path.exists(img_file)42img = cv.imread(img_file, cv.IMREAD_COLOR)43min_dim = min(img.shape[-3], img.shape[-2])44resize_ratio = self.frame_size / float(min_dim)45img = cv.resize(img, (0, 0), fx=resize_ratio, fy=resize_ratio)46cols = img.shape[1]47rows = img.shape[0]48y1 = (rows - self.frame_size) / 249y2 = y1 + self.frame_size50x1 = (cols - self.frame_size) / 251x2 = x1 + self.frame_size52img = img[y1:y2, x1:x2]53if self.bgr_to_rgb:54img = img[..., ::-1]55image_data = img[:, :, 0:3].transpose(2, 0, 1)56batch[i] = self.preprocess(image_data)57return batch585960class MeanBlobFetch(DataFetch):61mean_blob = np.ndarray(())6263def __init__(self, frame_size, mean_blob_path, imgs_dir):64self.imgs_dir = imgs_dir65self.frame_size = frame_size66blob = caffe.proto.caffe_pb2.BlobProto()67data = open(mean_blob_path, 'rb').read()68blob.ParseFromString(data)69self.mean_blob = np.array(caffe.io.blobproto_to_array(blob))70start = (self.mean_blob.shape[2] - self.frame_size) / 271stop = start + self.frame_size72self.mean_blob = self.mean_blob[:, :, start:stop, start:stop][0]7374def preprocess(self, img):75return img - self.mean_blob767778class MeanChannelsFetch(MeanBlobFetch):79def __init__(self, frame_size, imgs_dir):80self.imgs_dir = imgs_dir81self.frame_size = frame_size82self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)83self.mean_blob[0] *= 10484self.mean_blob[1] *= 11785self.mean_blob[2] *= 123868788class MeanValueFetch(MeanBlobFetch):89def __init__(self, frame_size, imgs_dir, bgr_to_rgb):90self.imgs_dir = imgs_dir91self.frame_size = frame_size92self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)93self.mean_blob *= 11794self.bgr_to_rgb = bgr_to_rgb959697def get_correct_answers(img_list, img_classes, net_output_blob):98correct_answers = 099for i in range(len(img_list)):100indexes = np.argsort(net_output_blob[i])[-5:]101correct_index = img_classes[img_list[i]]102if correct_index in indexes:103correct_answers += 1104return correct_answers105106107class Framework(object):108in_blob_name = ''109out_blob_name = ''110111__metaclass__ = ABCMeta112113@abstractmethod114def get_name(self):115pass116117@abstractmethod118def get_output(self, input_blob):119pass120121122class CaffeModel(Framework):123net = caffe.Net124need_reshape = False125126def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name, need_reshape=False):127caffe.set_mode_cpu()128self.net = caffe.Net(prototxt, caffemodel, caffe.TEST)129self.in_blob_name = in_blob_name130self.out_blob_name = out_blob_name131self.need_reshape = need_reshape132133def get_name(self):134return 'Caffe'135136def get_output(self, input_blob):137if self.need_reshape:138self.net.blobs[self.in_blob_name].reshape(*input_blob.shape)139return self.net.forward_all(**{self.in_blob_name: input_blob})[self.out_blob_name]140141142class DnnCaffeModel(Framework):143net = object144145def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name):146self.net = cv.dnn.readNetFromCaffe(prototxt, caffemodel)147self.in_blob_name = in_blob_name148self.out_blob_name = out_blob_name149150def get_name(self):151return 'DNN'152153def get_output(self, input_blob):154self.net.setInput(input_blob, self.in_blob_name)155return self.net.forward(self.out_blob_name)156157158class ClsAccEvaluation:159log = sys.stdout160img_classes = {}161batch_size = 0162163def __init__(self, log_path, img_classes_file, batch_size):164self.log = open(log_path, 'w')165self.img_classes = self.read_classes(img_classes_file)166self.batch_size = batch_size167168@staticmethod169def read_classes(img_classes_file):170result = {}171with open(img_classes_file) as file:172for l in file.readlines():173result[l.split()[0]] = int(l.split()[1])174return result175176def process(self, frameworks, data_fetcher):177sorted_imgs_names = sorted(self.img_classes.keys())178correct_answers = [0] * len(frameworks)179samples_handled = 0180blobs_l1_diff = [0] * len(frameworks)181blobs_l1_diff_count = [0] * len(frameworks)182blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)183inference_time = [0.0] * len(frameworks)184185for x in xrange(0, len(sorted_imgs_names), self.batch_size):186sublist = sorted_imgs_names[x:x + self.batch_size]187batch = data_fetcher.get_batch(sublist)188189samples_handled += len(sublist)190191frameworks_out = []192fw_accuracy = []193for i in range(len(frameworks)):194start = time.time()195out = frameworks[i].get_output(batch)196end = time.time()197correct_answers[i] += get_correct_answers(sublist, self.img_classes, out)198fw_accuracy.append(100 * correct_answers[i] / float(samples_handled))199frameworks_out.append(out)200inference_time[i] += end - start201print(samples_handled, 'Accuracy for', frameworks[i].get_name() + ':', fw_accuracy[i], file=self.log)202print("Inference time, ms ", \203frameworks[i].get_name(), inference_time[i] / samples_handled * 1000, file=self.log)204205for i in range(1, len(frameworks)):206log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'207diff = np.abs(frameworks_out[0] - frameworks_out[i])208l1_diff = np.sum(diff) / diff.size209print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)210blobs_l1_diff[i] += l1_diff211blobs_l1_diff_count[i] += 1212if np.max(diff) > blobs_l_inf_diff[i]:213blobs_l_inf_diff[i] = np.max(diff)214print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)215216self.log.flush()217218for i in range(1, len(blobs_l1_diff)):219log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'220print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)221222if __name__ == "__main__":223parser = argparse.ArgumentParser()224parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")225parser.add_argument("--img_cls_file", help="path to file with classes ids for images, val.txt file from this "226"archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")227parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "228"https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt")229parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "230"http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel")231parser.add_argument("--log", help="path to logging file")232parser.add_argument("--mean", help="path to ImageNet mean blob caffe file, imagenet_mean.binaryproto file from"233"this archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")234parser.add_argument("--batch_size", help="size of images in batch", default=1000)235parser.add_argument("--frame_size", help="size of input image", default=227)236parser.add_argument("--in_blob", help="name for input blob", default='data')237parser.add_argument("--out_blob", help="name for output blob", default='prob')238args = parser.parse_args()239240data_fetcher = MeanBlobFetch(args.frame_size, args.mean, args.imgs_dir)241242frameworks = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob),243DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]244245acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)246acc_eval.process(frameworks, data_fetcher)247248249