Path: blob/master/modules/dnn/test/cityscapes_semsegm_test_enet.py
16347 views
import numpy as np1import sys2import os3import fnmatch4import argparse56try:7import cv2 as cv8except ImportError:9raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '10'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')11try:12import torch13except ImportError:14raise ImportError('Can\'t find pytorch. Please install it by following instructions on the official site')1516from torch.utils.serialization import load_lua17from pascal_semsegm_test_fcn import eval_segm_result, get_conf_mat, get_metrics, DatasetImageFetch, SemSegmEvaluation18from imagenet_cls_test_alexnet import Framework, DnnCaffeModel192021class NormalizePreproc:22def __init__(self):23pass2425@staticmethod26def process(img):27image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)28image_data = np.expand_dims(image_data, 0)29image_data /= 255.030return image_data313233class CityscapesDataFetch(DatasetImageFetch):34img_dir = ''35segm_dir = ''36segm_files = []37colors = []38i = 03940def __init__(self, img_dir, segm_dir, preproc):41self.img_dir = img_dir42self.segm_dir = segm_dir43self.segm_files = sorted([img for img in self.locate('*_color.png', segm_dir)])44self.colors = self.get_colors()45self.data_prepoc = preproc46self.i = 04748@staticmethod49def get_colors():50result = []51colors_list = (52(0, 0, 0), (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),53(250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),54(0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32))5556for c in colors_list:57result.append(DatasetImageFetch.pix_to_c(c))58return result5960def __iter__(self):61return self6263def next(self):64if self.i < len(self.segm_files):65segm_file = self.segm_files[self.i]66segm = cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1]67segm = cv.resize(segm, (1024, 512), interpolation=cv.INTER_NEAREST)6869img_file = self.rreplace(self.img_dir + segm_file[len(self.segm_dir):], 'gtFine_color', 'leftImg8bit')70assert os.path.exists(img_file)71img = cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1]72img = cv.resize(img, (1024, 512))7374self.i += 175gt = self.color_to_gt(segm, self.colors)76img = self.data_prepoc.process(img)77return img, gt78else:79self.i = 080raise StopIteration8182def get_num_classes(self):83return len(self.colors)8485@staticmethod86def locate(pattern, root_path):87for path, dirs, files in os.walk(os.path.abspath(root_path)):88for filename in fnmatch.filter(files, pattern):89yield os.path.join(path, filename)9091@staticmethod92def rreplace(s, old, new, occurrence=1):93li = s.rsplit(old, occurrence)94return new.join(li)959697class TorchModel(Framework):98net = object99100def __init__(self, model_file):101self.net = load_lua(model_file)102103def get_name(self):104return 'Torch'105106def get_output(self, input_blob):107tensor = torch.FloatTensor(input_blob)108out = self.net.forward(tensor).numpy()109return out110111112class DnnTorchModel(DnnCaffeModel):113net = cv.dnn.Net()114115def __init__(self, model_file):116self.net = cv.dnn.readNetFromTorch(model_file)117118def get_output(self, input_blob):119self.net.setBlob("", input_blob)120self.net.forward()121return self.net.getBlob(self.net.getLayerNames()[-1])122123if __name__ == "__main__":124parser = argparse.ArgumentParser()125parser.add_argument("--imgs_dir", help="path to Cityscapes validation images dir, imgsfine/leftImg8bit/val")126parser.add_argument("--segm_dir", help="path to Cityscapes dir with segmentation, gtfine/gtFine/val")127parser.add_argument("--model", help="path to torch model, download it here: "128"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa")129parser.add_argument("--log", help="path to logging file")130args = parser.parse_args()131132prep = NormalizePreproc()133df = CityscapesDataFetch(args.imgs_dir, args.segm_dir, prep)134135fw = [TorchModel(args.model),136DnnTorchModel(args.model)]137138segm_eval = SemSegmEvaluation(args.log)139segm_eval.process(fw, df)140141142