Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/models/hifigan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The Hifigan Authors and TensorflowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Hifi Modules."""
16
17
import numpy as np
18
import tensorflow as tf
19
20
from tensorflow_tts.models.melgan import TFReflectionPad1d
21
from tensorflow_tts.models.melgan import TFConvTranspose1d
22
23
from tensorflow_tts.utils import GroupConv1D
24
from tensorflow_tts.utils import WeightNormalization
25
26
from tensorflow_tts.models import BaseModel
27
from tensorflow_tts.models import TFMelGANGenerator
28
29
30
class TFHifiResBlock(tf.keras.layers.Layer):
31
"""Tensorflow Hifigan resblock 1 module."""
32
33
def __init__(
34
self,
35
kernel_size,
36
filters,
37
dilation_rate,
38
use_bias,
39
nonlinear_activation,
40
nonlinear_activation_params,
41
is_weight_norm,
42
initializer_seed,
43
**kwargs
44
):
45
"""Initialize TFHifiResBlock module.
46
Args:
47
kernel_size (int): Kernel size.
48
filters (int): Number of filters.
49
dilation_rate (list): List dilation rate.
50
use_bias (bool): Whether to add bias parameter in convolution layers.
51
nonlinear_activation (str): Activation function module name.
52
nonlinear_activation_params (dict): Hyperparameters for activation function.
53
is_weight_norm (bool): Whether to use weight norm or not.
54
"""
55
super().__init__(**kwargs)
56
self.blocks_1 = []
57
self.blocks_2 = []
58
59
for i in range(len(dilation_rate)):
60
self.blocks_1.append(
61
[
62
TFReflectionPad1d((kernel_size - 1) // 2 * dilation_rate[i]),
63
tf.keras.layers.Conv1D(
64
filters=filters,
65
kernel_size=kernel_size,
66
dilation_rate=dilation_rate[i],
67
use_bias=use_bias,
68
),
69
]
70
)
71
self.blocks_2.append(
72
[
73
TFReflectionPad1d((kernel_size - 1) // 2 * 1),
74
tf.keras.layers.Conv1D(
75
filters=filters,
76
kernel_size=kernel_size,
77
dilation_rate=1,
78
use_bias=use_bias,
79
),
80
]
81
)
82
83
self.activation = getattr(tf.keras.layers, nonlinear_activation)(
84
**nonlinear_activation_params
85
)
86
87
# apply weightnorm
88
if is_weight_norm:
89
self._apply_weightnorm(self.blocks_1)
90
self._apply_weightnorm(self.blocks_2)
91
92
def call(self, x, training=False):
93
"""Calculate forward propagation.
94
Args:
95
x (Tensor): Input tensor (B, T, C).
96
Returns:
97
Tensor: Output tensor (B, T, C).
98
"""
99
for c1, c2 in zip(self.blocks_1, self.blocks_2):
100
xt = self.activation(x)
101
for c in c1:
102
xt = c(xt)
103
xt = self.activation(xt)
104
for c in c2:
105
xt = c(xt)
106
x = xt + x
107
return x
108
109
def _apply_weightnorm(self, list_layers):
110
"""Try apply weightnorm for all layer in list_layers."""
111
for i in range(len(list_layers)):
112
try:
113
layer_name = list_layers[i].name.lower()
114
if "conv1d" in layer_name or "dense" in layer_name:
115
list_layers[i] = WeightNormalization(list_layers[i])
116
except Exception:
117
pass
118
119
120
class TFMultiHifiResBlock(tf.keras.layers.Layer):
121
"""Tensorflow Multi Hifigan resblock 1 module."""
122
123
def __init__(self, list_resblock, **kwargs):
124
super().__init__(**kwargs)
125
self.list_resblock = list_resblock
126
127
def call(self, x, training=False):
128
xs = None
129
for resblock in self.list_resblock:
130
if xs is None:
131
xs = resblock(x, training=training)
132
else:
133
xs += resblock(x, training=training)
134
return xs / len(self.list_resblock)
135
136
137
class TFHifiGANGenerator(BaseModel):
138
def __init__(self, config, **kwargs):
139
super().__init__(**kwargs)
140
# check hyper parameter is valid or not
141
assert (
142
config.stacks
143
== len(config.stack_kernel_size)
144
== len(config.stack_dilation_rate)
145
)
146
147
# add initial layer
148
layers = []
149
layers += [
150
TFReflectionPad1d(
151
(config.kernel_size - 1) // 2,
152
padding_type=config.padding_type,
153
name="first_reflect_padding",
154
),
155
tf.keras.layers.Conv1D(
156
filters=config.filters,
157
kernel_size=config.kernel_size,
158
use_bias=config.use_bias,
159
),
160
]
161
162
for i, upsample_scale in enumerate(config.upsample_scales):
163
# add upsampling layer
164
layers += [
165
getattr(tf.keras.layers, config.nonlinear_activation)(
166
**config.nonlinear_activation_params
167
),
168
TFConvTranspose1d(
169
filters=config.filters // (2 ** (i + 1)),
170
kernel_size=upsample_scale * 2,
171
strides=upsample_scale,
172
padding="same",
173
is_weight_norm=config.is_weight_norm,
174
initializer_seed=config.initializer_seed,
175
name="conv_transpose_._{}".format(i),
176
),
177
]
178
179
# add residual stack layer
180
layers += [
181
TFMultiHifiResBlock(
182
list_resblock=[
183
TFHifiResBlock(
184
kernel_size=config.stack_kernel_size[j],
185
filters=config.filters // (2 ** (i + 1)),
186
dilation_rate=config.stack_dilation_rate[j],
187
use_bias=config.use_bias,
188
nonlinear_activation=config.nonlinear_activation,
189
nonlinear_activation_params=config.nonlinear_activation_params,
190
is_weight_norm=config.is_weight_norm,
191
initializer_seed=config.initializer_seed,
192
name="hifigan_resblock_._{}".format(j),
193
)
194
for j in range(config.stacks)
195
],
196
name="multi_hifigan_resblock_._{}".format(i),
197
)
198
]
199
# add final layer
200
layers += [
201
getattr(tf.keras.layers, config.nonlinear_activation)(
202
**config.nonlinear_activation_params
203
),
204
TFReflectionPad1d(
205
(config.kernel_size - 1) // 2,
206
padding_type=config.padding_type,
207
name="last_reflect_padding",
208
),
209
tf.keras.layers.Conv1D(
210
filters=config.out_channels,
211
kernel_size=config.kernel_size,
212
use_bias=config.use_bias,
213
dtype=tf.float32,
214
),
215
]
216
if config.use_final_nolinear_activation:
217
layers += [tf.keras.layers.Activation("tanh", dtype=tf.float32)]
218
219
if config.is_weight_norm is True:
220
self._apply_weightnorm(layers)
221
222
self.hifigan = tf.keras.models.Sequential(layers)
223
224
def call(self, mels, **kwargs):
225
"""Calculate forward propagation.
226
Args:
227
c (Tensor): Input tensor (B, T, channels)
228
Returns:
229
Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels)
230
"""
231
return self.inference(mels)
232
233
@tf.function(
234
input_signature=[
235
tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels")
236
]
237
)
238
def inference(self, mels):
239
return self.hifigan(mels)
240
241
@tf.function(
242
input_signature=[
243
tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels")
244
]
245
)
246
def inference_tflite(self, mels):
247
return self.hifigan(mels)
248
249
def _apply_weightnorm(self, list_layers):
250
"""Try apply weightnorm for all layer in list_layers."""
251
for i in range(len(list_layers)):
252
try:
253
layer_name = list_layers[i].name.lower()
254
if "conv1d" in layer_name or "dense" in layer_name:
255
list_layers[i] = WeightNormalization(list_layers[i])
256
except Exception:
257
pass
258
259
def _build(self):
260
"""Build model by passing fake input."""
261
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
262
self(fake_mels)
263
264
265
class TFHifiGANPeriodDiscriminator(tf.keras.layers.Layer):
266
"""Tensorflow Hifigan period discriminator module."""
267
268
def __init__(
269
self,
270
period,
271
out_channels=1,
272
n_layers=5,
273
kernel_size=5,
274
strides=3,
275
filters=8,
276
filter_scales=4,
277
max_filters=1024,
278
nonlinear_activation="LeakyReLU",
279
nonlinear_activation_params={"alpha": 0.2},
280
initializer_seed=42,
281
is_weight_norm=False,
282
**kwargs
283
):
284
super().__init__(**kwargs)
285
self.period = period
286
self.out_filters = out_channels
287
self.convs = []
288
289
for i in range(n_layers):
290
self.convs.append(
291
tf.keras.layers.Conv2D(
292
filters=min(filters * (filter_scales ** (i + 1)), max_filters),
293
kernel_size=(kernel_size, 1),
294
strides=(strides, 1),
295
padding="same",
296
)
297
)
298
self.conv_post = tf.keras.layers.Conv2D(
299
filters=out_channels, kernel_size=(3, 1), padding="same",
300
)
301
self.activation = getattr(tf.keras.layers, nonlinear_activation)(
302
**nonlinear_activation_params
303
)
304
305
if is_weight_norm:
306
self._apply_weightnorm(self.convs)
307
self.conv_post = WeightNormalization(self.conv_post)
308
309
def call(self, x):
310
"""Calculate forward propagation.
311
Args:
312
x (Tensor): Input noise signal (B, T, 1).
313
Returns:
314
List: List of output tensors.
315
"""
316
shape = tf.shape(x)
317
n_pad = tf.convert_to_tensor(0, dtype=tf.int32)
318
if shape[1] % self.period != 0:
319
n_pad = self.period - (shape[1] % self.period)
320
x = tf.pad(x, [[0, 0], [0, n_pad], [0, 0]], "REFLECT")
321
x = tf.reshape(
322
x, [shape[0], (shape[1] + n_pad) // self.period, self.period, x.shape[2]]
323
)
324
for layer in self.convs:
325
x = layer(x)
326
x = self.activation(x)
327
x = self.conv_post(x)
328
x = tf.reshape(x, [shape[0], -1, self.out_filters])
329
return [x]
330
331
def _apply_weightnorm(self, list_layers):
332
"""Try apply weightnorm for all layer in list_layers."""
333
for i in range(len(list_layers)):
334
try:
335
layer_name = list_layers[i].name.lower()
336
if "conv1d" in layer_name or "dense" in layer_name:
337
list_layers[i] = WeightNormalization(list_layers[i])
338
except Exception:
339
pass
340
341
342
class TFHifiGANMultiPeriodDiscriminator(BaseModel):
343
"""Tensorflow Hifigan Multi Period discriminator module."""
344
345
def __init__(self, config, **kwargs):
346
super().__init__(**kwargs)
347
self.discriminator = []
348
349
# add discriminator
350
for i in range(len(config.period_scales)):
351
self.discriminator += [
352
TFHifiGANPeriodDiscriminator(
353
config.period_scales[i],
354
out_channels=config.out_channels,
355
n_layers=config.n_layers,
356
kernel_size=config.kernel_size,
357
strides=config.strides,
358
filters=config.filters,
359
filter_scales=config.filter_scales,
360
max_filters=config.max_filters,
361
nonlinear_activation=config.nonlinear_activation,
362
nonlinear_activation_params=config.nonlinear_activation_params,
363
initializer_seed=config.initializer_seed,
364
is_weight_norm=config.is_weight_norm,
365
name="hifigan_period_discriminator_._{}".format(i),
366
)
367
]
368
369
def call(self, x):
370
"""Calculate forward propagation.
371
Args:
372
x (Tensor): Input noise signal (B, T, 1).
373
Returns:
374
List: list of each discriminator outputs
375
"""
376
outs = []
377
for f in self.discriminator:
378
outs += [f(x)]
379
return outs
380
381