Path: blob/master/tensorflow_tts/utils/griffin_lim.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Griffin-Lim phase reconstruction algorithm from mel spectrogram."""1516import os1718import librosa19import numpy as np20import soundfile as sf21import tensorflow as tf22from sklearn.preprocessing import StandardScaler232425def griffin_lim_lb(26mel_spec, stats_path, dataset_config, n_iter=32, output_dir=None, wav_name="lb"27):28"""Generate wave from mel spectrogram with Griffin-Lim algorithm using Librosa.29Args:30mel_spec (ndarray): array representing the mel spectrogram.31stats_path (str): path to the `stats.npy` file containing norm statistics.32dataset_config (Dict): dataset configuration parameters.33n_iter (int): number of iterations for GL.34output_dir (str): output directory where audio file will be saved.35wav_name (str): name of the output file.36Returns:37gl_lb (ndarray): generated wave.38"""39scaler = StandardScaler()40scaler.mean_, scaler.scale_ = np.load(stats_path)4142mel_spec = np.power(10.0, scaler.inverse_transform(mel_spec)).T43mel_basis = librosa.filters.mel(44dataset_config["sampling_rate"],45n_fft=dataset_config["fft_size"],46n_mels=dataset_config["num_mels"],47fmin=dataset_config["fmin"],48fmax=dataset_config["fmax"],49)50mel_to_linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel_spec))51gl_lb = librosa.griffinlim(52mel_to_linear,53n_iter=n_iter,54hop_length=dataset_config["hop_size"],55win_length=dataset_config["win_length"] or dataset_config["fft_size"],56)57if output_dir:58output_path = os.path.join(output_dir, f"{wav_name}.wav")59sf.write(output_path, gl_lb, dataset_config["sampling_rate"], "PCM_16")60return gl_lb616263class TFGriffinLim(tf.keras.layers.Layer):64"""Griffin-Lim algorithm for phase reconstruction from mel spectrogram magnitude."""6566def __init__(self, stats_path, dataset_config, normalized: bool = True):67"""Init GL params.68Args:69stats_path (str): path to the `stats.npy` file containing norm statistics.70dataset_config (Dict): dataset configuration parameters.71"""72super().__init__()73self.normalized = normalized74if normalized:75scaler = StandardScaler()76scaler.mean_, scaler.scale_ = np.load(stats_path)77self.scaler = scaler78self.ds_config = dataset_config79self.mel_basis = librosa.filters.mel(80self.ds_config["sampling_rate"],81n_fft=self.ds_config["fft_size"],82n_mels=self.ds_config["num_mels"],83fmin=self.ds_config["fmin"],84fmax=self.ds_config["fmax"],85) # [num_mels, fft_size // 2 + 1]8687def save_wav(self, gl_tf, output_dir, wav_name):88"""Generate WAV file and save it.89Args:90gl_tf (tf.Tensor): reconstructed signal from GL algorithm.91output_dir (str): output directory where audio file will be saved.92wav_name (str): name of the output file.93"""94encode_fn = lambda x: tf.audio.encode_wav(x, self.ds_config["sampling_rate"])95gl_tf = tf.expand_dims(gl_tf, -1)96if not isinstance(wav_name, list):97wav_name = [wav_name]9899if len(gl_tf.shape) > 2:100bs, *_ = gl_tf.shape101assert bs == len(wav_name), "Batch and 'wav_name' have different size."102tf_wav = tf.map_fn(encode_fn, gl_tf, dtype=tf.string)103for idx in tf.range(bs):104output_path = os.path.join(output_dir, f"{wav_name[idx]}.wav")105tf.io.write_file(output_path, tf_wav[idx])106else:107tf_wav = encode_fn(gl_tf)108tf.io.write_file(os.path.join(output_dir, f"{wav_name[0]}.wav"), tf_wav)109110@tf.function(111input_signature=[112tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),113tf.TensorSpec(shape=[], dtype=tf.int32),114]115)116def call(self, mel_spec, n_iter=32):117"""Apply GL algorithm to batched mel spectrograms.118Args:119mel_spec (tf.Tensor): normalized mel spectrogram.120n_iter (int): number of iterations to run GL algorithm.121Returns:122(tf.Tensor): reconstructed signal from GL algorithm.123"""124# de-normalize mel spectogram125if self.normalized:126mel_spec = tf.math.pow(12710.0, mel_spec * self.scaler.scale_ + self.scaler.mean_128)129else:130mel_spec = tf.math.pow(13110.0, mel_spec132) # TODO @dathudeptrai check if its ok without it wavs were too quiet133inverse_mel = tf.linalg.pinv(self.mel_basis)134135# [:, num_mels] @ [fft_size // 2 + 1, num_mels].T136mel_to_linear = tf.linalg.matmul(mel_spec, inverse_mel, transpose_b=True)137mel_to_linear = tf.cast(tf.math.maximum(1e-10, mel_to_linear), tf.complex64)138139init_phase = tf.cast(140tf.random.uniform(tf.shape(mel_to_linear), maxval=1), tf.complex64141)142phase = tf.math.exp(2j * np.pi * init_phase)143for _ in tf.range(n_iter):144inverse = tf.signal.inverse_stft(145mel_to_linear * phase,146frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],147frame_step=self.ds_config["hop_size"],148fft_length=self.ds_config["fft_size"],149window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),150)151phase = tf.signal.stft(152inverse,153self.ds_config["win_length"] or self.ds_config["fft_size"],154self.ds_config["hop_size"],155self.ds_config["fft_size"],156)157phase /= tf.cast(tf.maximum(1e-10, tf.abs(phase)), tf.complex64)158159return tf.signal.inverse_stft(160mel_to_linear * phase,161frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],162frame_step=self.ds_config["hop_size"],163fft_length=self.ds_config["fft_size"],164window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),165)166167168