Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/samples/dnn/tf_text_graph_common.py
16337 views
1
def tokenize(s):
2
tokens = []
3
token = ""
4
isString = False
5
isComment = False
6
for symbol in s:
7
isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
8
if isComment:
9
continue
10
11
if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
12
symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
13
symbol == ',':
14
15
if (symbol == '\"' or symbol == '\'') and isString:
16
tokens.append(token)
17
token = ""
18
else:
19
if isString:
20
token += symbol
21
elif token:
22
tokens.append(token)
23
token = ""
24
isString = (symbol == '\"' or symbol == '\'') ^ isString;
25
26
elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
27
if token:
28
tokens.append(token)
29
token = ""
30
tokens.append(symbol)
31
else:
32
token += symbol
33
if token:
34
tokens.append(token)
35
return tokens
36
37
38
def parseMessage(tokens, idx):
39
msg = {}
40
assert(tokens[idx] == '{')
41
42
isArray = False
43
while True:
44
if not isArray:
45
idx += 1
46
if idx < len(tokens):
47
fieldName = tokens[idx]
48
else:
49
return None
50
if fieldName == '}':
51
break
52
53
idx += 1
54
fieldValue = tokens[idx]
55
56
if fieldValue == '{':
57
embeddedMsg, idx = parseMessage(tokens, idx)
58
if fieldName in msg:
59
msg[fieldName].append(embeddedMsg)
60
else:
61
msg[fieldName] = [embeddedMsg]
62
elif fieldValue == '[':
63
isArray = True
64
elif fieldValue == ']':
65
isArray = False
66
else:
67
if fieldName in msg:
68
msg[fieldName].append(fieldValue)
69
else:
70
msg[fieldName] = [fieldValue]
71
return msg, idx
72
73
74
def readTextMessage(filePath):
75
if not filePath:
76
return {}
77
with open(filePath, 'rt') as f:
78
content = f.read()
79
80
tokens = tokenize('{' + content + '}')
81
msg = parseMessage(tokens, 0)
82
return msg[0] if msg else {}
83
84
85
def listToTensor(values):
86
if all([isinstance(v, float) for v in values]):
87
dtype = 'DT_FLOAT'
88
field = 'float_val'
89
elif all([isinstance(v, int) for v in values]):
90
dtype = 'DT_INT32'
91
field = 'int_val'
92
else:
93
raise Exception('Wrong values types')
94
95
msg = {
96
'tensor': {
97
'dtype': dtype,
98
'tensor_shape': {
99
'dim': {
100
'size': len(values)
101
}
102
}
103
}
104
}
105
msg['tensor'][field] = values
106
return msg
107
108
109
def addConstNode(name, values, graph_def):
110
node = NodeDef()
111
node.name = name
112
node.op = 'Const'
113
node.addAttr('value', values)
114
graph_def.node.extend([node])
115
116
117
def addSlice(inp, out, begins, sizes, graph_def):
118
beginsNode = NodeDef()
119
beginsNode.name = out + '/begins'
120
beginsNode.op = 'Const'
121
beginsNode.addAttr('value', begins)
122
graph_def.node.extend([beginsNode])
123
124
sizesNode = NodeDef()
125
sizesNode.name = out + '/sizes'
126
sizesNode.op = 'Const'
127
sizesNode.addAttr('value', sizes)
128
graph_def.node.extend([sizesNode])
129
130
sliced = NodeDef()
131
sliced.name = out
132
sliced.op = 'Slice'
133
sliced.input.append(inp)
134
sliced.input.append(beginsNode.name)
135
sliced.input.append(sizesNode.name)
136
graph_def.node.extend([sliced])
137
138
139
def addReshape(inp, out, shape, graph_def):
140
shapeNode = NodeDef()
141
shapeNode.name = out + '/shape'
142
shapeNode.op = 'Const'
143
shapeNode.addAttr('value', shape)
144
graph_def.node.extend([shapeNode])
145
146
reshape = NodeDef()
147
reshape.name = out
148
reshape.op = 'Reshape'
149
reshape.input.append(inp)
150
reshape.input.append(shapeNode.name)
151
graph_def.node.extend([reshape])
152
153
154
def addSoftMax(inp, out, graph_def):
155
softmax = NodeDef()
156
softmax.name = out
157
softmax.op = 'Softmax'
158
softmax.addAttr('axis', -1)
159
softmax.input.append(inp)
160
graph_def.node.extend([softmax])
161
162
163
def addFlatten(inp, out, graph_def):
164
flatten = NodeDef()
165
flatten.name = out
166
flatten.op = 'Flatten'
167
flatten.input.append(inp)
168
graph_def.node.extend([flatten])
169
170
171
class NodeDef:
172
def __init__(self):
173
self.input = []
174
self.name = ""
175
self.op = ""
176
self.attr = {}
177
178
def addAttr(self, key, value):
179
assert(not key in self.attr)
180
if isinstance(value, bool):
181
self.attr[key] = {'b': value}
182
elif isinstance(value, int):
183
self.attr[key] = {'i': value}
184
elif isinstance(value, float):
185
self.attr[key] = {'f': value}
186
elif isinstance(value, str):
187
self.attr[key] = {'s': value}
188
elif isinstance(value, list):
189
self.attr[key] = listToTensor(value)
190
else:
191
raise Exception('Unknown type of attribute ' + key)
192
193
def Clear(self):
194
self.input = []
195
self.name = ""
196
self.op = ""
197
self.attr = {}
198
199
200
class GraphDef:
201
def __init__(self):
202
self.node = []
203
204
def save(self, filePath):
205
with open(filePath, 'wt') as f:
206
207
def printAttr(d, indent):
208
indent = ' ' * indent
209
for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
210
value = value if isinstance(value, list) else [value]
211
for v in value:
212
if isinstance(v, dict):
213
f.write(indent + key + ' {\n')
214
printAttr(v, len(indent) + 2)
215
f.write(indent + '}\n')
216
else:
217
isString = False
218
if isinstance(v, str) and not v.startswith('DT_'):
219
try:
220
float(v)
221
except:
222
isString = True
223
224
if isinstance(v, bool):
225
printed = 'true' if v else 'false'
226
elif v == 'true' or v == 'false':
227
printed = 'true' if v == 'true' else 'false'
228
elif isString:
229
printed = '\"%s\"' % v
230
else:
231
printed = str(v)
232
f.write(indent + key + ': ' + printed + '\n')
233
234
for node in self.node:
235
f.write('node {\n')
236
f.write(' name: \"%s\"\n' % node.name)
237
f.write(' op: \"%s\"\n' % node.op)
238
for inp in node.input:
239
f.write(' input: \"%s\"\n' % inp)
240
for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
241
f.write(' attr {\n')
242
f.write(' key: \"%s\"\n' % key)
243
f.write(' value {\n')
244
printAttr(value, 6)
245
f.write(' }\n')
246
f.write(' }\n')
247
f.write('}\n')
248
249
250
def parseTextGraph(filePath):
251
msg = readTextMessage(filePath)
252
253
graph = GraphDef()
254
for node in msg['node']:
255
graphNode = NodeDef()
256
graphNode.name = node['name'][0]
257
graphNode.op = node['op'][0]
258
graphNode.input = node['input'] if 'input' in node else []
259
260
if 'attr' in node:
261
for attr in node['attr']:
262
graphNode.attr[attr['key'][0]] = attr['value'][0]
263
264
graph.node.append(graphNode)
265
return graph
266
267
268
# Removes Identity nodes
269
def removeIdentity(graph_def):
270
identities = {}
271
for node in graph_def.node:
272
if node.op == 'Identity':
273
identities[node.name] = node.input[0]
274
graph_def.node.remove(node)
275
276
for node in graph_def.node:
277
for i in range(len(node.input)):
278
if node.input[i] in identities:
279
node.input[i] = identities[node.input[i]]
280
281
282
def removeUnusedNodesAndAttrs(to_remove, graph_def):
283
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
284
'Index', 'Tperm', 'is_training', 'Tpaddings']
285
286
removedNodes = []
287
288
for i in reversed(range(len(graph_def.node))):
289
op = graph_def.node[i].op
290
name = graph_def.node[i].name
291
292
if op == 'Const' or to_remove(name, op):
293
if op != 'Const':
294
removedNodes.append(name)
295
296
del graph_def.node[i]
297
else:
298
for attr in unusedAttrs:
299
if attr in graph_def.node[i].attr:
300
del graph_def.node[i].attr[attr]
301
302
# Remove references to removed nodes except Const nodes.
303
for node in graph_def.node:
304
for i in reversed(range(len(node.input))):
305
if node.input[i] in removedNodes:
306
del node.input[i]
307
308
309
def writeTextGraph(modelPath, outputPath, outNodes):
310
try:
311
import cv2 as cv
312
313
cv.dnn.writeTextGraph(modelPath, outputPath)
314
except:
315
import tensorflow as tf
316
from tensorflow.tools.graph_transforms import TransformGraph
317
318
with tf.gfile.FastGFile(modelPath, 'rb') as f:
319
graph_def = tf.GraphDef()
320
graph_def.ParseFromString(f.read())
321
322
graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
323
324
for node in graph_def.node:
325
if node.op == 'Const':
326
if 'value' in node.attr:
327
del node.attr['value']
328
329
tf.train.write_graph(graph_def, "", outputPath, as_text=True)
330
331