Path: blob/master/tensorflow_tts/models/fastspeech2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The FastSpeech2 Authors and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tensorflow Model modules for FastSpeech2."""1516import tensorflow as tf1718from tensorflow_tts.models.fastspeech import TFFastSpeech, get_initializer192021class TFFastSpeechVariantPredictor(tf.keras.layers.Layer):22"""FastSpeech duration predictor module."""2324def __init__(self, config, **kwargs):25"""Init variables."""26super().__init__(**kwargs)27self.conv_layers = []28for i in range(config.variant_prediction_num_conv_layers):29self.conv_layers.append(30tf.keras.layers.Conv1D(31config.variant_predictor_filter,32config.variant_predictor_kernel_size,33padding="same",34name="conv_._{}".format(i),35)36)37self.conv_layers.append(tf.keras.layers.Activation(tf.nn.relu))38self.conv_layers.append(39tf.keras.layers.LayerNormalization(40epsilon=config.layer_norm_eps, name="LayerNorm_._{}".format(i)41)42)43self.conv_layers.append(44tf.keras.layers.Dropout(config.variant_predictor_dropout_rate)45)46self.conv_layers_sequence = tf.keras.Sequential(self.conv_layers)47self.output_layer = tf.keras.layers.Dense(1)4849if config.n_speakers > 1:50self.decoder_speaker_embeddings = tf.keras.layers.Embedding(51config.n_speakers,52config.encoder_self_attention_params.hidden_size,53embeddings_initializer=get_initializer(config.initializer_range),54name="speaker_embeddings",55)56self.speaker_fc = tf.keras.layers.Dense(57units=config.encoder_self_attention_params.hidden_size,58name="speaker_fc",59)6061self.config = config6263def call(self, inputs, training=False):64"""Call logic."""65encoder_hidden_states, speaker_ids, attention_mask = inputs66attention_mask = tf.cast(67tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype68)6970if self.config.n_speakers > 1:71speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)72speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))73# extended speaker embeddings74extended_speaker_features = speaker_features[:, tf.newaxis, :]75encoder_hidden_states += extended_speaker_features7677# mask encoder hidden states78masked_encoder_hidden_states = encoder_hidden_states * attention_mask7980# pass though first layer81outputs = self.conv_layers_sequence(masked_encoder_hidden_states)82outputs = self.output_layer(outputs)83masked_outputs = outputs * attention_mask8485outputs = tf.squeeze(masked_outputs, -1)86return outputs878889class TFFastSpeech2(TFFastSpeech):90"""TF Fastspeech module."""9192def __init__(self, config, **kwargs):93"""Init layers for fastspeech."""94super().__init__(config, **kwargs)95self.f0_predictor = TFFastSpeechVariantPredictor(96config, dtype=tf.float32, name="f0_predictor"97)98self.energy_predictor = TFFastSpeechVariantPredictor(99config, dtype=tf.float32, name="energy_predictor",100)101self.duration_predictor = TFFastSpeechVariantPredictor(102config, dtype=tf.float32, name="duration_predictor"103)104105# define f0_embeddings and energy_embeddings106self.f0_embeddings = tf.keras.layers.Conv1D(107filters=config.encoder_self_attention_params.hidden_size,108kernel_size=9,109padding="same",110name="f0_embeddings",111)112self.f0_dropout = tf.keras.layers.Dropout(0.5)113self.energy_embeddings = tf.keras.layers.Conv1D(114filters=config.encoder_self_attention_params.hidden_size,115kernel_size=9,116padding="same",117name="energy_embeddings",118)119self.energy_dropout = tf.keras.layers.Dropout(0.5)120121def _build(self):122"""Dummy input for building model."""123# fake inputs124input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)125speaker_ids = tf.convert_to_tensor([0], tf.int32)126duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)127f0_gts = tf.convert_to_tensor(128[[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]], tf.float32129)130energy_gts = tf.convert_to_tensor(131[[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]], tf.float32132)133self(134input_ids=input_ids,135speaker_ids=speaker_ids,136duration_gts=duration_gts,137f0_gts=f0_gts,138energy_gts=energy_gts,139)140141def call(142self,143input_ids,144speaker_ids,145duration_gts,146f0_gts,147energy_gts,148training=False,149**kwargs,150):151"""Call logic."""152attention_mask = tf.math.not_equal(input_ids, 0)153embedding_output = self.embeddings([input_ids, speaker_ids], training=training)154encoder_output = self.encoder(155[embedding_output, attention_mask], training=training156)157last_encoder_hidden_states = encoder_output[0]158159# energy predictor, here use last_encoder_hidden_states, u can use more hidden_states layers160# rather than just use last_hidden_states of encoder for energy_predictor.161duration_outputs = self.duration_predictor(162[last_encoder_hidden_states, speaker_ids, attention_mask]163) # [batch_size, length]164165f0_outputs = self.f0_predictor(166[last_encoder_hidden_states, speaker_ids, attention_mask], training=training167)168energy_outputs = self.energy_predictor(169[last_encoder_hidden_states, speaker_ids, attention_mask], training=training170)171172f0_embedding = self.f0_embeddings(173tf.expand_dims(f0_gts, 2)174) # [barch_size, mel_length, feature]175energy_embedding = self.energy_embeddings(176tf.expand_dims(energy_gts, 2)177) # [barch_size, mel_length, feature]178179# apply dropout both training/inference180f0_embedding = self.f0_dropout(f0_embedding, training=True)181energy_embedding = self.energy_dropout(energy_embedding, training=True)182183# sum features184last_encoder_hidden_states += f0_embedding + energy_embedding185186length_regulator_outputs, encoder_masks = self.length_regulator(187[last_encoder_hidden_states, duration_gts], training=training188)189190# create decoder positional embedding191decoder_pos = tf.range(1921, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32193)194masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks195196decoder_output = self.decoder(197[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],198training=training,199)200last_decoder_hidden_states = decoder_output[0]201202# here u can use sum or concat more than 1 hidden states layers from decoder.203mels_before = self.mel_dense(last_decoder_hidden_states)204mels_after = (205self.postnet([mels_before, encoder_masks], training=training) + mels_before206)207208outputs = (209mels_before,210mels_after,211duration_outputs,212f0_outputs,213energy_outputs,214)215return outputs216217def _inference(218self, input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios, **kwargs,219):220"""Call logic."""221attention_mask = tf.math.not_equal(input_ids, 0)222embedding_output = self.embeddings([input_ids, speaker_ids], training=False)223encoder_output = self.encoder(224[embedding_output, attention_mask], training=False225)226last_encoder_hidden_states = encoder_output[0]227228# expand ratios229speed_ratios = tf.expand_dims(speed_ratios, 1) # [B, 1]230f0_ratios = tf.expand_dims(f0_ratios, 1) # [B, 1]231energy_ratios = tf.expand_dims(energy_ratios, 1) # [B, 1]232233# energy predictor, here use last_encoder_hidden_states, u can use more hidden_states layers234# rather than just use last_hidden_states of encoder for energy_predictor.235duration_outputs = self.duration_predictor(236[last_encoder_hidden_states, speaker_ids, attention_mask]237) # [batch_size, length]238duration_outputs = tf.nn.relu(tf.math.exp(duration_outputs) - 1.0)239duration_outputs = tf.cast(240tf.math.round(duration_outputs * speed_ratios), tf.int32241)242243f0_outputs = self.f0_predictor(244[last_encoder_hidden_states, speaker_ids, attention_mask], training=False245)246f0_outputs *= f0_ratios247248energy_outputs = self.energy_predictor(249[last_encoder_hidden_states, speaker_ids, attention_mask], training=False250)251energy_outputs *= energy_ratios252253f0_embedding = self.f0_dropout(254self.f0_embeddings(tf.expand_dims(f0_outputs, 2)), training=True255)256energy_embedding = self.energy_dropout(257self.energy_embeddings(tf.expand_dims(energy_outputs, 2)), training=True258)259260# sum features261last_encoder_hidden_states += f0_embedding + energy_embedding262263length_regulator_outputs, encoder_masks = self.length_regulator(264[last_encoder_hidden_states, duration_outputs], training=False265)266267# create decoder positional embedding268decoder_pos = tf.range(2691, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32270)271masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks272273decoder_output = self.decoder(274[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],275training=False,276)277last_decoder_hidden_states = decoder_output[0]278279# here u can use sum or concat more than 1 hidden states layers from decoder.280mel_before = self.mel_dense(last_decoder_hidden_states)281mel_after = (282self.postnet([mel_before, encoder_masks], training=False) + mel_before283)284285outputs = (mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs)286return outputs287288def setup_inference_fn(self):289self.inference = tf.function(290self._inference,291experimental_relax_shapes=True,292input_signature=[293tf.TensorSpec(shape=[None, None], dtype=tf.int32, name="input_ids"),294tf.TensorSpec(shape=[None,], dtype=tf.int32, name="speaker_ids"),295tf.TensorSpec(shape=[None,], dtype=tf.float32, name="speed_ratios"),296tf.TensorSpec(shape=[None,], dtype=tf.float32, name="f0_ratios"),297tf.TensorSpec(shape=[None,], dtype=tf.float32, name="energy_ratios"),298],299)300301self.inference_tflite = tf.function(302self._inference,303experimental_relax_shapes=True,304input_signature=[305tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="input_ids"),306tf.TensorSpec(shape=[1,], dtype=tf.int32, name="speaker_ids"),307tf.TensorSpec(shape=[1,], dtype=tf.float32, name="speed_ratios"),308tf.TensorSpec(shape=[1,], dtype=tf.float32, name="f0_ratios"),309tf.TensorSpec(shape=[1,], dtype=tf.float32, name="energy_ratios"),310],311)312313314