Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/configs/fastspeech.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""FastSpeech Config object."""
16
17
import collections
18
19
from tensorflow_tts.configs import BaseConfig
20
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
21
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
22
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
23
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS as lbri_symbols
24
from tensorflow_tts.processor.jsut import JSUT_SYMBOLS as jsut_symbols
25
26
27
SelfAttentionParams = collections.namedtuple(
28
"SelfAttentionParams",
29
[
30
"n_speakers",
31
"hidden_size",
32
"num_hidden_layers",
33
"num_attention_heads",
34
"attention_head_size",
35
"intermediate_size",
36
"intermediate_kernel_size",
37
"hidden_act",
38
"output_attentions",
39
"output_hidden_states",
40
"initializer_range",
41
"hidden_dropout_prob",
42
"attention_probs_dropout_prob",
43
"layer_norm_eps",
44
"max_position_embeddings",
45
],
46
)
47
48
49
class FastSpeechConfig(BaseConfig):
50
"""Initialize FastSpeech Config."""
51
52
def __init__(
53
self,
54
dataset="ljspeech",
55
vocab_size=len(lj_symbols),
56
n_speakers=1,
57
encoder_hidden_size=384,
58
encoder_num_hidden_layers=4,
59
encoder_num_attention_heads=2,
60
encoder_attention_head_size=192,
61
encoder_intermediate_size=1024,
62
encoder_intermediate_kernel_size=3,
63
encoder_hidden_act="mish",
64
decoder_hidden_size=384,
65
decoder_num_hidden_layers=4,
66
decoder_num_attention_heads=2,
67
decoder_attention_head_size=192,
68
decoder_intermediate_size=1024,
69
decoder_intermediate_kernel_size=3,
70
decoder_hidden_act="mish",
71
output_attentions=True,
72
output_hidden_states=True,
73
hidden_dropout_prob=0.1,
74
attention_probs_dropout_prob=0.1,
75
initializer_range=0.02,
76
layer_norm_eps=1e-5,
77
max_position_embeddings=2048,
78
num_duration_conv_layers=2,
79
duration_predictor_filters=256,
80
duration_predictor_kernel_sizes=3,
81
num_mels=80,
82
duration_predictor_dropout_probs=0.1,
83
n_conv_postnet=5,
84
postnet_conv_filters=512,
85
postnet_conv_kernel_sizes=5,
86
postnet_dropout_rate=0.1,
87
**kwargs
88
):
89
"""Init parameters for Fastspeech model."""
90
# encoder params
91
if dataset == "ljspeech":
92
self.vocab_size = vocab_size
93
elif dataset == "kss":
94
self.vocab_size = len(kss_symbols)
95
elif dataset == "baker":
96
self.vocab_size = len(bk_symbols)
97
elif dataset == "libritts":
98
self.vocab_size = len(lbri_symbols)
99
elif dataset == "jsut":
100
self.vocab_size = len(jsut_symbols)
101
else:
102
raise ValueError("No such dataset: {}".format(dataset))
103
self.initializer_range = initializer_range
104
self.max_position_embeddings = max_position_embeddings
105
self.n_speakers = n_speakers
106
self.layer_norm_eps = layer_norm_eps
107
108
# encoder params
109
self.encoder_self_attention_params = SelfAttentionParams(
110
n_speakers=n_speakers,
111
hidden_size=encoder_hidden_size,
112
num_hidden_layers=encoder_num_hidden_layers,
113
num_attention_heads=encoder_num_attention_heads,
114
attention_head_size=encoder_attention_head_size,
115
hidden_act=encoder_hidden_act,
116
intermediate_size=encoder_intermediate_size,
117
intermediate_kernel_size=encoder_intermediate_kernel_size,
118
output_attentions=output_attentions,
119
output_hidden_states=output_hidden_states,
120
initializer_range=initializer_range,
121
hidden_dropout_prob=hidden_dropout_prob,
122
attention_probs_dropout_prob=attention_probs_dropout_prob,
123
layer_norm_eps=layer_norm_eps,
124
max_position_embeddings=max_position_embeddings,
125
)
126
127
# decoder params
128
self.decoder_self_attention_params = SelfAttentionParams(
129
n_speakers=n_speakers,
130
hidden_size=decoder_hidden_size,
131
num_hidden_layers=decoder_num_hidden_layers,
132
num_attention_heads=decoder_num_attention_heads,
133
attention_head_size=decoder_attention_head_size,
134
hidden_act=decoder_hidden_act,
135
intermediate_size=decoder_intermediate_size,
136
intermediate_kernel_size=decoder_intermediate_kernel_size,
137
output_attentions=output_attentions,
138
output_hidden_states=output_hidden_states,
139
initializer_range=initializer_range,
140
hidden_dropout_prob=hidden_dropout_prob,
141
attention_probs_dropout_prob=attention_probs_dropout_prob,
142
layer_norm_eps=layer_norm_eps,
143
max_position_embeddings=max_position_embeddings,
144
)
145
146
self.duration_predictor_dropout_probs = duration_predictor_dropout_probs
147
self.num_duration_conv_layers = num_duration_conv_layers
148
self.duration_predictor_filters = duration_predictor_filters
149
self.duration_predictor_kernel_sizes = duration_predictor_kernel_sizes
150
self.num_mels = num_mels
151
152
# postnet
153
self.n_conv_postnet = n_conv_postnet
154
self.postnet_conv_filters = postnet_conv_filters
155
self.postnet_conv_kernel_sizes = postnet_conv_kernel_sizes
156
self.postnet_dropout_rate = postnet_dropout_rate
157
158