Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech2/decode_fastspeech2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained FastSpeech from folders."""
16
17
import argparse
18
import logging
19
import os
20
import sys
21
22
sys.path.append(".")
23
24
import numpy as np
25
import tensorflow as tf
26
import yaml
27
from tqdm import tqdm
28
29
from examples.fastspeech.fastspeech_dataset import CharactorDataset
30
from tensorflow_tts.configs import FastSpeech2Config
31
from tensorflow_tts.models import TFFastSpeech2
32
33
34
def main():
35
"""Run fastspeech2 decoding from folder."""
36
parser = argparse.ArgumentParser(
37
description="Decode soft-mel features from charactor with trained FastSpeech "
38
"(See detail in examples/fastspeech2/decode_fastspeech2.py)."
39
)
40
parser.add_argument(
41
"--rootdir",
42
default=None,
43
type=str,
44
required=True,
45
help="directory including ids/durations files.",
46
)
47
parser.add_argument(
48
"--outdir", type=str, required=True, help="directory to save generated speech."
49
)
50
parser.add_argument(
51
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
52
)
53
parser.add_argument(
54
"--config",
55
default=None,
56
type=str,
57
required=True,
58
help="yaml format configuration file. if not explicitly provided, "
59
"it will be searched in the checkpoint directory. (default=None)",
60
)
61
parser.add_argument(
62
"--batch-size",
63
default=8,
64
type=int,
65
required=False,
66
help="Batch size for inference.",
67
)
68
parser.add_argument(
69
"--verbose",
70
type=int,
71
default=1,
72
help="logging level. higher is more logging. (default=1)",
73
)
74
args = parser.parse_args()
75
76
# set logger
77
if args.verbose > 1:
78
logging.basicConfig(
79
level=logging.DEBUG,
80
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
81
)
82
elif args.verbose > 0:
83
logging.basicConfig(
84
level=logging.INFO,
85
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
86
)
87
else:
88
logging.basicConfig(
89
level=logging.WARN,
90
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
91
)
92
logging.warning("Skip DEBUG/INFO messages")
93
94
# check directory existence
95
if not os.path.exists(args.outdir):
96
os.makedirs(args.outdir)
97
98
# load config
99
with open(args.config) as f:
100
config = yaml.load(f, Loader=yaml.Loader)
101
config.update(vars(args))
102
103
if config["format"] == "npy":
104
char_query = "*-ids.npy"
105
char_load_fn = np.load
106
else:
107
raise ValueError("Only npy is supported.")
108
109
# define data-loader
110
dataset = CharactorDataset(
111
root_dir=args.rootdir,
112
charactor_query=char_query,
113
charactor_load_fn=char_load_fn,
114
)
115
dataset = dataset.create(batch_size=args.batch_size)
116
117
# define model and load checkpoint
118
fastspeech2 = TFFastSpeech2(
119
config=FastSpeech2Config(**config["fastspeech2_params"]), name="fastspeech2"
120
)
121
fastspeech2._build()
122
fastspeech2.load_weights(args.checkpoint)
123
124
for data in tqdm(dataset, desc="Decoding"):
125
utt_ids = data["utt_ids"]
126
char_ids = data["input_ids"]
127
128
# fastspeech inference.
129
(
130
masked_mel_before,
131
masked_mel_after,
132
duration_outputs,
133
_,
134
_,
135
) = fastspeech2.inference(
136
char_ids,
137
speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]], dtype=tf.int32),
138
speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),
139
f0_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),
140
energy_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),
141
)
142
143
# convert to numpy
144
masked_mel_befores = masked_mel_before.numpy()
145
masked_mel_afters = masked_mel_after.numpy()
146
147
for (utt_id, mel_before, mel_after, durations) in zip(
148
utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs
149
):
150
# real len of mel predicted
151
real_length = durations.numpy().sum()
152
utt_id = utt_id.numpy().decode("utf-8")
153
# save to folder.
154
np.save(
155
os.path.join(args.outdir, f"{utt_id}-fs-before-feats.npy"),
156
mel_before[:real_length, :].astype(np.float32),
157
allow_pickle=False,
158
)
159
np.save(
160
os.path.join(args.outdir, f"{utt_id}-fs-after-feats.npy"),
161
mel_after[:real_length, :].astype(np.float32),
162
allow_pickle=False,
163
)
164
165
166
if __name__ == "__main__":
167
main()
168
169