Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Aniket025
GitHub Repository: Aniket025/Medical-Prescription-OCR
Path: blob/master/Model-3/ocr/tfhelpers.py
426 views
1
# -*- coding: utf-8 -*-
2
"""
3
Provide functions and classes:
4
Graph = Class for loading and using trained models from tensorflow
5
create_cell = function for creatting RNN cells with wrappers
6
"""
7
import tensorflow as tf
8
from tensorflow.python.ops.rnn_cell_impl import LSTMCell, ResidualWrapper, DropoutWrapper, MultiRNNCell
9
10
class Graph():
11
""" Loading and running isolated tf graph """
12
def __init__(self, loc, operation='activation', input_name='x'):
13
"""
14
loc: location of file containing saved model
15
operation: name of operation for running the model
16
input_name: name of input placeholder
17
"""
18
self.input = input_name + ":0"
19
self.graph = tf.Graph()
20
self.sess = tf.Session(graph=self.graph)
21
with self.graph.as_default():
22
saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
23
saver.restore(self.sess, loc)
24
self.op = self.graph.get_operation_by_name(operation).outputs[0]
25
26
def run(self, data):
27
""" Run the specified operation on given data """
28
return self.sess.run(self.op, feed_dict={self.input: data})
29
30
def eval_feed(self, feed):
31
""" Run the specified operation with given feed """
32
return self.sess.run(self.op, feed_dict=feed)
33
34
35
36
def create_single_cell(cell_fn, num_units, is_residual=False, is_dropout=False, keep_prob=None):
37
""" Create single RNN cell based on cell_fn"""
38
cell = cell_fn(num_units)
39
if is_dropout:
40
cell = DropoutWrapper(cell, input_keep_prob=keep_prob)
41
if is_residual:
42
cell = ResidualWrapper(cell)
43
return cell
44
45
46
def create_cell(num_units, num_layers, num_residual_layers, is_dropout=False, keep_prob=None, cell_fn=LSTMCell):
47
""" Create corresponding number of RNN cells with given wrappers"""
48
cell_list = []
49
50
for i in range(num_layers):
51
cell_list.append(create_single_cell(
52
cell_fn=cell_fn,
53
num_units=num_units,
54
is_residual=(i >= num_layers - num_residual_layers),
55
is_dropout=is_dropout,
56
keep_prob=keep_prob
57
))
58
59
if num_layers == 1:
60
return cell_list[0]
61
return MultiRNNCell(cell_list)
62