Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/utils/griffin_lim.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Griffin-Lim phase reconstruction algorithm from mel spectrogram."""
16
17
import os
18
19
import librosa
20
import numpy as np
21
import soundfile as sf
22
import tensorflow as tf
23
from sklearn.preprocessing import StandardScaler
24
25
26
def griffin_lim_lb(
27
mel_spec, stats_path, dataset_config, n_iter=32, output_dir=None, wav_name="lb"
28
):
29
"""Generate wave from mel spectrogram with Griffin-Lim algorithm using Librosa.
30
Args:
31
mel_spec (ndarray): array representing the mel spectrogram.
32
stats_path (str): path to the `stats.npy` file containing norm statistics.
33
dataset_config (Dict): dataset configuration parameters.
34
n_iter (int): number of iterations for GL.
35
output_dir (str): output directory where audio file will be saved.
36
wav_name (str): name of the output file.
37
Returns:
38
gl_lb (ndarray): generated wave.
39
"""
40
scaler = StandardScaler()
41
scaler.mean_, scaler.scale_ = np.load(stats_path)
42
43
mel_spec = np.power(10.0, scaler.inverse_transform(mel_spec)).T
44
mel_basis = librosa.filters.mel(
45
dataset_config["sampling_rate"],
46
n_fft=dataset_config["fft_size"],
47
n_mels=dataset_config["num_mels"],
48
fmin=dataset_config["fmin"],
49
fmax=dataset_config["fmax"],
50
)
51
mel_to_linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel_spec))
52
gl_lb = librosa.griffinlim(
53
mel_to_linear,
54
n_iter=n_iter,
55
hop_length=dataset_config["hop_size"],
56
win_length=dataset_config["win_length"] or dataset_config["fft_size"],
57
)
58
if output_dir:
59
output_path = os.path.join(output_dir, f"{wav_name}.wav")
60
sf.write(output_path, gl_lb, dataset_config["sampling_rate"], "PCM_16")
61
return gl_lb
62
63
64
class TFGriffinLim(tf.keras.layers.Layer):
65
"""Griffin-Lim algorithm for phase reconstruction from mel spectrogram magnitude."""
66
67
def __init__(self, stats_path, dataset_config, normalized: bool = True):
68
"""Init GL params.
69
Args:
70
stats_path (str): path to the `stats.npy` file containing norm statistics.
71
dataset_config (Dict): dataset configuration parameters.
72
"""
73
super().__init__()
74
self.normalized = normalized
75
if normalized:
76
scaler = StandardScaler()
77
scaler.mean_, scaler.scale_ = np.load(stats_path)
78
self.scaler = scaler
79
self.ds_config = dataset_config
80
self.mel_basis = librosa.filters.mel(
81
self.ds_config["sampling_rate"],
82
n_fft=self.ds_config["fft_size"],
83
n_mels=self.ds_config["num_mels"],
84
fmin=self.ds_config["fmin"],
85
fmax=self.ds_config["fmax"],
86
) # [num_mels, fft_size // 2 + 1]
87
88
def save_wav(self, gl_tf, output_dir, wav_name):
89
"""Generate WAV file and save it.
90
Args:
91
gl_tf (tf.Tensor): reconstructed signal from GL algorithm.
92
output_dir (str): output directory where audio file will be saved.
93
wav_name (str): name of the output file.
94
"""
95
encode_fn = lambda x: tf.audio.encode_wav(x, self.ds_config["sampling_rate"])
96
gl_tf = tf.expand_dims(gl_tf, -1)
97
if not isinstance(wav_name, list):
98
wav_name = [wav_name]
99
100
if len(gl_tf.shape) > 2:
101
bs, *_ = gl_tf.shape
102
assert bs == len(wav_name), "Batch and 'wav_name' have different size."
103
tf_wav = tf.map_fn(encode_fn, gl_tf, dtype=tf.string)
104
for idx in tf.range(bs):
105
output_path = os.path.join(output_dir, f"{wav_name[idx]}.wav")
106
tf.io.write_file(output_path, tf_wav[idx])
107
else:
108
tf_wav = encode_fn(gl_tf)
109
tf.io.write_file(os.path.join(output_dir, f"{wav_name[0]}.wav"), tf_wav)
110
111
@tf.function(
112
input_signature=[
113
tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),
114
tf.TensorSpec(shape=[], dtype=tf.int32),
115
]
116
)
117
def call(self, mel_spec, n_iter=32):
118
"""Apply GL algorithm to batched mel spectrograms.
119
Args:
120
mel_spec (tf.Tensor): normalized mel spectrogram.
121
n_iter (int): number of iterations to run GL algorithm.
122
Returns:
123
(tf.Tensor): reconstructed signal from GL algorithm.
124
"""
125
# de-normalize mel spectogram
126
if self.normalized:
127
mel_spec = tf.math.pow(
128
10.0, mel_spec * self.scaler.scale_ + self.scaler.mean_
129
)
130
else:
131
mel_spec = tf.math.pow(
132
10.0, mel_spec
133
) # TODO @dathudeptrai check if its ok without it wavs were too quiet
134
inverse_mel = tf.linalg.pinv(self.mel_basis)
135
136
# [:, num_mels] @ [fft_size // 2 + 1, num_mels].T
137
mel_to_linear = tf.linalg.matmul(mel_spec, inverse_mel, transpose_b=True)
138
mel_to_linear = tf.cast(tf.math.maximum(1e-10, mel_to_linear), tf.complex64)
139
140
init_phase = tf.cast(
141
tf.random.uniform(tf.shape(mel_to_linear), maxval=1), tf.complex64
142
)
143
phase = tf.math.exp(2j * np.pi * init_phase)
144
for _ in tf.range(n_iter):
145
inverse = tf.signal.inverse_stft(
146
mel_to_linear * phase,
147
frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],
148
frame_step=self.ds_config["hop_size"],
149
fft_length=self.ds_config["fft_size"],
150
window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),
151
)
152
phase = tf.signal.stft(
153
inverse,
154
self.ds_config["win_length"] or self.ds_config["fft_size"],
155
self.ds_config["hop_size"],
156
self.ds_config["fft_size"],
157
)
158
phase /= tf.cast(tf.maximum(1e-10, tf.abs(phase)), tf.complex64)
159
160
return tf.signal.inverse_stft(
161
mel_to_linear * phase,
162
frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],
163
frame_step=self.ds_config["hop_size"],
164
fft_length=self.ds_config["fft_size"],
165
window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),
166
)
167
168