Path: blob/master/samples/dnn/tf_text_graph_faster_rcnn.py
16337 views
import argparse1import numpy as np2from tf_text_graph_common import *345def createFasterRCNNGraph(modelPath, configPath, outputPath):6scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',7'FirstStageBoxPredictor/BoxEncodingPredictor',8'FirstStageBoxPredictor/ClassPredictor',9'CropAndResize',10'MaxPool2D',11'SecondStageFeatureExtractor',12'SecondStageBoxPredictor',13'Preprocessor/sub',14'Preprocessor/mul',15'image_tensor')1617scopesToIgnore = ('FirstStageFeatureExtractor/Assert',18'FirstStageFeatureExtractor/Shape',19'FirstStageFeatureExtractor/strided_slice',20'FirstStageFeatureExtractor/GreaterEqual',21'FirstStageFeatureExtractor/LogicalAnd')2223# Load a config file.24config = readTextMessage(configPath)25config = config['model'][0]['faster_rcnn'][0]26num_classes = int(config['num_classes'][0])2728grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]29scales = [float(s) for s in grid_anchor_generator['scales']]30aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]31width_stride = float(grid_anchor_generator['width_stride'][0])32height_stride = float(grid_anchor_generator['height_stride'][0])33features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])3435print('Number of classes: %d' % num_classes)36print('Scales: %s' % str(scales))37print('Aspect ratios: %s' % str(aspect_ratios))38print('Width stride: %f' % width_stride)39print('Height stride: %f' % height_stride)40print('Features stride: %f' % features_stride)4142# Read the graph.43writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'])44graph_def = parseTextGraph(outputPath)4546removeIdentity(graph_def)4748def to_remove(name, op):49return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)5051removeUnusedNodesAndAttrs(to_remove, graph_def)525354# Connect input node to the first layer55assert(graph_def.node[0].op == 'Placeholder')56graph_def.node[1].input.insert(0, graph_def.node[0].name)5758# Temporarily remove top nodes.59topNodes = []60while True:61node = graph_def.node.pop()62topNodes.append(node)63if node.op == 'CropAndResize':64break6566addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',67'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)6869addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',70'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_47172addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',73'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)7475# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd76addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',77'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)7879proposals = NodeDef()80proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)81proposals.op = 'PriorBox'82proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')83proposals.input.append(graph_def.node[0].name) # image_tensor8485proposals.addAttr('flip', False)86proposals.addAttr('clip', True)87proposals.addAttr('step', features_stride)88proposals.addAttr('offset', 0.0)89proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])9091widths = []92heights = []93for a in aspect_ratios:94for s in scales:95ar = np.sqrt(a)96heights.append((height_stride**2) * s / ar)97widths.append((width_stride**2) * s * ar)9899proposals.addAttr('width', widths)100proposals.addAttr('height', heights)101102graph_def.node.extend([proposals])103104# Compare with Reshape_5105detectionOut = NodeDef()106detectionOut.name = 'detection_out'107detectionOut.op = 'DetectionOutput'108109detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')110detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')111detectionOut.input.append('proposals')112113detectionOut.addAttr('num_classes', 2)114detectionOut.addAttr('share_location', True)115detectionOut.addAttr('background_label_id', 0)116detectionOut.addAttr('nms_threshold', 0.7)117detectionOut.addAttr('top_k', 6000)118detectionOut.addAttr('code_type', "CENTER_SIZE")119detectionOut.addAttr('keep_top_k', 100)120detectionOut.addAttr('clip', False)121122graph_def.node.extend([detectionOut])123124addConstNode('clip_by_value/lower', [0.0], graph_def)125addConstNode('clip_by_value/upper', [1.0], graph_def)126127clipByValueNode = NodeDef()128clipByValueNode.name = 'detection_out/clip_by_value'129clipByValueNode.op = 'ClipByValue'130clipByValueNode.input.append('detection_out')131clipByValueNode.input.append('clip_by_value/lower')132clipByValueNode.input.append('clip_by_value/upper')133graph_def.node.extend([clipByValueNode])134135# Save as text.136for node in reversed(topNodes):137graph_def.node.extend([node])138139addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)140141addSlice('SecondStageBoxPredictor/Reshape_1/softmax',142'SecondStageBoxPredictor/Reshape_1/slice',143[0, 0, 1], [-1, -1, -1], graph_def)144145addReshape('SecondStageBoxPredictor/Reshape_1/slice',146'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)147148# Replace Flatten subgraph onto a single node.149for i in reversed(range(len(graph_def.node))):150if graph_def.node[i].op == 'CropAndResize':151graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')152153if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':154addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)155156graph_def.node[i].input.pop()157graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')158159if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',160'SecondStageBoxPredictor/Flatten/flatten/strided_slice',161'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:162del graph_def.node[i]163164for node in graph_def.node:165if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':166node.op = 'Flatten'167node.input.pop()168169if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',170'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:171node.addAttr('loc_pred_transposed', True)172173################################################################################174### Postprocessing175################################################################################176addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)177178variance = NodeDef()179variance.name = 'proposals/variance'180variance.op = 'Const'181variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])182graph_def.node.extend([variance])183184varianceEncoder = NodeDef()185varianceEncoder.name = 'variance_encoded'186varianceEncoder.op = 'Mul'187varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')188varianceEncoder.input.append(variance.name)189varianceEncoder.addAttr('axis', 2)190graph_def.node.extend([varianceEncoder])191192addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)193addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)194195detectionOut = NodeDef()196detectionOut.name = 'detection_out_final'197detectionOut.op = 'DetectionOutput'198199detectionOut.input.append('variance_encoded/flatten')200detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')201detectionOut.input.append('detection_out/slice/reshape')202203detectionOut.addAttr('num_classes', num_classes)204detectionOut.addAttr('share_location', False)205detectionOut.addAttr('background_label_id', num_classes + 1)206detectionOut.addAttr('nms_threshold', 0.6)207detectionOut.addAttr('code_type', "CENTER_SIZE")208detectionOut.addAttr('keep_top_k', 100)209detectionOut.addAttr('clip', True)210detectionOut.addAttr('variance_encoded_in_target', True)211graph_def.node.extend([detectionOut])212213# Save as text.214graph_def.save(outputPath)215216217if __name__ == "__main__":218parser = argparse.ArgumentParser(description='Run this script to get a text graph of '219'Faster-RCNN model from TensorFlow Object Detection API. '220'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')221parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')222parser.add_argument('--output', required=True, help='Path to output text graph.')223parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')224args = parser.parse_args()225226createFasterRCNNGraph(args.input, args.config, args.output)227228229