Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/test/imagenet_cls_test_alexnet.py
16354 views
1
from __future__ import print_function
2
from abc import ABCMeta, abstractmethod
3
import numpy as np
4
import sys
5
import os
6
import argparse
7
import time
8
9
try:
10
import caffe
11
except ImportError:
12
raise ImportError('Can\'t find Caffe Python module. If you\'ve built it from sources without installation, '
13
'configure environment variable PYTHONPATH to "git/caffe/python" directory')
14
try:
15
import cv2 as cv
16
except ImportError:
17
raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
18
'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
19
20
try:
21
xrange # Python 2
22
except NameError:
23
xrange = range # Python 3
24
25
26
class DataFetch(object):
27
imgs_dir = ''
28
frame_size = 0
29
bgr_to_rgb = False
30
__metaclass__ = ABCMeta
31
32
@abstractmethod
33
def preprocess(self, img):
34
pass
35
36
def get_batch(self, imgs_names):
37
assert type(imgs_names) is list
38
batch = np.zeros((len(imgs_names), 3, self.frame_size, self.frame_size)).astype(np.float32)
39
for i in range(len(imgs_names)):
40
img_name = imgs_names[i]
41
img_file = self.imgs_dir + img_name
42
assert os.path.exists(img_file)
43
img = cv.imread(img_file, cv.IMREAD_COLOR)
44
min_dim = min(img.shape[-3], img.shape[-2])
45
resize_ratio = self.frame_size / float(min_dim)
46
img = cv.resize(img, (0, 0), fx=resize_ratio, fy=resize_ratio)
47
cols = img.shape[1]
48
rows = img.shape[0]
49
y1 = (rows - self.frame_size) / 2
50
y2 = y1 + self.frame_size
51
x1 = (cols - self.frame_size) / 2
52
x2 = x1 + self.frame_size
53
img = img[y1:y2, x1:x2]
54
if self.bgr_to_rgb:
55
img = img[..., ::-1]
56
image_data = img[:, :, 0:3].transpose(2, 0, 1)
57
batch[i] = self.preprocess(image_data)
58
return batch
59
60
61
class MeanBlobFetch(DataFetch):
62
mean_blob = np.ndarray(())
63
64
def __init__(self, frame_size, mean_blob_path, imgs_dir):
65
self.imgs_dir = imgs_dir
66
self.frame_size = frame_size
67
blob = caffe.proto.caffe_pb2.BlobProto()
68
data = open(mean_blob_path, 'rb').read()
69
blob.ParseFromString(data)
70
self.mean_blob = np.array(caffe.io.blobproto_to_array(blob))
71
start = (self.mean_blob.shape[2] - self.frame_size) / 2
72
stop = start + self.frame_size
73
self.mean_blob = self.mean_blob[:, :, start:stop, start:stop][0]
74
75
def preprocess(self, img):
76
return img - self.mean_blob
77
78
79
class MeanChannelsFetch(MeanBlobFetch):
80
def __init__(self, frame_size, imgs_dir):
81
self.imgs_dir = imgs_dir
82
self.frame_size = frame_size
83
self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
84
self.mean_blob[0] *= 104
85
self.mean_blob[1] *= 117
86
self.mean_blob[2] *= 123
87
88
89
class MeanValueFetch(MeanBlobFetch):
90
def __init__(self, frame_size, imgs_dir, bgr_to_rgb):
91
self.imgs_dir = imgs_dir
92
self.frame_size = frame_size
93
self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
94
self.mean_blob *= 117
95
self.bgr_to_rgb = bgr_to_rgb
96
97
98
def get_correct_answers(img_list, img_classes, net_output_blob):
99
correct_answers = 0
100
for i in range(len(img_list)):
101
indexes = np.argsort(net_output_blob[i])[-5:]
102
correct_index = img_classes[img_list[i]]
103
if correct_index in indexes:
104
correct_answers += 1
105
return correct_answers
106
107
108
class Framework(object):
109
in_blob_name = ''
110
out_blob_name = ''
111
112
__metaclass__ = ABCMeta
113
114
@abstractmethod
115
def get_name(self):
116
pass
117
118
@abstractmethod
119
def get_output(self, input_blob):
120
pass
121
122
123
class CaffeModel(Framework):
124
net = caffe.Net
125
need_reshape = False
126
127
def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name, need_reshape=False):
128
caffe.set_mode_cpu()
129
self.net = caffe.Net(prototxt, caffemodel, caffe.TEST)
130
self.in_blob_name = in_blob_name
131
self.out_blob_name = out_blob_name
132
self.need_reshape = need_reshape
133
134
def get_name(self):
135
return 'Caffe'
136
137
def get_output(self, input_blob):
138
if self.need_reshape:
139
self.net.blobs[self.in_blob_name].reshape(*input_blob.shape)
140
return self.net.forward_all(**{self.in_blob_name: input_blob})[self.out_blob_name]
141
142
143
class DnnCaffeModel(Framework):
144
net = object
145
146
def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name):
147
self.net = cv.dnn.readNetFromCaffe(prototxt, caffemodel)
148
self.in_blob_name = in_blob_name
149
self.out_blob_name = out_blob_name
150
151
def get_name(self):
152
return 'DNN'
153
154
def get_output(self, input_blob):
155
self.net.setInput(input_blob, self.in_blob_name)
156
return self.net.forward(self.out_blob_name)
157
158
159
class ClsAccEvaluation:
160
log = sys.stdout
161
img_classes = {}
162
batch_size = 0
163
164
def __init__(self, log_path, img_classes_file, batch_size):
165
self.log = open(log_path, 'w')
166
self.img_classes = self.read_classes(img_classes_file)
167
self.batch_size = batch_size
168
169
@staticmethod
170
def read_classes(img_classes_file):
171
result = {}
172
with open(img_classes_file) as file:
173
for l in file.readlines():
174
result[l.split()[0]] = int(l.split()[1])
175
return result
176
177
def process(self, frameworks, data_fetcher):
178
sorted_imgs_names = sorted(self.img_classes.keys())
179
correct_answers = [0] * len(frameworks)
180
samples_handled = 0
181
blobs_l1_diff = [0] * len(frameworks)
182
blobs_l1_diff_count = [0] * len(frameworks)
183
blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
184
inference_time = [0.0] * len(frameworks)
185
186
for x in xrange(0, len(sorted_imgs_names), self.batch_size):
187
sublist = sorted_imgs_names[x:x + self.batch_size]
188
batch = data_fetcher.get_batch(sublist)
189
190
samples_handled += len(sublist)
191
192
frameworks_out = []
193
fw_accuracy = []
194
for i in range(len(frameworks)):
195
start = time.time()
196
out = frameworks[i].get_output(batch)
197
end = time.time()
198
correct_answers[i] += get_correct_answers(sublist, self.img_classes, out)
199
fw_accuracy.append(100 * correct_answers[i] / float(samples_handled))
200
frameworks_out.append(out)
201
inference_time[i] += end - start
202
print(samples_handled, 'Accuracy for', frameworks[i].get_name() + ':', fw_accuracy[i], file=self.log)
203
print("Inference time, ms ", \
204
frameworks[i].get_name(), inference_time[i] / samples_handled * 1000, file=self.log)
205
206
for i in range(1, len(frameworks)):
207
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
208
diff = np.abs(frameworks_out[0] - frameworks_out[i])
209
l1_diff = np.sum(diff) / diff.size
210
print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)
211
blobs_l1_diff[i] += l1_diff
212
blobs_l1_diff_count[i] += 1
213
if np.max(diff) > blobs_l_inf_diff[i]:
214
blobs_l_inf_diff[i] = np.max(diff)
215
print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)
216
217
self.log.flush()
218
219
for i in range(1, len(blobs_l1_diff)):
220
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
221
print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)
222
223
if __name__ == "__main__":
224
parser = argparse.ArgumentParser()
225
parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
226
parser.add_argument("--img_cls_file", help="path to file with classes ids for images, val.txt file from this "
227
"archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
228
parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
229
"https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt")
230
parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
231
"http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel")
232
parser.add_argument("--log", help="path to logging file")
233
parser.add_argument("--mean", help="path to ImageNet mean blob caffe file, imagenet_mean.binaryproto file from"
234
"this archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
235
parser.add_argument("--batch_size", help="size of images in batch", default=1000)
236
parser.add_argument("--frame_size", help="size of input image", default=227)
237
parser.add_argument("--in_blob", help="name for input blob", default='data')
238
parser.add_argument("--out_blob", help="name for output blob", default='prob')
239
args = parser.parse_args()
240
241
data_fetcher = MeanBlobFetch(args.frame_size, args.mean, args.imgs_dir)
242
243
frameworks = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob),
244
DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]
245
246
acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
247
acc_eval.process(frameworks, data_fetcher)
248
249