Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/test/test_tacotron2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import logging
17
import os
18
import time
19
import yaml
20
21
import numpy as np
22
import pytest
23
import tensorflow as tf
24
25
from tensorflow_tts.configs import Tacotron2Config
26
from tensorflow_tts.models import TFTacotron2
27
from tensorflow_tts.utils import return_strategy
28
29
from examples.tacotron2.train_tacotron2 import Tacotron2Trainer
30
31
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
32
33
logging.basicConfig(
34
level=logging.WARNING,
35
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
36
)
37
38
39
@pytest.mark.parametrize(
40
"var_train_expr, config_path",
41
[
42
("embeddings|decoder_cell", "./examples/tacotron2/conf/tacotron2.v1.yaml"),
43
(None, "./examples/tacotron2/conf/tacotron2.v1.yaml"),
44
(
45
"embeddings|decoder_cell",
46
"./examples/tacotron2/conf/tacotron2.baker.v1.yaml",
47
),
48
("embeddings|decoder_cell", "./examples/tacotron2/conf/tacotron2.kss.v1.yaml"),
49
],
50
)
51
def test_tacotron2_train_some_layers(var_train_expr, config_path):
52
config = Tacotron2Config(n_speakers=5, reduction_factor=1)
53
model = TFTacotron2(config, name="tacotron2")
54
model._build()
55
optimizer = tf.keras.optimizers.Adam(lr=0.001)
56
57
with open(config_path) as f:
58
config = yaml.load(f, Loader=yaml.Loader)
59
60
config.update({"outdir": "./"})
61
config.update({"var_train_expr": var_train_expr})
62
63
STRATEGY = return_strategy()
64
65
trainer = Tacotron2Trainer(
66
config=config, strategy=STRATEGY, steps=0, epochs=0, is_mixed_precision=False,
67
)
68
trainer.compile(model, optimizer)
69
70
len_trainable_vars = len(trainer._trainable_variables)
71
all_trainable_vars = len(model.trainable_variables)
72
73
if var_train_expr is None:
74
tf.debugging.assert_equal(len_trainable_vars, all_trainable_vars)
75
else:
76
tf.debugging.assert_less(len_trainable_vars, all_trainable_vars)
77
78
79
@pytest.mark.parametrize(
80
"n_speakers, n_chars, max_input_length, max_mel_length, batch_size",
81
[(2, 15, 25, 50, 2),],
82
)
83
def test_tacotron2_trainable(
84
n_speakers, n_chars, max_input_length, max_mel_length, batch_size
85
):
86
config = Tacotron2Config(n_speakers=n_speakers, reduction_factor=1)
87
model = TFTacotron2(config, name="tacotron2")
88
model._build()
89
# fake input
90
input_ids = tf.random.uniform(
91
[batch_size, max_input_length], maxval=n_chars, dtype=tf.int32
92
)
93
speaker_ids = tf.convert_to_tensor([0] * batch_size, tf.int32)
94
mel_gts = tf.random.uniform(shape=[batch_size, max_mel_length, 80])
95
mel_lengths = np.random.randint(
96
max_mel_length, high=max_mel_length + 1, size=[batch_size]
97
)
98
mel_lengths[-1] = max_mel_length
99
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
100
101
stop_tokens = np.zeros((batch_size, max_mel_length), np.float32)
102
stop_tokens = tf.convert_to_tensor(stop_tokens)
103
104
optimizer = tf.keras.optimizers.Adam(lr=0.001)
105
106
binary_crossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
107
108
@tf.function(experimental_relax_shapes=True)
109
def one_step_training(input_ids, speaker_ids, mel_gts, mel_lengths):
110
with tf.GradientTape() as tape:
111
mel_preds, post_mel_preds, stop_preds, alignment_history = model(
112
input_ids,
113
tf.constant([max_input_length, max_input_length]),
114
speaker_ids,
115
mel_gts,
116
mel_lengths,
117
training=True,
118
)
119
loss_before = tf.keras.losses.MeanSquaredError()(mel_gts, mel_preds)
120
loss_after = tf.keras.losses.MeanSquaredError()(mel_gts, post_mel_preds)
121
122
stop_gts = tf.expand_dims(
123
tf.range(tf.reduce_max(mel_lengths), dtype=tf.int32), 0
124
) # [1, max_len]
125
stop_gts = tf.tile(stop_gts, [tf.shape(mel_lengths)[0], 1]) # [B, max_len]
126
stop_gts = tf.cast(
127
tf.math.greater_equal(stop_gts, tf.expand_dims(mel_lengths, 1) - 1),
128
tf.float32,
129
)
130
131
# calculate stop_token loss
132
stop_token_loss = binary_crossentropy(stop_gts, stop_preds)
133
134
loss = stop_token_loss + loss_before + loss_after
135
136
gradients = tape.gradient(loss, model.trainable_variables)
137
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
138
return loss, alignment_history
139
140
for i in range(2):
141
if i == 1:
142
start = time.time()
143
loss, alignment_history = one_step_training(
144
input_ids, speaker_ids, mel_gts, mel_lengths
145
)
146
print(f" > loss: {loss}")
147
total_runtime = time.time() - start
148
print(f" > Total run-time: {total_runtime}")
149
print(f" > Avg run-time: {total_runtime/10}")
150
151