Path: blob/master/tensorflow_tts/models/tacotron2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415"""Tacotron-2 Modules."""1617import collections1819import numpy as np20import tensorflow as tf2122# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,23# uncomment this line.24# from tensorflow_addons.seq2seq import dynamic_decode25from tensorflow_addons.seq2seq import BahdanauAttention, Decoder, Sampler2627from tensorflow_tts.utils import dynamic_decode2829from tensorflow_tts.models import BaseModel303132def get_initializer(initializer_range=0.02):33"""Creates a `tf.initializers.truncated_normal` with the given range.34Args:35initializer_range: float, initializer range for stddev.36Returns:37TruncatedNormal initializer with stddev = `initializer_range`.38"""39return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)404142def gelu(x):43"""Gaussian Error Linear unit."""44cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))45return x * cdf464748def gelu_new(x):49"""Smoother gaussian Error Linear Unit."""50cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))51return x * cdf525354def swish(x):55"""Swish activation function."""56return tf.nn.swish(x)575859def mish(x):60return x * tf.math.tanh(tf.math.softplus(x))616263ACT2FN = {64"identity": tf.keras.layers.Activation("linear"),65"tanh": tf.keras.layers.Activation("tanh"),66"gelu": tf.keras.layers.Activation(gelu),67"relu": tf.keras.activations.relu,68"swish": tf.keras.layers.Activation(swish),69"gelu_new": tf.keras.layers.Activation(gelu_new),70"mish": tf.keras.layers.Activation(mish),71}727374class TFEmbedding(tf.keras.layers.Embedding):75"""Faster version of embedding."""7677def __init__(self, *args, **kwargs):78super().__init__(*args, **kwargs)7980def call(self, inputs):81inputs = tf.cast(tf.expand_dims(inputs, -1), tf.int32)82outputs = tf.gather_nd(self.embeddings, inputs)83return outputs848586class TFTacotronConvBatchNorm(tf.keras.layers.Layer):87"""Tacotron-2 Convolutional Batchnorm module."""8889def __init__(90self, filters, kernel_size, dropout_rate, activation=None, name_idx=None91):92super().__init__()93self.conv1d = tf.keras.layers.Conv1D(94filters,95kernel_size,96kernel_initializer=get_initializer(0.02),97padding="same",98name="conv_._{}".format(name_idx),99)100self.norm = tf.keras.layers.experimental.SyncBatchNormalization(101axis=-1, name="batch_norm_._{}".format(name_idx)102)103self.dropout = tf.keras.layers.Dropout(104rate=dropout_rate, name="dropout_._{}".format(name_idx)105)106self.act = ACT2FN[activation]107108def call(self, inputs, training=False):109outputs = self.conv1d(inputs)110outputs = self.norm(outputs, training=training)111outputs = self.act(outputs)112outputs = self.dropout(outputs, training=training)113return outputs114115116class TFTacotronEmbeddings(tf.keras.layers.Layer):117"""Construct character/phoneme/positional/speaker embeddings."""118119def __init__(self, config, **kwargs):120"""Init variables."""121super().__init__(**kwargs)122self.vocab_size = config.vocab_size123self.embedding_hidden_size = config.embedding_hidden_size124self.initializer_range = config.initializer_range125self.config = config126127if config.n_speakers > 1:128self.speaker_embeddings = TFEmbedding(129config.n_speakers,130config.embedding_hidden_size,131embeddings_initializer=get_initializer(self.initializer_range),132name="speaker_embeddings",133)134self.speaker_fc = tf.keras.layers.Dense(135units=config.embedding_hidden_size, name="speaker_fc"136)137138self.LayerNorm = tf.keras.layers.LayerNormalization(139epsilon=config.layer_norm_eps, name="LayerNorm"140)141self.dropout = tf.keras.layers.Dropout(config.embedding_dropout_prob)142143def build(self, input_shape):144"""Build shared character/phoneme embedding layers."""145with tf.name_scope("character_embeddings"):146self.character_embeddings = self.add_weight(147"weight",148shape=[self.vocab_size, self.embedding_hidden_size],149initializer=get_initializer(self.initializer_range),150)151super().build(input_shape)152153def call(self, inputs, training=False):154"""Get character embeddings of inputs.155Args:1561. character, Tensor (int32) shape [batch_size, length].1572. speaker_id, Tensor (int32) shape [batch_size]158Returns:159Tensor (float32) shape [batch_size, length, embedding_size].160"""161return self._embedding(inputs, training=training)162163def _embedding(self, inputs, training=False):164"""Applies embedding based on inputs tensor."""165input_ids, speaker_ids = inputs166167# create embeddings168inputs_embeds = tf.gather(self.character_embeddings, input_ids)169embeddings = inputs_embeds170171if self.config.n_speakers > 1:172speaker_embeddings = self.speaker_embeddings(speaker_ids)173speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))174# extended speaker embeddings175extended_speaker_features = speaker_features[:, tf.newaxis, :]176# sum all embedding177embeddings += extended_speaker_features178179# apply layer-norm and dropout for embeddings.180embeddings = self.LayerNorm(embeddings)181embeddings = self.dropout(embeddings, training=training)182183return embeddings184185186class TFTacotronEncoderConvs(tf.keras.layers.Layer):187"""Tacotron-2 Encoder Convolutional Batchnorm module."""188189def __init__(self, config, **kwargs):190"""Init variables."""191super().__init__(**kwargs)192self.conv_batch_norm = []193for i in range(config.n_conv_encoder):194conv = TFTacotronConvBatchNorm(195filters=config.encoder_conv_filters,196kernel_size=config.encoder_conv_kernel_sizes,197activation=config.encoder_conv_activation,198dropout_rate=config.encoder_conv_dropout_rate,199name_idx=i,200)201self.conv_batch_norm.append(conv)202203def call(self, inputs, training=False):204"""Call logic."""205outputs = inputs206for conv in self.conv_batch_norm:207outputs = conv(outputs, training=training)208return outputs209210211class TFTacotronEncoder(tf.keras.layers.Layer):212"""Tacotron-2 Encoder."""213214def __init__(self, config, **kwargs):215"""Init variables."""216super().__init__(**kwargs)217self.embeddings = TFTacotronEmbeddings(config, name="embeddings")218self.convbn = TFTacotronEncoderConvs(config, name="conv_batch_norm")219self.bilstm = tf.keras.layers.Bidirectional(220tf.keras.layers.LSTM(221units=config.encoder_lstm_units, return_sequences=True222),223name="bilstm",224)225226if config.n_speakers > 1:227self.encoder_speaker_embeddings = TFEmbedding(228config.n_speakers,229config.embedding_hidden_size,230embeddings_initializer=get_initializer(config.initializer_range),231name="encoder_speaker_embeddings",232)233self.encoder_speaker_fc = tf.keras.layers.Dense(234units=config.encoder_lstm_units * 2, name="encoder_speaker_fc"235)236237self.config = config238239def call(self, inputs, training=False):240"""Call logic."""241input_ids, speaker_ids, input_mask = inputs242243# create embedding and mask them since we sum244# speaker embedding to all character embedding.245input_embeddings = self.embeddings([input_ids, speaker_ids], training=training)246247# pass embeddings to convolution batch norm248conv_outputs = self.convbn(input_embeddings, training=training)249250# bi-lstm.251outputs = self.bilstm(conv_outputs, mask=input_mask)252253if self.config.n_speakers > 1:254encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)255encoder_speaker_features = tf.math.softplus(256self.encoder_speaker_fc(encoder_speaker_embeddings)257)258# extended encoderspeaker embeddings259extended_encoder_speaker_features = encoder_speaker_features[260:, tf.newaxis, :261]262# sum to encoder outputs263outputs += extended_encoder_speaker_features264265return outputs266267268class Tacotron2Sampler(Sampler):269"""Tacotron2 sampler for Seq2Seq training."""270271def __init__(272self, config,273):274super().__init__()275self.config = config276# create schedule factor.277# the input of a next decoder cell is calculated by formular:278# next_inputs = ratio * prev_groundtruth_outputs + (1.0 - ratio) * prev_predicted_outputs.279self._ratio = tf.constant(1.0, dtype=tf.float32)280self._reduction_factor = self.config.reduction_factor281282def setup_target(self, targets, mel_lengths):283"""Setup ground-truth mel outputs for decoder."""284self.mel_lengths = mel_lengths285self.set_batch_size(tf.shape(targets)[0])286self.targets = targets[287:, self._reduction_factor - 1 :: self._reduction_factor, :288]289self.max_lengths = tf.tile([tf.shape(self.targets)[1]], [self._batch_size])290291@property292def batch_size(self):293return self._batch_size294295@property296def sample_ids_shape(self):297return tf.TensorShape([])298299@property300def sample_ids_dtype(self):301return tf.int32302303@property304def reduction_factor(self):305return self._reduction_factor306307def initialize(self):308"""Return (Finished, next_inputs)."""309return (310tf.tile([False], [self._batch_size]),311tf.tile([[0.0]], [self._batch_size, self.config.n_mels]),312)313314def sample(self, time, outputs, state):315return tf.tile([0], [self._batch_size])316317def next_inputs(318self,319time,320outputs,321state,322sample_ids,323stop_token_prediction,324training=False,325**kwargs,326):327if training:328finished = time + 1 >= self.max_lengths329next_inputs = (330self._ratio * self.targets[:, time, :]331+ (1.0 - self._ratio) * outputs[:, -self.config.n_mels :]332)333next_state = state334return (finished, next_inputs, next_state)335else:336stop_token_prediction = tf.nn.sigmoid(stop_token_prediction)337finished = tf.cast(tf.round(stop_token_prediction), tf.bool)338finished = tf.reduce_all(finished)339next_inputs = outputs[:, -self.config.n_mels :]340next_state = state341return (finished, next_inputs, next_state)342343def set_batch_size(self, batch_size):344self._batch_size = batch_size345346347class TFTacotronLocationSensitiveAttention(BahdanauAttention):348"""Tacotron-2 Location Sensitive Attention module."""349350def __init__(351self,352config,353memory,354mask_encoder=True,355memory_sequence_length=None,356is_cumulate=True,357):358"""Init variables."""359memory_length = memory_sequence_length if (mask_encoder is True) else None360super().__init__(361units=config.attention_dim,362memory=memory,363memory_sequence_length=memory_length,364probability_fn="softmax",365name="LocationSensitiveAttention",366)367self.location_convolution = tf.keras.layers.Conv1D(368filters=config.attention_filters,369kernel_size=config.attention_kernel,370padding="same",371use_bias=False,372name="location_conv",373)374self.location_layer = tf.keras.layers.Dense(375units=config.attention_dim, use_bias=False, name="location_layer"376)377378self.v = tf.keras.layers.Dense(1, use_bias=True, name="scores_attention")379self.config = config380self.is_cumulate = is_cumulate381self.use_window = False382383def setup_window(self, win_front=2, win_back=4):384self.win_front = tf.constant(win_front, tf.int32)385self.win_back = tf.constant(win_back, tf.int32)386387self._indices = tf.expand_dims(tf.range(tf.shape(self.keys)[1]), 0)388self._indices = tf.tile(389self._indices, [tf.shape(self.keys)[0], 1]390) # [batch_size, max_time]391392self.use_window = True393394def _compute_window_mask(self, max_alignments):395"""Compute window mask for inference.396Args:397max_alignments (int): [batch_size]398"""399expanded_max_alignments = tf.expand_dims(max_alignments, 1) # [batch_size, 1]400low = expanded_max_alignments - self.win_front401high = expanded_max_alignments + self.win_back402mlow = tf.cast((self._indices < low), tf.float32)403mhigh = tf.cast((self._indices > high), tf.float32)404mask = mlow + mhigh405return mask # [batch_size, max_length]406407def __call__(self, inputs, training=False):408query, state, prev_max_alignments = inputs409410processed_query = self.query_layer(query) if self.query_layer else query411processed_query = tf.expand_dims(processed_query, 1)412413expanded_alignments = tf.expand_dims(state, axis=2)414f = self.location_convolution(expanded_alignments)415processed_location_features = self.location_layer(f)416417energy = self._location_sensitive_score(418processed_query, processed_location_features, self.keys419)420421# mask energy on inference steps.422if self.use_window is True:423window_mask = self._compute_window_mask(prev_max_alignments)424energy = energy + window_mask * -1e20425426alignments = self.probability_fn(energy, state)427428if self.is_cumulate:429state = alignments + state430else:431state = alignments432433expanded_alignments = tf.expand_dims(alignments, 2)434context = tf.reduce_sum(expanded_alignments * self.values, 1)435436return context, alignments, state437438def _location_sensitive_score(self, W_query, W_fil, W_keys):439"""Calculate location sensitive energy."""440return tf.squeeze(self.v(tf.nn.tanh(W_keys + W_query + W_fil)), -1)441442def get_initial_state(self, batch_size, size):443"""Get initial alignments."""444return tf.zeros(shape=[batch_size, size], dtype=tf.float32)445446def get_initial_context(self, batch_size):447"""Get initial attention."""448return tf.zeros(449shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32450)451452453class TFTacotronPrenet(tf.keras.layers.Layer):454"""Tacotron-2 prenet."""455456def __init__(self, config, **kwargs):457"""Init variables."""458super().__init__(**kwargs)459self.prenet_dense = [460tf.keras.layers.Dense(461units=config.prenet_units,462activation=ACT2FN[config.prenet_activation],463name="dense_._{}".format(i),464)465for i in range(config.n_prenet_layers)466]467self.dropout = tf.keras.layers.Dropout(468rate=config.prenet_dropout_rate, name="dropout"469)470471def call(self, inputs, training=False):472"""Call logic."""473outputs = inputs474for layer in self.prenet_dense:475outputs = layer(outputs)476outputs = self.dropout(outputs, training=True)477return outputs478479480class TFTacotronPostnet(tf.keras.layers.Layer):481"""Tacotron-2 postnet."""482483def __init__(self, config, **kwargs):484"""Init variables."""485super().__init__(**kwargs)486self.conv_batch_norm = []487for i in range(config.n_conv_postnet):488conv = TFTacotronConvBatchNorm(489filters=config.postnet_conv_filters,490kernel_size=config.postnet_conv_kernel_sizes,491dropout_rate=config.postnet_dropout_rate,492activation="identity" if i + 1 == config.n_conv_postnet else "tanh",493name_idx=i,494)495self.conv_batch_norm.append(conv)496497def call(self, inputs, training=False):498"""Call logic."""499outputs = inputs500for _, conv in enumerate(self.conv_batch_norm):501outputs = conv(outputs, training=training)502return outputs503504505TFTacotronDecoderCellState = collections.namedtuple(506"TFTacotronDecoderCellState",507[508"attention_lstm_state",509"decoder_lstms_state",510"context",511"time",512"state",513"alignment_history",514"max_alignments",515],516)517518TFDecoderOutput = collections.namedtuple(519"TFDecoderOutput", ("mel_output", "token_output", "sample_id")520)521522523class TFTacotronDecoderCell(tf.keras.layers.AbstractRNNCell):524"""Tacotron-2 custom decoder cell."""525526def __init__(self, config, enable_tflite_convertible=False, **kwargs):527"""Init variables."""528super().__init__(**kwargs)529self.enable_tflite_convertible = enable_tflite_convertible530self.prenet = TFTacotronPrenet(config, name="prenet")531532# define lstm cell on decoder.533# TODO(@dathudeptrai) switch to zone-out lstm.534self.attention_lstm = tf.keras.layers.LSTMCell(535units=config.decoder_lstm_units, name="attention_lstm_cell"536)537lstm_cells = []538for i in range(config.n_lstm_decoder):539lstm_cell = tf.keras.layers.LSTMCell(540units=config.decoder_lstm_units, name="lstm_cell_._{}".format(i)541)542lstm_cells.append(lstm_cell)543self.decoder_lstms = tf.keras.layers.StackedRNNCells(544lstm_cells, name="decoder_lstms"545)546547# define attention layer.548if config.attention_type == "lsa":549# create location-sensitive attention.550self.attention_layer = TFTacotronLocationSensitiveAttention(551config,552memory=None,553mask_encoder=True,554memory_sequence_length=None,555is_cumulate=True,556)557else:558raise ValueError("Only lsa (location-sensitive attention) is supported")559560# frame, stop projection layer.561self.frame_projection = tf.keras.layers.Dense(562units=config.n_mels * config.reduction_factor, name="frame_projection"563)564self.stop_projection = tf.keras.layers.Dense(565units=config.reduction_factor, name="stop_projection"566)567568self.config = config569570def set_alignment_size(self, alignment_size):571self.alignment_size = alignment_size572573@property574def output_size(self):575"""Return output (mel) size."""576return self.frame_projection.units577578@property579def state_size(self):580"""Return hidden state size."""581return TFTacotronDecoderCellState(582attention_lstm_state=self.attention_lstm.state_size,583decoder_lstms_state=self.decoder_lstms.state_size,584time=tf.TensorShape([]),585attention=self.config.attention_dim,586state=self.alignment_size,587alignment_history=(),588max_alignments=tf.TensorShape([1]),589)590591def get_initial_state(self, batch_size):592"""Get initial states."""593initial_attention_lstm_cell_states = self.attention_lstm.get_initial_state(594None, batch_size, dtype=tf.float32595)596initial_decoder_lstms_cell_states = self.decoder_lstms.get_initial_state(597None, batch_size, dtype=tf.float32598)599initial_context = tf.zeros(600shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32601)602initial_state = self.attention_layer.get_initial_state(603batch_size, size=self.alignment_size604)605if self.enable_tflite_convertible:606initial_alignment_history = ()607else:608initial_alignment_history = tf.TensorArray(609dtype=tf.float32, size=0, dynamic_size=True610)611return TFTacotronDecoderCellState(612attention_lstm_state=initial_attention_lstm_cell_states,613decoder_lstms_state=initial_decoder_lstms_cell_states,614time=tf.zeros([], dtype=tf.int32),615context=initial_context,616state=initial_state,617alignment_history=initial_alignment_history,618max_alignments=tf.zeros([batch_size], dtype=tf.int32),619)620621def call(self, inputs, states, training=False):622"""Call logic."""623decoder_input = inputs624625# 1. apply prenet for decoder_input.626prenet_out = self.prenet(decoder_input, training=training) # [batch_size, dim]627628# 2. concat prenet_out and prev context vector629# then use it as input of attention lstm layer.630attention_lstm_input = tf.concat([prenet_out, states.context], axis=-1)631attention_lstm_output, next_attention_lstm_state = self.attention_lstm(632attention_lstm_input, states.attention_lstm_state633)634635# 3. compute context, alignment and cumulative alignment.636prev_state = states.state637if not self.enable_tflite_convertible:638prev_alignment_history = states.alignment_history639prev_max_alignments = states.max_alignments640context, alignments, state = self.attention_layer(641[attention_lstm_output, prev_state, prev_max_alignments], training=training,642)643644# 4. run decoder lstm(s)645decoder_lstms_input = tf.concat([attention_lstm_output, context], axis=-1)646decoder_lstms_output, next_decoder_lstms_state = self.decoder_lstms(647decoder_lstms_input, states.decoder_lstms_state648)649650# 5. compute frame feature and stop token.651projection_inputs = tf.concat([decoder_lstms_output, context], axis=-1)652decoder_outputs = self.frame_projection(projection_inputs)653654stop_inputs = tf.concat([decoder_lstms_output, decoder_outputs], axis=-1)655stop_tokens = self.stop_projection(stop_inputs)656657# 6. save alignment history to visualize.658if self.enable_tflite_convertible:659alignment_history = ()660else:661alignment_history = prev_alignment_history.write(states.time, alignments)662663# 7. return new states.664new_states = TFTacotronDecoderCellState(665attention_lstm_state=next_attention_lstm_state,666decoder_lstms_state=next_decoder_lstms_state,667time=states.time + 1,668context=context,669state=state,670alignment_history=alignment_history,671max_alignments=tf.argmax(alignments, -1, output_type=tf.int32),672)673674return (decoder_outputs, stop_tokens), new_states675676677class TFTacotronDecoder(Decoder):678"""Tacotron-2 Decoder."""679680def __init__(681self,682decoder_cell,683decoder_sampler,684output_layer=None,685enable_tflite_convertible=False,686):687"""Initial variables."""688self.cell = decoder_cell689self.sampler = decoder_sampler690self.output_layer = output_layer691self.enable_tflite_convertible = enable_tflite_convertible692693def setup_decoder_init_state(self, decoder_init_state):694self.initial_state = decoder_init_state695696def initialize(self, **kwargs):697return self.sampler.initialize() + (self.initial_state,)698699@property700def output_size(self):701return TFDecoderOutput(702mel_output=tf.nest.map_structure(703lambda shape: tf.TensorShape(shape), self.cell.output_size704),705token_output=tf.TensorShape(self.sampler.reduction_factor),706sample_id=tf.TensorShape([1])707if self.enable_tflite_convertible708else self.sampler.sample_ids_shape, # tf.TensorShape([])709)710711@property712def output_dtype(self):713return TFDecoderOutput(tf.float32, tf.float32, self.sampler.sample_ids_dtype)714715@property716def batch_size(self):717return self.sampler._batch_size718719def step(self, time, inputs, state, training=False):720(mel_outputs, stop_tokens), cell_state = self.cell(721inputs, state, training=training722)723if self.output_layer is not None:724mel_outputs = self.output_layer(mel_outputs)725sample_ids = self.sampler.sample(726time=time, outputs=mel_outputs, state=cell_state727)728(finished, next_inputs, next_state) = self.sampler.next_inputs(729time=time,730outputs=mel_outputs,731state=cell_state,732sample_ids=sample_ids,733stop_token_prediction=stop_tokens,734training=training,735)736737outputs = TFDecoderOutput(mel_outputs, stop_tokens, sample_ids)738return (outputs, next_state, next_inputs, finished)739740741class TFTacotron2(BaseModel):742"""Tensorflow tacotron-2 model."""743744def __init__(self, config, **kwargs):745"""Initalize tacotron-2 layers."""746enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)747super().__init__(self, **kwargs)748self.encoder = TFTacotronEncoder(config, name="encoder")749self.decoder_cell = TFTacotronDecoderCell(750config,751name="decoder_cell",752enable_tflite_convertible=enable_tflite_convertible,753)754self.decoder = TFTacotronDecoder(755self.decoder_cell,756Tacotron2Sampler(config),757enable_tflite_convertible=enable_tflite_convertible,758)759self.postnet = TFTacotronPostnet(config, name="post_net")760self.post_projection = tf.keras.layers.Dense(761units=config.n_mels, name="residual_projection"762)763764self.use_window_mask = False765self.maximum_iterations = 4000766self.enable_tflite_convertible = enable_tflite_convertible767self.config = config768769def setup_window(self, win_front, win_back):770"""Call only for inference."""771self.use_window_mask = True772self.win_front = win_front773self.win_back = win_back774775def setup_maximum_iterations(self, maximum_iterations):776"""Call only for inference."""777self.maximum_iterations = maximum_iterations778779def _build(self):780input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]])781input_lengths = np.array([9])782speaker_ids = np.array([0])783mel_outputs = np.random.normal(size=(1, 50, 80)).astype(np.float32)784mel_lengths = np.array([50])785self(786input_ids,787input_lengths,788speaker_ids,789mel_outputs,790mel_lengths,79110,792training=True,793)794795def call(796self,797input_ids,798input_lengths,799speaker_ids,800mel_gts,801mel_lengths,802maximum_iterations=None,803use_window_mask=False,804win_front=2,805win_back=3,806training=False,807**kwargs,808):809"""Call logic."""810# create input-mask based on input_lengths811input_mask = tf.sequence_mask(812input_lengths,813maxlen=tf.reduce_max(input_lengths),814name="input_sequence_masks",815)816817# Encoder Step.818encoder_hidden_states = self.encoder(819[input_ids, speaker_ids, input_mask], training=training820)821822batch_size = tf.shape(encoder_hidden_states)[0]823alignment_size = tf.shape(encoder_hidden_states)[1]824825# Setup some initial placeholders for decoder step. Include:826# 1. mel_gts, mel_lengths for teacher forcing mode.827# 2. alignment_size for attention size.828# 3. initial state for decoder cell.829# 4. memory (encoder hidden state) for attention mechanism.830self.decoder.sampler.setup_target(targets=mel_gts, mel_lengths=mel_lengths)831self.decoder.cell.set_alignment_size(alignment_size)832self.decoder.setup_decoder_init_state(833self.decoder.cell.get_initial_state(batch_size)834)835self.decoder.cell.attention_layer.setup_memory(836memory=encoder_hidden_states,837memory_sequence_length=input_lengths, # use for mask attention.838)839if use_window_mask:840self.decoder.cell.attention_layer.setup_window(841win_front=win_front, win_back=win_back842)843844# run decode step.845(846(frames_prediction, stop_token_prediction, _),847final_decoder_state,848_,849) = dynamic_decode(850self.decoder,851maximum_iterations=maximum_iterations,852enable_tflite_convertible=self.enable_tflite_convertible,853training=training,854)855856decoder_outputs = tf.reshape(857frames_prediction, [batch_size, -1, self.config.n_mels]858)859stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])860861residual = self.postnet(decoder_outputs, training=training)862residual_projection = self.post_projection(residual)863864mel_outputs = decoder_outputs + residual_projection865866if self.enable_tflite_convertible:867mask = tf.math.not_equal(868tf.cast(869tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32870),8710,872)873decoder_outputs = tf.expand_dims(874tf.boolean_mask(decoder_outputs, mask), axis=0875)876mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0)877alignment_history = ()878else:879alignment_history = tf.transpose(880final_decoder_state.alignment_history.stack(), [1, 2, 0]881)882883return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history884885@tf.function(886experimental_relax_shapes=True,887input_signature=[888tf.TensorSpec([None, None], dtype=tf.int32, name="input_ids"),889tf.TensorSpec([None,], dtype=tf.int32, name="input_lengths"),890tf.TensorSpec([None,], dtype=tf.int32, name="speaker_ids"),891],892)893def inference(self, input_ids, input_lengths, speaker_ids, **kwargs):894"""Call logic."""895# create input-mask based on input_lengths896input_mask = tf.sequence_mask(897input_lengths,898maxlen=tf.reduce_max(input_lengths),899name="input_sequence_masks",900)901902# Encoder Step.903encoder_hidden_states = self.encoder(904[input_ids, speaker_ids, input_mask], training=False905)906907batch_size = tf.shape(encoder_hidden_states)[0]908alignment_size = tf.shape(encoder_hidden_states)[1]909910# Setup some initial placeholders for decoder step. Include:911# 1. batch_size for inference.912# 2. alignment_size for attention size.913# 3. initial state for decoder cell.914# 4. memory (encoder hidden state) for attention mechanism.915# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)916self.decoder.sampler.set_batch_size(batch_size)917self.decoder.cell.set_alignment_size(alignment_size)918self.decoder.setup_decoder_init_state(919self.decoder.cell.get_initial_state(batch_size)920)921self.decoder.cell.attention_layer.setup_memory(922memory=encoder_hidden_states,923memory_sequence_length=input_lengths, # use for mask attention.924)925if self.use_window_mask:926self.decoder.cell.attention_layer.setup_window(927win_front=self.win_front, win_back=self.win_back928)929930# run decode step.931(932(frames_prediction, stop_token_prediction, _),933final_decoder_state,934_,935) = dynamic_decode(936self.decoder, maximum_iterations=self.maximum_iterations, training=False937)938939decoder_outputs = tf.reshape(940frames_prediction, [batch_size, -1, self.config.n_mels]941)942stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1])943944residual = self.postnet(decoder_outputs, training=False)945residual_projection = self.post_projection(residual)946947mel_outputs = decoder_outputs + residual_projection948949alignment_historys = tf.transpose(950final_decoder_state.alignment_history.stack(), [1, 2, 0]951)952953return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys954955@tf.function(956experimental_relax_shapes=True,957input_signature=[958tf.TensorSpec([1, None], dtype=tf.int32, name="input_ids"),959tf.TensorSpec([1,], dtype=tf.int32, name="input_lengths"),960tf.TensorSpec([1,], dtype=tf.int32, name="speaker_ids"),961],962)963def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs):964"""Call logic."""965# create input-mask based on input_lengths966input_mask = tf.sequence_mask(967input_lengths,968maxlen=tf.reduce_max(input_lengths),969name="input_sequence_masks",970)971972# Encoder Step.973encoder_hidden_states = self.encoder(974[input_ids, speaker_ids, input_mask], training=False975)976977batch_size = tf.shape(encoder_hidden_states)[0]978alignment_size = tf.shape(encoder_hidden_states)[1]979980# Setup some initial placeholders for decoder step. Include:981# 1. batch_size for inference.982# 2. alignment_size for attention size.983# 3. initial state for decoder cell.984# 4. memory (encoder hidden state) for attention mechanism.985# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)986self.decoder.sampler.set_batch_size(batch_size)987self.decoder.cell.set_alignment_size(alignment_size)988self.decoder.setup_decoder_init_state(989self.decoder.cell.get_initial_state(batch_size)990)991self.decoder.cell.attention_layer.setup_memory(992memory=encoder_hidden_states,993memory_sequence_length=input_lengths, # use for mask attention.994)995if self.use_window_mask:996self.decoder.cell.attention_layer.setup_window(997win_front=self.win_front, win_back=self.win_back998)9991000# run decode step.1001(1002(frames_prediction, stop_token_prediction, _),1003final_decoder_state,1004_,1005) = dynamic_decode(1006self.decoder,1007maximum_iterations=self.maximum_iterations,1008enable_tflite_convertible=self.enable_tflite_convertible,1009training=False,1010)10111012decoder_outputs = tf.reshape(1013frames_prediction, [batch_size, -1, self.config.n_mels]1014)1015stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1])10161017residual = self.postnet(decoder_outputs, training=False)1018residual_projection = self.post_projection(residual)10191020mel_outputs = decoder_outputs + residual_projection10211022if self.enable_tflite_convertible:1023mask = tf.math.not_equal(1024tf.cast(1025tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int321026),10270,1028)1029decoder_outputs = tf.expand_dims(1030tf.boolean_mask(decoder_outputs, mask), axis=01031)1032mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0)1033alignment_historys = ()1034else:1035alignment_historys = tf.transpose(1036final_decoder_state.alignment_history.stack(), [1, 2, 0]1037)10381039return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys104010411042