Path: blob/master/Model-3/ocr/tfhelpers.py
426 views
# -*- coding: utf-8 -*-1"""2Provide functions and classes:3Graph = Class for loading and using trained models from tensorflow4create_cell = function for creatting RNN cells with wrappers5"""6import tensorflow as tf7from tensorflow.python.ops.rnn_cell_impl import LSTMCell, ResidualWrapper, DropoutWrapper, MultiRNNCell89class Graph():10""" Loading and running isolated tf graph """11def __init__(self, loc, operation='activation', input_name='x'):12"""13loc: location of file containing saved model14operation: name of operation for running the model15input_name: name of input placeholder16"""17self.input = input_name + ":0"18self.graph = tf.Graph()19self.sess = tf.Session(graph=self.graph)20with self.graph.as_default():21saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)22saver.restore(self.sess, loc)23self.op = self.graph.get_operation_by_name(operation).outputs[0]2425def run(self, data):26""" Run the specified operation on given data """27return self.sess.run(self.op, feed_dict={self.input: data})2829def eval_feed(self, feed):30""" Run the specified operation with given feed """31return self.sess.run(self.op, feed_dict=feed)32333435def create_single_cell(cell_fn, num_units, is_residual=False, is_dropout=False, keep_prob=None):36""" Create single RNN cell based on cell_fn"""37cell = cell_fn(num_units)38if is_dropout:39cell = DropoutWrapper(cell, input_keep_prob=keep_prob)40if is_residual:41cell = ResidualWrapper(cell)42return cell434445def create_cell(num_units, num_layers, num_residual_layers, is_dropout=False, keep_prob=None, cell_fn=LSTMCell):46""" Create corresponding number of RNN cells with given wrappers"""47cell_list = []4849for i in range(num_layers):50cell_list.append(create_single_cell(51cell_fn=cell_fn,52num_units=num_units,53is_residual=(i >= num_layers - num_residual_layers),54is_dropout=is_dropout,55keep_prob=keep_prob56))5758if num_layers == 1:59return cell_list[0]60return MultiRNNCell(cell_list)6162