Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Aniket025
GitHub Repository: Aniket025/Medical-Prescription-OCR
Path: blob/master/Model-4/models/freeze_graph.py
427 views
1
# EDITED on 10. 9. 2017 for meta graph freezing
2
#
3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
# ==============================================================================
17
"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
18
19
This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
20
variable values stored in a checkpoint file, and output a GraphDef with all of
21
the variable ops converted into const ops containing the values of the
22
variables.
23
24
It's useful to do this when we need to load a single file in C++, especially in
25
environments like mobile or embedded where we may not have access to the
26
RestoreTensor ops and file loading calls that they rely on.
27
28
An example of command-line usage is:
29
bazel build tensorflow/python/tools:freeze_graph && \
30
bazel-bin/tensorflow/python/tools/freeze_graph \
31
--input_graph=some_graph_def.pb \
32
--input_checkpoint=model.ckpt-8361242 \
33
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
34
35
You can also look at freeze_graph_test.py for an example of how to use it.
36
37
"""
38
from __future__ import absolute_import
39
from __future__ import division
40
from __future__ import print_function
41
42
import argparse
43
import sys
44
45
from google.protobuf import text_format
46
47
from tensorflow.contrib.saved_model.python.saved_model import reader
48
from tensorflow.core.framework import graph_pb2
49
from tensorflow.core.protobuf import saver_pb2
50
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
51
from tensorflow.python import pywrap_tensorflow
52
from tensorflow.python.client import session
53
from tensorflow.python.framework import graph_util
54
from tensorflow.python.framework import importer
55
from tensorflow.python.platform import app
56
from tensorflow.python.platform import gfile
57
from tensorflow.python.saved_model import loader
58
from tensorflow.python.saved_model import tag_constants
59
from tensorflow.python.training import saver as saver_lib
60
61
FLAGS = None
62
63
64
def freeze_graph_with_def_protos(input_graph_def,
65
input_saver_def,
66
input_checkpoint,
67
output_node_names,
68
restore_op_name,
69
filename_tensor_name,
70
output_graph,
71
clear_devices,
72
initializer_nodes,
73
variable_names_blacklist="",
74
input_meta_graph_def=None,
75
input_saved_model_dir=None,
76
saved_model_tags=None):
77
"""Converts all variables in a graph and checkpoint into constants."""
78
del restore_op_name, filename_tensor_name # Unused by updated loading code.
79
80
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
81
if (not input_saved_model_dir and
82
not saver_lib.checkpoint_exists(input_checkpoint)):
83
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
84
return -1
85
86
if not output_node_names:
87
print("You need to supply the name of a node to --output_node_names.")
88
return -1
89
90
# Remove all the explicit device specifications for this node. This helps to
91
# make the graph more portable.
92
if clear_devices:
93
if input_meta_graph_def:
94
for node in input_meta_graph_def.graph_def.node:
95
node.device = ""
96
elif input_graph_def:
97
for node in input_graph_def.node:
98
node.device = ""
99
100
if input_graph_def:
101
_ = importer.import_graph_def(input_graph_def, name="")
102
with session.Session() as sess:
103
if input_saver_def:
104
saver = saver_lib.Saver(saver_def=input_saver_def)
105
saver.restore(sess, input_checkpoint)
106
elif input_meta_graph_def:
107
restorer = saver_lib.import_meta_graph(
108
input_meta_graph_def, clear_devices=True)
109
restorer.restore(sess, input_checkpoint)
110
if initializer_nodes:
111
sess.run(initializer_nodes.split(","))
112
elif input_saved_model_dir:
113
if saved_model_tags is None:
114
saved_model_tags = []
115
loader.load(sess, saved_model_tags, input_saved_model_dir)
116
else:
117
var_list = {}
118
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
119
var_to_shape_map = reader.get_variable_to_shape_map()
120
for key in var_to_shape_map:
121
try:
122
tensor = sess.graph.get_tensor_by_name(key + ":0")
123
except KeyError:
124
# This tensor doesn't exist in the graph (for example it's
125
# 'global_step' or a similar housekeeping element) so skip it.
126
continue
127
var_list[key] = tensor
128
saver = saver_lib.Saver(var_list=var_list)
129
saver.restore(sess, input_checkpoint)
130
if initializer_nodes:
131
sess.run(initializer_nodes.split(","))
132
133
variable_names_blacklist = (variable_names_blacklist.split(",")
134
if variable_names_blacklist else None)
135
136
if input_meta_graph_def:
137
output_graph_def = graph_util.convert_variables_to_constants(
138
sess,
139
input_meta_graph_def.graph_def,
140
output_node_names.split(","),
141
variable_names_blacklist=variable_names_blacklist)
142
else:
143
output_graph_def = graph_util.convert_variables_to_constants(
144
sess,
145
input_graph_def,
146
output_node_names.split(","),
147
variable_names_blacklist=variable_names_blacklist)
148
149
# Write GraphDef to file if output path has been given.
150
if output_graph:
151
with gfile.GFile(output_graph, "wb") as f:
152
f.write(output_graph_def.SerializeToString())
153
154
return output_graph_def
155
156
157
def _parse_input_graph_proto(input_graph, input_binary):
158
"""Parser input tensorflow graph into GraphDef proto."""
159
if not gfile.Exists(input_graph):
160
print("Input graph file '" + input_graph + "' does not exist!")
161
return -1
162
input_graph_def = graph_pb2.GraphDef()
163
mode = "rb" if input_binary else "r"
164
with gfile.FastGFile(input_graph, mode) as f:
165
if input_binary:
166
input_graph_def.ParseFromString(f.read())
167
else:
168
text_format.Merge(f.read(), input_graph_def)
169
return input_graph_def
170
171
172
def _parse_input_meta_graph_proto(input_graph, input_binary):
173
"""Parser input tensorflow graph into MetaGraphDef proto."""
174
if not gfile.Exists(input_graph):
175
print("Input meta graph file '" + input_graph + "' does not exist!")
176
return -1
177
input_meta_graph_def = MetaGraphDef()
178
mode = "rb" if input_binary else "r"
179
with gfile.FastGFile(input_graph, mode) as f:
180
if input_binary:
181
input_meta_graph_def.ParseFromString(f.read())
182
else:
183
text_format.Merge(f.read(), input_meta_graph_def)
184
print("Loaded meta graph file '" + input_graph)
185
return input_meta_graph_def
186
187
188
def _parse_input_saver_proto(input_saver, input_binary):
189
"""Parser input tensorflow Saver into SaverDef proto."""
190
if not gfile.Exists(input_saver):
191
print("Input saver file '" + input_saver + "' does not exist!")
192
return -1
193
mode = "rb" if input_binary else "r"
194
with gfile.FastGFile(input_saver, mode) as f:
195
saver_def = saver_pb2.SaverDef()
196
if input_binary:
197
saver_def.ParseFromString(f.read())
198
else:
199
text_format.Merge(f.read(), saver_def)
200
return saver_def
201
202
203
def get_meta_graph_def(saved_model_dir, tag_set):
204
"""Gets MetaGraphDef from SavedModel.
205
206
Returns the MetaGraphDef for the given tag-set and SavedModel directory.
207
208
Args:
209
saved_model_dir: Directory containing the SavedModel to inspect or execute.
210
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
211
separated by ','. For tag-set contains multiple tags, all tags must be
212
passed in.
213
214
Raises:
215
RuntimeError: An error when the given tag-set does not exist in the
216
SavedModel.
217
218
Returns:
219
A MetaGraphDef corresponding to the tag-set.
220
"""
221
saved_model = reader.read_saved_model(saved_model_dir)
222
set_of_tags = set(tag_set.split(','))
223
for meta_graph_def in saved_model.meta_graphs:
224
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
225
return meta_graph_def
226
227
raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
228
' could not be found in SavedModel')
229
230
231
def freeze_graph(input_graph,
232
input_saver,
233
input_binary,
234
input_checkpoint,
235
output_node_names,
236
restore_op_name,
237
filename_tensor_name,
238
output_graph,
239
clear_devices,
240
initializer_nodes,
241
variable_names_blacklist="",
242
input_meta_graph=None,
243
input_saved_model_dir=None,
244
saved_model_tags=tag_constants.SERVING):
245
"""Converts all variables in a graph and checkpoint into constants."""
246
input_graph_def = None
247
if input_saved_model_dir:
248
input_graph_def = get_meta_graph_def(
249
input_saved_model_dir, saved_model_tags).graph_def
250
elif input_graph:
251
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
252
input_meta_graph_def = None
253
if input_meta_graph:
254
input_meta_graph_def = _parse_input_meta_graph_proto(
255
input_meta_graph, input_binary)
256
input_saver_def = None
257
if input_saver:
258
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
259
freeze_graph_with_def_protos(
260
input_graph_def, input_saver_def, input_checkpoint, output_node_names,
261
restore_op_name, filename_tensor_name, output_graph, clear_devices,
262
initializer_nodes, variable_names_blacklist, input_meta_graph_def,
263
input_saved_model_dir, saved_model_tags.split(","))
264
265