Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
zackhy
GitHub Repository: zackhy/TextClassification
Path: blob/master/rnn_classifier.py
1 views
1
# -*- coding: utf-8 -*-
2
import tensorflow as tf
3
4
class rnn_clf(object):
5
""""
6
LSTM and Bi-LSTM classifiers for text classification
7
"""
8
def __init__(self, config):
9
self.num_classes = config.num_classes
10
self.vocab_size = config.vocab_size
11
self.hidden_size = config.hidden_size
12
self.num_layers = config.num_layers
13
self.l2_reg_lambda = config.l2_reg_lambda
14
15
# Placeholders
16
self.batch_size = tf.placeholder(dtype=tf.int32, shape=[], name='batch_size')
17
self.input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name='input_x')
18
self.input_y = tf.placeholder(dtype=tf.int64, shape=[None], name='input_y')
19
self.keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
20
self.sequence_length = tf.placeholder(dtype=tf.int32, shape=[None], name='sequence_length')
21
22
# L2 loss
23
self.l2_loss = tf.constant(0.0)
24
25
# Word embedding
26
with tf.device('/cpu:0'), tf.name_scope('embedding'):
27
embedding = tf.get_variable('embedding',
28
shape=[self.vocab_size, self.hidden_size],
29
dtype=tf.float32)
30
inputs = tf.nn.embedding_lookup(embedding, self.input_x)
31
32
# Input dropout
33
self.inputs = tf.nn.dropout(inputs, keep_prob=self.keep_prob)
34
35
# LSTM
36
if config.clf == 'lstm':
37
self.final_state = self.normal_lstm()
38
else:
39
self.final_state = self.bi_lstm()
40
41
# Softmax output layer
42
with tf.name_scope('softmax'):
43
# softmax_w = tf.get_variable('softmax_w', shape=[self.hidden_size, self.num_classes], dtype=tf.float32)
44
if config.clf == 'lstm':
45
softmax_w = tf.get_variable('softmax_w', shape=[self.hidden_size, self.num_classes], dtype=tf.float32)
46
else:
47
softmax_w = tf.get_variable('softmax_w', shape=[2 * self.hidden_size, self.num_classes], dtype=tf.float32)
48
softmax_b = tf.get_variable('softmax_b', shape=[self.num_classes], dtype=tf.float32)
49
50
# L2 regularization for output layer
51
self.l2_loss += tf.nn.l2_loss(softmax_w)
52
self.l2_loss += tf.nn.l2_loss(softmax_b)
53
54
# self.logits = tf.matmul(self.final_state[self.num_layers - 1].h, softmax_w) + softmax_b
55
if config.clf == 'lstm':
56
self.logits = tf.matmul(self.final_state[self.num_layers - 1].h, softmax_w) + softmax_b
57
else:
58
self.logits = tf.matmul(self.final_state, softmax_w) + softmax_b
59
predictions = tf.nn.softmax(self.logits)
60
self.predictions = tf.argmax(predictions, 1, name='predictions')
61
62
# Loss
63
with tf.name_scope('loss'):
64
tvars = tf.trainable_variables()
65
66
# L2 regularization for LSTM weights
67
for tv in tvars:
68
if 'kernel' in tv.name:
69
self.l2_loss += tf.nn.l2_loss(tv)
70
71
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_y,
72
logits=self.logits)
73
self.cost = tf.reduce_mean(losses) + self.l2_reg_lambda * self.l2_loss
74
75
# Accuracy
76
with tf.name_scope('accuracy'):
77
correct_predictions = tf.equal(self.predictions, self.input_y)
78
self.correct_num = tf.reduce_sum(tf.cast(correct_predictions, tf.float32))
79
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name='accuracy')
80
81
def normal_lstm(self):
82
# LSTM Cell
83
cell = tf.contrib.rnn.LSTMCell(self.hidden_size,
84
forget_bias=1.0,
85
state_is_tuple=True,
86
reuse=tf.get_variable_scope().reuse)
87
# Add dropout to cell output
88
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
89
90
# Stacked LSTMs
91
cell = tf.contrib.rnn.MultiRNNCell([cell] * self.num_layers, state_is_tuple=True)
92
93
self._initial_state = cell.zero_state(self.batch_size, dtype=tf.float32)
94
95
# Dynamic LSTM
96
with tf.variable_scope('LSTM'):
97
outputs, state = tf.nn.dynamic_rnn(cell,
98
inputs=self.inputs,
99
initial_state=self._initial_state,
100
sequence_length=self.sequence_length)
101
102
final_state = state
103
104
return final_state
105
106
107
def bi_lstm(self):
108
cell_fw = tf.contrib.rnn.LSTMCell(self.hidden_size,
109
forget_bias=1.0,
110
state_is_tuple=True,
111
reuse=tf.get_variable_scope().reuse)
112
cell_bw = tf.contrib.rnn.LSTMCell(self.hidden_size,
113
forget_bias=1.0,
114
state_is_tuple=True,
115
reuse=tf.get_variable_scope().reuse)
116
117
# Add dropout to cell output
118
cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, output_keep_prob=self.keep_prob)
119
cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, output_keep_prob=self.keep_prob)
120
121
# Stacked LSTMs
122
cell_fw = tf.contrib.rnn.MultiRNNCell([cell_fw] * self.num_layers, state_is_tuple=True)
123
cell_bw = tf.contrib.rnn.MultiRNNCell([cell_bw] * self.num_layers, state_is_tuple=True)
124
125
self._initial_state_fw = cell_fw.zero_state(self.batch_size, dtype=tf.float32)
126
self._initial_state_bw = cell_bw.zero_state(self.batch_size, dtype=tf.float32)
127
128
# Dynamic Bi-LSTM
129
with tf.variable_scope('Bi-LSTM'):
130
_, state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
131
cell_bw,
132
inputs=self.inputs,
133
initial_state_fw=self._initial_state_fw,
134
initial_state_bw=self._initial_state_bw,
135
sequence_length=self.sequence_length)
136
137
state_fw = state[0]
138
state_bw = state[1]
139
output = tf.concat([state_fw[self.num_layers - 1].h, state_bw[self.num_layers - 1].h], 1)
140
141
return output
142
143