Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech/decode_fastspeech.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained FastSpeech from folders."""
16
17
import argparse
18
import logging
19
import os
20
import sys
21
22
sys.path.append(".")
23
24
import numpy as np
25
import tensorflow as tf
26
import yaml
27
from tqdm import tqdm
28
29
from examples.fastspeech.fastspeech_dataset import CharactorDataset
30
from tensorflow_tts.configs import FastSpeechConfig
31
from tensorflow_tts.models import TFFastSpeech
32
33
34
def main():
35
"""Run fastspeech decoding from folder."""
36
parser = argparse.ArgumentParser(
37
description="Decode soft-mel features from charactor with trained FastSpeech "
38
"(See detail in examples/fastspeech/decode_fastspeech.py)."
39
)
40
parser.add_argument(
41
"--rootdir",
42
default=None,
43
type=str,
44
required=True,
45
help="directory including ids/durations files.",
46
)
47
parser.add_argument(
48
"--outdir", type=str, required=True, help="directory to save generated speech."
49
)
50
parser.add_argument(
51
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
52
)
53
parser.add_argument(
54
"--config",
55
default=None,
56
type=str,
57
required=True,
58
help="yaml format configuration file. if not explicitly provided, "
59
"it will be searched in the checkpoint directory. (default=None)",
60
)
61
parser.add_argument(
62
"--batch-size",
63
default=8,
64
type=int,
65
required=False,
66
help="Batch size for inference.",
67
)
68
parser.add_argument(
69
"--verbose",
70
type=int,
71
default=1,
72
help="logging level. higher is more logging. (default=1)",
73
)
74
args = parser.parse_args()
75
76
# set logger
77
if args.verbose > 1:
78
logging.basicConfig(
79
level=logging.DEBUG,
80
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
81
)
82
elif args.verbose > 0:
83
logging.basicConfig(
84
level=logging.INFO,
85
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
86
)
87
else:
88
logging.basicConfig(
89
level=logging.WARN,
90
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
91
)
92
logging.warning("Skip DEBUG/INFO messages")
93
94
# check directory existence
95
if not os.path.exists(args.outdir):
96
os.makedirs(args.outdir)
97
98
# load config
99
with open(args.config) as f:
100
config = yaml.load(f, Loader=yaml.Loader)
101
config.update(vars(args))
102
103
if config["format"] == "npy":
104
char_query = "*-ids.npy"
105
char_load_fn = np.load
106
else:
107
raise ValueError("Only npy is supported.")
108
109
# define data-loader
110
dataset = CharactorDataset(
111
root_dir=args.rootdir,
112
charactor_query=char_query,
113
charactor_load_fn=char_load_fn,
114
)
115
dataset = dataset.create(batch_size=args.batch_size)
116
117
# define model and load checkpoint
118
fastspeech = TFFastSpeech(
119
config=FastSpeechConfig(**config["fastspeech_params"]), name="fastspeech"
120
)
121
fastspeech._build()
122
fastspeech.load_weights(args.checkpoint)
123
124
for data in tqdm(dataset, desc="Decoding"):
125
utt_ids = data["utt_ids"]
126
char_ids = data["input_ids"]
127
128
# fastspeech inference.
129
masked_mel_before, masked_mel_after, duration_outputs = fastspeech.inference(
130
char_ids,
131
speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]], dtype=tf.int32),
132
speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),
133
)
134
135
# convert to numpy
136
masked_mel_befores = masked_mel_before.numpy()
137
masked_mel_afters = masked_mel_after.numpy()
138
139
for (utt_id, mel_before, mel_after, durations) in zip(
140
utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs
141
):
142
# real len of mel predicted
143
real_length = durations.numpy().sum()
144
utt_id = utt_id.numpy().decode("utf-8")
145
# save to folder.
146
np.save(
147
os.path.join(args.outdir, f"{utt_id}-fs-before-feats.npy"),
148
mel_before[:real_length, :].astype(np.float32),
149
allow_pickle=False,
150
)
151
np.save(
152
os.path.join(args.outdir, f"{utt_id}-fs-after-feats.npy"),
153
mel_after[:real_length, :].astype(np.float32),
154
allow_pickle=False,
155
)
156
157
158
if __name__ == "__main__":
159
main()
160
161