Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
JianGuanTHU
GitHub Repository: JianGuanTHU/StoryEndGen
Path: blob/master/attention_decoder_fn.py
487 views
1
from __future__ import absolute_import
2
from __future__ import division
3
from __future__ import print_function
4
import tensorflow as tf
5
6
from tensorflow.python.ops import gen_data_flow_ops
7
from tensorflow.python.ops import tensor_array_ops
8
9
from tensorflow.contrib.layers.python.layers import layers
10
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
11
from tensorflow.python.framework import dtypes
12
from tensorflow.python.framework import function
13
from tensorflow.python.framework import ops
14
from tensorflow.python.ops import array_ops
15
from tensorflow.python.ops import control_flow_ops
16
from tensorflow.python.ops import math_ops
17
from tensorflow.python.ops import nn_ops
18
from tensorflow.python.ops import variable_scope
19
from tensorflow.python.util import nest
20
21
__all__ = [
22
"prepare_attention", "attention_decoder_fn_train",
23
"attention_decoder_fn_inference"
24
]
25
26
def attention_decoder_fn_train(encoder_state,
27
attention_keys,
28
attention_values,
29
attention_score_fn,
30
attention_construct_fn,
31
max_length=None,
32
name=None):
33
"""Attentional decoder function for `dynamic_rnn_decoder` during training.
34
35
The `attention_decoder_fn_train` is a training function for an
36
attention-based sequence-to-sequence model. It should be used when
37
`dynamic_rnn_decoder` is in the training mode.
38
39
The `attention_decoder_fn_train` is called with a set of the user arguments
40
and returns the `decoder_fn`, which can be passed to the
41
`dynamic_rnn_decoder`, such that
42
43
```
44
dynamic_fn_train = attention_decoder_fn_train(encoder_state)
45
outputs_train, state_train = dynamic_rnn_decoder(
46
decoder_fn=dynamic_fn_train, ...)
47
```
48
49
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
50
51
Args:
52
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
53
attention_keys: to be compared with target states.
54
attention_values: to be used to construct context vectors.
55
attention_score_fn: to compute similarity between key and target states.
56
attention_construct_fn: to build attention states.
57
name: (default: `None`) NameScope for the decoder function;
58
defaults to "simple_decoder_fn_train"
59
60
Returns:
61
A decoder function with the required interface of `dynamic_rnn_decoder`
62
intended for training.
63
"""
64
with ops.name_scope(name, "attention_decoder_fn_train", [
65
encoder_state, attention_keys, attention_values, attention_score_fn,
66
attention_construct_fn
67
]):
68
pass
69
70
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
71
"""Decoder function used in the `dynamic_rnn_decoder` for training.
72
73
Args:
74
time: positive integer constant reflecting the current timestep.
75
cell_state: state of RNNCell.
76
cell_input: input provided by `dynamic_rnn_decoder`.
77
cell_output: output of RNNCell.
78
context_state: context state provided by `dynamic_rnn_decoder`.
79
80
Returns:
81
A tuple (done, next state, next input, emit output, next context state)
82
where:
83
84
done: `None`, which is used by the `dynamic_rnn_decoder` to indicate
85
that `sequence_lengths` in `dynamic_rnn_decoder` should be used.
86
87
next state: `cell_state`, this decoder function does not modify the
88
given state.
89
90
next input: `cell_input`, this decoder function does not modify the
91
given input. The input could be modified when applying e.g. attention.
92
93
emit output: `cell_output`, this decoder function does not modify the
94
given output.
95
96
next context state: `context_state`, this decoder function does not
97
modify the given context state. The context state could be modified when
98
applying e.g. beam search.
99
"""
100
with ops.name_scope(
101
name, "attention_decoder_fn_train",
102
[time, cell_state, cell_input, cell_output, context_state]):
103
if cell_state is None: # first call, return encoder_state
104
cell_state = encoder_state
105
106
# init attention
107
attention_embed = _init_attention(encoder_state)
108
attention_output = _init_attention(encoder_state)
109
attention = layers.linear(tf.concat([attention_embed, attention_output], 1),
110
512, biases_initializer=None, scope="attention_total_tran")
111
#print("ini_attention:", attention)
112
context_state = tensor_array_ops.TensorArray(dtype=dtypes.float32,
113
tensor_array_name="alignments_ta", size=max_length, dynamic_size=True, infer_shape=False)
114
else:
115
# construct attention
116
attention_embed, alignments_embed = attention_construct_fn[0](cell_output, attention_keys[0],
117
attention_values[0])
118
# construct attention
119
attention_output, alignments_output = attention_construct_fn[1](cell_output, attention_keys[1],
120
attention_values[1])
121
attention = layers.linear(tf.concat([attention_embed, attention_output], 1),
122
512, biases_initializer=None, scope="attention_total_tran", reuse=True)
123
cell_output = attention
124
125
context_state = context_state.write(time - 1, alignments_embed)
126
127
next_input = array_ops.concat([cell_input, attention], 1)
128
return (None, cell_state, next_input, cell_output, context_state)
129
130
return decoder_fn
131
132
133
def attention_decoder_fn_inference(output_fn,
134
encoder_state,
135
attention_keys,
136
attention_values,
137
attention_score_fn,
138
attention_construct_fn,
139
embeddings,
140
start_of_sequence_id,
141
end_of_sequence_id,
142
maximum_length,
143
num_decoder_symbols,
144
dtype=dtypes.int32,
145
name=None):
146
"""Attentional decoder function for `dynamic_rnn_decoder` during inference.
147
148
The `attention_decoder_fn_inference` is a simple inference function for a
149
sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is
150
in the inference mode.
151
152
The `attention_decoder_fn_inference` is called with user arguments
153
and returns the `decoder_fn`, which can be passed to the
154
`dynamic_rnn_decoder`, such that
155
156
```
157
dynamic_fn_inference = attention_decoder_fn_inference(...)
158
outputs_inference, state_inference = dynamic_rnn_decoder(
159
decoder_fn=dynamic_fn_inference, ...)
160
```
161
162
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
163
164
Args:
165
output_fn: An output function to project your `cell_output` onto class
166
logits.
167
168
An example of an output function;
169
170
```
171
tf.variable_scope("decoder") as varscope
172
output_fn = lambda x: layers.linear(x, num_decoder_symbols,
173
scope=varscope)
174
175
outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...)
176
logits_train = output_fn(outputs_train)
177
178
varscope.reuse_variables()
179
logits_inference, state_inference = seq2seq.dynamic_rnn_decoder(
180
output_fn=output_fn, ...)
181
```
182
183
If `None` is supplied it will act as an identity function, which
184
might be wanted when using the RNNCell `OutputProjectionWrapper`.
185
186
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
187
attention_keys: to be compared with target states.
188
attention_values: to be used to construct context vectors.
189
attention_score_fn: to compute similarity between key and target states.
190
attention_construct_fn: to build attention states.
191
embeddings: The embeddings matrix used for the decoder sized
192
`[num_decoder_symbols, embedding_size]`.
193
start_of_sequence_id: The start of sequence ID in the decoder embeddings.
194
end_of_sequence_id: The end of sequence ID in the decoder embeddings.
195
maximum_length: The maximum allowed of time steps to decode.
196
num_decoder_symbols: The number of classes to decode at each time step.
197
dtype: (default: `dtypes.int32`) The default data type to use when
198
handling integer objects.
199
name: (default: `None`) NameScope for the decoder function;
200
defaults to "attention_decoder_fn_inference"
201
202
Returns:
203
A decoder function with the required interface of `dynamic_rnn_decoder`
204
intended for inference.
205
"""
206
with ops.name_scope(name, "attention_decoder_fn_inference", [
207
output_fn, encoder_state, attention_keys, attention_values,
208
attention_score_fn, attention_construct_fn, embeddings,
209
start_of_sequence_id, end_of_sequence_id, maximum_length,
210
num_decoder_symbols, dtype
211
]):
212
start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype)
213
end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype)
214
maximum_length = ops.convert_to_tensor(maximum_length, dtype)
215
num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype)
216
encoder_info = nest.flatten(encoder_state)[0]
217
batch_size = encoder_info.get_shape()[0].value
218
if output_fn is None:
219
output_fn = lambda x: x
220
if batch_size is None:
221
batch_size = array_ops.shape(encoder_info)[0]
222
223
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
224
"""Decoder function used in the `dynamic_rnn_decoder` for inference.
225
226
The main difference between this decoder function and the `decoder_fn` in
227
`attention_decoder_fn_train` is how `next_cell_input` is calculated. In
228
decoder function we calculate the next input by applying an argmax across
229
the feature dimension of the output from the decoder. This is a
230
greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
231
use beam-search instead.
232
233
Args:
234
time: positive integer constant reflecting the current timestep.
235
cell_state: state of RNNCell.
236
cell_input: input provided by `dynamic_rnn_decoder`.
237
cell_output: output of RNNCell.
238
context_state: context state provided by `dynamic_rnn_decoder`.
239
240
Returns:
241
A tuple (done, next state, next input, emit output, next context state)
242
where:
243
244
done: A boolean vector to indicate which sentences has reached a
245
`end_of_sequence_id`. This is used for early stopping by the
246
`dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
247
all elements as `true` is returned.
248
249
next state: `cell_state`, this decoder function does not modify the
250
given state.
251
252
next input: The embedding from argmax of the `cell_output` is used as
253
`next_input`.
254
255
emit output: If `output_fn is None` the supplied `cell_output` is
256
returned, else the `output_fn` is used to update the `cell_output`
257
before calculating `next_input` and returning `cell_output`.
258
259
next context state: `context_state`, this decoder function does not
260
modify the given context state. The context state could be modified when
261
applying e.g. beam search.
262
263
Raises:
264
ValueError: if cell_input is not None.
265
266
"""
267
with ops.name_scope(
268
name, "attention_decoder_fn_inference",
269
[time, cell_state, cell_input, cell_output, context_state]):
270
if cell_input is not None:
271
raise ValueError("Expected cell_input to be None, but saw: %s" %
272
cell_input)
273
if cell_output is None:
274
# invariant that this is time == 0
275
next_input_id = array_ops.ones(
276
[batch_size,], dtype=dtype) * (start_of_sequence_id)
277
done = array_ops.zeros([batch_size,], dtype=dtypes.bool)
278
cell_state = encoder_state
279
cell_output = array_ops.zeros(
280
[num_decoder_symbols], dtype=dtypes.float32)
281
cell_input = array_ops.gather(embeddings, next_input_id)
282
283
attention_embed = _init_attention(encoder_state)
284
attention_output = _init_attention(encoder_state)
285
attention = layers.linear(tf.concat([attention_embed, attention_output], 1),
286
512, biases_initializer=None, scope="attention_total_tran")
287
context_state = tensor_array_ops.TensorArray(dtype=dtypes.float32,
288
tensor_array_name="alignments_ta", size=maximum_length, dynamic_size=True, infer_shape=False)
289
else:
290
attention_embed, alignments_embed = attention_construct_fn[0](cell_output, attention_keys[0],
291
attention_values[0])
292
attention_output, alignments_output = attention_construct_fn[1](cell_output, attention_keys[1],
293
attention_values[1])
294
attention = layers.linear(tf.concat([attention_embed, attention_output], 1),
295
512, biases_initializer=None, scope="attention_total_tran", reuse=True)
296
297
cell_output = attention
298
299
# argmax decoder
300
cell_output = output_fn(cell_output) # logits
301
next_input_id = math_ops.cast(
302
math_ops.argmax(cell_output, 1), dtype=dtype)
303
done = math_ops.equal(next_input_id, end_of_sequence_id)
304
cell_input = array_ops.gather(embeddings, next_input_id)
305
306
context_state = context_state.write(time - 1, alignments_embed)
307
# combine cell_input and attention
308
next_input = array_ops.concat([cell_input, attention], 1)
309
310
# if time > maxlen, return all true vector
311
done = control_flow_ops.cond(
312
math_ops.greater(time, maximum_length),
313
lambda: array_ops.ones([batch_size,], dtype=dtypes.bool),
314
lambda: done)
315
return (done, cell_state, next_input, cell_output, context_state)
316
317
return decoder_fn
318
319
320
## Helper functions ##
321
def prepare_attention(attention_embed,
322
attention_output,
323
attention_option,
324
num_units,
325
reuse=False):
326
"""Prepare keys/values/functions for attention.
327
328
Args:
329
attention_states: hidden states to attend over.
330
attention_option: how to compute attention, either "luong" or "bahdanau".
331
num_units: hidden state dimension.
332
reuse: whether to reuse variable scope.
333
334
Returns:
335
attention_keys: to be compared with target states.
336
attention_values: to be used to construct context vectors.
337
attention_score_fn: to compute similarity between key and target states.
338
attention_construct_fn: to build attention states.
339
"""
340
341
# Prepare attention keys / values from attention_states
342
with variable_scope.variable_scope("attention_embed_keys", reuse=reuse) as scope:
343
attention_embed_keys = layers.linear(
344
attention_embed, num_units, biases_initializer=None, scope=scope)
345
with variable_scope.variable_scope("attention_embed_train", reuse=reuse) as scope:
346
attention_embed_values = layers.linear(
347
attention_embed, num_units, biases_initializer=None, scope=scope)
348
349
with variable_scope.variable_scope("attention_output_keys", reuse=reuse) as scope:
350
attention_output_keys = layers.linear(
351
attention_output, num_units, biases_initializer=None, scope=scope)
352
attention_output_values = attention_output
353
354
# Attention score function
355
attention_embed_score_fn = _create_attention_score_fn("attention_embed_score", num_units,
356
attention_option, reuse)
357
358
# Attention construction function
359
attention_embed_construct_fn = _create_attention_construct_fn("attention_embed_construct",
360
num_units,
361
attention_embed_score_fn,
362
reuse)
363
364
# Attention score function
365
attention_output_score_fn = _create_attention_score_fn("attention_output_score", num_units,
366
attention_option, reuse)
367
368
# Attention construction function
369
attention_output_construct_fn = _create_attention_construct_fn("attention_output_construct",
370
num_units,
371
attention_output_score_fn,
372
reuse)
373
374
return ((attention_embed_keys, attention_output_keys), (attention_embed_values, attention_output_values), (attention_embed_score_fn, attention_output_score_fn),
375
(attention_embed_construct_fn, attention_output_construct_fn))
376
377
378
def _init_attention(encoder_state):
379
"""Initialize attention. Handling both LSTM and GRU.
380
381
Args:
382
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
383
384
Returns:
385
attn: initial zero attention vector.
386
"""
387
388
# Multi- vs single-layer
389
# TODO(thangluong): is this the best way to check?
390
if isinstance(encoder_state, tuple):
391
top_state = encoder_state[-1]
392
else:
393
top_state = encoder_state
394
395
# LSTM vs GRU
396
if isinstance(top_state, core_rnn_cell_impl.LSTMStateTuple):
397
attn = array_ops.zeros_like(top_state.h)
398
else:
399
attn = array_ops.zeros_like(top_state)
400
401
return attn
402
403
404
def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse):
405
"""Function to compute attention vectors.
406
407
Args:
408
name: to label variables.
409
num_units: hidden state dimension.
410
attention_score_fn: to compute similarity between key and target states.
411
reuse: whether to reuse variable scope.
412
413
Returns:
414
attention_construct_fn: to build attention states.
415
"""
416
with variable_scope.variable_scope(name, reuse=reuse) as scope:
417
418
def construct_fn(attention_query, attention_keys, attention_values):
419
context, alignments = attention_score_fn(attention_query, attention_keys,
420
attention_values)
421
422
concat_input = array_ops.concat([attention_query, context], 1)
423
concat_input = array_ops.reshape(concat_input, [-1, 1024])
424
attention = layers.linear(
425
concat_input, num_units, biases_initializer=None, scope=scope)
426
return attention, alignments
427
428
return construct_fn
429
430
431
# keys: [batch_size, attention_length, attn_size]
432
# query: [batch_size, 1, attn_size]
433
# return weights [batch_size, attention_length]
434
@function.Defun(func_name="attn_add_fun", noinline=True)
435
def _attn_add_fun(v, keys, query):
436
return math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2])
437
438
439
@function.Defun(func_name="attn_mul_fun", noinline=True)
440
def _attn_mul_fun(keys, query):
441
return math_ops.reduce_sum(keys * query, [2])
442
443
444
def _create_attention_score_fn(name,
445
num_units,
446
attention_option,
447
reuse,
448
dtype=dtypes.float32):
449
"""Different ways to compute attention scores.
450
451
Args:
452
name: to label variables.
453
num_units: hidden state dimension.
454
attention_option: how to compute attention, either "luong" or "bahdanau".
455
"bahdanau": additive (Bahdanau et al., ICLR'2015)
456
"luong": multiplicative (Luong et al., EMNLP'2015)
457
reuse: whether to reuse variable scope.
458
dtype: (default: `dtypes.float32`) data type to use.
459
460
Returns:
461
attention_score_fn: to compute similarity between key and target states.
462
"""
463
with variable_scope.variable_scope(name, reuse=reuse):
464
if attention_option == "bahdanau":
465
query_w = variable_scope.get_variable(
466
"attnW", [num_units, num_units], dtype=dtype)
467
score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype)
468
469
def attention_score_fn(query, keys, values):
470
"""Put attention masks on attention_values using attention_keys and query.
471
472
Args:
473
query: A Tensor of shape [batch_size, num_units].
474
keys: A Tensor of shape [batch_size, attention_length, num_units].
475
values: A Tensor of shape [batch_size, attention_length, num_units].
476
477
Returns:
478
context_vector: A Tensor of shape [batch_size, num_units].
479
480
Raises:
481
ValueError: if attention_option is neither "luong" or "bahdanau".
482
483
484
"""
485
if attention_option == "bahdanau":
486
# transform query
487
query = math_ops.matmul(query, query_w)
488
489
# reshape query: [batch_size, 1, num_units]
490
query = array_ops.reshape(query, [-1, 1, num_units])
491
492
# attn_fun
493
scores = _attn_add_fun(score_v, keys, query)
494
elif attention_option == "luong":
495
# reshape query: [batch_size, 1, num_units]
496
query = array_ops.reshape(query, [-1, 1, num_units])
497
498
# attn_fun
499
scores = _attn_mul_fun(keys, query)
500
else:
501
raise ValueError("Unknown attention option %s!" % attention_option)
502
503
# Compute alignment weights
504
# scores: [batch_size, length]
505
# alignments: [batch_size, length]
506
# TODO(thangluong): not normalize over padding positions.
507
alignments_prob = tf.clip_by_value(nn_ops.softmax(scores), 1e-4, 1)
508
alignments = array_ops.expand_dims(alignments_prob, 2)
509
context_vector = math_ops.reduce_sum(alignments * values, [1])
510
context_vector.set_shape([None, num_units])
511
512
return context_vector, alignments_prob
513
514
return attention_score_fn
515
516