Path: blob/master/tensorflow_tts/models/fastspeech.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The FastSpeech Authors, The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tensorflow Model modules for FastSpeech."""1516import numpy as np17import tensorflow as tf1819from tensorflow_tts.models import BaseModel202122def get_initializer(initializer_range=0.02):23"""Creates a `tf.initializers.truncated_normal` with the given range.2425Args:26initializer_range: float, initializer range for stddev.2728Returns:29TruncatedNormal initializer with stddev = `initializer_range`.3031"""32return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)333435def gelu(x):36"""Gaussian Error Linear unit."""37cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))38return x * cdf394041def gelu_new(x):42"""Smoother gaussian Error Linear Unit."""43cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))44return x * cdf454647def swish(x):48"""Swish activation function."""49return tf.nn.swish(x)505152def mish(x):53return x * tf.math.tanh(tf.math.softplus(x))545556ACT2FN = {57"identity": tf.keras.layers.Activation("linear"),58"tanh": tf.keras.layers.Activation("tanh"),59"gelu": tf.keras.layers.Activation(gelu),60"relu": tf.keras.activations.relu,61"swish": tf.keras.layers.Activation(swish),62"gelu_new": tf.keras.layers.Activation(gelu_new),63"mish": tf.keras.layers.Activation(mish),64}656667class TFEmbedding(tf.keras.layers.Embedding):68"""Faster version of embedding."""6970def __init__(self, *args, **kwargs):71super().__init__(*args, **kwargs)7273def call(self, inputs):74inputs = tf.cast(inputs, tf.int32)75outputs = tf.gather(self.embeddings, inputs)76return outputs777879class TFFastSpeechEmbeddings(tf.keras.layers.Layer):80"""Construct charactor/phoneme/positional/speaker embeddings."""8182def __init__(self, config, **kwargs):83"""Init variables."""84super().__init__(**kwargs)85self.vocab_size = config.vocab_size86self.hidden_size = config.encoder_self_attention_params.hidden_size87self.initializer_range = config.initializer_range88self.config = config8990self.position_embeddings = TFEmbedding(91config.max_position_embeddings + 1,92self.hidden_size,93weights=[94self._sincos_embedding(95self.hidden_size, self.config.max_position_embeddings96)97],98name="position_embeddings",99trainable=False,100)101102if config.n_speakers > 1:103self.encoder_speaker_embeddings = TFEmbedding(104config.n_speakers,105self.hidden_size,106embeddings_initializer=get_initializer(self.initializer_range),107name="speaker_embeddings",108)109self.speaker_fc = tf.keras.layers.Dense(110units=self.hidden_size, name="speaker_fc"111)112113def build(self, input_shape):114"""Build shared charactor/phoneme embedding layers."""115with tf.name_scope("charactor_embeddings"):116self.charactor_embeddings = self.add_weight(117"weight",118shape=[self.vocab_size, self.hidden_size],119initializer=get_initializer(self.initializer_range),120)121super().build(input_shape)122123def call(self, inputs, training=False):124"""Get charactor embeddings of inputs.125126Args:1271. charactor, Tensor (int32) shape [batch_size, length].1282. speaker_id, Tensor (int32) shape [batch_size]129Returns:130Tensor (float32) shape [batch_size, length, embedding_size].131132"""133return self._embedding(inputs, training=training)134135def _embedding(self, inputs, training=False):136"""Applies embedding based on inputs tensor."""137input_ids, speaker_ids = inputs138139input_shape = tf.shape(input_ids)140seq_length = input_shape[1]141142position_ids = tf.range(1, seq_length + 1, dtype=tf.int32)[tf.newaxis, :]143144# create embeddings145inputs_embeds = tf.gather(self.charactor_embeddings, input_ids)146position_embeddings = self.position_embeddings(position_ids)147148# sum embedding149embeddings = inputs_embeds + tf.cast(position_embeddings, inputs_embeds.dtype)150if self.config.n_speakers > 1:151speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)152speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))153# extended speaker embeddings154extended_speaker_features = speaker_features[:, tf.newaxis, :]155embeddings += extended_speaker_features156157return embeddings158159def _sincos_embedding(160self, hidden_size, max_positional_embedding,161):162position_enc = np.array(163[164[165pos / np.power(10000, 2.0 * (i // 2) / hidden_size)166for i in range(hidden_size)167]168for pos in range(max_positional_embedding + 1)169]170)171172position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])173position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])174175# pad embedding.176position_enc[0] = 0.0177178return position_enc179180def resize_positional_embeddings(self, new_size):181self.position_embeddings = TFEmbedding(182new_size + 1,183self.hidden_size,184weights=[self._sincos_embedding(self.hidden_size, new_size)],185name="position_embeddings",186trainable=False,187)188189190class TFFastSpeechSelfAttention(tf.keras.layers.Layer):191"""Self attention module for fastspeech."""192193def __init__(self, config, **kwargs):194"""Init variables."""195super().__init__(**kwargs)196if config.hidden_size % config.num_attention_heads != 0:197raise ValueError(198"The hidden size (%d) is not a multiple of the number of attention "199"heads (%d)" % (config.hidden_size, config.num_attention_heads)200)201self.output_attentions = config.output_attentions202self.num_attention_heads = config.num_attention_heads203self.all_head_size = self.num_attention_heads * config.attention_head_size204205self.query = tf.keras.layers.Dense(206self.all_head_size,207kernel_initializer=get_initializer(config.initializer_range),208name="query",209)210self.key = tf.keras.layers.Dense(211self.all_head_size,212kernel_initializer=get_initializer(config.initializer_range),213name="key",214)215self.value = tf.keras.layers.Dense(216self.all_head_size,217kernel_initializer=get_initializer(config.initializer_range),218name="value",219)220221self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)222self.config = config223224def transpose_for_scores(self, x, batch_size):225"""Transpose to calculate attention scores."""226x = tf.reshape(227x,228(batch_size, -1, self.num_attention_heads, self.config.attention_head_size),229)230return tf.transpose(x, perm=[0, 2, 1, 3])231232def call(self, inputs, training=False):233"""Call logic."""234hidden_states, attention_mask = inputs235236batch_size = tf.shape(hidden_states)[0]237mixed_query_layer = self.query(hidden_states)238mixed_key_layer = self.key(hidden_states)239mixed_value_layer = self.value(hidden_states)240241query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)242key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)243value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)244245attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)246dk = tf.cast(247tf.shape(key_layer)[-1], attention_scores.dtype248) # scale attention_scores249attention_scores = attention_scores / tf.math.sqrt(dk)250251if attention_mask is not None:252# extended_attention_masks for self attention encoder.253extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]254extended_attention_mask = tf.cast(255extended_attention_mask, attention_scores.dtype256)257extended_attention_mask = (1.0 - extended_attention_mask) * -1e9258attention_scores = attention_scores + extended_attention_mask259260# Normalize the attention scores to probabilities.261attention_probs = tf.nn.softmax(attention_scores, axis=-1)262attention_probs = self.dropout(attention_probs, training=training)263264context_layer = tf.matmul(attention_probs, value_layer)265context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])266context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))267268outputs = (269(context_layer, attention_probs)270if self.output_attentions271else (context_layer,)272)273return outputs274275276class TFFastSpeechSelfOutput(tf.keras.layers.Layer):277"""Fastspeech output of self attention module."""278279def __init__(self, config, **kwargs):280"""Init variables."""281super().__init__(**kwargs)282self.dense = tf.keras.layers.Dense(283config.hidden_size,284kernel_initializer=get_initializer(config.initializer_range),285name="dense",286)287self.LayerNorm = tf.keras.layers.LayerNormalization(288epsilon=config.layer_norm_eps, name="LayerNorm"289)290self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)291292def call(self, inputs, training=False):293"""Call logic."""294hidden_states, input_tensor = inputs295296hidden_states = self.dense(hidden_states)297hidden_states = self.dropout(hidden_states, training=training)298hidden_states = self.LayerNorm(hidden_states + input_tensor)299return hidden_states300301302class TFFastSpeechAttention(tf.keras.layers.Layer):303"""Fastspeech attention module."""304305def __init__(self, config, **kwargs):306"""Init variables."""307super().__init__(**kwargs)308self.self_attention = TFFastSpeechSelfAttention(config, name="self")309self.dense_output = TFFastSpeechSelfOutput(config, name="output")310311def call(self, inputs, training=False):312input_tensor, attention_mask = inputs313314self_outputs = self.self_attention(315[input_tensor, attention_mask], training=training316)317attention_output = self.dense_output(318[self_outputs[0], input_tensor], training=training319)320masked_attention_output = attention_output * tf.cast(321tf.expand_dims(attention_mask, 2), dtype=attention_output.dtype322)323outputs = (masked_attention_output,) + self_outputs[3241:325] # add attentions if we output them326return outputs327328329class TFFastSpeechIntermediate(tf.keras.layers.Layer):330"""Intermediate representation module."""331332def __init__(self, config, **kwargs):333"""Init variables."""334super().__init__(**kwargs)335self.conv1d_1 = tf.keras.layers.Conv1D(336config.intermediate_size,337kernel_size=config.intermediate_kernel_size,338kernel_initializer=get_initializer(config.initializer_range),339padding="same",340name="conv1d_1",341)342self.conv1d_2 = tf.keras.layers.Conv1D(343config.hidden_size,344kernel_size=config.intermediate_kernel_size,345kernel_initializer=get_initializer(config.initializer_range),346padding="same",347name="conv1d_2",348)349if isinstance(config.hidden_act, str):350self.intermediate_act_fn = ACT2FN[config.hidden_act]351else:352self.intermediate_act_fn = config.hidden_act353354def call(self, inputs):355"""Call logic."""356hidden_states, attention_mask = inputs357358hidden_states = self.conv1d_1(hidden_states)359hidden_states = self.intermediate_act_fn(hidden_states)360hidden_states = self.conv1d_2(hidden_states)361362masked_hidden_states = hidden_states * tf.cast(363tf.expand_dims(attention_mask, 2), dtype=hidden_states.dtype364)365return masked_hidden_states366367368class TFFastSpeechOutput(tf.keras.layers.Layer):369"""Output module."""370371def __init__(self, config, **kwargs):372"""Init variables."""373super().__init__(**kwargs)374self.LayerNorm = tf.keras.layers.LayerNormalization(375epsilon=config.layer_norm_eps, name="LayerNorm"376)377self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)378379def call(self, inputs, training=False):380"""Call logic."""381hidden_states, input_tensor = inputs382383hidden_states = self.dropout(hidden_states, training=training)384hidden_states = self.LayerNorm(hidden_states + input_tensor)385return hidden_states386387388class TFFastSpeechLayer(tf.keras.layers.Layer):389"""Fastspeech module (FFT module on the paper)."""390391def __init__(self, config, **kwargs):392"""Init variables."""393super().__init__(**kwargs)394self.attention = TFFastSpeechAttention(config, name="attention")395self.intermediate = TFFastSpeechIntermediate(config, name="intermediate")396self.bert_output = TFFastSpeechOutput(config, name="output")397398def call(self, inputs, training=False):399"""Call logic."""400hidden_states, attention_mask = inputs401402attention_outputs = self.attention(403[hidden_states, attention_mask], training=training404)405attention_output = attention_outputs[0]406intermediate_output = self.intermediate(407[attention_output, attention_mask], training=training408)409layer_output = self.bert_output(410[intermediate_output, attention_output], training=training411)412masked_layer_output = layer_output * tf.cast(413tf.expand_dims(attention_mask, 2), dtype=layer_output.dtype414)415outputs = (masked_layer_output,) + attention_outputs[4161:417] # add attentions if we output them418return outputs419420421class TFFastSpeechEncoder(tf.keras.layers.Layer):422"""Fast Speech encoder module."""423424def __init__(self, config, **kwargs):425"""Init variables."""426super().__init__(**kwargs)427self.output_attentions = config.output_attentions428self.output_hidden_states = config.output_hidden_states429self.layer = [430TFFastSpeechLayer(config, name="layer_._{}".format(i))431for i in range(config.num_hidden_layers)432]433434def call(self, inputs, training=False):435"""Call logic."""436hidden_states, attention_mask = inputs437438all_hidden_states = ()439all_attentions = ()440for _, layer_module in enumerate(self.layer):441if self.output_hidden_states:442all_hidden_states = all_hidden_states + (hidden_states,)443444layer_outputs = layer_module(445[hidden_states, attention_mask], training=training446)447hidden_states = layer_outputs[0]448449if self.output_attentions:450all_attentions = all_attentions + (layer_outputs[1],)451452# Add last layer453if self.output_hidden_states:454all_hidden_states = all_hidden_states + (hidden_states,)455456outputs = (hidden_states,)457if self.output_hidden_states:458outputs = outputs + (all_hidden_states,)459if self.output_attentions:460outputs = outputs + (all_attentions,)461return outputs # outputs, (hidden states), (attentions)462463464class TFFastSpeechDecoder(TFFastSpeechEncoder):465"""Fast Speech decoder module."""466467def __init__(self, config, **kwargs):468self.is_compatible_encoder = kwargs.pop("is_compatible_encoder", True)469470super().__init__(config, **kwargs)471self.config = config472473# create decoder positional embedding474self.decoder_positional_embeddings = TFEmbedding(475config.max_position_embeddings + 1,476config.hidden_size,477weights=[self._sincos_embedding()],478name="position_embeddings",479trainable=False,480)481482if self.is_compatible_encoder is False:483self.project_compatible_decoder = tf.keras.layers.Dense(484units=config.hidden_size, name="project_compatible_decoder"485)486487if config.n_speakers > 1:488self.decoder_speaker_embeddings = TFEmbedding(489config.n_speakers,490config.hidden_size,491embeddings_initializer=get_initializer(config.initializer_range),492name="speaker_embeddings",493)494self.speaker_fc = tf.keras.layers.Dense(495units=config.hidden_size, name="speaker_fc"496)497498def call(self, inputs, training=False):499hidden_states, speaker_ids, encoder_mask, decoder_pos = inputs500501if self.is_compatible_encoder is False:502hidden_states = self.project_compatible_decoder(hidden_states)503504# calculate new hidden states.505hidden_states += tf.cast(506self.decoder_positional_embeddings(decoder_pos), hidden_states.dtype507)508509if self.config.n_speakers > 1:510speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)511speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))512# extended speaker embeddings513extended_speaker_features = speaker_features[:, tf.newaxis, :]514hidden_states += extended_speaker_features515516return super().call([hidden_states, encoder_mask], training=training)517518def _sincos_embedding(self):519position_enc = np.array(520[521[522pos / np.power(10000, 2.0 * (i // 2) / self.config.hidden_size)523for i in range(self.config.hidden_size)524]525for pos in range(self.config.max_position_embeddings + 1)526]527)528529position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])530position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])531532# pad embedding.533position_enc[0] = 0.0534535return position_enc536537538class TFTacotronPostnet(tf.keras.layers.Layer):539"""Tacotron-2 postnet."""540541def __init__(self, config, **kwargs):542"""Init variables."""543super().__init__(**kwargs)544self.conv_batch_norm = []545for i in range(config.n_conv_postnet):546conv = tf.keras.layers.Conv1D(547filters=config.postnet_conv_filters548if i < config.n_conv_postnet - 1549else config.num_mels,550kernel_size=config.postnet_conv_kernel_sizes,551padding="same",552name="conv_._{}".format(i),553)554batch_norm = tf.keras.layers.BatchNormalization(555axis=-1, name="batch_norm_._{}".format(i)556)557self.conv_batch_norm.append((conv, batch_norm))558self.dropout = tf.keras.layers.Dropout(559rate=config.postnet_dropout_rate, name="dropout"560)561self.activation = [tf.nn.tanh] * (config.n_conv_postnet - 1) + [tf.identity]562563def call(self, inputs, training=False):564"""Call logic."""565outputs, mask = inputs566extended_mask = tf.cast(tf.expand_dims(mask, axis=2), outputs.dtype)567for i, (conv, bn) in enumerate(self.conv_batch_norm):568outputs = conv(outputs)569outputs = bn(outputs)570outputs = self.activation[i](outputs)571outputs = self.dropout(outputs, training=training)572return outputs * extended_mask573574575class TFFastSpeechDurationPredictor(tf.keras.layers.Layer):576"""FastSpeech duration predictor module."""577578def __init__(self, config, **kwargs):579"""Init variables."""580super().__init__(**kwargs)581self.conv_layers = []582for i in range(config.num_duration_conv_layers):583self.conv_layers.append(584tf.keras.layers.Conv1D(585config.duration_predictor_filters,586config.duration_predictor_kernel_sizes,587padding="same",588name="conv_._{}".format(i),589)590)591self.conv_layers.append(592tf.keras.layers.LayerNormalization(593epsilon=config.layer_norm_eps, name="LayerNorm_._{}".format(i)594)595)596self.conv_layers.append(tf.keras.layers.Activation(tf.nn.relu6))597self.conv_layers.append(598tf.keras.layers.Dropout(config.duration_predictor_dropout_probs)599)600self.conv_layers_sequence = tf.keras.Sequential(self.conv_layers)601self.output_layer = tf.keras.layers.Dense(1)602603def call(self, inputs, training=False):604"""Call logic."""605encoder_hidden_states, attention_mask = inputs606attention_mask = tf.cast(607tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype608)609610# mask encoder hidden states611masked_encoder_hidden_states = encoder_hidden_states * attention_mask612613# pass though first layer614outputs = self.conv_layers_sequence(masked_encoder_hidden_states)615outputs = self.output_layer(outputs)616masked_outputs = outputs * attention_mask617return tf.squeeze(tf.nn.relu6(masked_outputs), -1) # make sure positive value.618619620class TFFastSpeechLengthRegulator(tf.keras.layers.Layer):621"""FastSpeech lengthregulator module."""622623def __init__(self, config, **kwargs):624"""Init variables."""625self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)626super().__init__(**kwargs)627self.config = config628629def call(self, inputs, training=False):630"""Call logic.631Args:6321. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size]6332. durations_gt, Tensor (float32/int32) shape [batch_size, length]634"""635encoder_hidden_states, durations_gt = inputs636outputs, encoder_masks = self._length_regulator(637encoder_hidden_states, durations_gt638)639return outputs, encoder_masks640641def _length_regulator(self, encoder_hidden_states, durations_gt):642"""Length regulator logic."""643sum_durations = tf.reduce_sum(durations_gt, axis=-1) # [batch_size]644max_durations = tf.reduce_max(sum_durations)645646input_shape = tf.shape(encoder_hidden_states)647batch_size = input_shape[0]648hidden_size = input_shape[-1]649650# initialize output hidden states and encoder masking.651if self.enable_tflite_convertible:652# There is only 1 batch in inference, so we don't have to use653# `tf.While` op with 3-D output tensor.654repeats = durations_gt[0]655real_length = tf.reduce_sum(repeats)656pad_size = max_durations - real_length657# masks : [max_durations]658masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)659repeat_encoder_hidden_states = tf.repeat(660encoder_hidden_states[0], repeats=repeats, axis=0661)662repeat_encoder_hidden_states = tf.expand_dims(663tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0664) # [1, max_durations, hidden_size]665666outputs = repeat_encoder_hidden_states667encoder_masks = masks668else:669outputs = tf.zeros(670shape=[0, max_durations, hidden_size], dtype=encoder_hidden_states.dtype671)672encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)673674def condition(675i,676batch_size,677outputs,678encoder_masks,679encoder_hidden_states,680durations_gt,681max_durations,682):683return tf.less(i, batch_size)684685def body(686i,687batch_size,688outputs,689encoder_masks,690encoder_hidden_states,691durations_gt,692max_durations,693):694repeats = durations_gt[i]695real_length = tf.reduce_sum(repeats)696pad_size = max_durations - real_length697masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)698repeat_encoder_hidden_states = tf.repeat(699encoder_hidden_states[i], repeats=repeats, axis=0700)701repeat_encoder_hidden_states = tf.expand_dims(702tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0703) # [1, max_durations, hidden_size]704outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)705encoder_masks = tf.concat([encoder_masks, masks], axis=0)706return [707i + 1,708batch_size,709outputs,710encoder_masks,711encoder_hidden_states,712durations_gt,713max_durations,714]715716# initialize iteration i.717i = tf.constant(0, dtype=tf.int32)718_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(719condition,720body,721[722i,723batch_size,724outputs,725encoder_masks,726encoder_hidden_states,727durations_gt,728max_durations,729],730shape_invariants=[731i.get_shape(),732batch_size.get_shape(),733tf.TensorShape(734[735None,736None,737self.config.encoder_self_attention_params.hidden_size,738]739),740tf.TensorShape([None, None]),741encoder_hidden_states.get_shape(),742durations_gt.get_shape(),743max_durations.get_shape(),744],745)746747return outputs, encoder_masks748749750class TFFastSpeech(BaseModel):751"""TF Fastspeech module."""752753def __init__(self, config, **kwargs):754"""Init layers for fastspeech."""755self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)756super().__init__(**kwargs)757self.embeddings = TFFastSpeechEmbeddings(config, name="embeddings")758self.encoder = TFFastSpeechEncoder(759config.encoder_self_attention_params, name="encoder"760)761self.duration_predictor = TFFastSpeechDurationPredictor(762config, dtype=tf.float32, name="duration_predictor"763)764self.length_regulator = TFFastSpeechLengthRegulator(765config,766enable_tflite_convertible=self.enable_tflite_convertible,767name="length_regulator",768)769self.decoder = TFFastSpeechDecoder(770config.decoder_self_attention_params,771is_compatible_encoder=config.encoder_self_attention_params.hidden_size772== config.decoder_self_attention_params.hidden_size,773name="decoder",774)775self.mel_dense = tf.keras.layers.Dense(776units=config.num_mels, dtype=tf.float32, name="mel_before"777)778self.postnet = TFTacotronPostnet(779config=config, dtype=tf.float32, name="postnet"780)781782self.setup_inference_fn()783784def _build(self):785"""Dummy input for building model."""786# fake inputs787input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)788speaker_ids = tf.convert_to_tensor([0], tf.int32)789duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)790self(input_ids, speaker_ids, duration_gts)791792def resize_positional_embeddings(self, new_size):793self.embeddings.resize_positional_embeddings(new_size)794self._build()795796def call(797self, input_ids, speaker_ids, duration_gts, training=False, **kwargs,798):799"""Call logic."""800attention_mask = tf.math.not_equal(input_ids, 0)801embedding_output = self.embeddings([input_ids, speaker_ids], training=training)802encoder_output = self.encoder(803[embedding_output, attention_mask], training=training804)805last_encoder_hidden_states = encoder_output[0]806807# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers808# rather than just use last_hidden_states of encoder for duration_predictor.809duration_outputs = self.duration_predictor(810[last_encoder_hidden_states, attention_mask]811) # [batch_size, length]812813length_regulator_outputs, encoder_masks = self.length_regulator(814[last_encoder_hidden_states, duration_gts], training=training815)816817# create decoder positional embedding818decoder_pos = tf.range(8191, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32820)821masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks822823decoder_output = self.decoder(824[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],825training=training,826)827last_decoder_hidden_states = decoder_output[0]828829# here u can use sum or concat more than 1 hidden states layers from decoder.830mel_before = self.mel_dense(last_decoder_hidden_states)831mel_after = (832self.postnet([mel_before, encoder_masks], training=training) + mel_before833)834835outputs = (mel_before, mel_after, duration_outputs)836return outputs837838def _inference(self, input_ids, speaker_ids, speed_ratios, **kwargs):839"""Call logic."""840attention_mask = tf.math.not_equal(input_ids, 0)841embedding_output = self.embeddings([input_ids, speaker_ids], training=False)842encoder_output = self.encoder(843[embedding_output, attention_mask], training=False844)845last_encoder_hidden_states = encoder_output[0]846847# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers848# rather than just use last_hidden_states of encoder for duration_predictor.849duration_outputs = self.duration_predictor(850[last_encoder_hidden_states, attention_mask]851) # [batch_size, length]852duration_outputs = tf.math.exp(duration_outputs) - 1.0853854if speed_ratios is None:855speed_ratios = tf.convert_to_tensor(np.array([1.0]), dtype=tf.float32)856857speed_ratios = tf.expand_dims(speed_ratios, 1)858859duration_outputs = tf.cast(860tf.math.round(duration_outputs * speed_ratios), tf.int32861)862863length_regulator_outputs, encoder_masks = self.length_regulator(864[last_encoder_hidden_states, duration_outputs], training=False865)866867# create decoder positional embedding868decoder_pos = tf.range(8691, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32870)871masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks872873decoder_output = self.decoder(874[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],875training=False,876)877last_decoder_hidden_states = decoder_output[0]878879# here u can use sum or concat more than 1 hidden states layers from decoder.880mel_before = self.mel_dense(last_decoder_hidden_states)881mel_after = (882self.postnet([mel_before, encoder_masks], training=False) + mel_before883)884885outputs = (mel_before, mel_after, duration_outputs)886return outputs887888def setup_inference_fn(self):889self.inference = tf.function(890self._inference,891experimental_relax_shapes=True,892input_signature=[893tf.TensorSpec(shape=[None, None], dtype=tf.int32, name="input_ids"),894tf.TensorSpec(shape=[None,], dtype=tf.int32, name="speaker_ids"),895tf.TensorSpec(shape=[None,], dtype=tf.float32, name="speed_ratios"),896],897)898899self.inference_tflite = tf.function(900self._inference,901experimental_relax_shapes=True,902input_signature=[903tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="input_ids"),904tf.TensorSpec(shape=[1,], dtype=tf.int32, name="speaker_ids"),905tf.TensorSpec(shape=[1,], dtype=tf.float32, name="speed_ratios"),906],907)908909910