Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jianguanthu
GitHub Repository: jianguanthu/storyendgen
Path: blob/master/main.py
59 views
1
import numpy as np
2
import tensorflow as tf
3
import sys
4
import time
5
import random
6
from pattern.en import lemma
7
random.seed(time.time())
8
9
from model import IEMSAModel, _START_VOCAB
10
11
tf.app.flags.DEFINE_boolean("is_train", True, "Set to False to inference.")
12
tf.app.flags.DEFINE_integer("symbols", 10000, "vocabulary size.")
13
tf.app.flags.DEFINE_integer("embed_units", 200, "Size of word embedding.")
14
tf.app.flags.DEFINE_integer("units", 512, "Size of each model layer.")
15
tf.app.flags.DEFINE_integer("layers", 2, "Number of layers in the model.")
16
tf.app.flags.DEFINE_integer("batch_size", 128, "Batch size to use during training.")
17
tf.app.flags.DEFINE_string("data_dir", "./data", "Data directory")
18
tf.app.flags.DEFINE_string("train_dir", "./train", "Training directory.")
19
tf.app.flags.DEFINE_integer("per_checkpoint", 1000, "How many steps to do per checkpoint.")
20
tf.app.flags.DEFINE_integer("inference_version", 0, "The version for inferencing.")
21
tf.app.flags.DEFINE_integer("triple_num", 10, "max number of triple for each query")
22
tf.app.flags.DEFINE_boolean("log_parameters", True, "Set to True to show the parameters")
23
tf.app.flags.DEFINE_string("inference_path", "", "Set filename of inference, default isscreen")
24
25
FLAGS = tf.app.flags.FLAGS
26
27
def load_data(path, fname):
28
post = []
29
with open('%s/%s.post' % (path, fname)) as f:
30
for line in f:
31
tmp = line.strip().split("\t")
32
post.append([p.split() for p in tmp])
33
34
with open('%s/%s.response' % (path, fname)) as f:
35
response = [line.strip().split() for line in f.readlines()]
36
data = []
37
for p, r in zip(post, response):
38
data.append({'post': p, 'response': r})
39
return data
40
41
def load_relation(path):
42
file = open('%s/triples_shrink.txt' % (path), "r")
43
44
relation = {}
45
for line in file:
46
tmp = line.strip().split()
47
if tmp[0] in relation:
48
if tmp[2] not in relation[tmp[0]]:
49
relation[tmp[0]].append(tmp)
50
else:
51
relation[tmp[0]] = [tmp]
52
53
for r in relation.keys():
54
tmp_vocab = {}
55
i = 0
56
for re in relation[r]:
57
if re[2] in vocab_dict.keys():
58
tmp_vocab[i] = vocab_dict[re[2]]
59
i += 1
60
tmp_list = sorted(tmp_vocab, key=tmp_vocab.get)[:FLAGS.triple_num] if len(tmp_vocab) > FLAGS.triple_num else sorted(tmp_vocab, key=tmp_vocab.get)
61
new_relation = []
62
for i in tmp_list:
63
new_relation.append(relation[r][i])
64
relation[r] = new_relation
65
66
return relation
67
68
def build_vocab(path, data):
69
print("Creating vocabulary...")
70
71
relation_vocab_list = []
72
relation_file = open(path + "/relations.txt", "r")
73
for line in relation_file:
74
relation_vocab_list += line.strip().split()
75
76
vocab = {}
77
for i, pair in enumerate(data):
78
if i % 100000 == 0:
79
print(" processing line %d" % i)
80
for token in [word for p in pair['post'] for word in p]+pair['response']:
81
if token in vocab:
82
vocab[token] += 1
83
else:
84
vocab[token] = 1
85
vocab_list = _START_VOCAB + relation_vocab_list + sorted(vocab, key=vocab.get, reverse=True)
86
87
if len(vocab_list) > FLAGS.symbols:
88
vocab_list = vocab_list[:FLAGS.symbols]
89
90
print("Loading word vectors...")
91
vectors = {}
92
with open(path + '/glove.6B.200d.txt', 'r') as f:
93
for i, line in enumerate(f):
94
if i % 100000 == 0:
95
print(" processing line %d" % i)
96
s = line.strip()
97
word = s[:s.find(' ')]
98
vector = s[s.find(' ')+1:]
99
vectors[word] = vector
100
101
embed = []
102
for word in vocab_list:
103
if word in vectors:
104
vector = map(float, vectors[word].split())
105
else:
106
vector = np.zeros((FLAGS.embed_units), dtype=np.float32)
107
embed.append(vector)
108
embed = np.array(embed, dtype=np.float32)
109
return vocab_list, embed, vocab
110
111
def gen_batched_data(data):
112
encoder_len = [max([len(item['post'][i]) for item in data]) + 1 for i in range(4)]
113
decoder_len = max([len(item['response']) for item in data]) + 1
114
posts_1, posts_2, posts_3, posts_4, posts_length_1, posts_length_2, posts_length_3, posts_length_4, responses, responses_length = [], [], [], [], [], [], [], [], [], []
115
116
def padding(sent, l):
117
return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1)
118
119
for item in data:
120
posts_1.append(padding(item['post'][0], encoder_len[0]))
121
posts_2.append(padding(item['post'][1], encoder_len[1]))
122
posts_3.append(padding(item['post'][2], encoder_len[2]))
123
posts_4.append(padding(item['post'][3], encoder_len[3]))
124
125
posts_length_1.append(len(item['post'][0]) + 1)
126
posts_length_2.append(len(item['post'][1]) + 1)
127
posts_length_3.append(len(item['post'][2]) + 1)
128
posts_length_4.append(len(item['post'][3]) + 1)
129
130
responses.append(padding(item['response'], decoder_len))
131
responses_length.append(len(item['response']) + 1)
132
133
entity = [[], [], [], []]
134
for item in data:
135
for i in range(4):
136
entity[i].append([])
137
for word in item['post'][i]:
138
try:
139
w = lemma(word).encode("ascii")
140
except UnicodeDecodeError, e:
141
w = word
142
if w in relation:
143
entity[i][-1].append(relation[w])
144
else:
145
entity[i][-1].append([['_NAF_H', '_NAF_R', '_NAF_T']])
146
max_response_length = [0,0,0,0]
147
max_triple_length = [0,0,0,0]
148
for i in range(4):
149
for item in entity[i]:
150
if len(item) > max_response_length[i]:
151
max_response_length[i] = len(item)
152
for triple in item:
153
if len(triple) > max_triple_length[i]:
154
max_triple_length[i] = len(triple)
155
for i in range(4):
156
for j in range(len(entity[i])):
157
for k in range(len(entity[i][j])):
158
if len(entity[i][j][k]) < max_triple_length[i]:
159
entity[i][j][k] = entity[i][j][k] + [['_NAF_H', '_NAF_R', '_NAF_T']] * (max_triple_length[i] - len(entity[i][j][k]))
160
if len(entity[i][j]) < (max_response_length[i] + 1):
161
entity[i][j] = entity[i][j] + [[['_NAF_H', '_NAF_R', '_NAF_T']] * max_triple_length[i]] * (max_response_length[i] + 1 - len(entity[i][j]))
162
163
entity_0, entity_1, entity_2, entity_3 = entity[0], entity[1], entity[2], entity[3]
164
entity_mask = [[], [], [], []]
165
for i in range(4):
166
for j in range(len(entity[i])):
167
entity_mask[i].append([])
168
for k in range(len(entity[i][j])):
169
entity_mask[i][-1].append([])
170
for r in entity[i][j][k]:
171
if r[0] == '_NAF_H':
172
entity_mask[i][-1][-1].append(0)
173
else:
174
entity_mask[i][-1][-1].append(1)
175
176
entity_mask_0, entity_mask_1, entity_mask_2, entity_mask_3 = entity_mask[0], entity_mask[1], entity_mask[2], entity_mask[3]
177
178
batched_data = {'posts_1': np.array(posts_1),
179
'posts_2': np.array(posts_2),
180
'posts_3': np.array(posts_3),
181
'posts_4': np.array(posts_4),
182
'entity_1': np.array(entity_0),
183
'entity_2': np.array(entity_1),
184
'entity_3': np.array(entity_2),
185
'entity_4': np.array(entity_3),
186
'entity_mask_1': np.array(entity_mask_0),
187
'entity_mask_2': np.array(entity_mask_1),
188
'entity_mask_3': np.array(entity_mask_2),
189
'entity_mask_4': np.array(entity_mask_3),
190
'posts_length_1': posts_length_1,
191
'posts_length_2': posts_length_2,
192
'posts_length_3': posts_length_3,
193
'posts_length_4': posts_length_4,
194
'responses': np.array(responses),
195
'responses_length': responses_length}
196
return batched_data
197
198
def train(model, sess, dataset):
199
st, ed, loss = 0, 0, []
200
while ed < len(dataset):
201
print "epoch %d, training %.4f %%...\r" % (epoch, float(ed) / len(dataset) * 100),
202
st, ed = ed, ed + FLAGS.batch_size if ed + \
203
FLAGS.batch_size < len(dataset) else len(dataset)
204
batch_data = gen_batched_data(dataset[st:ed])
205
outputs = model.step_decoder(sess, batch_data)
206
loss.append(outputs[0])
207
208
sess.run(model.epoch_add_op)
209
return np.mean(loss)
210
211
def evaluate(model, sess, dataset):
212
st, ed, loss = 0, 0, []
213
while ed < len(dataset):
214
print "epoch %d, evaluate %.4f %%...\r" % (epoch, float(ed) / len(dataset) * 100),
215
st, ed = ed, ed + FLAGS.batch_size if ed + \
216
FLAGS.batch_size < len(dataset) else len(dataset)
217
batch_data = gen_batched_data(dataset[st:ed])
218
outputs = model.step_decoder(sess, batch_data, forward_only=True)
219
loss.append(outputs[0])
220
return np.mean(loss)
221
222
def inference(model, sess, dataset):
223
st, ed, posts, truth, generations, alignments_2, alignments_3, alignments_4, alignments = 0, 0, [], [], [], [], [], [], []
224
while ed < len(dataset):
225
st, ed = ed, ed + FLAGS.batch_size if ed + \
226
FLAGS.batch_size < len(dataset) else len(dataset)
227
data = gen_batched_data(dataset[st:ed])
228
outputs = sess.run(['generation:0', model.alignments_2, model.alignments_3, model.alignments_4, model.alignments],
229
{model.posts_1: data['posts_1'],
230
model.posts_2: data['posts_2'],
231
model.posts_3: data['posts_3'],
232
model.posts_4: data['posts_4'],
233
model.entity_1: data['entity_1'],
234
model.entity_2: data['entity_2'],
235
model.entity_3: data['entity_3'],
236
model.entity_4: data['entity_4'],
237
model.entity_mask_1: data['entity_mask_1'],
238
model.entity_mask_2: data['entity_mask_2'],
239
model.entity_mask_3: data['entity_mask_3'],
240
model.entity_mask_4: data['entity_mask_4'],
241
model.posts_length_1: data['posts_length_1'],
242
model.posts_length_2: data['posts_length_2'],
243
model.posts_length_3: data['posts_length_3'],
244
model.posts_length_4: data['posts_length_4']})
245
generations.append(outputs[0])
246
alignments_2.append(outputs[1])
247
alignments_3.append(outputs[2])
248
alignments_4.append(outputs[3])
249
alignments.append(outputs[4])
250
251
posts.append([d['post'] for d in dataset[st:ed]])
252
truth.append([d['response'] for d in dataset[st:ed]])
253
254
output_file = open("./output_"+ str(FLAGS.inference_version) + ".txt", "w")
255
256
for batch_generation in generations:
257
for response in batch_generation:
258
result = []
259
for token in response:
260
if token != '_EOS':
261
result.append(token)
262
else:
263
break
264
print >> output_file, ' '.join(result)
265
return
266
267
config = tf.ConfigProto()
268
config.gpu_options.allow_growth = True
269
with tf.Session(config=config) as sess:
270
if FLAGS.is_train:
271
data_train = load_data(FLAGS.data_dir, 'train')
272
data_dev = load_data(FLAGS.data_dir, 'val')
273
data_test = load_data(FLAGS.data_dir, 'test')
274
vocab, embed, vocab_dict = build_vocab(FLAGS.data_dir, data_train)
275
relation = load_relation(FLAGS.data_dir)
276
277
model = IEMSAModel(
278
FLAGS.symbols,
279
FLAGS.embed_units,
280
FLAGS.units,
281
FLAGS.layers,
282
is_train=True,
283
vocab=vocab,
284
embed=embed)
285
286
if FLAGS.log_parameters:
287
model.print_parameters()
288
289
if tf.train.get_checkpoint_state(FLAGS.train_dir):
290
print("Reading model parameters from %s" % FLAGS.train_dir)
291
model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))
292
model.symbol2index.init.run()
293
else:
294
print("Created model with fresh parameters.")
295
tf.global_variables_initializer().run()
296
model.symbol2index.init.run()
297
pre_losses = [1e18] * 3
298
while True:
299
epoch = model.epoch.eval()
300
random.shuffle(data_train)
301
start_time = time.time()
302
loss = train(model, sess, data_train)
303
model.saver.save(sess, '%s/checkpoint' %
304
FLAGS.train_dir, global_step=model.global_step)
305
if loss > max(pre_losses):
306
sess.run(model.learning_rate_decay_op)
307
pre_losses = pre_losses[1:] + [loss]
308
print "epoch %d learning rate %.4f epoch-time %.4f perplexity [%.8f]" \
309
% (epoch, model.learning_rate.eval(), time.time() - start_time, np.exp(loss))
310
311
loss = evaluate(model, sess, data_dev)
312
print " val_set, perplexity [%.8f]" % np.exp(loss)
313
loss = evaluate(model, sess, data_test)
314
print " test_set, perplexity [%.8f]" % np.exp(loss)
315
316
else:
317
model = IEMSAModel(
318
FLAGS.symbols,
319
FLAGS.embed_units,
320
FLAGS.units,
321
FLAGS.layers,
322
is_train=False,
323
vocab=None)
324
325
if FLAGS.log_parameters:
326
model.print_parameters()
327
328
if FLAGS.inference_version == 0:
329
model_path = tf.train.latest_checkpoint(FLAGS.train_dir)
330
else:
331
model_path = '%s/checkpoint-%08d' % (
332
FLAGS.train_dir, FLAGS.inference_version)
333
print 'restore from %s' % model_path
334
model.saver.restore(sess, model_path)
335
model.symbol2index.init.run()
336
337
data_train = load_data(FLAGS.data_dir, 'train')
338
data_dev = load_data(FLAGS.data_dir, 'val')
339
data_test = load_data(FLAGS.data_dir, 'test')
340
vocab, embed, vocab_dict = build_vocab(FLAGS.data_dir, data_train)
341
relation = load_relation(FLAGS.data_dir)
342
343
inference(model, sess, data_test)
344
345