Path: blob/master/tensorflow_tts/utils/decoder.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlow Authors, All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14from typing import Any, Optional, Tuple, Union1516import tensorflow as tf17from tensorflow.python.ops import control_flow_util18from tensorflow_addons.seq2seq import Decoder19from tensorflow_addons.seq2seq.decoder import (20BaseDecoder,21_prepend_batch,22_transpose_batch_time,23)24from tensorflow_addons.utils.types import Number, TensorLike252627def dynamic_decode(28decoder: Union[Decoder, BaseDecoder],29output_time_major: bool = False,30impute_finished: bool = False,31maximum_iterations: Optional[TensorLike] = None,32parallel_iterations: int = 32,33swap_memory: bool = False,34training: Optional[bool] = None,35scope: Optional[str] = None,36enable_tflite_convertible: bool = False,37**kwargs38) -> Tuple[Any, Any, Any]:39"""Perform dynamic decoding with `decoder`.40Calls initialize() once and step() repeatedly on the Decoder object.41Args:42decoder: A `Decoder` instance.43output_time_major: Python boolean. Default: `False` (batch major). If44`True`, outputs are returned as time major tensors (this mode is45faster). Otherwise, outputs are returned as batch major tensors (this46adds extra time to the computation).47impute_finished: Python boolean. If `True`, then states for batch48entries which are marked as finished get copied through and the49corresponding outputs get zeroed out. This causes some slowdown at50each time step, but ensures that the final state and outputs have51the correct values and that backprop ignores time steps that were52marked as finished.53maximum_iterations: A strictly positive `int32` scalar, the maximum54allowed number of decoding steps. Default is `None` (decode until the55decoder is fully done).56parallel_iterations: Argument passed to `tf.while_loop`.57swap_memory: Argument passed to `tf.while_loop`.58training: Python boolean. Indicates whether the layer should behave59in training mode or in inference mode. Only relevant60when `dropout` or `recurrent_dropout` is used.61scope: Optional name scope to use.62enable_tflite_convertible: Python boolean. If `True`, then the variables63of `TensorArray` become of 1-D static shape. Also zero pads in the64output tensor will be discarded. Default: `False`.65**kwargs: dict, other keyword arguments for dynamic_decode. It might66contain arguments for `BaseDecoder` to initialize, which takes all67tensor inputs during call().68Returns:69`(final_outputs, final_state, final_sequence_lengths)`.70Raises:71ValueError: if `maximum_iterations` is provided but is not a scalar.72"""73with tf.name_scope(scope or "decoder"):74is_xla = not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext(75tf.compat.v1.get_default_graph()76)7778if maximum_iterations is not None:79maximum_iterations = tf.convert_to_tensor(80maximum_iterations, dtype=tf.int32, name="maximum_iterations"81)82if maximum_iterations.shape.ndims != 0:83raise ValueError("maximum_iterations must be a scalar")84tf.debugging.assert_greater(85maximum_iterations,860,87message="maximum_iterations should be greater than 0",88)89elif is_xla:90raise ValueError("maximum_iterations is required for XLA compilation.")9192if isinstance(decoder, Decoder):93initial_finished, initial_inputs, initial_state = decoder.initialize()94else:95# For BaseDecoder that takes tensor inputs during call.96decoder_init_input = kwargs.pop("decoder_init_input", None)97decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})98initial_finished, initial_inputs, initial_state = decoder.initialize(99decoder_init_input, **decoder_init_kwargs100)101102if enable_tflite_convertible:103# Assume the batch_size = 1 for inference.104# So we can change 2-D TensorArray into 1-D by reshaping it.105zero_outputs = tf.nest.map_structure(106lambda shape, dtype: tf.reshape(107tf.zeros(_prepend_batch(decoder.batch_size, shape), dtype=dtype),108[-1],109),110decoder.output_size,111decoder.output_dtype,112)113else:114zero_outputs = tf.nest.map_structure(115lambda shape, dtype: tf.zeros(116_prepend_batch(decoder.batch_size, shape), dtype=dtype117),118decoder.output_size,119decoder.output_dtype,120)121122if maximum_iterations is not None:123initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations)124initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32)125initial_time = tf.constant(0, dtype=tf.int32)126127def _shape(batch_size, from_shape):128if not isinstance(from_shape, tf.TensorShape) or from_shape.ndims == 0:129return None130else:131batch_size = tf.get_static_value(132tf.convert_to_tensor(batch_size, name="batch_size")133)134if enable_tflite_convertible:135# Since we can't use 2-D TensoArray and assume `batch_size` = 1,136# we use `from_shape` dimension only.137return from_shape138return tf.TensorShape([batch_size]).concatenate(from_shape)139140dynamic_size = maximum_iterations is None or not is_xla141# The dynamic shape `TensoArray` is not allowed in TFLite yet.142dynamic_size = dynamic_size and (not enable_tflite_convertible)143144def _create_ta(s, d):145return tf.TensorArray(146dtype=d,147size=0 if dynamic_size else maximum_iterations,148dynamic_size=dynamic_size,149element_shape=_shape(decoder.batch_size, s),150)151152initial_outputs_ta = tf.nest.map_structure(153_create_ta, decoder.output_size, decoder.output_dtype154)155156def condition(157unused_time,158unused_outputs_ta,159unused_state,160unused_inputs,161finished,162unused_sequence_lengths,163):164return tf.logical_not(tf.reduce_all(finished))165166def body(time, outputs_ta, state, inputs, finished, sequence_lengths):167"""Internal while_loop body.168Args:169time: scalar int32 tensor.170outputs_ta: structure of TensorArray.171state: (structure of) state tensors and TensorArrays.172inputs: (structure of) input tensors.173finished: bool tensor (keeping track of what's finished).174sequence_lengths: int32 tensor (keeping track of time of finish).175Returns:176`(time + 1, outputs_ta, next_state, next_inputs, next_finished,177next_sequence_lengths)`.178```179"""180(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(181time, inputs, state, training182)183decoder_state_sequence_lengths = False184if decoder.tracks_own_finished:185next_finished = decoder_finished186lengths = getattr(decoder_state, "lengths", None)187if lengths is not None:188# sequence lengths are provided by decoder_state.lengths;189# overwrite our sequence lengths.190decoder_state_sequence_lengths = True191sequence_lengths = tf.cast(lengths, tf.int32)192else:193next_finished = tf.logical_or(decoder_finished, finished)194195if decoder_state_sequence_lengths:196# Just pass something through the loop; at the next iteration197# we'll pull the sequence lengths from the decoder_state again.198next_sequence_lengths = sequence_lengths199else:200next_sequence_lengths = tf.where(201tf.logical_not(finished),202tf.fill(tf.shape(sequence_lengths), time + 1),203sequence_lengths,204)205206tf.nest.assert_same_structure(state, decoder_state)207tf.nest.assert_same_structure(outputs_ta, next_outputs)208tf.nest.assert_same_structure(inputs, next_inputs)209210# Zero out output values past finish211if impute_finished:212213def zero_out_finished(out, zero):214if finished.shape.rank < zero.shape.rank:215broadcast_finished = tf.broadcast_to(216tf.expand_dims(finished, axis=-1), zero.shape217)218return tf.where(broadcast_finished, zero, out)219else:220return tf.where(finished, zero, out)221222emit = tf.nest.map_structure(223zero_out_finished, next_outputs, zero_outputs224)225else:226emit = next_outputs227228# Copy through states past finish229def _maybe_copy_state(new, cur):230# TensorArrays and scalar states get passed through.231if isinstance(cur, tf.TensorArray):232pass_through = True233else:234new.set_shape(cur.shape)235pass_through = new.shape.ndims == 0236if not pass_through:237broadcast_finished = tf.broadcast_to(238tf.expand_dims(finished, axis=-1), new.shape239)240return tf.where(broadcast_finished, cur, new)241else:242return new243244if impute_finished:245next_state = tf.nest.map_structure(246_maybe_copy_state, decoder_state, state247)248else:249next_state = decoder_state250251if enable_tflite_convertible:252# Reshape to 1-D.253emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit)254255outputs_ta = tf.nest.map_structure(256lambda ta, out: ta.write(time, out), outputs_ta, emit257)258return (259time + 1,260outputs_ta,261next_state,262next_inputs,263next_finished,264next_sequence_lengths,265)266267res = tf.while_loop(268condition,269body,270loop_vars=(271initial_time,272initial_outputs_ta,273initial_state,274initial_inputs,275initial_finished,276initial_sequence_lengths,277),278parallel_iterations=parallel_iterations,279maximum_iterations=maximum_iterations,280swap_memory=swap_memory,281)282283final_outputs_ta = res[1]284final_state = res[2]285final_sequence_lengths = res[5]286287final_outputs = tf.nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)288289try:290final_outputs, final_state = decoder.finalize(291final_outputs, final_state, final_sequence_lengths292)293except NotImplementedError:294pass295296if not output_time_major:297if enable_tflite_convertible:298# Reshape the output to the original shape.299def _restore_batch(x):300return tf.expand_dims(x, [1])301302final_outputs = tf.nest.map_structure(_restore_batch, final_outputs)303304final_outputs = tf.nest.map_structure(_transpose_batch_time, final_outputs)305306return final_outputs, final_state, final_sequence_lengths307308309