Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/configs/tacotron2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tacotron-2 Config object."""
16
17
18
from tensorflow_tts.configs import BaseConfig
19
from tensorflow_tts.processor.jsut import JSUT_SYMBOLS
20
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
21
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
22
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
23
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS as lbri_symbols
24
from tensorflow_tts.processor.ljspeechu import LJSPEECH_U_SYMBOLS as lju_symbols
25
from tensorflow_tts.processor.synpaflex import SYNPAFLEX_SYMBOLS as synpaflex_symbols
26
from tensorflow_tts.processor.jsut import JSUT_SYMBOLS as jsut_symbols
27
28
29
class Tacotron2Config(BaseConfig):
30
"""Initialize Tacotron-2 Config."""
31
32
def __init__(
33
self,
34
dataset="ljspeech",
35
vocab_size=len(lj_symbols),
36
embedding_hidden_size=512,
37
initializer_range=0.02,
38
layer_norm_eps=1e-6,
39
embedding_dropout_prob=0.1,
40
n_speakers=5,
41
n_conv_encoder=3,
42
encoder_conv_filters=512,
43
encoder_conv_kernel_sizes=5,
44
encoder_conv_activation="mish",
45
encoder_conv_dropout_rate=0.5,
46
encoder_lstm_units=256,
47
reduction_factor=5,
48
n_prenet_layers=2,
49
prenet_units=256,
50
prenet_activation="mish",
51
prenet_dropout_rate=0.5,
52
n_lstm_decoder=1,
53
decoder_lstm_units=1024,
54
attention_type="lsa",
55
attention_dim=128,
56
attention_filters=32,
57
attention_kernel=31,
58
n_mels=80,
59
n_conv_postnet=5,
60
postnet_conv_filters=512,
61
postnet_conv_kernel_sizes=5,
62
postnet_dropout_rate=0.1,
63
):
64
"""Init parameters for Tacotron-2 model."""
65
if dataset == "ljspeech":
66
self.vocab_size = vocab_size
67
elif dataset == "kss":
68
self.vocab_size = len(kss_symbols)
69
elif dataset == "baker":
70
self.vocab_size = len(bk_symbols)
71
elif dataset == "libritts":
72
self.vocab_size = len(lbri_symbols)
73
elif dataset == "ljspeechu":
74
self.vocab_size = len(lju_symbols)
75
elif dataset == "synpaflex":
76
self.vocab_size = len(synpaflex_symbols)
77
elif dataset == "jsut":
78
self.vocab_size = len(jsut_symbols)
79
else:
80
raise ValueError("No such dataset: {}".format(dataset))
81
self.embedding_hidden_size = embedding_hidden_size
82
self.initializer_range = initializer_range
83
self.layer_norm_eps = layer_norm_eps
84
self.embedding_dropout_prob = embedding_dropout_prob
85
self.n_speakers = n_speakers
86
self.n_conv_encoder = n_conv_encoder
87
self.encoder_conv_filters = encoder_conv_filters
88
self.encoder_conv_kernel_sizes = encoder_conv_kernel_sizes
89
self.encoder_conv_activation = encoder_conv_activation
90
self.encoder_conv_dropout_rate = encoder_conv_dropout_rate
91
self.encoder_lstm_units = encoder_lstm_units
92
93
# decoder param
94
self.reduction_factor = reduction_factor
95
self.n_prenet_layers = n_prenet_layers
96
self.prenet_units = prenet_units
97
self.prenet_activation = prenet_activation
98
self.prenet_dropout_rate = prenet_dropout_rate
99
self.n_lstm_decoder = n_lstm_decoder
100
self.decoder_lstm_units = decoder_lstm_units
101
self.attention_type = attention_type
102
self.attention_dim = attention_dim
103
self.attention_filters = attention_filters
104
self.attention_kernel = attention_kernel
105
self.n_mels = n_mels
106
107
# postnet
108
self.n_conv_postnet = n_conv_postnet
109
self.postnet_conv_filters = postnet_conv_filters
110
self.postnet_conv_kernel_sizes = postnet_conv_kernel_sizes
111
self.postnet_dropout_rate = postnet_dropout_rate
112
113