Path: blob/master/tensorflow_tts/models/mb_melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The Multi-band MelGAN Authors , Minh Nguyen (@dathudeptrai) and Tomoki Hayashi (@kan-bayashi)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ============================================================================15#16# Compatible with https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/layers/pqmf.py.17"""Multi-band MelGAN Modules."""1819import numpy as np20import tensorflow as tf21from scipy.signal import kaiser2223from tensorflow_tts.models import BaseModel24from tensorflow_tts.models import TFMelGANGenerator252627def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):28"""Design prototype filter for PQMF.29This method is based on `A Kaiser window approach for the design of prototype30filters of cosine modulated filterbanks`_.31Args:32taps (int): The number of filter taps.33cutoff_ratio (float): Cut-off frequency ratio.34beta (float): Beta coefficient for kaiser window.35Returns:36ndarray: Impluse response of prototype filter (taps + 1,).37.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:38https://ieeexplore.ieee.org/abstract/document/68142739"""40# check the arguments are valid41assert taps % 2 == 0, "The number of taps mush be even number."42assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."4344# make initial filter45omega_c = np.pi * cutoff_ratio46with np.errstate(invalid="ignore"):47h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (48np.pi * (np.arange(taps + 1) - 0.5 * taps)49)50# fix nan due to indeterminate form51h_i[taps // 2] = np.cos(0) * cutoff_ratio5253# apply kaiser window54w = kaiser(taps + 1, beta)55h = h_i * w5657return h585960class TFPQMF(tf.keras.layers.Layer):61"""PQMF module."""6263def __init__(self, config, **kwargs):64"""Initilize PQMF module.65Args:66config (class): MultiBandMelGANGeneratorConfig67"""68super().__init__(**kwargs)69subbands = config.subbands70taps = config.taps71cutoff_ratio = config.cutoff_ratio72beta = config.beta7374# define filter coefficient75h_proto = design_prototype_filter(taps, cutoff_ratio, beta)76h_analysis = np.zeros((subbands, len(h_proto)))77h_synthesis = np.zeros((subbands, len(h_proto)))78for k in range(subbands):79h_analysis[k] = (80281* h_proto82* np.cos(83(2 * k + 1)84* (np.pi / (2 * subbands))85* (np.arange(taps + 1) - (taps / 2))86+ (-1) ** k * np.pi / 487)88)89h_synthesis[k] = (90291* h_proto92* np.cos(93(2 * k + 1)94* (np.pi / (2 * subbands))95* (np.arange(taps + 1) - (taps / 2))96- (-1) ** k * np.pi / 497)98)99100# [subbands, 1, taps + 1] == [filter_width, in_channels, out_channels]101analysis_filter = np.expand_dims(h_analysis, 1)102analysis_filter = np.transpose(analysis_filter, (2, 1, 0))103104synthesis_filter = np.expand_dims(h_synthesis, 0)105synthesis_filter = np.transpose(synthesis_filter, (2, 1, 0))106107# filter for downsampling & upsampling108updown_filter = np.zeros((subbands, subbands, subbands), dtype=np.float32)109for k in range(subbands):110updown_filter[0, k, k] = 1.0111112self.subbands = subbands113self.taps = taps114self.analysis_filter = analysis_filter.astype(np.float32)115self.synthesis_filter = synthesis_filter.astype(np.float32)116self.updown_filter = updown_filter.astype(np.float32)117118@tf.function(119experimental_relax_shapes=True,120input_signature=[tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)],121)122def analysis(self, x):123"""Analysis with PQMF.124Args:125x (Tensor): Input tensor (B, T, 1).126Returns:127Tensor: Output tensor (B, T // subbands, subbands).128"""129x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])130x = tf.nn.conv1d(x, self.analysis_filter, stride=1, padding="VALID")131x = tf.nn.conv1d(x, self.updown_filter, stride=self.subbands, padding="VALID")132return x133134@tf.function(135experimental_relax_shapes=True,136input_signature=[tf.TensorSpec(shape=[None, None, None], dtype=tf.float32)],137)138def synthesis(self, x):139"""Synthesis with PQMF.140Args:141x (Tensor): Input tensor (B, T // subbands, subbands).142Returns:143Tensor: Output tensor (B, T, 1).144"""145x = tf.nn.conv1d_transpose(146x,147self.updown_filter * self.subbands,148strides=self.subbands,149output_shape=(150tf.shape(x)[0],151tf.shape(x)[1] * self.subbands,152self.subbands,153),154)155x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])156return tf.nn.conv1d(x, self.synthesis_filter, stride=1, padding="VALID")157158159class TFMBMelGANGenerator(TFMelGANGenerator):160"""Tensorflow MBMelGAN generator module."""161162def __init__(self, config, **kwargs):163super().__init__(config, **kwargs)164self.pqmf = TFPQMF(config=config, dtype=tf.float32, name="pqmf")165166def call(self, mels, **kwargs):167"""Calculate forward propagation.168Args:169c (Tensor): Input tensor (B, T, channels)170Returns:171Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels)172"""173return self.inference(mels)174175@tf.function(176input_signature=[177tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels")178]179)180def inference(self, mels):181mb_audios = self.melgan(mels)182return self.pqmf.synthesis(mb_audios)183184@tf.function(185input_signature=[186tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels")187]188)189def inference_tflite(self, mels):190mb_audios = self.melgan(mels)191return self.pqmf.synthesis(mb_audios)192193194