Path: blob/master/Model-4/models/graph_optimizer.py
427 views
"""1Usage:2python graph_optimizer.py \3--tf_path ../../tensorflow/ \4--model_folder "path_to_the_model_folder" \5--output_names "activation, accuracy" \6--input_names "x"7"""89import os, argparse10from subprocess import call1112import freeze_graph13import tensorflow as tf1415dir = os.path.dirname(os.path.realpath(__file__))1617fr_name = "_frozen.pb"18op_name = "_optimized.pb"192021def graph_freez(model_folder, output_names):22print("Model folder", model_folder)23checkpoint = tf.train.get_checkpoint_state(model_folder)24print(checkpoint)25checkpoint_path = checkpoint.model_checkpoint_path26output_graph_filename = checkpoint_path + fr_name2728input_saver_def_path = ""29input_binary = True30output_node_names = output_names31restore_op_name = "save/restore_all"32filename_tensor_name = "save/Const:0"33clear_devices = False34input_meta_graph = checkpoint_path + ".meta"3536freeze_graph.freeze_graph(37"", input_saver_def_path, input_binary, checkpoint_path,38output_node_names, restore_op_name, filename_tensor_name,39output_graph_filename, clear_devices, "", "", input_meta_graph)4041return output_graph_filename424344def graph_optimization(tf_path, graph_file, input_names, output_names):45output_file = graph_file[:-len(fr_name)] + op_name46tf_path += "bazel-bin/tensorflow/tools/graph_transforms/transform_graph"4748call([tf_path,49"--in_graph=" + graph_file,50"--out_graph=" + output_file,51"--inputs=" + input_names,52"--outputs=" + output_names,53"""--transforms=54strip_unused_nodes(type=float, shape="1,299,299,3")55fold_constants(ignore_errors=true)56fold_batch_norms57fold_old_batch_norms"""])585960if __name__ == '__main__':61parser = argparse.ArgumentParser(62"Script freezes graph and optimize it for mobile usage")63parser.add_argument(64"--model",65type=str,66help="Path of folder + model name (folder_path/model_name)")67parser.add_argument(68"--input_names",69type=str,70default="",71help="Input node names, comma separated.")72parser.add_argument(73"--output_names",74type=str,75default="",76help="Output node names, comma separated.")77parser.add_argument(78"--tf_path",79type=str,80default="../../tensorflow/",81help="Path to the folder with tensorflow (requires bazel build of graph_transforms)")8283args = parser.parse_args()8485graph = graph_freez(args.model, args.output_names)86graph_optimization(args.tf_path, graph, args.input_names, args.output_names)878889