Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/models/mb_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The Multi-band MelGAN Authors , Minh Nguyen (@dathudeptrai) and Tomoki Hayashi (@kan-bayashi)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ============================================================================
16
#
17
# Compatible with https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/layers/pqmf.py.
18
"""Multi-band MelGAN Modules."""
19
20
import numpy as np
21
import tensorflow as tf
22
from scipy.signal import kaiser
23
24
from tensorflow_tts.models import BaseModel
25
from tensorflow_tts.models import TFMelGANGenerator
26
27
28
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
29
"""Design prototype filter for PQMF.
30
This method is based on `A Kaiser window approach for the design of prototype
31
filters of cosine modulated filterbanks`_.
32
Args:
33
taps (int): The number of filter taps.
34
cutoff_ratio (float): Cut-off frequency ratio.
35
beta (float): Beta coefficient for kaiser window.
36
Returns:
37
ndarray: Impluse response of prototype filter (taps + 1,).
38
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
39
https://ieeexplore.ieee.org/abstract/document/681427
40
"""
41
# check the arguments are valid
42
assert taps % 2 == 0, "The number of taps mush be even number."
43
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
44
45
# make initial filter
46
omega_c = np.pi * cutoff_ratio
47
with np.errstate(invalid="ignore"):
48
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
49
np.pi * (np.arange(taps + 1) - 0.5 * taps)
50
)
51
# fix nan due to indeterminate form
52
h_i[taps // 2] = np.cos(0) * cutoff_ratio
53
54
# apply kaiser window
55
w = kaiser(taps + 1, beta)
56
h = h_i * w
57
58
return h
59
60
61
class TFPQMF(tf.keras.layers.Layer):
62
"""PQMF module."""
63
64
def __init__(self, config, **kwargs):
65
"""Initilize PQMF module.
66
Args:
67
config (class): MultiBandMelGANGeneratorConfig
68
"""
69
super().__init__(**kwargs)
70
subbands = config.subbands
71
taps = config.taps
72
cutoff_ratio = config.cutoff_ratio
73
beta = config.beta
74
75
# define filter coefficient
76
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
77
h_analysis = np.zeros((subbands, len(h_proto)))
78
h_synthesis = np.zeros((subbands, len(h_proto)))
79
for k in range(subbands):
80
h_analysis[k] = (
81
2
82
* h_proto
83
* np.cos(
84
(2 * k + 1)
85
* (np.pi / (2 * subbands))
86
* (np.arange(taps + 1) - (taps / 2))
87
+ (-1) ** k * np.pi / 4
88
)
89
)
90
h_synthesis[k] = (
91
2
92
* h_proto
93
* np.cos(
94
(2 * k + 1)
95
* (np.pi / (2 * subbands))
96
* (np.arange(taps + 1) - (taps / 2))
97
- (-1) ** k * np.pi / 4
98
)
99
)
100
101
# [subbands, 1, taps + 1] == [filter_width, in_channels, out_channels]
102
analysis_filter = np.expand_dims(h_analysis, 1)
103
analysis_filter = np.transpose(analysis_filter, (2, 1, 0))
104
105
synthesis_filter = np.expand_dims(h_synthesis, 0)
106
synthesis_filter = np.transpose(synthesis_filter, (2, 1, 0))
107
108
# filter for downsampling & upsampling
109
updown_filter = np.zeros((subbands, subbands, subbands), dtype=np.float32)
110
for k in range(subbands):
111
updown_filter[0, k, k] = 1.0
112
113
self.subbands = subbands
114
self.taps = taps
115
self.analysis_filter = analysis_filter.astype(np.float32)
116
self.synthesis_filter = synthesis_filter.astype(np.float32)
117
self.updown_filter = updown_filter.astype(np.float32)
118
119
@tf.function(
120
experimental_relax_shapes=True,
121
input_signature=[tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)],
122
)
123
def analysis(self, x):
124
"""Analysis with PQMF.
125
Args:
126
x (Tensor): Input tensor (B, T, 1).
127
Returns:
128
Tensor: Output tensor (B, T // subbands, subbands).
129
"""
130
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])
131
x = tf.nn.conv1d(x, self.analysis_filter, stride=1, padding="VALID")
132
x = tf.nn.conv1d(x, self.updown_filter, stride=self.subbands, padding="VALID")
133
return x
134
135
@tf.function(
136
experimental_relax_shapes=True,
137
input_signature=[tf.TensorSpec(shape=[None, None, None], dtype=tf.float32)],
138
)
139
def synthesis(self, x):
140
"""Synthesis with PQMF.
141
Args:
142
x (Tensor): Input tensor (B, T // subbands, subbands).
143
Returns:
144
Tensor: Output tensor (B, T, 1).
145
"""
146
x = tf.nn.conv1d_transpose(
147
x,
148
self.updown_filter * self.subbands,
149
strides=self.subbands,
150
output_shape=(
151
tf.shape(x)[0],
152
tf.shape(x)[1] * self.subbands,
153
self.subbands,
154
),
155
)
156
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])
157
return tf.nn.conv1d(x, self.synthesis_filter, stride=1, padding="VALID")
158
159
160
class TFMBMelGANGenerator(TFMelGANGenerator):
161
"""Tensorflow MBMelGAN generator module."""
162
163
def __init__(self, config, **kwargs):
164
super().__init__(config, **kwargs)
165
self.pqmf = TFPQMF(config=config, dtype=tf.float32, name="pqmf")
166
167
def call(self, mels, **kwargs):
168
"""Calculate forward propagation.
169
Args:
170
c (Tensor): Input tensor (B, T, channels)
171
Returns:
172
Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels)
173
"""
174
return self.inference(mels)
175
176
@tf.function(
177
input_signature=[
178
tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels")
179
]
180
)
181
def inference(self, mels):
182
mb_audios = self.melgan(mels)
183
return self.pqmf.synthesis(mb_audios)
184
185
@tf.function(
186
input_signature=[
187
tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels")
188
]
189
)
190
def inference_tflite(self, mels):
191
mb_audios = self.melgan(mels)
192
return self.pqmf.synthesis(mb_audios)
193
194