Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
JianGuanTHU
GitHub Repository: JianGuanTHU/StoryEndGen
Path: blob/master/model.py
487 views
1
import numpy as np
2
import tensorflow as tf
3
4
from tensorflow.python.ops.nn import dynamic_rnn
5
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell, LSTMCell, MultiRNNCell
6
import attention_decoder_fn
7
from tensorflow.contrib.seq2seq.python.ops.seq2seq import dynamic_rnn_decoder
8
from tensorflow.contrib.seq2seq.python.ops.loss import sequence_loss
9
from tensorflow.contrib.lookup.lookup_ops import HashTable, KeyValueTensorInitializer
10
from tensorflow.contrib.layers.python.layers import layers
11
from output_projection import output_projection_layer
12
from tensorflow.python.ops import variable_scope
13
14
PAD_ID = 0
15
UNK_ID = 1
16
GO_ID = 2
17
EOS_ID = 3
18
_START_VOCAB = ['_PAD', '_UNK', '_GO', '_EOS', '_NAF_H', '_NAF_R', '_NAF_T']
19
20
class IEMSAModel(object):
21
def __init__(self,
22
num_symbols,
23
num_embed_units,
24
num_units,
25
num_layers,
26
is_train,
27
vocab=None,
28
embed=None,
29
learning_rate=0.1,
30
learning_rate_decay_factor=0.95,
31
max_gradient_norm=5.0,
32
num_samples=512,
33
max_length=30,
34
use_lstm=True):
35
36
self.posts_1 = tf.placeholder(tf.string, shape=(None, None))
37
self.posts_2 = tf.placeholder(tf.string, shape=(None, None))
38
self.posts_3 = tf.placeholder(tf.string, shape=(None, None))
39
self.posts_4 = tf.placeholder(tf.string, shape=(None, None))
40
41
self.entity_1 = tf.placeholder(tf.string, shape=(None,None,None,3))
42
self.entity_2 = tf.placeholder(tf.string, shape=(None,None,None,3))
43
self.entity_3 = tf.placeholder(tf.string, shape=(None,None,None,3))
44
self.entity_4 = tf.placeholder(tf.string, shape=(None,None,None,3))
45
46
self.entity_mask_1 = tf.placeholder(tf.float32, shape=(None, None, None))
47
self.entity_mask_2 = tf.placeholder(tf.float32, shape=(None, None, None))
48
self.entity_mask_3 = tf.placeholder(tf.float32, shape=(None, None, None))
49
self.entity_mask_4 = tf.placeholder(tf.float32, shape=(None, None, None))
50
51
self.posts_length_1 = tf.placeholder(tf.int32, shape=(None))
52
self.posts_length_2 = tf.placeholder(tf.int32, shape=(None))
53
self.posts_length_3 = tf.placeholder(tf.int32, shape=(None))
54
self.posts_length_4 = tf.placeholder(tf.int32, shape=(None))
55
56
self.responses = tf.placeholder(tf.string, shape=(None, None))
57
self.responses_length = tf.placeholder(tf.int32, shape=(None))
58
59
self.epoch = tf.Variable(0, trainable=False, name='epoch')
60
self.epoch_add_op = self.epoch.assign(self.epoch + 1)
61
62
if is_train:
63
self.symbols = tf.Variable(vocab, trainable=False, name="symbols")
64
else:
65
self.symbols = tf.Variable(np.array(['.']*num_symbols), name="symbols")
66
self.symbol2index = HashTable(KeyValueTensorInitializer(self.symbols,
67
tf.Variable(np.array([i for i in range(num_symbols)], dtype=np.int32), False)),
68
default_value=UNK_ID, name="symbol2index")
69
70
self.posts_input_1 = self.symbol2index.lookup(self.posts_1)
71
72
self.posts_2_target = self.posts_2_embed = self.symbol2index.lookup(self.posts_2)
73
self.posts_3_target = self.posts_3_embed = self.symbol2index.lookup(self.posts_3)
74
self.posts_4_target = self.posts_4_embed = self.symbol2index.lookup(self.posts_4)
75
76
self.responses_target = self.symbol2index.lookup(self.responses)
77
78
batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(self.responses)[1]
79
80
self.posts_input_2 = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
81
tf.split(self.posts_2_embed, [tf.shape(self.posts_2)[1]-1, 1], 1)[0]], 1)
82
self.posts_input_3 = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
83
tf.split(self.posts_3_embed, [tf.shape(self.posts_3)[1]-1, 1], 1)[0]], 1)
84
self.posts_input_4 = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32) * GO_ID,
85
tf.split(self.posts_4_embed, [tf.shape(self.posts_4)[1]-1, 1], 1)[0]], 1)
86
87
self.responses_target = self.symbol2index.lookup(self.responses)
88
89
batch_size, decoder_len = tf.shape(self.posts_1)[0], tf.shape(self.responses)[1]
90
91
self.responses_input = tf.concat([tf.ones([batch_size, 1], dtype=tf.int32)*GO_ID,
92
tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1)
93
94
self.encoder_2_mask = tf.reshape(tf.cumsum(tf.one_hot(self.posts_length_2-1,
95
tf.shape(self.posts_2)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_2)[1]])
96
self.encoder_3_mask = tf.reshape(tf.cumsum(tf.one_hot(self.posts_length_3-1,
97
tf.shape(self.posts_3)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_3)[1]])
98
self.encoder_4_mask = tf.reshape(tf.cumsum(tf.one_hot(self.posts_length_4-1,
99
tf.shape(self.posts_4)[1]), reverse=True, axis=1), [-1, tf.shape(self.posts_4)[1]])
100
101
self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1,
102
decoder_len), reverse=True, axis=1), [-1, decoder_len])
103
104
if embed is None:
105
self.embed = tf.get_variable('embed', [num_symbols, num_embed_units], tf.float32)
106
else:
107
self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed)
108
109
self.encoder_input_1 = tf.nn.embedding_lookup(self.embed, self.posts_input_1)
110
self.encoder_input_2 = tf.nn.embedding_lookup(self.embed, self.posts_input_2)
111
self.encoder_input_3 = tf.nn.embedding_lookup(self.embed, self.posts_input_3)
112
self.encoder_input_4 = tf.nn.embedding_lookup(self.embed, self.posts_input_4)
113
114
self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input)
115
116
entity_embedding_1 = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_1)),
117
[batch_size, tf.shape(self.entity_1)[1], tf.shape(self.entity_1)[2], 3 * num_embed_units])
118
entity_embedding_2 = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_2)),
119
[batch_size, tf.shape(self.entity_2)[1], tf.shape(self.entity_2)[2], 3 * num_embed_units])
120
entity_embedding_3 = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_3)),
121
[batch_size, tf.shape(self.entity_3)[1], tf.shape(self.entity_3)[2], 3 * num_embed_units])
122
entity_embedding_4 = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entity_4)),
123
[batch_size, tf.shape(self.entity_4)[1], tf.shape(self.entity_4)[2], 3 * num_embed_units])
124
125
head_1, relation_1, tail_1 = tf.split(entity_embedding_1, [num_embed_units] * 3, axis=3)
126
head_2, relation_2, tail_2 = tf.split(entity_embedding_2, [num_embed_units] * 3, axis=3)
127
head_3, relation_3, tail_3 = tf.split(entity_embedding_3, [num_embed_units] * 3, axis=3)
128
head_4, relation_4, tail_4 = tf.split(entity_embedding_4, [num_embed_units] * 3, axis=3)
129
130
with tf.variable_scope('graph_attention'):
131
#[batch_size, max_reponse_length, max_triple_num, 2*embed_units]
132
head_tail_1 = tf.concat([head_1, tail_1], axis=3)
133
#[batch_size, max_reponse_length, max_triple_num, embed_units]
134
head_tail_transformed_1 = tf.layers.dense(head_tail_1, num_embed_units, activation=tf.tanh, name='head_tail_transform')
135
#[batch_size, max_reponse_length, max_triple_num, embed_units]
136
relation_transformed_1 = tf.layers.dense(relation_1, num_embed_units, name='relation_transform')
137
#[batch_size, max_reponse_length, max_triple_num]
138
e_weight_1 = tf.reduce_sum(relation_transformed_1 * head_tail_transformed_1, axis=3)
139
#[batch_size, max_reponse_length, max_triple_num]
140
alpha_weight_1 = tf.nn.softmax(e_weight_1)
141
#[batch_size, max_reponse_length, embed_units]
142
graph_embed_1 = tf.reduce_sum(tf.expand_dims(alpha_weight_1, 3) * (tf.expand_dims(self.entity_mask_1, 3) * head_tail_1), axis=2)
143
144
with tf.variable_scope('graph_attention', reuse=True):
145
head_tail_2 = tf.concat([head_2, tail_2], axis=3)
146
head_tail_transformed_2 = tf.layers.dense(head_tail_2, num_embed_units, activation=tf.tanh, name='head_tail_transform')
147
relation_transformed_2 = tf.layers.dense(relation_2, num_embed_units, name='relation_transform')
148
e_weight_2 = tf.reduce_sum(relation_transformed_2 * head_tail_transformed_2, axis=3)
149
alpha_weight_2 = tf.nn.softmax(e_weight_2)
150
graph_embed_2 = tf.reduce_sum(tf.expand_dims(alpha_weight_2, 3) * (tf.expand_dims(self.entity_mask_2, 3) * head_tail_2), axis=2)
151
152
with tf.variable_scope('graph_attention', reuse=True):
153
head_tail_3 = tf.concat([head_3, tail_3], axis=3)
154
head_tail_transformed_3 = tf.layers.dense(head_tail_3, num_embed_units, activation=tf.tanh, name='head_tail_transform')
155
relation_transformed_3 = tf.layers.dense(relation_3, num_embed_units, name='relation_transform')
156
e_weight_3 = tf.reduce_sum(relation_transformed_3 * head_tail_transformed_3, axis=3)
157
alpha_weight_3 = tf.nn.softmax(e_weight_3)
158
graph_embed_3 = tf.reduce_sum(tf.expand_dims(alpha_weight_3, 3) * (tf.expand_dims(self.entity_mask_3, 3) * head_tail_3), axis=2)
159
160
with tf.variable_scope('graph_attention', reuse=True):
161
head_tail_4 = tf.concat([head_4, tail_4], axis=3)
162
head_tail_transformed_4 = tf.layers.dense(head_tail_4, num_embed_units, activation=tf.tanh, name='head_tail_transform')
163
relation_transformed_4 = tf.layers.dense(relation_4, num_embed_units, name='relation_transform')
164
e_weight_4 = tf.reduce_sum(relation_transformed_4 * head_tail_transformed_4, axis=3)
165
alpha_weight_4 = tf.nn.softmax(e_weight_4)
166
graph_embed_4 = tf.reduce_sum(tf.expand_dims(alpha_weight_4, 3) * (tf.expand_dims(self.entity_mask_4, 3) * head_tail_4), axis=2)
167
168
if use_lstm:
169
cell = MultiRNNCell([LSTMCell(num_units)] * num_layers)
170
else:
171
cell = MultiRNNCell([GRUCell(num_units)] * num_layers)
172
173
output_fn, sampled_sequence_loss = output_projection_layer(num_units,
174
num_symbols, num_samples)
175
176
encoder_output_1, encoder_state_1 = dynamic_rnn(cell, self.encoder_input_1, self.posts_length_1, dtype=tf.float32, scope="encoder")
177
178
attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1 \
179
= attention_decoder_fn.prepare_attention(graph_embed_1, encoder_output_1, 'luong', num_units)
180
decoder_fn_train_1 = attention_decoder_fn.attention_decoder_fn_train(encoder_state_1,
181
attention_keys_1, attention_values_1, attention_score_fn_1, attention_construct_fn_1, max_length=tf.reduce_max(self.posts_length_2))
182
encoder_output_2, encoder_state_2, alignments_ta_2 = dynamic_rnn_decoder(cell, decoder_fn_train_1,
183
self.encoder_input_2, self.posts_length_2, scope="decoder")
184
self.alignments_2 = tf.transpose(alignments_ta_2.stack(), perm=[1, 0, 2])
185
186
self.decoder_loss_2 = sampled_sequence_loss(encoder_output_2,
187
self.posts_2_target, self.encoder_2_mask)
188
189
with variable_scope.variable_scope('', reuse=True):
190
attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2 \
191
= attention_decoder_fn.prepare_attention(graph_embed_2, encoder_output_2, 'luong', num_units)
192
decoder_fn_train_2 = attention_decoder_fn.attention_decoder_fn_train(encoder_state_2,
193
attention_keys_2, attention_values_2, attention_score_fn_2, attention_construct_fn_2, max_length=tf.reduce_max(self.posts_length_3))
194
encoder_output_3, encoder_state_3, alignments_ta_3 = dynamic_rnn_decoder(cell, decoder_fn_train_2,
195
self.encoder_input_3, self.posts_length_3, scope="decoder")
196
self.alignments_3 = tf.transpose(alignments_ta_3.stack(), perm=[1, 0, 2])
197
198
self.decoder_loss_3 = sampled_sequence_loss(encoder_output_3,
199
self.posts_3_target, self.encoder_3_mask)
200
201
attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3 \
202
= attention_decoder_fn.prepare_attention(graph_embed_3, encoder_output_3, 'luong', num_units)
203
decoder_fn_train_3 = attention_decoder_fn.attention_decoder_fn_train(encoder_state_3,
204
attention_keys_3, attention_values_3, attention_score_fn_3, attention_construct_fn_3, max_length=tf.reduce_max(self.posts_length_4))
205
encoder_output_4, encoder_state_4, alignments_ta_4 = dynamic_rnn_decoder(cell, decoder_fn_train_3,
206
self.encoder_input_4, self.posts_length_4, scope="decoder")
207
self.alignments_4 = tf.transpose(alignments_ta_4.stack(), perm=[1, 0, 2])
208
209
self.decoder_loss_4 = sampled_sequence_loss(encoder_output_4,
210
self.posts_4_target, self.encoder_4_mask)
211
212
attention_keys, attention_values, attention_score_fn, attention_construct_fn \
213
= attention_decoder_fn.prepare_attention(graph_embed_4, encoder_output_4, 'luong', num_units)
214
215
if is_train:
216
with variable_scope.variable_scope('', reuse=True):
217
decoder_fn_train = attention_decoder_fn.attention_decoder_fn_train(encoder_state_4,
218
attention_keys, attention_values, attention_score_fn, attention_construct_fn, max_length=tf.reduce_max(self.responses_length))
219
self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(cell, decoder_fn_train,
220
self.decoder_input, self.responses_length, scope="decoder")
221
self.alignments = tf.transpose(alignments_ta.stack(), perm=[1, 0, 2])
222
223
self.decoder_loss = sampled_sequence_loss(self.decoder_output,
224
self.responses_target, self.decoder_mask)
225
226
self.params = tf.trainable_variables()
227
228
self.learning_rate = tf.Variable(float(learning_rate), trainable=False,
229
dtype=tf.float32)
230
self.learning_rate_decay_op = self.learning_rate.assign(
231
self.learning_rate * learning_rate_decay_factor)
232
self.global_step = tf.Variable(0, trainable=False)
233
234
#opt = tf.train.GradientDescentOptimizer(self.learning_rate)
235
opt = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
236
237
gradients = tf.gradients(self.decoder_loss + self.decoder_loss_2 + self.decoder_loss_3 + self.decoder_loss_4, self.params)
238
clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients,
239
max_gradient_norm)
240
self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
241
global_step=self.global_step)
242
243
else:
244
with variable_scope.variable_scope('', reuse=True):
245
decoder_fn_inference = attention_decoder_fn.attention_decoder_fn_inference(output_fn,
246
encoder_state_4, attention_keys, attention_values, attention_score_fn,
247
attention_construct_fn, self.embed, GO_ID, EOS_ID, max_length, num_symbols)
248
self.decoder_distribution, _, alignments_ta = dynamic_rnn_decoder(cell, decoder_fn_inference,
249
scope="decoder")
250
output_len = tf.shape(self.decoder_distribution)[1]
251
self.alignments = tf.transpose(alignments_ta.gather(tf.range(output_len)), [1, 0, 2])
252
253
self.generation_index = tf.argmax(tf.split(self.decoder_distribution,
254
[2, num_symbols-2], 2)[1], 2) + 2 # for removing UNK
255
self.generation = tf.nn.embedding_lookup(self.symbols, self.generation_index, name="generation")
256
257
self.params = tf.trainable_variables()
258
259
self.saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2,
260
max_to_keep=10, pad_step_number=True, keep_checkpoint_every_n_hours=1.0)
261
262
def print_parameters(self):
263
for item in self.params:
264
print('%s: %s' % (item.name, item.get_shape()))
265
266
def step_decoder(self, session, data, forward_only=False):
267
input_feed = {self.posts_1: data['posts_1'],
268
self.posts_2: data['posts_2'],
269
self.posts_3: data['posts_3'],
270
self.posts_4: data['posts_4'],
271
self.entity_1: data['entity_1'],
272
self.entity_2: data['entity_2'],
273
self.entity_3: data['entity_3'],
274
self.entity_4: data['entity_4'],
275
self.entity_mask_1: data['entity_mask_1'],
276
self.entity_mask_2: data['entity_mask_2'],
277
self.entity_mask_3: data['entity_mask_3'],
278
self.entity_mask_4: data['entity_mask_4'],
279
self.posts_length_1: data['posts_length_1'],
280
self.posts_length_2: data['posts_length_2'],
281
self.posts_length_3: data['posts_length_3'],
282
self.posts_length_4: data['posts_length_4'],
283
self.responses: data['responses'],
284
self.responses_length: data['responses_length']}
285
if forward_only:
286
output_feed = [self.decoder_loss, self.alignments_2]
287
else:
288
output_feed = [self.decoder_loss, self.gradient_norm, self.update]
289
return session.run(output_feed, input_feed)
290
291