Path: blob/master/tensorflow_tts/models/melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The MelGAN Authors and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""MelGAN Modules."""1516import numpy as np17import tensorflow as tf1819from tensorflow_tts.models import BaseModel20from tensorflow_tts.utils import GroupConv1D, WeightNormalization212223def get_initializer(initializer_seed=42):24"""Creates a `tf.initializers.glorot_normal` with the given seed.25Args:26initializer_seed: int, initializer seed.27Returns:28GlorotNormal initializer with seed = `initializer_seed`.29"""30return tf.keras.initializers.GlorotNormal(seed=initializer_seed)313233class TFReflectionPad1d(tf.keras.layers.Layer):34"""Tensorflow ReflectionPad1d module."""3536def __init__(self, padding_size, padding_type="REFLECT", **kwargs):37"""Initialize TFReflectionPad1d module.3839Args:40padding_size (int)41padding_type (str) ("CONSTANT", "REFLECT", or "SYMMETRIC". Default is "REFLECT")42"""43super().__init__(**kwargs)44self.padding_size = padding_size45self.padding_type = padding_type4647def call(self, x):48"""Calculate forward propagation.49Args:50x (Tensor): Input tensor (B, T, C).51Returns:52Tensor: Padded tensor (B, T + 2 * padding_size, C).53"""54return tf.pad(55x,56[[0, 0], [self.padding_size, self.padding_size], [0, 0]],57self.padding_type,58)596061class TFConvTranspose1d(tf.keras.layers.Layer):62"""Tensorflow ConvTranspose1d module."""6364def __init__(65self,66filters,67kernel_size,68strides,69padding,70is_weight_norm,71initializer_seed,72**kwargs73):74"""Initialize TFConvTranspose1d( module.75Args:76filters (int): Number of filters.77kernel_size (int): kernel size.78strides (int): Stride width.79padding (str): Padding type ("same" or "valid").80"""81super().__init__(**kwargs)82self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(83filters=filters,84kernel_size=(kernel_size, 1),85strides=(strides, 1),86padding="same",87kernel_initializer=get_initializer(initializer_seed),88)89if is_weight_norm:90self.conv1d_transpose = WeightNormalization(self.conv1d_transpose)9192def call(self, x):93"""Calculate forward propagation.94Args:95x (Tensor): Input tensor (B, T, C).96Returns:97Tensor: Output tensor (B, T', C').98"""99x = tf.expand_dims(x, 2)100x = self.conv1d_transpose(x)101x = tf.squeeze(x, 2)102return x103104105class TFResidualStack(tf.keras.layers.Layer):106"""Tensorflow ResidualStack module."""107108def __init__(109self,110kernel_size,111filters,112dilation_rate,113use_bias,114nonlinear_activation,115nonlinear_activation_params,116is_weight_norm,117initializer_seed,118**kwargs119):120"""Initialize TFResidualStack module.121Args:122kernel_size (int): Kernel size.123filters (int): Number of filters.124dilation_rate (int): Dilation rate.125use_bias (bool): Whether to add bias parameter in convolution layers.126nonlinear_activation (str): Activation function module name.127nonlinear_activation_params (dict): Hyperparameters for activation function.128"""129super().__init__(**kwargs)130self.blocks = [131getattr(tf.keras.layers, nonlinear_activation)(132**nonlinear_activation_params133),134TFReflectionPad1d((kernel_size - 1) // 2 * dilation_rate),135tf.keras.layers.Conv1D(136filters=filters,137kernel_size=kernel_size,138dilation_rate=dilation_rate,139use_bias=use_bias,140kernel_initializer=get_initializer(initializer_seed),141),142getattr(tf.keras.layers, nonlinear_activation)(143**nonlinear_activation_params144),145tf.keras.layers.Conv1D(146filters=filters,147kernel_size=1,148use_bias=use_bias,149kernel_initializer=get_initializer(initializer_seed),150),151]152self.shortcut = tf.keras.layers.Conv1D(153filters=filters,154kernel_size=1,155use_bias=use_bias,156kernel_initializer=get_initializer(initializer_seed),157name="shortcut",158)159160# apply weightnorm161if is_weight_norm:162self._apply_weightnorm(self.blocks)163self.shortcut = WeightNormalization(self.shortcut)164165def call(self, x):166"""Calculate forward propagation.167Args:168x (Tensor): Input tensor (B, T, C).169Returns:170Tensor: Output tensor (B, T, C).171"""172_x = tf.identity(x)173for layer in self.blocks:174_x = layer(_x)175shortcut = self.shortcut(x)176return shortcut + _x177178def _apply_weightnorm(self, list_layers):179"""Try apply weightnorm for all layer in list_layers."""180for i in range(len(list_layers)):181try:182layer_name = list_layers[i].name.lower()183if "conv1d" in layer_name or "dense" in layer_name:184list_layers[i] = WeightNormalization(list_layers[i])185except Exception:186pass187188189class TFMelGANGenerator(BaseModel):190"""Tensorflow MelGAN generator module."""191192def __init__(self, config, **kwargs):193"""Initialize TFMelGANGenerator module.194Args:195config: config object of Melgan generator.196"""197super().__init__(**kwargs)198199# check hyper parameter is valid or not200assert config.filters >= np.prod(config.upsample_scales)201assert config.filters % (2 ** len(config.upsample_scales)) == 0202203# add initial layer204layers = []205layers += [206TFReflectionPad1d(207(config.kernel_size - 1) // 2,208padding_type=config.padding_type,209name="first_reflect_padding",210),211tf.keras.layers.Conv1D(212filters=config.filters,213kernel_size=config.kernel_size,214use_bias=config.use_bias,215kernel_initializer=get_initializer(config.initializer_seed),216),217]218219for i, upsample_scale in enumerate(config.upsample_scales):220# add upsampling layer221layers += [222getattr(tf.keras.layers, config.nonlinear_activation)(223**config.nonlinear_activation_params224),225TFConvTranspose1d(226filters=config.filters // (2 ** (i + 1)),227kernel_size=upsample_scale * 2,228strides=upsample_scale,229padding="same",230is_weight_norm=config.is_weight_norm,231initializer_seed=config.initializer_seed,232name="conv_transpose_._{}".format(i),233),234]235236# ad residual stack layer237for j in range(config.stacks):238layers += [239TFResidualStack(240kernel_size=config.stack_kernel_size,241filters=config.filters // (2 ** (i + 1)),242dilation_rate=config.stack_kernel_size ** j,243use_bias=config.use_bias,244nonlinear_activation=config.nonlinear_activation,245nonlinear_activation_params=config.nonlinear_activation_params,246is_weight_norm=config.is_weight_norm,247initializer_seed=config.initializer_seed,248name="residual_stack_._{}._._{}".format(i, j),249)250]251# add final layer252layers += [253getattr(tf.keras.layers, config.nonlinear_activation)(254**config.nonlinear_activation_params255),256TFReflectionPad1d(257(config.kernel_size - 1) // 2,258padding_type=config.padding_type,259name="last_reflect_padding",260),261tf.keras.layers.Conv1D(262filters=config.out_channels,263kernel_size=config.kernel_size,264use_bias=config.use_bias,265kernel_initializer=get_initializer(config.initializer_seed),266dtype=tf.float32,267),268]269if config.use_final_nolinear_activation:270layers += [tf.keras.layers.Activation("tanh", dtype=tf.float32)]271272if config.is_weight_norm is True:273self._apply_weightnorm(layers)274275self.melgan = tf.keras.models.Sequential(layers)276277def call(self, mels, **kwargs):278"""Calculate forward propagation.279Args:280c (Tensor): Input tensor (B, T, channels)281Returns:282Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels)283"""284return self.inference(mels)285286@tf.function(287input_signature=[288tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels")289]290)291def inference(self, mels):292return self.melgan(mels)293294@tf.function(295input_signature=[296tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels")297]298)299def inference_tflite(self, mels):300return self.melgan(mels)301302def _apply_weightnorm(self, list_layers):303"""Try apply weightnorm for all layer in list_layers."""304for i in range(len(list_layers)):305try:306layer_name = list_layers[i].name.lower()307if "conv1d" in layer_name or "dense" in layer_name:308list_layers[i] = WeightNormalization(list_layers[i])309except Exception:310pass311312def _build(self):313"""Build model by passing fake input."""314fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)315self(fake_mels)316317318class TFMelGANDiscriminator(tf.keras.layers.Layer):319"""Tensorflow MelGAN generator module."""320321def __init__(322self,323out_channels=1,324kernel_sizes=[5, 3],325filters=16,326max_downsample_filters=1024,327use_bias=True,328downsample_scales=[4, 4, 4, 4],329nonlinear_activation="LeakyReLU",330nonlinear_activation_params={"alpha": 0.2},331padding_type="REFLECT",332is_weight_norm=True,333initializer_seed=0.02,334**kwargs335):336"""Initilize MelGAN discriminator module.337Args:338out_channels (int): Number of output channels.339kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,340and the first and the second kernel sizes will be used for the last two layers.341For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15.342the last two layers' kernel size will be 5 and 3, respectively.343filters (int): Initial number of filters for conv layer.344max_downsample_filters (int): Maximum number of filters for downsampling layers.345use_bias (bool): Whether to add bias parameter in convolution layers.346downsample_scales (list): List of downsampling scales.347nonlinear_activation (str): Activation function module name.348nonlinear_activation_params (dict): Hyperparameters for activation function.349padding_type (str): Padding type (support only "REFLECT", "CONSTANT", "SYMMETRIC")350"""351super().__init__(**kwargs)352discriminator = []353354# check kernel_size is valid355assert len(kernel_sizes) == 2356assert kernel_sizes[0] % 2 == 1357assert kernel_sizes[1] % 2 == 1358359# add first layer360discriminator = [361TFReflectionPad1d(362(np.prod(kernel_sizes) - 1) // 2, padding_type=padding_type363),364tf.keras.layers.Conv1D(365filters=filters,366kernel_size=int(np.prod(kernel_sizes)),367use_bias=use_bias,368kernel_initializer=get_initializer(initializer_seed),369),370getattr(tf.keras.layers, nonlinear_activation)(371**nonlinear_activation_params372),373]374375# add downsample layers376in_chs = filters377with tf.keras.utils.CustomObjectScope({"GroupConv1D": GroupConv1D}):378for downsample_scale in downsample_scales:379out_chs = min(in_chs * downsample_scale, max_downsample_filters)380discriminator += [381GroupConv1D(382filters=out_chs,383kernel_size=downsample_scale * 10 + 1,384strides=downsample_scale,385padding="same",386use_bias=use_bias,387groups=in_chs // 4,388kernel_initializer=get_initializer(initializer_seed),389)390]391discriminator += [392getattr(tf.keras.layers, nonlinear_activation)(393**nonlinear_activation_params394)395]396in_chs = out_chs397398# add final layers399out_chs = min(in_chs * 2, max_downsample_filters)400discriminator += [401tf.keras.layers.Conv1D(402filters=out_chs,403kernel_size=kernel_sizes[0],404padding="same",405use_bias=use_bias,406kernel_initializer=get_initializer(initializer_seed),407)408]409discriminator += [410getattr(tf.keras.layers, nonlinear_activation)(411**nonlinear_activation_params412)413]414discriminator += [415tf.keras.layers.Conv1D(416filters=out_channels,417kernel_size=kernel_sizes[1],418padding="same",419use_bias=use_bias,420kernel_initializer=get_initializer(initializer_seed),421)422]423424if is_weight_norm is True:425self._apply_weightnorm(discriminator)426427self.disciminator = discriminator428429def call(self, x, **kwargs):430"""Calculate forward propagation.431Args:432x (Tensor): Input noise signal (B, T, 1).433Returns:434List: List of output tensors of each layer.435"""436outs = []437for f in self.disciminator:438x = f(x)439outs += [x]440return outs441442def _apply_weightnorm(self, list_layers):443"""Try apply weightnorm for all layer in list_layers."""444for i in range(len(list_layers)):445try:446layer_name = list_layers[i].name.lower()447if "conv1d" in layer_name or "dense" in layer_name:448list_layers[i] = WeightNormalization(list_layers[i])449except Exception:450pass451452453class TFMelGANMultiScaleDiscriminator(BaseModel):454"""MelGAN multi-scale discriminator module."""455456def __init__(self, config, **kwargs):457"""Initilize MelGAN multi-scale discriminator module.458Args:459config: config object for melgan discriminator460"""461super().__init__(**kwargs)462self.discriminator = []463464# add discriminator465for i in range(config.scales):466self.discriminator += [467TFMelGANDiscriminator(468out_channels=config.out_channels,469kernel_sizes=config.kernel_sizes,470filters=config.filters,471max_downsample_filters=config.max_downsample_filters,472use_bias=config.use_bias,473downsample_scales=config.downsample_scales,474nonlinear_activation=config.nonlinear_activation,475nonlinear_activation_params=config.nonlinear_activation_params,476padding_type=config.padding_type,477is_weight_norm=config.is_weight_norm,478initializer_seed=config.initializer_seed,479name="melgan_discriminator_scale_._{}".format(i),480)481]482self.pooling = getattr(tf.keras.layers, config.downsample_pooling)(483**config.downsample_pooling_params484)485486def call(self, x, **kwargs):487"""Calculate forward propagation.488Args:489x (Tensor): Input noise signal (B, T, 1).490Returns:491List: List of list of each discriminator outputs, which consists of each layer output tensors.492"""493outs = []494for f in self.discriminator:495outs += [f(x)]496x = self.pooling(x)497return outs498499500