Path: blob/master/samples/dnn/tf_text_graph_mask_rcnn.py
16337 views
import argparse1import numpy as np2from tf_text_graph_common import *34parser = argparse.ArgumentParser(description='Run this script to get a text graph of '5'Mask-RCNN model from TensorFlow Object Detection API. '6'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')7parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')8parser.add_argument('--output', required=True, help='Path to output text graph.')9parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')10args = parser.parse_args()1112scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',13'FirstStageBoxPredictor/BoxEncodingPredictor',14'FirstStageBoxPredictor/ClassPredictor',15'CropAndResize',16'MaxPool2D',17'SecondStageFeatureExtractor',18'SecondStageBoxPredictor',19'Preprocessor/sub',20'Preprocessor/mul',21'image_tensor')2223scopesToIgnore = ('FirstStageFeatureExtractor/Assert',24'FirstStageFeatureExtractor/Shape',25'FirstStageFeatureExtractor/strided_slice',26'FirstStageFeatureExtractor/GreaterEqual',27'FirstStageFeatureExtractor/LogicalAnd')2829# Load a config file.30config = readTextMessage(args.config)31config = config['model'][0]['faster_rcnn'][0]32num_classes = int(config['num_classes'][0])3334grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]35scales = [float(s) for s in grid_anchor_generator['scales']]36aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]37width_stride = float(grid_anchor_generator['width_stride'][0])38height_stride = float(grid_anchor_generator['height_stride'][0])39features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])4041print('Number of classes: %d' % num_classes)42print('Scales: %s' % str(scales))43print('Aspect ratios: %s' % str(aspect_ratios))44print('Width stride: %f' % width_stride)45print('Height stride: %f' % height_stride)46print('Features stride: %f' % features_stride)4748# Read the graph.49writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])50graph_def = parseTextGraph(args.output)5152removeIdentity(graph_def)5354def to_remove(name, op):55return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)5657removeUnusedNodesAndAttrs(to_remove, graph_def)585960# Connect input node to the first layer61assert(graph_def.node[0].op == 'Placeholder')62graph_def.node[1].input.insert(0, graph_def.node[0].name)6364# Temporarily remove top nodes.65topNodes = []66numCropAndResize = 067while True:68node = graph_def.node.pop()69topNodes.append(node)70if node.op == 'CropAndResize':71numCropAndResize += 172if numCropAndResize == 2:73break7475addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',76'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)7778addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',79'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_48081addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',82'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)8384# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd85addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',86'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)8788proposals = NodeDef()89proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)90proposals.op = 'PriorBox'91proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')92proposals.input.append(graph_def.node[0].name) # image_tensor9394proposals.addAttr('flip', False)95proposals.addAttr('clip', True)96proposals.addAttr('step', features_stride)97proposals.addAttr('offset', 0.0)98proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])99100widths = []101heights = []102for a in aspect_ratios:103for s in scales:104ar = np.sqrt(a)105heights.append((features_stride**2) * s / ar)106widths.append((features_stride**2) * s * ar)107108proposals.addAttr('width', widths)109proposals.addAttr('height', heights)110111graph_def.node.extend([proposals])112113# Compare with Reshape_5114detectionOut = NodeDef()115detectionOut.name = 'detection_out'116detectionOut.op = 'DetectionOutput'117118detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')119detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')120detectionOut.input.append('proposals')121122detectionOut.addAttr('num_classes', 2)123detectionOut.addAttr('share_location', True)124detectionOut.addAttr('background_label_id', 0)125detectionOut.addAttr('nms_threshold', 0.7)126detectionOut.addAttr('top_k', 6000)127detectionOut.addAttr('code_type', "CENTER_SIZE")128detectionOut.addAttr('keep_top_k', 100)129detectionOut.addAttr('clip', True)130131graph_def.node.extend([detectionOut])132133# Save as text.134for node in reversed(topNodes):135if node.op != 'CropAndResize':136graph_def.node.extend([node])137topNodes.pop()138else:139if numCropAndResize == 1:140break141else:142graph_def.node.extend([node])143topNodes.pop()144numCropAndResize -= 1145146addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)147148addSlice('SecondStageBoxPredictor/Reshape_1/softmax',149'SecondStageBoxPredictor/Reshape_1/slice',150[0, 0, 1], [-1, -1, -1], graph_def)151152addReshape('SecondStageBoxPredictor/Reshape_1/slice',153'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)154155# Replace Flatten subgraph onto a single node.156for i in reversed(range(len(graph_def.node))):157if graph_def.node[i].op == 'CropAndResize':158graph_def.node[i].input.insert(1, 'detection_out')159160if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':161addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)162163graph_def.node[i].input.pop()164graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')165166if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',167'SecondStageBoxPredictor/Flatten/flatten/strided_slice',168'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:169del graph_def.node[i]170171for node in graph_def.node:172if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':173node.op = 'Flatten'174node.input.pop()175176if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',177'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:178node.addAttr('loc_pred_transposed', True)179180################################################################################181### Postprocessing182################################################################################183addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)184185variance = NodeDef()186variance.name = 'proposals/variance'187variance.op = 'Const'188variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])189graph_def.node.extend([variance])190191varianceEncoder = NodeDef()192varianceEncoder.name = 'variance_encoded'193varianceEncoder.op = 'Mul'194varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')195varianceEncoder.input.append(variance.name)196varianceEncoder.addAttr('axis', 2)197graph_def.node.extend([varianceEncoder])198199addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)200addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)201202detectionOut = NodeDef()203detectionOut.name = 'detection_out_final'204detectionOut.op = 'DetectionOutput'205206detectionOut.input.append('variance_encoded/flatten')207detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')208detectionOut.input.append('detection_out/slice/reshape')209210detectionOut.addAttr('num_classes', num_classes)211detectionOut.addAttr('share_location', False)212detectionOut.addAttr('background_label_id', num_classes + 1)213detectionOut.addAttr('nms_threshold', 0.6)214detectionOut.addAttr('code_type', "CENTER_SIZE")215detectionOut.addAttr('keep_top_k',100)216detectionOut.addAttr('clip', True)217detectionOut.addAttr('variance_encoded_in_target', True)218detectionOut.addAttr('confidence_threshold', 0.3)219detectionOut.addAttr('group_by_classes', False)220graph_def.node.extend([detectionOut])221222for node in reversed(topNodes):223graph_def.node.extend([node])224225for i in reversed(range(len(graph_def.node))):226if graph_def.node[i].op == 'CropAndResize':227graph_def.node[i].input.insert(1, 'detection_out_final')228break229230graph_def.node[-1].name = 'detection_masks'231graph_def.node[-1].op = 'Sigmoid'232graph_def.node[-1].input.pop()233234# Save as text.235graph_def.save(args.output)236237238