Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
zackhy
GitHub Repository: zackhy/TextClassification
Path: blob/master/test.py
1 views
1
# -*- coding: utf-8 -*-
2
import os
3
import csv
4
import numpy as np
5
import pickle as pkl
6
import tensorflow as tf
7
from tensorflow.contrib import learn
8
9
import data_helper
10
11
# Show warnings and errors only
12
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
13
14
# File paths
15
tf.flags.DEFINE_string('test_data_file', None, 'Test data file path')
16
tf.flags.DEFINE_string('run_dir', None, 'Restore the model from this run')
17
tf.flags.DEFINE_string('checkpoint', None, 'Restore the graph from this checkpoint')
18
19
# Test batch size
20
tf.flags.DEFINE_integer('batch_size', 64, 'Test batch size')
21
22
FLAGS = tf.app.flags.FLAGS
23
24
# Restore parameters
25
with open(os.path.join(FLAGS.run_dir, 'params.pkl'), 'rb') as f:
26
params = pkl.load(f, encoding='bytes')
27
28
# Restore vocabulary processor
29
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(os.path.join(FLAGS.run_dir, 'vocab'))
30
31
# Load test data
32
data, labels, lengths, _ = data_helper.load_data(file_path=FLAGS.test_data_file,
33
sw_path=params['stop_word_file'],
34
min_frequency=params['min_frequency'],
35
max_length=params['max_length'],
36
language=params['language'],
37
vocab_processor=vocab_processor,
38
shuffle=False)
39
40
# Restore graph
41
graph = tf.Graph()
42
with graph.as_default():
43
sess = tf.Session()
44
# Restore metagraph
45
saver = tf.train.import_meta_graph('{}.meta'.format(os.path.join(FLAGS.run_dir, 'model', FLAGS.checkpoint)))
46
# Restore weights
47
saver.restore(sess, os.path.join(FLAGS.run_dir, 'model', FLAGS.checkpoint))
48
49
# Get tensors
50
input_x = graph.get_tensor_by_name('input_x:0')
51
input_y = graph.get_tensor_by_name('input_y:0')
52
keep_prob = graph.get_tensor_by_name('keep_prob:0')
53
predictions = graph.get_tensor_by_name('softmax/predictions:0')
54
accuracy = graph.get_tensor_by_name('accuracy/accuracy:0')
55
56
# Generate batches
57
batches = data_helper.batch_iter(data, labels, lengths, FLAGS.batch_size, 1)
58
59
num_batches = int(len(data)/FLAGS.batch_size)
60
all_predictions = []
61
sum_accuracy = 0
62
63
# Test
64
for batch in batches:
65
x_test, y_test, x_lengths = batch
66
if params['clf'] == 'cnn':
67
feed_dict = {input_x: x_test, input_y: y_test, keep_prob: 1.0}
68
batch_predictions, batch_accuracy = sess.run([predictions, accuracy], feed_dict)
69
else:
70
batch_size = graph.get_tensor_by_name('batch_size:0')
71
sequence_length = graph.get_tensor_by_name('sequence_length:0')
72
feed_dict = {input_x: x_test, input_y: y_test, batch_size: FLAGS.batch_size, sequence_length: x_lengths, keep_prob: 1.0}
73
74
batch_predictions, batch_accuracy = sess.run([predictions, accuracy], feed_dict)
75
76
sum_accuracy += batch_accuracy
77
all_predictions = np.concatenate([all_predictions, batch_predictions])
78
79
final_accuracy = sum_accuracy / num_batches
80
81
# Print test accuracy
82
print('Test accuracy: {}'.format(final_accuracy))
83
84
# Save all predictions
85
with open(os.path.join(FLAGS.run_dir, 'predictions.csv'), 'w', encoding='utf-8', newline='') as f:
86
csvwriter = csv.writer(f)
87
csvwriter.writerow(['True class', 'Prediction'])
88
for i in range(len(all_predictions)):
89
csvwriter.writerow([labels[i], all_predictions[i]])
90
print('Predictions saved to {}'.format(os.path.join(FLAGS.run_dir, 'predictions.csv')))
91
92