Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/models/parallel_wavegan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The TensorFlowTTS Team and Tomoki Hayashi (@kan-bayashi)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""Parallel-wavegan Modules. Based on pytorch implementation (https://github.com/kan-bayashi/ParallelWaveGAN)"""
17
18
import tensorflow as tf
19
20
from tensorflow_tts.models import BaseModel
21
22
23
def get_initializer(initializer_seed=42):
24
"""Creates a `tf.initializers.he_normal` with the given seed.
25
Args:
26
initializer_seed: int, initializer seed.
27
Returns:
28
HeNormal initializer with seed = `initializer_seed`.
29
"""
30
return tf.keras.initializers.he_normal(seed=initializer_seed)
31
32
33
class TFConv1d1x1(tf.keras.layers.Conv1D):
34
"""1x1 Conv1d with customized initialization."""
35
36
def __init__(self, filters, use_bias, padding, initializer_seed, **kwargs):
37
"""Initialize 1x1 Conv1d module."""
38
super().__init__(
39
filters=filters,
40
kernel_size=1,
41
strides=1,
42
padding=padding,
43
dilation_rate=1,
44
use_bias=use_bias,
45
kernel_initializer=get_initializer(initializer_seed),
46
**kwargs,
47
)
48
49
50
class TFConv1d(tf.keras.layers.Conv1D):
51
"""Conv1d with customized initialization."""
52
53
def __init__(self, *args, **kwargs):
54
"""Initialize Conv1d module."""
55
initializer_seed = kwargs.pop("initializer_seed", 42)
56
super().__init__(
57
*args, **kwargs, kernel_initializer=get_initializer(initializer_seed)
58
)
59
60
61
class TFResidualBlock(tf.keras.layers.Layer):
62
"""Residual block module in WaveNet."""
63
64
def __init__(
65
self,
66
kernel_size=3,
67
residual_channels=64,
68
gate_channels=128,
69
skip_channels=64,
70
aux_channels=80,
71
dropout_rate=0.0,
72
dilation_rate=1,
73
use_bias=True,
74
use_causal_conv=False,
75
initializer_seed=42,
76
**kwargs,
77
):
78
"""Initialize ResidualBlock module.
79
80
Args:
81
kernel_size (int): Kernel size of dilation convolution layer.
82
residual_channels (int): Number of channels for residual connection.
83
skip_channels (int): Number of channels for skip connection.
84
aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
85
dropout_rate (float): Dropout probability.
86
dilation_rate (int): Dilation factor.
87
use_bias (bool): Whether to add bias parameter in convolution layers.
88
use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
89
initializer_seed (int32): initializer seed.
90
"""
91
super().__init__(**kwargs)
92
self.dropout_rate = dropout_rate
93
# no future time stamps available
94
self.use_causal_conv = use_causal_conv
95
96
# dilation conv
97
self.conv = TFConv1d(
98
filters=gate_channels,
99
kernel_size=kernel_size,
100
padding="same" if self.use_causal_conv is False else "causal",
101
strides=1,
102
dilation_rate=dilation_rate,
103
use_bias=use_bias,
104
initializer_seed=initializer_seed,
105
)
106
107
# local conditionong
108
if aux_channels > 0:
109
self.conv1x1_aux = TFConv1d1x1(
110
gate_channels,
111
use_bias=False,
112
padding="same",
113
initializer_seed=initializer_seed,
114
name="conv1x1_aux",
115
)
116
else:
117
self.conv1x1_aux = None
118
119
# conv output is split into two groups
120
gate_out_channels = gate_channels // 2
121
self.conv1x1_out = TFConv1d1x1(
122
residual_channels,
123
use_bias=use_bias,
124
padding="same",
125
initializer_seed=initializer_seed,
126
name="conv1x1_out",
127
)
128
self.conv1x1_skip = TFConv1d1x1(
129
skip_channels,
130
use_bias=use_bias,
131
padding="same",
132
initializer_seed=initializer_seed,
133
name="conv1x1_skip",
134
)
135
136
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
137
138
def call(self, x, c, training=False):
139
"""Calculate forward propagation.
140
141
Args:
142
x (Tensor): Input tensor (B, residual_channels, T).
143
c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
144
145
Returns:
146
Tensor: Output tensor for residual connection (B, T, residual_channels).
147
Tensor: Output tensor for skip connection (B, T, skip_channels).
148
"""
149
residual = x
150
x = self.dropout(x, training=training)
151
x = self.conv(x)
152
153
# split into two part for gated activation
154
xa, xb = tf.split(x, 2, axis=-1)
155
156
# local conditioning
157
if c is not None:
158
assert self.conv1x1_aux is not None
159
c = self.conv1x1_aux(c)
160
ca, cb = tf.split(c, 2, axis=-1)
161
xa, xb = xa + ca, xb + cb
162
163
x = tf.nn.tanh(xa) * tf.nn.sigmoid(xb)
164
165
# for skip connection
166
s = self.conv1x1_skip(x)
167
168
# for residual connection
169
x = self.conv1x1_out(x)
170
x = (x + residual) * tf.math.sqrt(0.5)
171
172
return x, s
173
174
175
class TFStretch1d(tf.keras.layers.Layer):
176
"""Stretch2d module."""
177
178
def __init__(self, x_scale, y_scale, method="nearest", **kwargs):
179
"""Initialize Stretch2d module.
180
181
Args:
182
x_scale (int): X scaling factor (Time axis in spectrogram).
183
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
184
method (str): Interpolation method.
185
186
"""
187
super().__init__(**kwargs)
188
self.x_scale = x_scale
189
self.y_scale = y_scale
190
self.method = method
191
192
def call(self, x):
193
"""Calculate forward propagation.
194
195
Args:
196
x (Tensor): Input tensor (B, T, C, 1).
197
Returns:
198
Tensor: Interpolated tensor (B, T * x_scale, C * y_scale, 1)
199
200
"""
201
x_shape = tf.shape(x)
202
new_size = (x_shape[1] * self.x_scale, x_shape[2] * self.y_scale)
203
x = tf.image.resize(x, method=self.method, size=new_size)
204
return x
205
206
207
class TFUpsampleNetWork(tf.keras.layers.Layer):
208
"""Upsampling network module."""
209
210
def __init__(
211
self,
212
output_channels,
213
upsample_scales,
214
nonlinear_activation=None,
215
nonlinear_activation_params={},
216
interpolate_mode="nearest",
217
freq_axis_kernel_size=1,
218
use_causal_conv=False,
219
**kwargs,
220
):
221
"""Initialize upsampling network module.
222
223
Args:
224
output_channels (int): output feature channels.
225
upsample_scales (list): List of upsampling scales.
226
nonlinear_activation (str): Activation function name.
227
nonlinear_activation_params (dict): Arguments for specified activation function.
228
interpolate_mode (str): Interpolation mode.
229
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
230
231
"""
232
super().__init__(**kwargs)
233
self.use_causal_conv = use_causal_conv
234
self.up_layers = []
235
236
for scale in upsample_scales:
237
# interpolation layer
238
stretch = TFStretch1d(
239
scale, 1, interpolate_mode, name="stretch_._{}".format(scale)
240
) # ->> outputs: [B, T * scale, C * 1, 1]
241
self.up_layers += [stretch]
242
243
# conv layer
244
assert (
245
freq_axis_kernel_size - 1
246
) % 2 == 0, "Not support even number freq axis kernel size."
247
kernel_size = scale * 2 + 1
248
conv = tf.keras.layers.Conv2D(
249
filters=1,
250
kernel_size=(kernel_size, freq_axis_kernel_size),
251
padding="causal" if self.use_causal_conv is True else "same",
252
use_bias=False,
253
) # ->> outputs: [B, T * scale, C * 1, 1]
254
self.up_layers += [conv]
255
256
# nonlinear
257
if nonlinear_activation is not None:
258
nonlinear = getattr(tf.keras.layers, nonlinear_activation)(
259
**nonlinear_activation_params
260
)
261
self.up_layers += [nonlinear]
262
263
def call(self, c):
264
"""Calculate forward propagation.
265
Args:
266
c : Input tensor (B, T, C).
267
Returns:
268
Tensor: Upsampled tensor (B, T', C), where T' = T * prod(upsample_scales).
269
"""
270
c = tf.expand_dims(c, -1) # [B, T, C, 1]
271
for f in self.up_layers:
272
c = f(c)
273
return tf.squeeze(c, -1) # [B, T, C]
274
275
276
class TFConvInUpsampleNetWork(tf.keras.layers.Layer):
277
"""Convolution + upsampling network module."""
278
279
def __init__(
280
self,
281
upsample_scales,
282
nonlinear_activation=None,
283
nonlinear_activation_params={},
284
interpolate_mode="nearest",
285
freq_axis_kernel_size=1,
286
aux_channels=80,
287
aux_context_window=0,
288
use_causal_conv=False,
289
initializer_seed=42,
290
**kwargs,
291
):
292
"""Initialize convolution + upsampling network module.
293
294
Args:
295
upsample_scales (list): List of upsampling scales.
296
nonlinear_activation (str): Activation function name.
297
nonlinear_activation_params (dict): Arguments for specified activation function.
298
mode (str): Interpolation mode.
299
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
300
aux_channels (int): Number of channels of pre-convolutional layer.
301
aux_context_window (int): Context window size of the pre-convolutional layer.
302
use_causal_conv (bool): Whether to use causal structure.
303
304
"""
305
super().__init__(**kwargs)
306
self.aux_context_window = aux_context_window
307
self.use_causal_conv = use_causal_conv and aux_context_window > 0
308
309
# To capture wide-context information in conditional features
310
kernel_size = (
311
aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
312
)
313
314
self.conv_in = TFConv1d(
315
filters=aux_channels,
316
kernel_size=kernel_size,
317
padding="same",
318
use_bias=False,
319
initializer_seed=initializer_seed,
320
name="conv_in",
321
)
322
self.upsample = TFUpsampleNetWork(
323
output_channels=aux_channels,
324
upsample_scales=upsample_scales,
325
nonlinear_activation=nonlinear_activation,
326
nonlinear_activation_params=nonlinear_activation_params,
327
interpolate_mode=interpolate_mode,
328
freq_axis_kernel_size=freq_axis_kernel_size,
329
use_causal_conv=use_causal_conv,
330
name="upsample_network",
331
)
332
333
def call(self, c):
334
"""Calculate forward propagation.
335
336
Args:
337
c : Input tensor (B, T', C).
338
339
Returns:
340
Tensor: Upsampled tensor (B, T, C),
341
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
342
343
Note:
344
The length of inputs considers the context window size.
345
"""
346
c_ = self.conv_in(c)
347
return self.upsample(c_)
348
349
350
class TFParallelWaveGANGenerator(BaseModel):
351
"""Parallel WaveGAN Generator module."""
352
353
def __init__(self, config, **kwargs):
354
super().__init__(**kwargs)
355
self.out_channels = config.out_channels
356
self.aux_channels = config.aux_channels
357
self.n_layers = config.n_layers
358
self.stacks = config.stacks
359
self.kernel_size = config.kernel_size
360
self.upsample_params = config.upsample_params
361
362
# check the number of layers and stacks
363
assert self.n_layers % self.stacks == 0
364
n_layers_per_stack = self.n_layers // self.stacks
365
366
# define first convolution
367
self.first_conv = TFConv1d1x1(
368
filters=config.residual_channels,
369
use_bias=True,
370
padding="same",
371
initializer_seed=config.initializer_seed,
372
name="first_convolution",
373
)
374
375
# define conv + upsampling network
376
if config.upsample_conditional_features:
377
self.upsample_params.update({"use_causal_conv": config.use_causal_conv})
378
self.upsample_params.update(
379
{
380
"aux_channels": config.aux_channels,
381
"aux_context_window": config.aux_context_window,
382
}
383
)
384
self.upsample_net = TFConvInUpsampleNetWork(**self.upsample_params)
385
else:
386
self.upsample_net = None
387
388
# define residual blocks
389
self.conv_layers = []
390
for layer in range(self.n_layers):
391
dilation_rate = 2 ** (layer % n_layers_per_stack)
392
conv = TFResidualBlock(
393
kernel_size=config.kernel_size,
394
residual_channels=config.residual_channels,
395
gate_channels=config.gate_channels,
396
skip_channels=config.skip_channels,
397
aux_channels=config.aux_channels,
398
dilation_rate=dilation_rate,
399
dropout_rate=config.dropout_rate,
400
use_bias=config.use_bias,
401
use_causal_conv=config.use_causal_conv,
402
initializer_seed=config.initializer_seed,
403
name="residual_block_._{}".format(layer),
404
)
405
self.conv_layers += [conv]
406
407
# define output layers
408
self.last_conv_layers = [
409
tf.keras.layers.ReLU(),
410
TFConv1d1x1(
411
filters=config.skip_channels,
412
use_bias=config.use_bias,
413
padding="same",
414
initializer_seed=config.initializer_seed,
415
),
416
tf.keras.layers.ReLU(),
417
TFConv1d1x1(
418
filters=config.out_channels,
419
use_bias=True,
420
padding="same",
421
initializer_seed=config.initializer_seed,
422
),
423
tf.keras.layers.Activation("tanh"),
424
]
425
426
def _build(self):
427
mels = tf.random.uniform(shape=[2, 20, 80], dtype=tf.float32)
428
self(mels, training=tf.cast(True, tf.bool))
429
430
def call(self, mels, training=False, **kwargs):
431
"""Calculate forward propagation.
432
433
Args:
434
mels (Tensor): Local conditioning auxiliary features (B, T', C).
435
Returns:
436
437
Tensor: Output tensor (B, T, 1)
438
"""
439
# perform upsampling
440
if mels is not None and self.upsample_net is not None:
441
c = self.upsample_net(mels)
442
443
# random noise x
444
# enccode to hidden representation
445
x = tf.expand_dims(tf.random.normal(shape=tf.shape(c)[0:2]), axis=2)
446
x = self.first_conv(x)
447
skips = 0
448
for f in self.conv_layers:
449
x, h = f(x, c, training=training)
450
skips += h
451
skips *= tf.math.sqrt(1.0 / len(self.conv_layers))
452
453
# apply final layers
454
x = skips
455
for f in self.last_conv_layers:
456
x = f(x)
457
458
return x
459
460
@tf.function(
461
experimental_relax_shapes=True,
462
input_signature=[
463
tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels"),
464
],
465
)
466
def inference(self, mels):
467
"""Calculate forward propagation.
468
469
Args:
470
c (Tensor): Local conditioning auxiliary features (B, T', C).
471
Returns:
472
473
Tensor: Output tensor (B, T, 1)
474
"""
475
# perform upsampling
476
if mels is not None and self.upsample_net is not None:
477
c = self.upsample_net(mels)
478
479
# enccode to hidden representation
480
x = tf.expand_dims(tf.random.normal(shape=tf.shape(c)[0:2]), axis=2)
481
x = self.first_conv(x)
482
skips = 0
483
for f in self.conv_layers:
484
x, h = f(x, c, training=False)
485
skips += h
486
skips *= tf.math.sqrt(1.0 / len(self.conv_layers))
487
488
# apply final layers
489
x = skips
490
for f in self.last_conv_layers:
491
x = f(x)
492
493
return x
494
495
496
class TFParallelWaveGANDiscriminator(BaseModel):
497
"""Parallel WaveGAN Discriminator module."""
498
499
def __init__(self, config, **kwargs):
500
super().__init__(**kwargs)
501
assert (config.kernel_size - 1) % 2 == 0, "Not support even number kernel size."
502
assert config.dilation_factor > 0, "Dilation factor must be > 0."
503
self.conv_layers = []
504
for i in range(config.n_layers - 1):
505
if i == 0:
506
dilation_rate = 1
507
else:
508
dilation_rate = (
509
i if config.dilation_factor == 1 else config.dilation_factor ** i
510
)
511
self.conv_layers += [
512
TFConv1d(
513
filters=config.conv_channels,
514
kernel_size=config.kernel_size,
515
padding="same",
516
dilation_rate=dilation_rate,
517
use_bias=config.use_bias,
518
initializer_seed=config.initializer_seed,
519
)
520
]
521
self.conv_layers += [
522
getattr(tf.keras.layers, config.nonlinear_activation)(
523
**config.nonlinear_activation_params
524
)
525
]
526
self.conv_layers += [
527
TFConv1d(
528
filters=config.out_channels,
529
kernel_size=config.kernel_size,
530
padding="same",
531
use_bias=config.use_bias,
532
initializer_seed=config.initializer_seed,
533
)
534
]
535
536
if config.apply_sigmoid_at_last:
537
self.conv_layers += [
538
tf.keras.layers.Activation("sigmoid"),
539
]
540
541
def _build(self):
542
x = tf.random.uniform(shape=[2, 16000, 1])
543
self(x)
544
545
def call(self, x):
546
"""Calculate forward propagation.
547
548
Args:
549
x (Tensor): Input noise signal (B, T, 1).
550
551
Returns:
552
Tensor: Output tensor (B, T, 1)
553
"""
554
for f in self.conv_layers:
555
x = f(x)
556
return x
557
558