Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/samples/dnn/tf_text_graph_faster_rcnn.py
16337 views
1
import argparse
2
import numpy as np
3
from tf_text_graph_common import *
4
5
6
def createFasterRCNNGraph(modelPath, configPath, outputPath):
7
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
8
'FirstStageBoxPredictor/BoxEncodingPredictor',
9
'FirstStageBoxPredictor/ClassPredictor',
10
'CropAndResize',
11
'MaxPool2D',
12
'SecondStageFeatureExtractor',
13
'SecondStageBoxPredictor',
14
'Preprocessor/sub',
15
'Preprocessor/mul',
16
'image_tensor')
17
18
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
19
'FirstStageFeatureExtractor/Shape',
20
'FirstStageFeatureExtractor/strided_slice',
21
'FirstStageFeatureExtractor/GreaterEqual',
22
'FirstStageFeatureExtractor/LogicalAnd')
23
24
# Load a config file.
25
config = readTextMessage(configPath)
26
config = config['model'][0]['faster_rcnn'][0]
27
num_classes = int(config['num_classes'][0])
28
29
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
30
scales = [float(s) for s in grid_anchor_generator['scales']]
31
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
32
width_stride = float(grid_anchor_generator['width_stride'][0])
33
height_stride = float(grid_anchor_generator['height_stride'][0])
34
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
35
36
print('Number of classes: %d' % num_classes)
37
print('Scales: %s' % str(scales))
38
print('Aspect ratios: %s' % str(aspect_ratios))
39
print('Width stride: %f' % width_stride)
40
print('Height stride: %f' % height_stride)
41
print('Features stride: %f' % features_stride)
42
43
# Read the graph.
44
writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'])
45
graph_def = parseTextGraph(outputPath)
46
47
removeIdentity(graph_def)
48
49
def to_remove(name, op):
50
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
51
52
removeUnusedNodesAndAttrs(to_remove, graph_def)
53
54
55
# Connect input node to the first layer
56
assert(graph_def.node[0].op == 'Placeholder')
57
graph_def.node[1].input.insert(0, graph_def.node[0].name)
58
59
# Temporarily remove top nodes.
60
topNodes = []
61
while True:
62
node = graph_def.node.pop()
63
topNodes.append(node)
64
if node.op == 'CropAndResize':
65
break
66
67
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
68
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
69
70
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
71
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
72
73
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
74
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
75
76
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
77
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
78
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
79
80
proposals = NodeDef()
81
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
82
proposals.op = 'PriorBox'
83
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
84
proposals.input.append(graph_def.node[0].name) # image_tensor
85
86
proposals.addAttr('flip', False)
87
proposals.addAttr('clip', True)
88
proposals.addAttr('step', features_stride)
89
proposals.addAttr('offset', 0.0)
90
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
91
92
widths = []
93
heights = []
94
for a in aspect_ratios:
95
for s in scales:
96
ar = np.sqrt(a)
97
heights.append((height_stride**2) * s / ar)
98
widths.append((width_stride**2) * s * ar)
99
100
proposals.addAttr('width', widths)
101
proposals.addAttr('height', heights)
102
103
graph_def.node.extend([proposals])
104
105
# Compare with Reshape_5
106
detectionOut = NodeDef()
107
detectionOut.name = 'detection_out'
108
detectionOut.op = 'DetectionOutput'
109
110
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
111
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
112
detectionOut.input.append('proposals')
113
114
detectionOut.addAttr('num_classes', 2)
115
detectionOut.addAttr('share_location', True)
116
detectionOut.addAttr('background_label_id', 0)
117
detectionOut.addAttr('nms_threshold', 0.7)
118
detectionOut.addAttr('top_k', 6000)
119
detectionOut.addAttr('code_type', "CENTER_SIZE")
120
detectionOut.addAttr('keep_top_k', 100)
121
detectionOut.addAttr('clip', False)
122
123
graph_def.node.extend([detectionOut])
124
125
addConstNode('clip_by_value/lower', [0.0], graph_def)
126
addConstNode('clip_by_value/upper', [1.0], graph_def)
127
128
clipByValueNode = NodeDef()
129
clipByValueNode.name = 'detection_out/clip_by_value'
130
clipByValueNode.op = 'ClipByValue'
131
clipByValueNode.input.append('detection_out')
132
clipByValueNode.input.append('clip_by_value/lower')
133
clipByValueNode.input.append('clip_by_value/upper')
134
graph_def.node.extend([clipByValueNode])
135
136
# Save as text.
137
for node in reversed(topNodes):
138
graph_def.node.extend([node])
139
140
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
141
142
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
143
'SecondStageBoxPredictor/Reshape_1/slice',
144
[0, 0, 1], [-1, -1, -1], graph_def)
145
146
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
147
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
148
149
# Replace Flatten subgraph onto a single node.
150
for i in reversed(range(len(graph_def.node))):
151
if graph_def.node[i].op == 'CropAndResize':
152
graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
153
154
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
155
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
156
157
graph_def.node[i].input.pop()
158
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
159
160
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
161
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
162
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
163
del graph_def.node[i]
164
165
for node in graph_def.node:
166
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
167
node.op = 'Flatten'
168
node.input.pop()
169
170
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
171
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
172
node.addAttr('loc_pred_transposed', True)
173
174
################################################################################
175
### Postprocessing
176
################################################################################
177
addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
178
179
variance = NodeDef()
180
variance.name = 'proposals/variance'
181
variance.op = 'Const'
182
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
183
graph_def.node.extend([variance])
184
185
varianceEncoder = NodeDef()
186
varianceEncoder.name = 'variance_encoded'
187
varianceEncoder.op = 'Mul'
188
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
189
varianceEncoder.input.append(variance.name)
190
varianceEncoder.addAttr('axis', 2)
191
graph_def.node.extend([varianceEncoder])
192
193
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
194
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
195
196
detectionOut = NodeDef()
197
detectionOut.name = 'detection_out_final'
198
detectionOut.op = 'DetectionOutput'
199
200
detectionOut.input.append('variance_encoded/flatten')
201
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
202
detectionOut.input.append('detection_out/slice/reshape')
203
204
detectionOut.addAttr('num_classes', num_classes)
205
detectionOut.addAttr('share_location', False)
206
detectionOut.addAttr('background_label_id', num_classes + 1)
207
detectionOut.addAttr('nms_threshold', 0.6)
208
detectionOut.addAttr('code_type', "CENTER_SIZE")
209
detectionOut.addAttr('keep_top_k', 100)
210
detectionOut.addAttr('clip', True)
211
detectionOut.addAttr('variance_encoded_in_target', True)
212
graph_def.node.extend([detectionOut])
213
214
# Save as text.
215
graph_def.save(outputPath)
216
217
218
if __name__ == "__main__":
219
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
220
'Faster-RCNN model from TensorFlow Object Detection API. '
221
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
222
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
223
parser.add_argument('--output', required=True, help='Path to output text graph.')
224
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
225
args = parser.parse_args()
226
227
createFasterRCNNGraph(args.input, args.config, args.output)
228
229