Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/utils/decoder.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlow Authors, All Rights Reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from typing import Any, Optional, Tuple, Union
16
17
import tensorflow as tf
18
from tensorflow.python.ops import control_flow_util
19
from tensorflow_addons.seq2seq import Decoder
20
from tensorflow_addons.seq2seq.decoder import (
21
BaseDecoder,
22
_prepend_batch,
23
_transpose_batch_time,
24
)
25
from tensorflow_addons.utils.types import Number, TensorLike
26
27
28
def dynamic_decode(
29
decoder: Union[Decoder, BaseDecoder],
30
output_time_major: bool = False,
31
impute_finished: bool = False,
32
maximum_iterations: Optional[TensorLike] = None,
33
parallel_iterations: int = 32,
34
swap_memory: bool = False,
35
training: Optional[bool] = None,
36
scope: Optional[str] = None,
37
enable_tflite_convertible: bool = False,
38
**kwargs
39
) -> Tuple[Any, Any, Any]:
40
"""Perform dynamic decoding with `decoder`.
41
Calls initialize() once and step() repeatedly on the Decoder object.
42
Args:
43
decoder: A `Decoder` instance.
44
output_time_major: Python boolean. Default: `False` (batch major). If
45
`True`, outputs are returned as time major tensors (this mode is
46
faster). Otherwise, outputs are returned as batch major tensors (this
47
adds extra time to the computation).
48
impute_finished: Python boolean. If `True`, then states for batch
49
entries which are marked as finished get copied through and the
50
corresponding outputs get zeroed out. This causes some slowdown at
51
each time step, but ensures that the final state and outputs have
52
the correct values and that backprop ignores time steps that were
53
marked as finished.
54
maximum_iterations: A strictly positive `int32` scalar, the maximum
55
allowed number of decoding steps. Default is `None` (decode until the
56
decoder is fully done).
57
parallel_iterations: Argument passed to `tf.while_loop`.
58
swap_memory: Argument passed to `tf.while_loop`.
59
training: Python boolean. Indicates whether the layer should behave
60
in training mode or in inference mode. Only relevant
61
when `dropout` or `recurrent_dropout` is used.
62
scope: Optional name scope to use.
63
enable_tflite_convertible: Python boolean. If `True`, then the variables
64
of `TensorArray` become of 1-D static shape. Also zero pads in the
65
output tensor will be discarded. Default: `False`.
66
**kwargs: dict, other keyword arguments for dynamic_decode. It might
67
contain arguments for `BaseDecoder` to initialize, which takes all
68
tensor inputs during call().
69
Returns:
70
`(final_outputs, final_state, final_sequence_lengths)`.
71
Raises:
72
ValueError: if `maximum_iterations` is provided but is not a scalar.
73
"""
74
with tf.name_scope(scope or "decoder"):
75
is_xla = not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext(
76
tf.compat.v1.get_default_graph()
77
)
78
79
if maximum_iterations is not None:
80
maximum_iterations = tf.convert_to_tensor(
81
maximum_iterations, dtype=tf.int32, name="maximum_iterations"
82
)
83
if maximum_iterations.shape.ndims != 0:
84
raise ValueError("maximum_iterations must be a scalar")
85
tf.debugging.assert_greater(
86
maximum_iterations,
87
0,
88
message="maximum_iterations should be greater than 0",
89
)
90
elif is_xla:
91
raise ValueError("maximum_iterations is required for XLA compilation.")
92
93
if isinstance(decoder, Decoder):
94
initial_finished, initial_inputs, initial_state = decoder.initialize()
95
else:
96
# For BaseDecoder that takes tensor inputs during call.
97
decoder_init_input = kwargs.pop("decoder_init_input", None)
98
decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
99
initial_finished, initial_inputs, initial_state = decoder.initialize(
100
decoder_init_input, **decoder_init_kwargs
101
)
102
103
if enable_tflite_convertible:
104
# Assume the batch_size = 1 for inference.
105
# So we can change 2-D TensorArray into 1-D by reshaping it.
106
zero_outputs = tf.nest.map_structure(
107
lambda shape, dtype: tf.reshape(
108
tf.zeros(_prepend_batch(decoder.batch_size, shape), dtype=dtype),
109
[-1],
110
),
111
decoder.output_size,
112
decoder.output_dtype,
113
)
114
else:
115
zero_outputs = tf.nest.map_structure(
116
lambda shape, dtype: tf.zeros(
117
_prepend_batch(decoder.batch_size, shape), dtype=dtype
118
),
119
decoder.output_size,
120
decoder.output_dtype,
121
)
122
123
if maximum_iterations is not None:
124
initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations)
125
initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32)
126
initial_time = tf.constant(0, dtype=tf.int32)
127
128
def _shape(batch_size, from_shape):
129
if not isinstance(from_shape, tf.TensorShape) or from_shape.ndims == 0:
130
return None
131
else:
132
batch_size = tf.get_static_value(
133
tf.convert_to_tensor(batch_size, name="batch_size")
134
)
135
if enable_tflite_convertible:
136
# Since we can't use 2-D TensoArray and assume `batch_size` = 1,
137
# we use `from_shape` dimension only.
138
return from_shape
139
return tf.TensorShape([batch_size]).concatenate(from_shape)
140
141
dynamic_size = maximum_iterations is None or not is_xla
142
# The dynamic shape `TensoArray` is not allowed in TFLite yet.
143
dynamic_size = dynamic_size and (not enable_tflite_convertible)
144
145
def _create_ta(s, d):
146
return tf.TensorArray(
147
dtype=d,
148
size=0 if dynamic_size else maximum_iterations,
149
dynamic_size=dynamic_size,
150
element_shape=_shape(decoder.batch_size, s),
151
)
152
153
initial_outputs_ta = tf.nest.map_structure(
154
_create_ta, decoder.output_size, decoder.output_dtype
155
)
156
157
def condition(
158
unused_time,
159
unused_outputs_ta,
160
unused_state,
161
unused_inputs,
162
finished,
163
unused_sequence_lengths,
164
):
165
return tf.logical_not(tf.reduce_all(finished))
166
167
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
168
"""Internal while_loop body.
169
Args:
170
time: scalar int32 tensor.
171
outputs_ta: structure of TensorArray.
172
state: (structure of) state tensors and TensorArrays.
173
inputs: (structure of) input tensors.
174
finished: bool tensor (keeping track of what's finished).
175
sequence_lengths: int32 tensor (keeping track of time of finish).
176
Returns:
177
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
178
next_sequence_lengths)`.
179
```
180
"""
181
(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
182
time, inputs, state, training
183
)
184
decoder_state_sequence_lengths = False
185
if decoder.tracks_own_finished:
186
next_finished = decoder_finished
187
lengths = getattr(decoder_state, "lengths", None)
188
if lengths is not None:
189
# sequence lengths are provided by decoder_state.lengths;
190
# overwrite our sequence lengths.
191
decoder_state_sequence_lengths = True
192
sequence_lengths = tf.cast(lengths, tf.int32)
193
else:
194
next_finished = tf.logical_or(decoder_finished, finished)
195
196
if decoder_state_sequence_lengths:
197
# Just pass something through the loop; at the next iteration
198
# we'll pull the sequence lengths from the decoder_state again.
199
next_sequence_lengths = sequence_lengths
200
else:
201
next_sequence_lengths = tf.where(
202
tf.logical_not(finished),
203
tf.fill(tf.shape(sequence_lengths), time + 1),
204
sequence_lengths,
205
)
206
207
tf.nest.assert_same_structure(state, decoder_state)
208
tf.nest.assert_same_structure(outputs_ta, next_outputs)
209
tf.nest.assert_same_structure(inputs, next_inputs)
210
211
# Zero out output values past finish
212
if impute_finished:
213
214
def zero_out_finished(out, zero):
215
if finished.shape.rank < zero.shape.rank:
216
broadcast_finished = tf.broadcast_to(
217
tf.expand_dims(finished, axis=-1), zero.shape
218
)
219
return tf.where(broadcast_finished, zero, out)
220
else:
221
return tf.where(finished, zero, out)
222
223
emit = tf.nest.map_structure(
224
zero_out_finished, next_outputs, zero_outputs
225
)
226
else:
227
emit = next_outputs
228
229
# Copy through states past finish
230
def _maybe_copy_state(new, cur):
231
# TensorArrays and scalar states get passed through.
232
if isinstance(cur, tf.TensorArray):
233
pass_through = True
234
else:
235
new.set_shape(cur.shape)
236
pass_through = new.shape.ndims == 0
237
if not pass_through:
238
broadcast_finished = tf.broadcast_to(
239
tf.expand_dims(finished, axis=-1), new.shape
240
)
241
return tf.where(broadcast_finished, cur, new)
242
else:
243
return new
244
245
if impute_finished:
246
next_state = tf.nest.map_structure(
247
_maybe_copy_state, decoder_state, state
248
)
249
else:
250
next_state = decoder_state
251
252
if enable_tflite_convertible:
253
# Reshape to 1-D.
254
emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit)
255
256
outputs_ta = tf.nest.map_structure(
257
lambda ta, out: ta.write(time, out), outputs_ta, emit
258
)
259
return (
260
time + 1,
261
outputs_ta,
262
next_state,
263
next_inputs,
264
next_finished,
265
next_sequence_lengths,
266
)
267
268
res = tf.while_loop(
269
condition,
270
body,
271
loop_vars=(
272
initial_time,
273
initial_outputs_ta,
274
initial_state,
275
initial_inputs,
276
initial_finished,
277
initial_sequence_lengths,
278
),
279
parallel_iterations=parallel_iterations,
280
maximum_iterations=maximum_iterations,
281
swap_memory=swap_memory,
282
)
283
284
final_outputs_ta = res[1]
285
final_state = res[2]
286
final_sequence_lengths = res[5]
287
288
final_outputs = tf.nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
289
290
try:
291
final_outputs, final_state = decoder.finalize(
292
final_outputs, final_state, final_sequence_lengths
293
)
294
except NotImplementedError:
295
pass
296
297
if not output_time_major:
298
if enable_tflite_convertible:
299
# Reshape the output to the original shape.
300
def _restore_batch(x):
301
return tf.expand_dims(x, [1])
302
303
final_outputs = tf.nest.map_structure(_restore_batch, final_outputs)
304
305
final_outputs = tf.nest.map_structure(_transpose_batch_time, final_outputs)
306
307
return final_outputs, final_state, final_sequence_lengths
308
309