Path: blob/master/tensorflow_tts/models/parallel_wavegan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The TensorFlowTTS Team and Tomoki Hayashi (@kan-bayashi)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415"""Parallel-wavegan Modules. Based on pytorch implementation (https://github.com/kan-bayashi/ParallelWaveGAN)"""1617import tensorflow as tf1819from tensorflow_tts.models import BaseModel202122def get_initializer(initializer_seed=42):23"""Creates a `tf.initializers.he_normal` with the given seed.24Args:25initializer_seed: int, initializer seed.26Returns:27HeNormal initializer with seed = `initializer_seed`.28"""29return tf.keras.initializers.he_normal(seed=initializer_seed)303132class TFConv1d1x1(tf.keras.layers.Conv1D):33"""1x1 Conv1d with customized initialization."""3435def __init__(self, filters, use_bias, padding, initializer_seed, **kwargs):36"""Initialize 1x1 Conv1d module."""37super().__init__(38filters=filters,39kernel_size=1,40strides=1,41padding=padding,42dilation_rate=1,43use_bias=use_bias,44kernel_initializer=get_initializer(initializer_seed),45**kwargs,46)474849class TFConv1d(tf.keras.layers.Conv1D):50"""Conv1d with customized initialization."""5152def __init__(self, *args, **kwargs):53"""Initialize Conv1d module."""54initializer_seed = kwargs.pop("initializer_seed", 42)55super().__init__(56*args, **kwargs, kernel_initializer=get_initializer(initializer_seed)57)585960class TFResidualBlock(tf.keras.layers.Layer):61"""Residual block module in WaveNet."""6263def __init__(64self,65kernel_size=3,66residual_channels=64,67gate_channels=128,68skip_channels=64,69aux_channels=80,70dropout_rate=0.0,71dilation_rate=1,72use_bias=True,73use_causal_conv=False,74initializer_seed=42,75**kwargs,76):77"""Initialize ResidualBlock module.7879Args:80kernel_size (int): Kernel size of dilation convolution layer.81residual_channels (int): Number of channels for residual connection.82skip_channels (int): Number of channels for skip connection.83aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.84dropout_rate (float): Dropout probability.85dilation_rate (int): Dilation factor.86use_bias (bool): Whether to add bias parameter in convolution layers.87use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.88initializer_seed (int32): initializer seed.89"""90super().__init__(**kwargs)91self.dropout_rate = dropout_rate92# no future time stamps available93self.use_causal_conv = use_causal_conv9495# dilation conv96self.conv = TFConv1d(97filters=gate_channels,98kernel_size=kernel_size,99padding="same" if self.use_causal_conv is False else "causal",100strides=1,101dilation_rate=dilation_rate,102use_bias=use_bias,103initializer_seed=initializer_seed,104)105106# local conditionong107if aux_channels > 0:108self.conv1x1_aux = TFConv1d1x1(109gate_channels,110use_bias=False,111padding="same",112initializer_seed=initializer_seed,113name="conv1x1_aux",114)115else:116self.conv1x1_aux = None117118# conv output is split into two groups119gate_out_channels = gate_channels // 2120self.conv1x1_out = TFConv1d1x1(121residual_channels,122use_bias=use_bias,123padding="same",124initializer_seed=initializer_seed,125name="conv1x1_out",126)127self.conv1x1_skip = TFConv1d1x1(128skip_channels,129use_bias=use_bias,130padding="same",131initializer_seed=initializer_seed,132name="conv1x1_skip",133)134135self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)136137def call(self, x, c, training=False):138"""Calculate forward propagation.139140Args:141x (Tensor): Input tensor (B, residual_channels, T).142c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).143144Returns:145Tensor: Output tensor for residual connection (B, T, residual_channels).146Tensor: Output tensor for skip connection (B, T, skip_channels).147"""148residual = x149x = self.dropout(x, training=training)150x = self.conv(x)151152# split into two part for gated activation153xa, xb = tf.split(x, 2, axis=-1)154155# local conditioning156if c is not None:157assert self.conv1x1_aux is not None158c = self.conv1x1_aux(c)159ca, cb = tf.split(c, 2, axis=-1)160xa, xb = xa + ca, xb + cb161162x = tf.nn.tanh(xa) * tf.nn.sigmoid(xb)163164# for skip connection165s = self.conv1x1_skip(x)166167# for residual connection168x = self.conv1x1_out(x)169x = (x + residual) * tf.math.sqrt(0.5)170171return x, s172173174class TFStretch1d(tf.keras.layers.Layer):175"""Stretch2d module."""176177def __init__(self, x_scale, y_scale, method="nearest", **kwargs):178"""Initialize Stretch2d module.179180Args:181x_scale (int): X scaling factor (Time axis in spectrogram).182y_scale (int): Y scaling factor (Frequency axis in spectrogram).183method (str): Interpolation method.184185"""186super().__init__(**kwargs)187self.x_scale = x_scale188self.y_scale = y_scale189self.method = method190191def call(self, x):192"""Calculate forward propagation.193194Args:195x (Tensor): Input tensor (B, T, C, 1).196Returns:197Tensor: Interpolated tensor (B, T * x_scale, C * y_scale, 1)198199"""200x_shape = tf.shape(x)201new_size = (x_shape[1] * self.x_scale, x_shape[2] * self.y_scale)202x = tf.image.resize(x, method=self.method, size=new_size)203return x204205206class TFUpsampleNetWork(tf.keras.layers.Layer):207"""Upsampling network module."""208209def __init__(210self,211output_channels,212upsample_scales,213nonlinear_activation=None,214nonlinear_activation_params={},215interpolate_mode="nearest",216freq_axis_kernel_size=1,217use_causal_conv=False,218**kwargs,219):220"""Initialize upsampling network module.221222Args:223output_channels (int): output feature channels.224upsample_scales (list): List of upsampling scales.225nonlinear_activation (str): Activation function name.226nonlinear_activation_params (dict): Arguments for specified activation function.227interpolate_mode (str): Interpolation mode.228freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.229230"""231super().__init__(**kwargs)232self.use_causal_conv = use_causal_conv233self.up_layers = []234235for scale in upsample_scales:236# interpolation layer237stretch = TFStretch1d(238scale, 1, interpolate_mode, name="stretch_._{}".format(scale)239) # ->> outputs: [B, T * scale, C * 1, 1]240self.up_layers += [stretch]241242# conv layer243assert (244freq_axis_kernel_size - 1245) % 2 == 0, "Not support even number freq axis kernel size."246kernel_size = scale * 2 + 1247conv = tf.keras.layers.Conv2D(248filters=1,249kernel_size=(kernel_size, freq_axis_kernel_size),250padding="causal" if self.use_causal_conv is True else "same",251use_bias=False,252) # ->> outputs: [B, T * scale, C * 1, 1]253self.up_layers += [conv]254255# nonlinear256if nonlinear_activation is not None:257nonlinear = getattr(tf.keras.layers, nonlinear_activation)(258**nonlinear_activation_params259)260self.up_layers += [nonlinear]261262def call(self, c):263"""Calculate forward propagation.264Args:265c : Input tensor (B, T, C).266Returns:267Tensor: Upsampled tensor (B, T', C), where T' = T * prod(upsample_scales).268"""269c = tf.expand_dims(c, -1) # [B, T, C, 1]270for f in self.up_layers:271c = f(c)272return tf.squeeze(c, -1) # [B, T, C]273274275class TFConvInUpsampleNetWork(tf.keras.layers.Layer):276"""Convolution + upsampling network module."""277278def __init__(279self,280upsample_scales,281nonlinear_activation=None,282nonlinear_activation_params={},283interpolate_mode="nearest",284freq_axis_kernel_size=1,285aux_channels=80,286aux_context_window=0,287use_causal_conv=False,288initializer_seed=42,289**kwargs,290):291"""Initialize convolution + upsampling network module.292293Args:294upsample_scales (list): List of upsampling scales.295nonlinear_activation (str): Activation function name.296nonlinear_activation_params (dict): Arguments for specified activation function.297mode (str): Interpolation mode.298freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.299aux_channels (int): Number of channels of pre-convolutional layer.300aux_context_window (int): Context window size of the pre-convolutional layer.301use_causal_conv (bool): Whether to use causal structure.302303"""304super().__init__(**kwargs)305self.aux_context_window = aux_context_window306self.use_causal_conv = use_causal_conv and aux_context_window > 0307308# To capture wide-context information in conditional features309kernel_size = (310aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1311)312313self.conv_in = TFConv1d(314filters=aux_channels,315kernel_size=kernel_size,316padding="same",317use_bias=False,318initializer_seed=initializer_seed,319name="conv_in",320)321self.upsample = TFUpsampleNetWork(322output_channels=aux_channels,323upsample_scales=upsample_scales,324nonlinear_activation=nonlinear_activation,325nonlinear_activation_params=nonlinear_activation_params,326interpolate_mode=interpolate_mode,327freq_axis_kernel_size=freq_axis_kernel_size,328use_causal_conv=use_causal_conv,329name="upsample_network",330)331332def call(self, c):333"""Calculate forward propagation.334335Args:336c : Input tensor (B, T', C).337338Returns:339Tensor: Upsampled tensor (B, T, C),340where T = (T' - aux_context_window * 2) * prod(upsample_scales).341342Note:343The length of inputs considers the context window size.344"""345c_ = self.conv_in(c)346return self.upsample(c_)347348349class TFParallelWaveGANGenerator(BaseModel):350"""Parallel WaveGAN Generator module."""351352def __init__(self, config, **kwargs):353super().__init__(**kwargs)354self.out_channels = config.out_channels355self.aux_channels = config.aux_channels356self.n_layers = config.n_layers357self.stacks = config.stacks358self.kernel_size = config.kernel_size359self.upsample_params = config.upsample_params360361# check the number of layers and stacks362assert self.n_layers % self.stacks == 0363n_layers_per_stack = self.n_layers // self.stacks364365# define first convolution366self.first_conv = TFConv1d1x1(367filters=config.residual_channels,368use_bias=True,369padding="same",370initializer_seed=config.initializer_seed,371name="first_convolution",372)373374# define conv + upsampling network375if config.upsample_conditional_features:376self.upsample_params.update({"use_causal_conv": config.use_causal_conv})377self.upsample_params.update(378{379"aux_channels": config.aux_channels,380"aux_context_window": config.aux_context_window,381}382)383self.upsample_net = TFConvInUpsampleNetWork(**self.upsample_params)384else:385self.upsample_net = None386387# define residual blocks388self.conv_layers = []389for layer in range(self.n_layers):390dilation_rate = 2 ** (layer % n_layers_per_stack)391conv = TFResidualBlock(392kernel_size=config.kernel_size,393residual_channels=config.residual_channels,394gate_channels=config.gate_channels,395skip_channels=config.skip_channels,396aux_channels=config.aux_channels,397dilation_rate=dilation_rate,398dropout_rate=config.dropout_rate,399use_bias=config.use_bias,400use_causal_conv=config.use_causal_conv,401initializer_seed=config.initializer_seed,402name="residual_block_._{}".format(layer),403)404self.conv_layers += [conv]405406# define output layers407self.last_conv_layers = [408tf.keras.layers.ReLU(),409TFConv1d1x1(410filters=config.skip_channels,411use_bias=config.use_bias,412padding="same",413initializer_seed=config.initializer_seed,414),415tf.keras.layers.ReLU(),416TFConv1d1x1(417filters=config.out_channels,418use_bias=True,419padding="same",420initializer_seed=config.initializer_seed,421),422tf.keras.layers.Activation("tanh"),423]424425def _build(self):426mels = tf.random.uniform(shape=[2, 20, 80], dtype=tf.float32)427self(mels, training=tf.cast(True, tf.bool))428429def call(self, mels, training=False, **kwargs):430"""Calculate forward propagation.431432Args:433mels (Tensor): Local conditioning auxiliary features (B, T', C).434Returns:435436Tensor: Output tensor (B, T, 1)437"""438# perform upsampling439if mels is not None and self.upsample_net is not None:440c = self.upsample_net(mels)441442# random noise x443# enccode to hidden representation444x = tf.expand_dims(tf.random.normal(shape=tf.shape(c)[0:2]), axis=2)445x = self.first_conv(x)446skips = 0447for f in self.conv_layers:448x, h = f(x, c, training=training)449skips += h450skips *= tf.math.sqrt(1.0 / len(self.conv_layers))451452# apply final layers453x = skips454for f in self.last_conv_layers:455x = f(x)456457return x458459@tf.function(460experimental_relax_shapes=True,461input_signature=[462tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels"),463],464)465def inference(self, mels):466"""Calculate forward propagation.467468Args:469c (Tensor): Local conditioning auxiliary features (B, T', C).470Returns:471472Tensor: Output tensor (B, T, 1)473"""474# perform upsampling475if mels is not None and self.upsample_net is not None:476c = self.upsample_net(mels)477478# enccode to hidden representation479x = tf.expand_dims(tf.random.normal(shape=tf.shape(c)[0:2]), axis=2)480x = self.first_conv(x)481skips = 0482for f in self.conv_layers:483x, h = f(x, c, training=False)484skips += h485skips *= tf.math.sqrt(1.0 / len(self.conv_layers))486487# apply final layers488x = skips489for f in self.last_conv_layers:490x = f(x)491492return x493494495class TFParallelWaveGANDiscriminator(BaseModel):496"""Parallel WaveGAN Discriminator module."""497498def __init__(self, config, **kwargs):499super().__init__(**kwargs)500assert (config.kernel_size - 1) % 2 == 0, "Not support even number kernel size."501assert config.dilation_factor > 0, "Dilation factor must be > 0."502self.conv_layers = []503for i in range(config.n_layers - 1):504if i == 0:505dilation_rate = 1506else:507dilation_rate = (508i if config.dilation_factor == 1 else config.dilation_factor ** i509)510self.conv_layers += [511TFConv1d(512filters=config.conv_channels,513kernel_size=config.kernel_size,514padding="same",515dilation_rate=dilation_rate,516use_bias=config.use_bias,517initializer_seed=config.initializer_seed,518)519]520self.conv_layers += [521getattr(tf.keras.layers, config.nonlinear_activation)(522**config.nonlinear_activation_params523)524]525self.conv_layers += [526TFConv1d(527filters=config.out_channels,528kernel_size=config.kernel_size,529padding="same",530use_bias=config.use_bias,531initializer_seed=config.initializer_seed,532)533]534535if config.apply_sigmoid_at_last:536self.conv_layers += [537tf.keras.layers.Activation("sigmoid"),538]539540def _build(self):541x = tf.random.uniform(shape=[2, 16000, 1])542self(x)543544def call(self, x):545"""Calculate forward propagation.546547Args:548x (Tensor): Input noise signal (B, T, 1).549550Returns:551Tensor: Output tensor (B, T, 1)552"""553for f in self.conv_layers:554x = f(x)555return x556557558