Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/test/cityscapes_semsegm_test_enet.py
16347 views
1
import numpy as np
2
import sys
3
import os
4
import fnmatch
5
import argparse
6
7
try:
8
import cv2 as cv
9
except ImportError:
10
raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
11
'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
12
try:
13
import torch
14
except ImportError:
15
raise ImportError('Can\'t find pytorch. Please install it by following instructions on the official site')
16
17
from torch.utils.serialization import load_lua
18
from pascal_semsegm_test_fcn import eval_segm_result, get_conf_mat, get_metrics, DatasetImageFetch, SemSegmEvaluation
19
from imagenet_cls_test_alexnet import Framework, DnnCaffeModel
20
21
22
class NormalizePreproc:
23
def __init__(self):
24
pass
25
26
@staticmethod
27
def process(img):
28
image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)
29
image_data = np.expand_dims(image_data, 0)
30
image_data /= 255.0
31
return image_data
32
33
34
class CityscapesDataFetch(DatasetImageFetch):
35
img_dir = ''
36
segm_dir = ''
37
segm_files = []
38
colors = []
39
i = 0
40
41
def __init__(self, img_dir, segm_dir, preproc):
42
self.img_dir = img_dir
43
self.segm_dir = segm_dir
44
self.segm_files = sorted([img for img in self.locate('*_color.png', segm_dir)])
45
self.colors = self.get_colors()
46
self.data_prepoc = preproc
47
self.i = 0
48
49
@staticmethod
50
def get_colors():
51
result = []
52
colors_list = (
53
(0, 0, 0), (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
54
(250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
55
(0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32))
56
57
for c in colors_list:
58
result.append(DatasetImageFetch.pix_to_c(c))
59
return result
60
61
def __iter__(self):
62
return self
63
64
def next(self):
65
if self.i < len(self.segm_files):
66
segm_file = self.segm_files[self.i]
67
segm = cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1]
68
segm = cv.resize(segm, (1024, 512), interpolation=cv.INTER_NEAREST)
69
70
img_file = self.rreplace(self.img_dir + segm_file[len(self.segm_dir):], 'gtFine_color', 'leftImg8bit')
71
assert os.path.exists(img_file)
72
img = cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1]
73
img = cv.resize(img, (1024, 512))
74
75
self.i += 1
76
gt = self.color_to_gt(segm, self.colors)
77
img = self.data_prepoc.process(img)
78
return img, gt
79
else:
80
self.i = 0
81
raise StopIteration
82
83
def get_num_classes(self):
84
return len(self.colors)
85
86
@staticmethod
87
def locate(pattern, root_path):
88
for path, dirs, files in os.walk(os.path.abspath(root_path)):
89
for filename in fnmatch.filter(files, pattern):
90
yield os.path.join(path, filename)
91
92
@staticmethod
93
def rreplace(s, old, new, occurrence=1):
94
li = s.rsplit(old, occurrence)
95
return new.join(li)
96
97
98
class TorchModel(Framework):
99
net = object
100
101
def __init__(self, model_file):
102
self.net = load_lua(model_file)
103
104
def get_name(self):
105
return 'Torch'
106
107
def get_output(self, input_blob):
108
tensor = torch.FloatTensor(input_blob)
109
out = self.net.forward(tensor).numpy()
110
return out
111
112
113
class DnnTorchModel(DnnCaffeModel):
114
net = cv.dnn.Net()
115
116
def __init__(self, model_file):
117
self.net = cv.dnn.readNetFromTorch(model_file)
118
119
def get_output(self, input_blob):
120
self.net.setBlob("", input_blob)
121
self.net.forward()
122
return self.net.getBlob(self.net.getLayerNames()[-1])
123
124
if __name__ == "__main__":
125
parser = argparse.ArgumentParser()
126
parser.add_argument("--imgs_dir", help="path to Cityscapes validation images dir, imgsfine/leftImg8bit/val")
127
parser.add_argument("--segm_dir", help="path to Cityscapes dir with segmentation, gtfine/gtFine/val")
128
parser.add_argument("--model", help="path to torch model, download it here: "
129
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa")
130
parser.add_argument("--log", help="path to logging file")
131
args = parser.parse_args()
132
133
prep = NormalizePreproc()
134
df = CityscapesDataFetch(args.imgs_dir, args.segm_dir, prep)
135
136
fw = [TorchModel(args.model),
137
DnnTorchModel(args.model)]
138
139
segm_eval = SemSegmEvaluation(args.log)
140
segm_eval.process(fw, df)
141
142