Path: blob/master/samples/dnn/tf_text_graph_common.py
16337 views
def tokenize(s):1tokens = []2token = ""3isString = False4isComment = False5for symbol in s:6isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')7if isComment:8continue910if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \11symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \12symbol == ',':1314if (symbol == '\"' or symbol == '\'') and isString:15tokens.append(token)16token = ""17else:18if isString:19token += symbol20elif token:21tokens.append(token)22token = ""23isString = (symbol == '\"' or symbol == '\'') ^ isString;2425elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':26if token:27tokens.append(token)28token = ""29tokens.append(symbol)30else:31token += symbol32if token:33tokens.append(token)34return tokens353637def parseMessage(tokens, idx):38msg = {}39assert(tokens[idx] == '{')4041isArray = False42while True:43if not isArray:44idx += 145if idx < len(tokens):46fieldName = tokens[idx]47else:48return None49if fieldName == '}':50break5152idx += 153fieldValue = tokens[idx]5455if fieldValue == '{':56embeddedMsg, idx = parseMessage(tokens, idx)57if fieldName in msg:58msg[fieldName].append(embeddedMsg)59else:60msg[fieldName] = [embeddedMsg]61elif fieldValue == '[':62isArray = True63elif fieldValue == ']':64isArray = False65else:66if fieldName in msg:67msg[fieldName].append(fieldValue)68else:69msg[fieldName] = [fieldValue]70return msg, idx717273def readTextMessage(filePath):74if not filePath:75return {}76with open(filePath, 'rt') as f:77content = f.read()7879tokens = tokenize('{' + content + '}')80msg = parseMessage(tokens, 0)81return msg[0] if msg else {}828384def listToTensor(values):85if all([isinstance(v, float) for v in values]):86dtype = 'DT_FLOAT'87field = 'float_val'88elif all([isinstance(v, int) for v in values]):89dtype = 'DT_INT32'90field = 'int_val'91else:92raise Exception('Wrong values types')9394msg = {95'tensor': {96'dtype': dtype,97'tensor_shape': {98'dim': {99'size': len(values)100}101}102}103}104msg['tensor'][field] = values105return msg106107108def addConstNode(name, values, graph_def):109node = NodeDef()110node.name = name111node.op = 'Const'112node.addAttr('value', values)113graph_def.node.extend([node])114115116def addSlice(inp, out, begins, sizes, graph_def):117beginsNode = NodeDef()118beginsNode.name = out + '/begins'119beginsNode.op = 'Const'120beginsNode.addAttr('value', begins)121graph_def.node.extend([beginsNode])122123sizesNode = NodeDef()124sizesNode.name = out + '/sizes'125sizesNode.op = 'Const'126sizesNode.addAttr('value', sizes)127graph_def.node.extend([sizesNode])128129sliced = NodeDef()130sliced.name = out131sliced.op = 'Slice'132sliced.input.append(inp)133sliced.input.append(beginsNode.name)134sliced.input.append(sizesNode.name)135graph_def.node.extend([sliced])136137138def addReshape(inp, out, shape, graph_def):139shapeNode = NodeDef()140shapeNode.name = out + '/shape'141shapeNode.op = 'Const'142shapeNode.addAttr('value', shape)143graph_def.node.extend([shapeNode])144145reshape = NodeDef()146reshape.name = out147reshape.op = 'Reshape'148reshape.input.append(inp)149reshape.input.append(shapeNode.name)150graph_def.node.extend([reshape])151152153def addSoftMax(inp, out, graph_def):154softmax = NodeDef()155softmax.name = out156softmax.op = 'Softmax'157softmax.addAttr('axis', -1)158softmax.input.append(inp)159graph_def.node.extend([softmax])160161162def addFlatten(inp, out, graph_def):163flatten = NodeDef()164flatten.name = out165flatten.op = 'Flatten'166flatten.input.append(inp)167graph_def.node.extend([flatten])168169170class NodeDef:171def __init__(self):172self.input = []173self.name = ""174self.op = ""175self.attr = {}176177def addAttr(self, key, value):178assert(not key in self.attr)179if isinstance(value, bool):180self.attr[key] = {'b': value}181elif isinstance(value, int):182self.attr[key] = {'i': value}183elif isinstance(value, float):184self.attr[key] = {'f': value}185elif isinstance(value, str):186self.attr[key] = {'s': value}187elif isinstance(value, list):188self.attr[key] = listToTensor(value)189else:190raise Exception('Unknown type of attribute ' + key)191192def Clear(self):193self.input = []194self.name = ""195self.op = ""196self.attr = {}197198199class GraphDef:200def __init__(self):201self.node = []202203def save(self, filePath):204with open(filePath, 'wt') as f:205206def printAttr(d, indent):207indent = ' ' * indent208for key, value in sorted(d.items(), key=lambda x:x[0].lower()):209value = value if isinstance(value, list) else [value]210for v in value:211if isinstance(v, dict):212f.write(indent + key + ' {\n')213printAttr(v, len(indent) + 2)214f.write(indent + '}\n')215else:216isString = False217if isinstance(v, str) and not v.startswith('DT_'):218try:219float(v)220except:221isString = True222223if isinstance(v, bool):224printed = 'true' if v else 'false'225elif v == 'true' or v == 'false':226printed = 'true' if v == 'true' else 'false'227elif isString:228printed = '\"%s\"' % v229else:230printed = str(v)231f.write(indent + key + ': ' + printed + '\n')232233for node in self.node:234f.write('node {\n')235f.write(' name: \"%s\"\n' % node.name)236f.write(' op: \"%s\"\n' % node.op)237for inp in node.input:238f.write(' input: \"%s\"\n' % inp)239for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):240f.write(' attr {\n')241f.write(' key: \"%s\"\n' % key)242f.write(' value {\n')243printAttr(value, 6)244f.write(' }\n')245f.write(' }\n')246f.write('}\n')247248249def parseTextGraph(filePath):250msg = readTextMessage(filePath)251252graph = GraphDef()253for node in msg['node']:254graphNode = NodeDef()255graphNode.name = node['name'][0]256graphNode.op = node['op'][0]257graphNode.input = node['input'] if 'input' in node else []258259if 'attr' in node:260for attr in node['attr']:261graphNode.attr[attr['key'][0]] = attr['value'][0]262263graph.node.append(graphNode)264return graph265266267# Removes Identity nodes268def removeIdentity(graph_def):269identities = {}270for node in graph_def.node:271if node.op == 'Identity':272identities[node.name] = node.input[0]273graph_def.node.remove(node)274275for node in graph_def.node:276for i in range(len(node.input)):277if node.input[i] in identities:278node.input[i] = identities[node.input[i]]279280281def removeUnusedNodesAndAttrs(to_remove, graph_def):282unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',283'Index', 'Tperm', 'is_training', 'Tpaddings']284285removedNodes = []286287for i in reversed(range(len(graph_def.node))):288op = graph_def.node[i].op289name = graph_def.node[i].name290291if op == 'Const' or to_remove(name, op):292if op != 'Const':293removedNodes.append(name)294295del graph_def.node[i]296else:297for attr in unusedAttrs:298if attr in graph_def.node[i].attr:299del graph_def.node[i].attr[attr]300301# Remove references to removed nodes except Const nodes.302for node in graph_def.node:303for i in reversed(range(len(node.input))):304if node.input[i] in removedNodes:305del node.input[i]306307308def writeTextGraph(modelPath, outputPath, outNodes):309try:310import cv2 as cv311312cv.dnn.writeTextGraph(modelPath, outputPath)313except:314import tensorflow as tf315from tensorflow.tools.graph_transforms import TransformGraph316317with tf.gfile.FastGFile(modelPath, 'rb') as f:318graph_def = tf.GraphDef()319graph_def.ParseFromString(f.read())320321graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])322323for node in graph_def.node:324if node.op == 'Const':325if 'value' in node.attr:326del node.attr['value']327328tf.train.write_graph(graph_def, "", outputPath, as_text=True)329330331