Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/samples/dnn/tf_text_graph_mask_rcnn.py
16337 views
1
import argparse
2
import numpy as np
3
from tf_text_graph_common import *
4
5
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
6
'Mask-RCNN model from TensorFlow Object Detection API. '
7
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
8
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
9
parser.add_argument('--output', required=True, help='Path to output text graph.')
10
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
11
args = parser.parse_args()
12
13
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
14
'FirstStageBoxPredictor/BoxEncodingPredictor',
15
'FirstStageBoxPredictor/ClassPredictor',
16
'CropAndResize',
17
'MaxPool2D',
18
'SecondStageFeatureExtractor',
19
'SecondStageBoxPredictor',
20
'Preprocessor/sub',
21
'Preprocessor/mul',
22
'image_tensor')
23
24
scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
25
'FirstStageFeatureExtractor/Shape',
26
'FirstStageFeatureExtractor/strided_slice',
27
'FirstStageFeatureExtractor/GreaterEqual',
28
'FirstStageFeatureExtractor/LogicalAnd')
29
30
# Load a config file.
31
config = readTextMessage(args.config)
32
config = config['model'][0]['faster_rcnn'][0]
33
num_classes = int(config['num_classes'][0])
34
35
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
36
scales = [float(s) for s in grid_anchor_generator['scales']]
37
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
38
width_stride = float(grid_anchor_generator['width_stride'][0])
39
height_stride = float(grid_anchor_generator['height_stride'][0])
40
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
41
42
print('Number of classes: %d' % num_classes)
43
print('Scales: %s' % str(scales))
44
print('Aspect ratios: %s' % str(aspect_ratios))
45
print('Width stride: %f' % width_stride)
46
print('Height stride: %f' % height_stride)
47
print('Features stride: %f' % features_stride)
48
49
# Read the graph.
50
writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
51
graph_def = parseTextGraph(args.output)
52
53
removeIdentity(graph_def)
54
55
def to_remove(name, op):
56
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
57
58
removeUnusedNodesAndAttrs(to_remove, graph_def)
59
60
61
# Connect input node to the first layer
62
assert(graph_def.node[0].op == 'Placeholder')
63
graph_def.node[1].input.insert(0, graph_def.node[0].name)
64
65
# Temporarily remove top nodes.
66
topNodes = []
67
numCropAndResize = 0
68
while True:
69
node = graph_def.node.pop()
70
topNodes.append(node)
71
if node.op == 'CropAndResize':
72
numCropAndResize += 1
73
if numCropAndResize == 2:
74
break
75
76
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
77
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
78
79
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
80
'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
81
82
addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
83
'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
84
85
# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
86
addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
87
'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
88
89
proposals = NodeDef()
90
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
91
proposals.op = 'PriorBox'
92
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
93
proposals.input.append(graph_def.node[0].name) # image_tensor
94
95
proposals.addAttr('flip', False)
96
proposals.addAttr('clip', True)
97
proposals.addAttr('step', features_stride)
98
proposals.addAttr('offset', 0.0)
99
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
100
101
widths = []
102
heights = []
103
for a in aspect_ratios:
104
for s in scales:
105
ar = np.sqrt(a)
106
heights.append((features_stride**2) * s / ar)
107
widths.append((features_stride**2) * s * ar)
108
109
proposals.addAttr('width', widths)
110
proposals.addAttr('height', heights)
111
112
graph_def.node.extend([proposals])
113
114
# Compare with Reshape_5
115
detectionOut = NodeDef()
116
detectionOut.name = 'detection_out'
117
detectionOut.op = 'DetectionOutput'
118
119
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
120
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
121
detectionOut.input.append('proposals')
122
123
detectionOut.addAttr('num_classes', 2)
124
detectionOut.addAttr('share_location', True)
125
detectionOut.addAttr('background_label_id', 0)
126
detectionOut.addAttr('nms_threshold', 0.7)
127
detectionOut.addAttr('top_k', 6000)
128
detectionOut.addAttr('code_type', "CENTER_SIZE")
129
detectionOut.addAttr('keep_top_k', 100)
130
detectionOut.addAttr('clip', True)
131
132
graph_def.node.extend([detectionOut])
133
134
# Save as text.
135
for node in reversed(topNodes):
136
if node.op != 'CropAndResize':
137
graph_def.node.extend([node])
138
topNodes.pop()
139
else:
140
if numCropAndResize == 1:
141
break
142
else:
143
graph_def.node.extend([node])
144
topNodes.pop()
145
numCropAndResize -= 1
146
147
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
148
149
addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
150
'SecondStageBoxPredictor/Reshape_1/slice',
151
[0, 0, 1], [-1, -1, -1], graph_def)
152
153
addReshape('SecondStageBoxPredictor/Reshape_1/slice',
154
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
155
156
# Replace Flatten subgraph onto a single node.
157
for i in reversed(range(len(graph_def.node))):
158
if graph_def.node[i].op == 'CropAndResize':
159
graph_def.node[i].input.insert(1, 'detection_out')
160
161
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
162
addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
163
164
graph_def.node[i].input.pop()
165
graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
166
167
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
168
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
169
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
170
del graph_def.node[i]
171
172
for node in graph_def.node:
173
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
174
node.op = 'Flatten'
175
node.input.pop()
176
177
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
178
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
179
node.addAttr('loc_pred_transposed', True)
180
181
################################################################################
182
### Postprocessing
183
################################################################################
184
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
185
186
variance = NodeDef()
187
variance.name = 'proposals/variance'
188
variance.op = 'Const'
189
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
190
graph_def.node.extend([variance])
191
192
varianceEncoder = NodeDef()
193
varianceEncoder.name = 'variance_encoded'
194
varianceEncoder.op = 'Mul'
195
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
196
varianceEncoder.input.append(variance.name)
197
varianceEncoder.addAttr('axis', 2)
198
graph_def.node.extend([varianceEncoder])
199
200
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
201
addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
202
203
detectionOut = NodeDef()
204
detectionOut.name = 'detection_out_final'
205
detectionOut.op = 'DetectionOutput'
206
207
detectionOut.input.append('variance_encoded/flatten')
208
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
209
detectionOut.input.append('detection_out/slice/reshape')
210
211
detectionOut.addAttr('num_classes', num_classes)
212
detectionOut.addAttr('share_location', False)
213
detectionOut.addAttr('background_label_id', num_classes + 1)
214
detectionOut.addAttr('nms_threshold', 0.6)
215
detectionOut.addAttr('code_type', "CENTER_SIZE")
216
detectionOut.addAttr('keep_top_k',100)
217
detectionOut.addAttr('clip', True)
218
detectionOut.addAttr('variance_encoded_in_target', True)
219
detectionOut.addAttr('confidence_threshold', 0.3)
220
detectionOut.addAttr('group_by_classes', False)
221
graph_def.node.extend([detectionOut])
222
223
for node in reversed(topNodes):
224
graph_def.node.extend([node])
225
226
for i in reversed(range(len(graph_def.node))):
227
if graph_def.node[i].op == 'CropAndResize':
228
graph_def.node[i].input.insert(1, 'detection_out_final')
229
break
230
231
graph_def.node[-1].name = 'detection_masks'
232
graph_def.node[-1].op = 'Sigmoid'
233
graph_def.node[-1].input.pop()
234
235
# Save as text.
236
graph_def.save(args.output)
237
238