Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
zackhy
GitHub Repository: zackhy/TextClassification
Path: blob/master/train.py
1 views
1
# -*- coding: utf-8 -*-
2
import os
3
import sys
4
import csv
5
import time
6
import json
7
import datetime
8
import pickle as pkl
9
import tensorflow as tf
10
from tensorflow.contrib import learn
11
12
import data_helper
13
from rnn_classifier import rnn_clf
14
from cnn_classifier import cnn_clf
15
from clstm_classifier import clstm_clf
16
17
try:
18
from sklearn.model_selection import train_test_split
19
except ImportError as e:
20
error = "Please install scikit-learn."
21
print(str(e) + ': ' + error)
22
sys.exit()
23
24
# Show warnings and errors only
25
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
26
27
# Parameters
28
# =============================================================================
29
30
# Model choices
31
tf.flags.DEFINE_string('clf', 'cnn', "Type of classifiers. Default: cnn. You have four choices: [cnn, lstm, blstm, clstm]")
32
33
# Data parameters
34
tf.flags.DEFINE_string('data_file', None, 'Data file path')
35
tf.flags.DEFINE_string('stop_word_file', None, 'Stop word file path')
36
tf.flags.DEFINE_string('language', 'en', "Language of the data file. You have two choices: [ch, en]")
37
tf.flags.DEFINE_integer('min_frequency', 0, 'Minimal word frequency')
38
tf.flags.DEFINE_integer('num_classes', 2, 'Number of classes')
39
tf.flags.DEFINE_integer('max_length', 0, 'Max document length')
40
tf.flags.DEFINE_integer('vocab_size', 0, 'Vocabulary size')
41
tf.flags.DEFINE_float('test_size', 0.1, 'Cross validation test size')
42
43
# Model hyperparameters
44
tf.flags.DEFINE_integer('embedding_size', 256, 'Word embedding size. For CNN, C-LSTM.')
45
tf.flags.DEFINE_string('filter_sizes', '3, 4, 5', 'CNN filter sizes. For CNN, C-LSTM.')
46
tf.flags.DEFINE_integer('num_filters', 128, 'Number of filters per filter size. For CNN, C-LSTM.')
47
tf.flags.DEFINE_integer('hidden_size', 128, 'Number of hidden units in the LSTM cell. For LSTM, Bi-LSTM')
48
tf.flags.DEFINE_integer('num_layers', 2, 'Number of the LSTM cells. For LSTM, Bi-LSTM, C-LSTM')
49
tf.flags.DEFINE_float('keep_prob', 0.5, 'Dropout keep probability') # All
50
tf.flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate') # All
51
tf.flags.DEFINE_float('l2_reg_lambda', 0.001, 'L2 regularization lambda') # All
52
53
# Training parameters
54
tf.flags.DEFINE_integer('batch_size', 32, 'Batch size')
55
tf.flags.DEFINE_integer('num_epochs', 50, 'Number of epochs')
56
tf.flags.DEFINE_float('decay_rate', 1, 'Learning rate decay rate. Range: (0, 1]') # Learning rate decay
57
tf.flags.DEFINE_integer('decay_steps', 100000, 'Learning rate decay steps') # Learning rate decay
58
tf.flags.DEFINE_integer('evaluate_every_steps', 100, 'Evaluate the model on validation set after this many steps')
59
tf.flags.DEFINE_integer('save_every_steps', 1000, 'Save the model after this many steps')
60
tf.flags.DEFINE_integer('num_checkpoint', 10, 'Number of models to store')
61
62
FLAGS = tf.app.flags.FLAGS
63
64
if FLAGS.clf == 'lstm':
65
FLAGS.embedding_size = FLAGS.hidden_size
66
elif FLAGS.clf == 'clstm':
67
FLAGS.hidden_size = len(FLAGS.filter_sizes.split(",")) * FLAGS.num_filters
68
69
# Output files directory
70
timestamp = str(int(time.time()))
71
outdir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
72
if not os.path.exists(outdir):
73
os.makedirs(outdir)
74
75
# Load and save data
76
# =============================================================================
77
78
data, labels, lengths, vocab_processor = data_helper.load_data(file_path=FLAGS.data_file,
79
sw_path=FLAGS.stop_word_file,
80
min_frequency=FLAGS.min_frequency,
81
max_length=FLAGS.max_length,
82
language=FLAGS.language,
83
shuffle=True)
84
85
# Save vocabulary processor
86
vocab_processor.save(os.path.join(outdir, 'vocab'))
87
88
FLAGS.vocab_size = len(vocab_processor.vocabulary_._mapping)
89
90
FLAGS.max_length = vocab_processor.max_document_length
91
92
params = FLAGS.flag_values_dict()
93
# Print parameters
94
model = params['clf']
95
if model == 'cnn':
96
del params['hidden_size']
97
del params['num_layers']
98
elif model == 'lstm' or model == 'blstm':
99
del params['num_filters']
100
del params['filter_sizes']
101
params['embedding_size'] = params['hidden_size']
102
elif model == 'clstm':
103
params['hidden_size'] = len(list(map(int, params['filter_sizes'].split(",")))) * params['num_filters']
104
105
params_dict = sorted(params.items(), key=lambda x: x[0])
106
print('Parameters:')
107
for item in params_dict:
108
print('{}: {}'.format(item[0], item[1]))
109
print('')
110
111
# Save parameters to file
112
params_file = open(os.path.join(outdir, 'params.pkl'), 'wb')
113
pkl.dump(params, params_file, True)
114
params_file.close()
115
116
117
# Simple Cross validation
118
x_train, x_valid, y_train, y_valid, train_lengths, valid_lengths = train_test_split(data,
119
labels,
120
lengths,
121
test_size=FLAGS.test_size,
122
random_state=22)
123
# Batch iterator
124
train_data = data_helper.batch_iter(x_train, y_train, train_lengths, FLAGS.batch_size, FLAGS.num_epochs)
125
126
# Train
127
# =============================================================================
128
129
with tf.Graph().as_default():
130
with tf.Session() as sess:
131
if FLAGS.clf == 'cnn':
132
classifier = cnn_clf(FLAGS)
133
elif FLAGS.clf == 'lstm' or FLAGS.clf == 'blstm':
134
classifier = rnn_clf(FLAGS)
135
elif FLAGS.clf == 'clstm':
136
classifier = clstm_clf(FLAGS)
137
else:
138
raise ValueError('clf should be one of [cnn, lstm, blstm, clstm]')
139
140
# Train procedure
141
global_step = tf.Variable(0, name='global_step', trainable=False)
142
# Learning rate decay
143
starter_learning_rate = FLAGS.learning_rate
144
learning_rate = tf.train.exponential_decay(starter_learning_rate,
145
global_step,
146
FLAGS.decay_steps,
147
FLAGS.decay_rate,
148
staircase=True)
149
optimizer = tf.train.AdamOptimizer(learning_rate)
150
grads_and_vars = optimizer.compute_gradients(classifier.cost)
151
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
152
153
# Summaries
154
loss_summary = tf.summary.scalar('Loss', classifier.cost)
155
accuracy_summary = tf.summary.scalar('Accuracy', classifier.accuracy)
156
157
# Train summary
158
train_summary_op = tf.summary.merge_all()
159
train_summary_dir = os.path.join(outdir, 'summaries', 'train')
160
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
161
162
# Validation summary
163
valid_summary_op = tf.summary.merge_all()
164
valid_summary_dir = os.path.join(outdir, 'summaries', 'valid')
165
valid_summary_writer = tf.summary.FileWriter(valid_summary_dir, sess.graph)
166
167
saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoint)
168
169
sess.run(tf.global_variables_initializer())
170
171
172
def run_step(input_data, is_training=True):
173
"""Run one step of the training process."""
174
input_x, input_y, sequence_length = input_data
175
176
fetches = {'step': global_step,
177
'cost': classifier.cost,
178
'accuracy': classifier.accuracy,
179
'learning_rate': learning_rate}
180
feed_dict = {classifier.input_x: input_x,
181
classifier.input_y: input_y}
182
183
if FLAGS.clf != 'cnn':
184
fetches['final_state'] = classifier.final_state
185
feed_dict[classifier.batch_size] = len(input_x)
186
feed_dict[classifier.sequence_length] = sequence_length
187
188
if is_training:
189
fetches['train_op'] = train_op
190
fetches['summaries'] = train_summary_op
191
feed_dict[classifier.keep_prob] = FLAGS.keep_prob
192
else:
193
fetches['summaries'] = valid_summary_op
194
feed_dict[classifier.keep_prob] = 1.0
195
196
vars = sess.run(fetches, feed_dict)
197
step = vars['step']
198
cost = vars['cost']
199
accuracy = vars['accuracy']
200
summaries = vars['summaries']
201
202
# Write summaries to file
203
if is_training:
204
train_summary_writer.add_summary(summaries, step)
205
else:
206
valid_summary_writer.add_summary(summaries, step)
207
208
time_str = datetime.datetime.now().isoformat()
209
print("{}: step: {}, loss: {:g}, accuracy: {:g}".format(time_str, step, cost, accuracy))
210
211
return accuracy
212
213
214
print('Start training ...')
215
216
for train_input in train_data:
217
run_step(train_input, is_training=True)
218
current_step = tf.train.global_step(sess, global_step)
219
220
if current_step % FLAGS.evaluate_every_steps == 0:
221
print('\nValidation')
222
run_step((x_valid, y_valid, valid_lengths), is_training=False)
223
print('')
224
225
if current_step % FLAGS.save_every_steps == 0:
226
save_path = saver.save(sess, os.path.join(outdir, 'model/clf'), current_step)
227
228
print('\nAll the files have been saved to {}\n'.format(outdir))
229
230