Path: blob/master/samples/dnn/shrink_tf_graph_weights.py
16337 views
# This file is part of OpenCV project.1# It is subject to the license terms in the LICENSE file found in the top-level directory2# of this distribution and at http://opencv.org/license.html.3#4# Copyright (C) 2017, Intel Corporation, all rights reserved.5# Third party copyrights are property of their respective owners.6import tensorflow as tf7import struct8import argparse9import numpy as np1011parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')12parser.add_argument('--input', required=True, help='Path to frozen graph.')13parser.add_argument('--output', required=True, help='Path to output graph.')14parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',15help='List of ops which weights are converted.')16args = parser.parse_args()1718DT_FLOAT = 119DT_HALF = 192021# For the frozen graphs, an every node that uses weights connected to Const nodes22# through an Identity node. Usually they're called in the same way with '/read' suffix.23# We'll replace all of them to Cast nodes.2425# Load the model26with tf.gfile.FastGFile(args.input) as f:27graph_def = tf.GraphDef()28graph_def.ParseFromString(f.read())2930# Set of all inputs from desired nodes.31inputs = []32for node in graph_def.node:33if node.op in args.ops:34inputs += node.input3536weightsNodes = []37for node in graph_def.node:38# From the whole inputs we need to keep only an Identity nodes.39if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:40weightsNodes.append(node.input[0])4142# Replace Identity to Cast.43node.op = 'Cast'44node.attr['DstT'].type = DT_FLOAT45node.attr['SrcT'].type = DT_HALF46del node.attr['T']47del node.attr['_class']4849# Convert weights to halfs.50for node in graph_def.node:51if node.name in weightsNodes:52node.attr['dtype'].type = DT_HALF53node.attr['value'].tensor.dtype = DT_HALF5455floats = node.attr['value'].tensor.tensor_content5657floats = struct.unpack('f' * (len(floats) / 4), floats)58halfs = np.array(floats).astype(np.float16).view(np.uint16)59node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)6061tf.train.write_graph(graph_def, "", args.output, as_text=False)626364