Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech2/extractfs_postnets.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained FastSpeech from folders."""
16
17
import argparse
18
import logging
19
import os
20
import sys
21
22
sys.path.append(".")
23
24
import numpy as np
25
import tensorflow as tf
26
import yaml
27
from tqdm import tqdm
28
29
from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset
30
from tensorflow_tts.configs import FastSpeech2Config
31
from tensorflow_tts.models import TFFastSpeech2
32
33
34
def main():
35
"""Run fastspeech2 decoding from folder."""
36
parser = argparse.ArgumentParser(
37
description="Decode soft-mel features from charactor with trained FastSpeech "
38
"(See detail in examples/fastspeech2/decode_fastspeech2.py)."
39
)
40
parser.add_argument(
41
"--rootdir",
42
default=None,
43
type=str,
44
required=True,
45
help="directory including ids/durations files.",
46
)
47
parser.add_argument(
48
"--outdir", type=str, required=True, help="directory to save generated speech."
49
)
50
parser.add_argument(
51
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
52
)
53
parser.add_argument(
54
"--config",
55
default=None,
56
type=str,
57
required=True,
58
help="yaml format configuration file. if not explicitly provided, "
59
"it will be searched in the checkpoint directory. (default=None)",
60
)
61
parser.add_argument(
62
"--batch-size",
63
default=8,
64
type=int,
65
required=False,
66
help="Batch size for inference.",
67
)
68
parser.add_argument(
69
"--verbose",
70
type=int,
71
default=1,
72
help="logging level. higher is more logging. (default=1)",
73
)
74
args = parser.parse_args()
75
76
# set logger
77
if args.verbose > 1:
78
logging.basicConfig(
79
level=logging.DEBUG,
80
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
81
)
82
elif args.verbose > 0:
83
logging.basicConfig(
84
level=logging.INFO,
85
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
86
)
87
else:
88
logging.basicConfig(
89
level=logging.WARN,
90
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
91
)
92
logging.warning("Skip DEBUG/INFO messages")
93
94
# check directory existence
95
if not os.path.exists(args.outdir):
96
os.makedirs(args.outdir)
97
98
# load config
99
100
outdpost = os.path.join(args.outdir, "postnets")
101
102
if not os.path.exists(outdpost):
103
os.makedirs(outdpost)
104
105
with open(args.config) as f:
106
config = yaml.load(f, Loader=yaml.Loader)
107
config.update(vars(args))
108
109
if config["format"] == "npy":
110
char_query = "*-ids.npy"
111
char_load_fn = np.load
112
else:
113
raise ValueError("Only npy is supported.")
114
115
# define data-loader
116
dataset = CharactorDurationF0EnergyMelDataset(
117
root_dir=args.rootdir,
118
charactor_query=char_query,
119
charactor_load_fn=char_load_fn,
120
)
121
dataset = dataset.create(
122
batch_size=1
123
) # force batch size to 1 otherwise it may miss certain files
124
125
# define model and load checkpoint
126
fastspeech2 = TFFastSpeech2(
127
config=FastSpeech2Config(**config["fastspeech2_params"]), name="fastspeech2"
128
)
129
fastspeech2._build()
130
fastspeech2.load_weights(args.checkpoint)
131
fastspeech2 = tf.function(fastspeech2, experimental_relax_shapes=True)
132
133
for data in tqdm(dataset, desc="Decoding"):
134
utt_ids = data["utt_ids"]
135
char_ids = data["input_ids"]
136
mel_lens = data["mel_lengths"]
137
138
# fastspeech inference.
139
masked_mel_before, masked_mel_after, duration_outputs, _, _ = fastspeech2(
140
**data, training=True
141
)
142
143
# convert to numpy
144
masked_mel_befores = masked_mel_before.numpy()
145
masked_mel_afters = masked_mel_after.numpy()
146
147
for (utt_id, mel_before, mel_after, durations, mel_len) in zip(
148
utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs, mel_lens
149
):
150
# real len of mel predicted
151
real_length = np.around(durations.numpy().sum()).astype(int)
152
utt_id = utt_id.numpy().decode("utf-8")
153
154
np.save(
155
os.path.join(outdpost, f"{utt_id}-postnet.npy"),
156
mel_after[:mel_len, :].astype(np.float32),
157
allow_pickle=False,
158
)
159
160
161
if __name__ == "__main__":
162
main()
163
164