Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/test/test_parallel_wavegan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import logging
17
import os
18
19
import pytest
20
import tensorflow as tf
21
22
from tensorflow_tts.configs import (
23
ParallelWaveGANGeneratorConfig,
24
ParallelWaveGANDiscriminatorConfig,
25
)
26
from tensorflow_tts.models import (
27
TFParallelWaveGANGenerator,
28
TFParallelWaveGANDiscriminator,
29
)
30
31
os.environ["CUDA_VISIBLE_DEVICES"] = ""
32
33
logging.basicConfig(
34
level=logging.DEBUG,
35
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
36
)
37
38
39
def make_pwgan_generator_args(**kwargs):
40
defaults = dict(
41
out_channels=1,
42
kernel_size=3,
43
n_layers=30,
44
stacks=3,
45
residual_channels=64,
46
gate_channels=128,
47
skip_channels=64,
48
aux_channels=80,
49
aux_context_window=2,
50
dropout_rate=0.0,
51
use_bias=True,
52
use_causal_conv=False,
53
upsample_conditional_features=True,
54
upsample_params={"upsample_scales": [4, 4, 4, 4]},
55
initializer_seed=42,
56
)
57
defaults.update(kwargs)
58
return defaults
59
60
61
def make_pwgan_discriminator_args(**kwargs):
62
defaults = dict(
63
out_channels=1,
64
kernel_size=3,
65
n_layers=10,
66
conv_channels=64,
67
use_bias=True,
68
dilation_factor=1,
69
nonlinear_activation="LeakyReLU",
70
nonlinear_activation_params={"alpha": 0.2},
71
initializer_seed=42,
72
apply_sigmoid_at_last=False,
73
)
74
defaults.update(kwargs)
75
return defaults
76
77
78
@pytest.mark.parametrize(
79
"dict_g, dict_d",
80
[
81
({}, {}),
82
(
83
{"kernel_size": 3, "aux_context_window": 5, "residual_channels": 128},
84
{"dilation_factor": 2},
85
),
86
({"stacks": 4, "n_layers": 40}, {"conv_channels": 128}),
87
],
88
)
89
def test_melgan_trainable(dict_g, dict_d):
90
random_c = tf.random.uniform(shape=[4, 32, 80], dtype=tf.float32)
91
92
args_g = make_pwgan_generator_args(**dict_g)
93
args_d = make_pwgan_discriminator_args(**dict_d)
94
95
args_g = ParallelWaveGANGeneratorConfig(**args_g)
96
args_d = ParallelWaveGANDiscriminatorConfig(**args_d)
97
98
generator = TFParallelWaveGANGenerator(args_g)
99
generator._build()
100
discriminator = TFParallelWaveGANDiscriminator(args_d)
101
discriminator._build()
102
103
generated_audios = generator(random_c, training=True)
104
discriminator(generated_audios)
105
106
generator.summary()
107
discriminator.summary()
108
109