Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/samples/dnn/shrink_tf_graph_weights.py
16337 views
1
# This file is part of OpenCV project.
2
# It is subject to the license terms in the LICENSE file found in the top-level directory
3
# of this distribution and at http://opencv.org/license.html.
4
#
5
# Copyright (C) 2017, Intel Corporation, all rights reserved.
6
# Third party copyrights are property of their respective owners.
7
import tensorflow as tf
8
import struct
9
import argparse
10
import numpy as np
11
12
parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
13
parser.add_argument('--input', required=True, help='Path to frozen graph.')
14
parser.add_argument('--output', required=True, help='Path to output graph.')
15
parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
16
help='List of ops which weights are converted.')
17
args = parser.parse_args()
18
19
DT_FLOAT = 1
20
DT_HALF = 19
21
22
# For the frozen graphs, an every node that uses weights connected to Const nodes
23
# through an Identity node. Usually they're called in the same way with '/read' suffix.
24
# We'll replace all of them to Cast nodes.
25
26
# Load the model
27
with tf.gfile.FastGFile(args.input) as f:
28
graph_def = tf.GraphDef()
29
graph_def.ParseFromString(f.read())
30
31
# Set of all inputs from desired nodes.
32
inputs = []
33
for node in graph_def.node:
34
if node.op in args.ops:
35
inputs += node.input
36
37
weightsNodes = []
38
for node in graph_def.node:
39
# From the whole inputs we need to keep only an Identity nodes.
40
if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
41
weightsNodes.append(node.input[0])
42
43
# Replace Identity to Cast.
44
node.op = 'Cast'
45
node.attr['DstT'].type = DT_FLOAT
46
node.attr['SrcT'].type = DT_HALF
47
del node.attr['T']
48
del node.attr['_class']
49
50
# Convert weights to halfs.
51
for node in graph_def.node:
52
if node.name in weightsNodes:
53
node.attr['dtype'].type = DT_HALF
54
node.attr['value'].tensor.dtype = DT_HALF
55
56
floats = node.attr['value'].tensor.tensor_content
57
58
floats = struct.unpack('f' * (len(floats) / 4), floats)
59
halfs = np.array(floats).astype(np.float16).view(np.uint16)
60
node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
61
62
tf.train.write_graph(graph_def, "", args.output, as_text=False)
63
64