Path: blob/master/Model-5/models/freeze_graph.py
427 views
# EDITED on 10. 9. 2017 for meta graph freezing1#2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.15# ==============================================================================16"""Converts checkpoint variables into Const ops in a standalone GraphDef file.1718This script is designed to take a GraphDef proto, a SaverDef proto, and a set of19variable values stored in a checkpoint file, and output a GraphDef with all of20the variable ops converted into const ops containing the values of the21variables.2223It's useful to do this when we need to load a single file in C++, especially in24environments like mobile or embedded where we may not have access to the25RestoreTensor ops and file loading calls that they rely on.2627An example of command-line usage is:28bazel build tensorflow/python/tools:freeze_graph && \29bazel-bin/tensorflow/python/tools/freeze_graph \30--input_graph=some_graph_def.pb \31--input_checkpoint=model.ckpt-8361242 \32--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax3334You can also look at freeze_graph_test.py for an example of how to use it.3536"""37from __future__ import absolute_import38from __future__ import division39from __future__ import print_function4041import argparse42import sys4344from google.protobuf import text_format4546from tensorflow.contrib.saved_model.python.saved_model import reader47from tensorflow.core.framework import graph_pb248from tensorflow.core.protobuf import saver_pb249from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef50from tensorflow.python import pywrap_tensorflow51from tensorflow.python.client import session52from tensorflow.python.framework import graph_util53from tensorflow.python.framework import importer54from tensorflow.python.platform import app55from tensorflow.python.platform import gfile56from tensorflow.python.saved_model import loader57from tensorflow.python.saved_model import tag_constants58from tensorflow.python.training import saver as saver_lib5960FLAGS = None616263def freeze_graph_with_def_protos(input_graph_def,64input_saver_def,65input_checkpoint,66output_node_names,67restore_op_name,68filename_tensor_name,69output_graph,70clear_devices,71initializer_nodes,72variable_names_blacklist="",73input_meta_graph_def=None,74input_saved_model_dir=None,75saved_model_tags=None):76"""Converts all variables in a graph and checkpoint into constants."""77del restore_op_name, filename_tensor_name # Unused by updated loading code.7879# 'input_checkpoint' may be a prefix if we're using Saver V2 format80if (not input_saved_model_dir and81not saver_lib.checkpoint_exists(input_checkpoint)):82print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")83return -18485if not output_node_names:86print("You need to supply the name of a node to --output_node_names.")87return -18889# Remove all the explicit device specifications for this node. This helps to90# make the graph more portable.91if clear_devices:92if input_meta_graph_def:93for node in input_meta_graph_def.graph_def.node:94node.device = ""95elif input_graph_def:96for node in input_graph_def.node:97node.device = ""9899if input_graph_def:100_ = importer.import_graph_def(input_graph_def, name="")101with session.Session() as sess:102if input_saver_def:103saver = saver_lib.Saver(saver_def=input_saver_def)104saver.restore(sess, input_checkpoint)105elif input_meta_graph_def:106restorer = saver_lib.import_meta_graph(107input_meta_graph_def, clear_devices=True)108restorer.restore(sess, input_checkpoint)109if initializer_nodes:110sess.run(initializer_nodes.split(","))111elif input_saved_model_dir:112if saved_model_tags is None:113saved_model_tags = []114loader.load(sess, saved_model_tags, input_saved_model_dir)115else:116var_list = {}117reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)118var_to_shape_map = reader.get_variable_to_shape_map()119for key in var_to_shape_map:120try:121tensor = sess.graph.get_tensor_by_name(key + ":0")122except KeyError:123# This tensor doesn't exist in the graph (for example it's124# 'global_step' or a similar housekeeping element) so skip it.125continue126var_list[key] = tensor127saver = saver_lib.Saver(var_list=var_list)128saver.restore(sess, input_checkpoint)129if initializer_nodes:130sess.run(initializer_nodes.split(","))131132variable_names_blacklist = (variable_names_blacklist.split(",")133if variable_names_blacklist else None)134135if input_meta_graph_def:136output_graph_def = graph_util.convert_variables_to_constants(137sess,138input_meta_graph_def.graph_def,139output_node_names.split(","),140variable_names_blacklist=variable_names_blacklist)141else:142output_graph_def = graph_util.convert_variables_to_constants(143sess,144input_graph_def,145output_node_names.split(","),146variable_names_blacklist=variable_names_blacklist)147148# Write GraphDef to file if output path has been given.149if output_graph:150with gfile.GFile(output_graph, "wb") as f:151f.write(output_graph_def.SerializeToString())152153return output_graph_def154155156def _parse_input_graph_proto(input_graph, input_binary):157"""Parser input tensorflow graph into GraphDef proto."""158if not gfile.Exists(input_graph):159print("Input graph file '" + input_graph + "' does not exist!")160return -1161input_graph_def = graph_pb2.GraphDef()162mode = "rb" if input_binary else "r"163with gfile.FastGFile(input_graph, mode) as f:164if input_binary:165input_graph_def.ParseFromString(f.read())166else:167text_format.Merge(f.read(), input_graph_def)168return input_graph_def169170171def _parse_input_meta_graph_proto(input_graph, input_binary):172"""Parser input tensorflow graph into MetaGraphDef proto."""173if not gfile.Exists(input_graph):174print("Input meta graph file '" + input_graph + "' does not exist!")175return -1176input_meta_graph_def = MetaGraphDef()177mode = "rb" if input_binary else "r"178with gfile.FastGFile(input_graph, mode) as f:179if input_binary:180input_meta_graph_def.ParseFromString(f.read())181else:182text_format.Merge(f.read(), input_meta_graph_def)183print("Loaded meta graph file '" + input_graph)184return input_meta_graph_def185186187def _parse_input_saver_proto(input_saver, input_binary):188"""Parser input tensorflow Saver into SaverDef proto."""189if not gfile.Exists(input_saver):190print("Input saver file '" + input_saver + "' does not exist!")191return -1192mode = "rb" if input_binary else "r"193with gfile.FastGFile(input_saver, mode) as f:194saver_def = saver_pb2.SaverDef()195if input_binary:196saver_def.ParseFromString(f.read())197else:198text_format.Merge(f.read(), saver_def)199return saver_def200201202def get_meta_graph_def(saved_model_dir, tag_set):203"""Gets MetaGraphDef from SavedModel.204205Returns the MetaGraphDef for the given tag-set and SavedModel directory.206207Args:208saved_model_dir: Directory containing the SavedModel to inspect or execute.209tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,210separated by ','. For tag-set contains multiple tags, all tags must be211passed in.212213Raises:214RuntimeError: An error when the given tag-set does not exist in the215SavedModel.216217Returns:218A MetaGraphDef corresponding to the tag-set.219"""220saved_model = reader.read_saved_model(saved_model_dir)221set_of_tags = set(tag_set.split(','))222for meta_graph_def in saved_model.meta_graphs:223if set(meta_graph_def.meta_info_def.tags) == set_of_tags:224return meta_graph_def225226raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +227' could not be found in SavedModel')228229230def freeze_graph(input_graph,231input_saver,232input_binary,233input_checkpoint,234output_node_names,235restore_op_name,236filename_tensor_name,237output_graph,238clear_devices,239initializer_nodes,240variable_names_blacklist="",241input_meta_graph=None,242input_saved_model_dir=None,243saved_model_tags=tag_constants.SERVING):244"""Converts all variables in a graph and checkpoint into constants."""245input_graph_def = None246if input_saved_model_dir:247input_graph_def = get_meta_graph_def(248input_saved_model_dir, saved_model_tags).graph_def249elif input_graph:250input_graph_def = _parse_input_graph_proto(input_graph, input_binary)251input_meta_graph_def = None252if input_meta_graph:253input_meta_graph_def = _parse_input_meta_graph_proto(254input_meta_graph, input_binary)255input_saver_def = None256if input_saver:257input_saver_def = _parse_input_saver_proto(input_saver, input_binary)258freeze_graph_with_def_protos(259input_graph_def, input_saver_def, input_checkpoint, output_node_names,260restore_op_name, filename_tensor_name, output_graph, clear_devices,261initializer_nodes, variable_names_blacklist, input_meta_graph_def,262input_saved_model_dir, saved_model_tags.split(","))263264265