Path: blob/master/tensorflow_tts/models/hifigan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The Hifigan Authors and TensorflowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Hifi Modules."""1516import numpy as np17import tensorflow as tf1819from tensorflow_tts.models.melgan import TFReflectionPad1d20from tensorflow_tts.models.melgan import TFConvTranspose1d2122from tensorflow_tts.utils import GroupConv1D23from tensorflow_tts.utils import WeightNormalization2425from tensorflow_tts.models import BaseModel26from tensorflow_tts.models import TFMelGANGenerator272829class TFHifiResBlock(tf.keras.layers.Layer):30"""Tensorflow Hifigan resblock 1 module."""3132def __init__(33self,34kernel_size,35filters,36dilation_rate,37use_bias,38nonlinear_activation,39nonlinear_activation_params,40is_weight_norm,41initializer_seed,42**kwargs43):44"""Initialize TFHifiResBlock module.45Args:46kernel_size (int): Kernel size.47filters (int): Number of filters.48dilation_rate (list): List dilation rate.49use_bias (bool): Whether to add bias parameter in convolution layers.50nonlinear_activation (str): Activation function module name.51nonlinear_activation_params (dict): Hyperparameters for activation function.52is_weight_norm (bool): Whether to use weight norm or not.53"""54super().__init__(**kwargs)55self.blocks_1 = []56self.blocks_2 = []5758for i in range(len(dilation_rate)):59self.blocks_1.append(60[61TFReflectionPad1d((kernel_size - 1) // 2 * dilation_rate[i]),62tf.keras.layers.Conv1D(63filters=filters,64kernel_size=kernel_size,65dilation_rate=dilation_rate[i],66use_bias=use_bias,67),68]69)70self.blocks_2.append(71[72TFReflectionPad1d((kernel_size - 1) // 2 * 1),73tf.keras.layers.Conv1D(74filters=filters,75kernel_size=kernel_size,76dilation_rate=1,77use_bias=use_bias,78),79]80)8182self.activation = getattr(tf.keras.layers, nonlinear_activation)(83**nonlinear_activation_params84)8586# apply weightnorm87if is_weight_norm:88self._apply_weightnorm(self.blocks_1)89self._apply_weightnorm(self.blocks_2)9091def call(self, x, training=False):92"""Calculate forward propagation.93Args:94x (Tensor): Input tensor (B, T, C).95Returns:96Tensor: Output tensor (B, T, C).97"""98for c1, c2 in zip(self.blocks_1, self.blocks_2):99xt = self.activation(x)100for c in c1:101xt = c(xt)102xt = self.activation(xt)103for c in c2:104xt = c(xt)105x = xt + x106return x107108def _apply_weightnorm(self, list_layers):109"""Try apply weightnorm for all layer in list_layers."""110for i in range(len(list_layers)):111try:112layer_name = list_layers[i].name.lower()113if "conv1d" in layer_name or "dense" in layer_name:114list_layers[i] = WeightNormalization(list_layers[i])115except Exception:116pass117118119class TFMultiHifiResBlock(tf.keras.layers.Layer):120"""Tensorflow Multi Hifigan resblock 1 module."""121122def __init__(self, list_resblock, **kwargs):123super().__init__(**kwargs)124self.list_resblock = list_resblock125126def call(self, x, training=False):127xs = None128for resblock in self.list_resblock:129if xs is None:130xs = resblock(x, training=training)131else:132xs += resblock(x, training=training)133return xs / len(self.list_resblock)134135136class TFHifiGANGenerator(BaseModel):137def __init__(self, config, **kwargs):138super().__init__(**kwargs)139# check hyper parameter is valid or not140assert (141config.stacks142== len(config.stack_kernel_size)143== len(config.stack_dilation_rate)144)145146# add initial layer147layers = []148layers += [149TFReflectionPad1d(150(config.kernel_size - 1) // 2,151padding_type=config.padding_type,152name="first_reflect_padding",153),154tf.keras.layers.Conv1D(155filters=config.filters,156kernel_size=config.kernel_size,157use_bias=config.use_bias,158),159]160161for i, upsample_scale in enumerate(config.upsample_scales):162# add upsampling layer163layers += [164getattr(tf.keras.layers, config.nonlinear_activation)(165**config.nonlinear_activation_params166),167TFConvTranspose1d(168filters=config.filters // (2 ** (i + 1)),169kernel_size=upsample_scale * 2,170strides=upsample_scale,171padding="same",172is_weight_norm=config.is_weight_norm,173initializer_seed=config.initializer_seed,174name="conv_transpose_._{}".format(i),175),176]177178# add residual stack layer179layers += [180TFMultiHifiResBlock(181list_resblock=[182TFHifiResBlock(183kernel_size=config.stack_kernel_size[j],184filters=config.filters // (2 ** (i + 1)),185dilation_rate=config.stack_dilation_rate[j],186use_bias=config.use_bias,187nonlinear_activation=config.nonlinear_activation,188nonlinear_activation_params=config.nonlinear_activation_params,189is_weight_norm=config.is_weight_norm,190initializer_seed=config.initializer_seed,191name="hifigan_resblock_._{}".format(j),192)193for j in range(config.stacks)194],195name="multi_hifigan_resblock_._{}".format(i),196)197]198# add final layer199layers += [200getattr(tf.keras.layers, config.nonlinear_activation)(201**config.nonlinear_activation_params202),203TFReflectionPad1d(204(config.kernel_size - 1) // 2,205padding_type=config.padding_type,206name="last_reflect_padding",207),208tf.keras.layers.Conv1D(209filters=config.out_channels,210kernel_size=config.kernel_size,211use_bias=config.use_bias,212dtype=tf.float32,213),214]215if config.use_final_nolinear_activation:216layers += [tf.keras.layers.Activation("tanh", dtype=tf.float32)]217218if config.is_weight_norm is True:219self._apply_weightnorm(layers)220221self.hifigan = tf.keras.models.Sequential(layers)222223def call(self, mels, **kwargs):224"""Calculate forward propagation.225Args:226c (Tensor): Input tensor (B, T, channels)227Returns:228Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels)229"""230return self.inference(mels)231232@tf.function(233input_signature=[234tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels")235]236)237def inference(self, mels):238return self.hifigan(mels)239240@tf.function(241input_signature=[242tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels")243]244)245def inference_tflite(self, mels):246return self.hifigan(mels)247248def _apply_weightnorm(self, list_layers):249"""Try apply weightnorm for all layer in list_layers."""250for i in range(len(list_layers)):251try:252layer_name = list_layers[i].name.lower()253if "conv1d" in layer_name or "dense" in layer_name:254list_layers[i] = WeightNormalization(list_layers[i])255except Exception:256pass257258def _build(self):259"""Build model by passing fake input."""260fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)261self(fake_mels)262263264class TFHifiGANPeriodDiscriminator(tf.keras.layers.Layer):265"""Tensorflow Hifigan period discriminator module."""266267def __init__(268self,269period,270out_channels=1,271n_layers=5,272kernel_size=5,273strides=3,274filters=8,275filter_scales=4,276max_filters=1024,277nonlinear_activation="LeakyReLU",278nonlinear_activation_params={"alpha": 0.2},279initializer_seed=42,280is_weight_norm=False,281**kwargs282):283super().__init__(**kwargs)284self.period = period285self.out_filters = out_channels286self.convs = []287288for i in range(n_layers):289self.convs.append(290tf.keras.layers.Conv2D(291filters=min(filters * (filter_scales ** (i + 1)), max_filters),292kernel_size=(kernel_size, 1),293strides=(strides, 1),294padding="same",295)296)297self.conv_post = tf.keras.layers.Conv2D(298filters=out_channels, kernel_size=(3, 1), padding="same",299)300self.activation = getattr(tf.keras.layers, nonlinear_activation)(301**nonlinear_activation_params302)303304if is_weight_norm:305self._apply_weightnorm(self.convs)306self.conv_post = WeightNormalization(self.conv_post)307308def call(self, x):309"""Calculate forward propagation.310Args:311x (Tensor): Input noise signal (B, T, 1).312Returns:313List: List of output tensors.314"""315shape = tf.shape(x)316n_pad = tf.convert_to_tensor(0, dtype=tf.int32)317if shape[1] % self.period != 0:318n_pad = self.period - (shape[1] % self.period)319x = tf.pad(x, [[0, 0], [0, n_pad], [0, 0]], "REFLECT")320x = tf.reshape(321x, [shape[0], (shape[1] + n_pad) // self.period, self.period, x.shape[2]]322)323for layer in self.convs:324x = layer(x)325x = self.activation(x)326x = self.conv_post(x)327x = tf.reshape(x, [shape[0], -1, self.out_filters])328return [x]329330def _apply_weightnorm(self, list_layers):331"""Try apply weightnorm for all layer in list_layers."""332for i in range(len(list_layers)):333try:334layer_name = list_layers[i].name.lower()335if "conv1d" in layer_name or "dense" in layer_name:336list_layers[i] = WeightNormalization(list_layers[i])337except Exception:338pass339340341class TFHifiGANMultiPeriodDiscriminator(BaseModel):342"""Tensorflow Hifigan Multi Period discriminator module."""343344def __init__(self, config, **kwargs):345super().__init__(**kwargs)346self.discriminator = []347348# add discriminator349for i in range(len(config.period_scales)):350self.discriminator += [351TFHifiGANPeriodDiscriminator(352config.period_scales[i],353out_channels=config.out_channels,354n_layers=config.n_layers,355kernel_size=config.kernel_size,356strides=config.strides,357filters=config.filters,358filter_scales=config.filter_scales,359max_filters=config.max_filters,360nonlinear_activation=config.nonlinear_activation,361nonlinear_activation_params=config.nonlinear_activation_params,362initializer_seed=config.initializer_seed,363is_weight_norm=config.is_weight_norm,364name="hifigan_period_discriminator_._{}".format(i),365)366]367368def call(self, x):369"""Calculate forward propagation.370Args:371x (Tensor): Input noise signal (B, T, 1).372Returns:373List: list of each discriminator outputs374"""375outs = []376for f in self.discriminator:377outs += [f(x)]378return outs379380381