Path: blob/master/modules/dnn/test/pascal_semsegm_test_fcn.py
16356 views
from __future__ import print_function1from abc import ABCMeta, abstractmethod2import numpy as np3import sys4import argparse5import time67from imagenet_cls_test_alexnet import CaffeModel, DnnCaffeModel8try:9import cv2 as cv10except ImportError:11raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '12'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')131415def get_metrics(conf_mat):16pix_accuracy = np.trace(conf_mat) / np.sum(conf_mat)17t = np.sum(conf_mat, 1)18num_cl = np.count_nonzero(t)19assert num_cl20mean_accuracy = np.sum(np.nan_to_num(np.divide(np.diagonal(conf_mat), t))) / num_cl21col_sum = np.sum(conf_mat, 0)22mean_iou = np.sum(23np.nan_to_num(np.divide(np.diagonal(conf_mat), (t + col_sum - np.diagonal(conf_mat))))) / num_cl24return pix_accuracy, mean_accuracy, mean_iou252627def eval_segm_result(net_out):28assert type(net_out) is np.ndarray29assert len(net_out.shape) == 43031channels_dim = 132y_dim = channels_dim + 133x_dim = y_dim + 134res = np.zeros(net_out.shape).astype(np.int)35for i in range(net_out.shape[y_dim]):36for j in range(net_out.shape[x_dim]):37max_ch = np.argmax(net_out[..., i, j])38res[0, max_ch, i, j] = 139return res404142def get_conf_mat(gt, prob):43assert type(gt) is np.ndarray44assert type(prob) is np.ndarray4546conf_mat = np.zeros((gt.shape[0], gt.shape[0]))47for ch_gt in range(conf_mat.shape[0]):48gt_channel = gt[ch_gt, ...]49for ch_pr in range(conf_mat.shape[1]):50prob_channel = prob[ch_pr, ...]51conf_mat[ch_gt][ch_pr] = np.count_nonzero(np.multiply(gt_channel, prob_channel))52return conf_mat535455class MeanChannelsPreproc:56def __init__(self):57pass5859@staticmethod60def process(img):61image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)62mean = np.ones(image_data.shape)63mean[0] *= 10464mean[1] *= 11765mean[2] *= 12366image_data -= mean67image_data = np.expand_dims(image_data, 0)68return image_data697071class DatasetImageFetch(object):72__metaclass__ = ABCMeta73data_prepoc = object7475@abstractmethod76def __iter__(self):77pass7879@abstractmethod80def next(self):81pass8283@staticmethod84def pix_to_c(pix):85return pix[0] * 256 * 256 + pix[1] * 256 + pix[2]8687@staticmethod88def color_to_gt(color_img, colors):89num_classes = len(colors)90gt = np.zeros((num_classes, color_img.shape[0], color_img.shape[1])).astype(np.int)91for img_y in range(color_img.shape[0]):92for img_x in range(color_img.shape[1]):93c = DatasetImageFetch.pix_to_c(color_img[img_y][img_x])94if c in colors:95cls = colors.index(c)96gt[cls][img_y][img_x] = 197return gt9899100class PASCALDataFetch(DatasetImageFetch):101img_dir = ''102segm_dir = ''103names = []104colors = []105i = 0106107def __init__(self, img_dir, segm_dir, names_file, segm_cls_colors_file, preproc):108self.img_dir = img_dir109self.segm_dir = segm_dir110self.colors = self.read_colors(segm_cls_colors_file)111self.data_prepoc = preproc112self.i = 0113114with open(names_file) as f:115for l in f.readlines():116self.names.append(l.rstrip())117118@staticmethod119def read_colors(img_classes_file):120result = []121with open(img_classes_file) as f:122for l in f.readlines():123color = np.array(map(int, l.split()[1:]))124result.append(DatasetImageFetch.pix_to_c(color))125return result126127def __iter__(self):128return self129130def next(self):131if self.i < len(self.names):132name = self.names[self.i]133self.i += 1134segm_file = self.segm_dir + name + ".png"135img_file = self.img_dir + name + ".jpg"136gt = self.color_to_gt(cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1], self.colors)137img = self.data_prepoc.process(cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1])138return img, gt139else:140self.i = 0141raise StopIteration142143def get_num_classes(self):144return len(self.colors)145146147class SemSegmEvaluation:148log = sys.stdout149150def __init__(self, log_path,):151self.log = open(log_path, 'w')152153def process(self, frameworks, data_fetcher):154samples_handled = 0155156conf_mats = [np.zeros((data_fetcher.get_num_classes(), data_fetcher.get_num_classes())) for i in range(len(frameworks))]157blobs_l1_diff = [0] * len(frameworks)158blobs_l1_diff_count = [0] * len(frameworks)159blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)160inference_time = [0.0] * len(frameworks)161162for in_blob, gt in data_fetcher:163frameworks_out = []164samples_handled += 1165for i in range(len(frameworks)):166start = time.time()167out = frameworks[i].get_output(in_blob)168end = time.time()169segm = eval_segm_result(out)170conf_mats[i] += get_conf_mat(gt, segm[0])171frameworks_out.append(out)172inference_time[i] += end - start173174pix_acc, mean_acc, miou = get_metrics(conf_mats[i])175176name = frameworks[i].get_name()177print(samples_handled, 'Pixel accuracy, %s:' % name, 100 * pix_acc, file=self.log)178print(samples_handled, 'Mean accuracy, %s:' % name, 100 * mean_acc, file=self.log)179print(samples_handled, 'Mean IOU, %s:' % name, 100 * miou, file=self.log)180print("Inference time, ms ", \181frameworks[i].get_name(), inference_time[i] / samples_handled * 1000, file=self.log)182183for i in range(1, len(frameworks)):184log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'185diff = np.abs(frameworks_out[0] - frameworks_out[i])186l1_diff = np.sum(diff) / diff.size187print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)188blobs_l1_diff[i] += l1_diff189blobs_l1_diff_count[i] += 1190if np.max(diff) > blobs_l_inf_diff[i]:191blobs_l_inf_diff[i] = np.max(diff)192print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)193194self.log.flush()195196for i in range(1, len(blobs_l1_diff)):197log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'198print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)199200if __name__ == "__main__":201parser = argparse.ArgumentParser()202parser.add_argument("--imgs_dir", help="path to PASCAL VOC 2012 images dir, data/VOC2012/JPEGImages")203parser.add_argument("--segm_dir", help="path to PASCAL VOC 2012 segmentation dir, data/VOC2012/SegmentationClass/")204parser.add_argument("--val_names", help="path to file with validation set image names, download it here: "205"https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/data/pascal/seg11valid.txt")206parser.add_argument("--cls_file", help="path to file with colors for classes, download it here: "207"https://github.com/opencv/opencv/blob/master/samples/data/dnn/pascal-classes.txt")208parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "209"https://github.com/opencv/opencv/blob/master/samples/data/dnn/fcn8s-heavy-pascal.prototxt")210parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "211"http://dl.caffe.berkeleyvision.org/fcn8s-heavy-pascal.caffemodel")212parser.add_argument("--log", help="path to logging file")213parser.add_argument("--in_blob", help="name for input blob", default='data')214parser.add_argument("--out_blob", help="name for output blob", default='score')215args = parser.parse_args()216217prep = MeanChannelsPreproc()218df = PASCALDataFetch(args.imgs_dir, args.segm_dir, args.val_names, args.cls_file, prep)219220fw = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob, True),221DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]222223segm_eval = SemSegmEvaluation(args.log)224segm_eval.process(fw, df)225226227