Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Aniket025
GitHub Repository: Aniket025/Medical-Prescription-OCR
Path: blob/master/Model-4/models/graph_optimizer.py
427 views
1
"""
2
Usage:
3
python graph_optimizer.py \
4
--tf_path ../../tensorflow/ \
5
--model_folder "path_to_the_model_folder" \
6
--output_names "activation, accuracy" \
7
--input_names "x"
8
"""
9
10
import os, argparse
11
from subprocess import call
12
13
import freeze_graph
14
import tensorflow as tf
15
16
dir = os.path.dirname(os.path.realpath(__file__))
17
18
fr_name = "_frozen.pb"
19
op_name = "_optimized.pb"
20
21
22
def graph_freez(model_folder, output_names):
23
print("Model folder", model_folder)
24
checkpoint = tf.train.get_checkpoint_state(model_folder)
25
print(checkpoint)
26
checkpoint_path = checkpoint.model_checkpoint_path
27
output_graph_filename = checkpoint_path + fr_name
28
29
input_saver_def_path = ""
30
input_binary = True
31
output_node_names = output_names
32
restore_op_name = "save/restore_all"
33
filename_tensor_name = "save/Const:0"
34
clear_devices = False
35
input_meta_graph = checkpoint_path + ".meta"
36
37
freeze_graph.freeze_graph(
38
"", input_saver_def_path, input_binary, checkpoint_path,
39
output_node_names, restore_op_name, filename_tensor_name,
40
output_graph_filename, clear_devices, "", "", input_meta_graph)
41
42
return output_graph_filename
43
44
45
def graph_optimization(tf_path, graph_file, input_names, output_names):
46
output_file = graph_file[:-len(fr_name)] + op_name
47
tf_path += "bazel-bin/tensorflow/tools/graph_transforms/transform_graph"
48
49
call([tf_path,
50
"--in_graph=" + graph_file,
51
"--out_graph=" + output_file,
52
"--inputs=" + input_names,
53
"--outputs=" + output_names,
54
"""--transforms=
55
strip_unused_nodes(type=float, shape="1,299,299,3")
56
fold_constants(ignore_errors=true)
57
fold_batch_norms
58
fold_old_batch_norms"""])
59
60
61
if __name__ == '__main__':
62
parser = argparse.ArgumentParser(
63
"Script freezes graph and optimize it for mobile usage")
64
parser.add_argument(
65
"--model",
66
type=str,
67
help="Path of folder + model name (folder_path/model_name)")
68
parser.add_argument(
69
"--input_names",
70
type=str,
71
default="",
72
help="Input node names, comma separated.")
73
parser.add_argument(
74
"--output_names",
75
type=str,
76
default="",
77
help="Output node names, comma separated.")
78
parser.add_argument(
79
"--tf_path",
80
type=str,
81
default="../../tensorflow/",
82
help="Path to the folder with tensorflow (requires bazel build of graph_transforms)")
83
84
args = parser.parse_args()
85
86
graph = graph_freez(args.model, args.output_names)
87
graph_optimization(args.tf_path, graph, args.input_names, args.output_names)
88
89