Path: blob/master/test/test_parallel_wavegan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import logging16import os1718import pytest19import tensorflow as tf2021from tensorflow_tts.configs import (22ParallelWaveGANGeneratorConfig,23ParallelWaveGANDiscriminatorConfig,24)25from tensorflow_tts.models import (26TFParallelWaveGANGenerator,27TFParallelWaveGANDiscriminator,28)2930os.environ["CUDA_VISIBLE_DEVICES"] = ""3132logging.basicConfig(33level=logging.DEBUG,34format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",35)363738def make_pwgan_generator_args(**kwargs):39defaults = dict(40out_channels=1,41kernel_size=3,42n_layers=30,43stacks=3,44residual_channels=64,45gate_channels=128,46skip_channels=64,47aux_channels=80,48aux_context_window=2,49dropout_rate=0.0,50use_bias=True,51use_causal_conv=False,52upsample_conditional_features=True,53upsample_params={"upsample_scales": [4, 4, 4, 4]},54initializer_seed=42,55)56defaults.update(kwargs)57return defaults585960def make_pwgan_discriminator_args(**kwargs):61defaults = dict(62out_channels=1,63kernel_size=3,64n_layers=10,65conv_channels=64,66use_bias=True,67dilation_factor=1,68nonlinear_activation="LeakyReLU",69nonlinear_activation_params={"alpha": 0.2},70initializer_seed=42,71apply_sigmoid_at_last=False,72)73defaults.update(kwargs)74return defaults757677@pytest.mark.parametrize(78"dict_g, dict_d",79[80({}, {}),81(82{"kernel_size": 3, "aux_context_window": 5, "residual_channels": 128},83{"dilation_factor": 2},84),85({"stacks": 4, "n_layers": 40}, {"conv_channels": 128}),86],87)88def test_melgan_trainable(dict_g, dict_d):89random_c = tf.random.uniform(shape=[4, 32, 80], dtype=tf.float32)9091args_g = make_pwgan_generator_args(**dict_g)92args_d = make_pwgan_discriminator_args(**dict_d)9394args_g = ParallelWaveGANGeneratorConfig(**args_g)95args_d = ParallelWaveGANDiscriminatorConfig(**args_d)9697generator = TFParallelWaveGANGenerator(args_g)98generator._build()99discriminator = TFParallelWaveGANDiscriminator(args_d)100discriminator._build()101102generated_audios = generator(random_c, training=True)103discriminator(generated_audios)104105generator.summary()106discriminator.summary()107108109