Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/test/test_mb_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import tensorflow as tf
17
18
import logging
19
import os
20
21
import numpy as np
22
import pytest
23
24
from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig
25
from tensorflow_tts.models import TFPQMF, TFMelGANGenerator
26
27
os.environ["CUDA_VISIBLE_DEVICES"] = ""
28
29
logging.basicConfig(
30
level=logging.DEBUG,
31
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
32
)
33
34
35
def make_multi_band_melgan_generator_args(**kwargs):
36
defaults = dict(
37
out_channels=1,
38
kernel_size=7,
39
filters=512,
40
use_bias=True,
41
upsample_scales=[8, 8, 2, 2],
42
stack_kernel_size=3,
43
stacks=3,
44
nonlinear_activation="LeakyReLU",
45
nonlinear_activation_params={"alpha": 0.2},
46
padding_type="REFLECT",
47
subbands=4,
48
tabs=62,
49
cutoff_ratio=0.15,
50
beta=9.0,
51
)
52
defaults.update(kwargs)
53
return defaults
54
55
56
@pytest.mark.parametrize(
57
"dict_g",
58
[
59
{"subbands": 4, "upsample_scales": [2, 4, 8], "stacks": 4, "out_channels": 4},
60
{"subbands": 4, "upsample_scales": [4, 4, 4], "stacks": 5, "out_channels": 4},
61
],
62
)
63
def test_multi_band_melgan(dict_g):
64
args_g = make_multi_band_melgan_generator_args(**dict_g)
65
args_g = MultiBandMelGANGeneratorConfig(**args_g)
66
generator = TFMelGANGenerator(args_g, name="multi_band_melgan")
67
generator._build()
68
69
pqmf = TFPQMF(args_g, name="pqmf")
70
71
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
72
fake_y = tf.random.uniform(shape=[1, 100 * 256, 1], dtype=tf.float32)
73
y_hat_subbands = generator(fake_mels)
74
75
y_hat = pqmf.synthesis(y_hat_subbands)
76
y_subbands = pqmf.analysis(fake_y)
77
78
assert np.shape(y_subbands) == np.shape(y_hat_subbands)
79
assert np.shape(fake_y) == np.shape(y_hat)
80
81