Path: blob/master/examples/mfa_extraction/fix_mismatch.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Fix mismatch between sum durations and mel lengths."""1516import numpy as np17import os18from tqdm import tqdm19import click20import logging21import sys222324logging.basicConfig(25level=logging.DEBUG,26stream=sys.stdout,27format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",28)293031@click.command()32@click.option("--base_path", default="dump")33@click.option("--trimmed_dur_path", default="dataset/trimmed-durations")34@click.option("--dur_path", default="dataset/durations")35@click.option("--use_norm", default="f")36def fix(base_path: str, dur_path: str, trimmed_dur_path: str, use_norm: str):37for t in ["train", "valid"]:38mfa_longer = []39mfa_shorter = []40big_diff = []41not_fixed = []42pre_path = os.path.join(base_path, t)43os.makedirs(os.path.join(pre_path, "fix_dur"), exist_ok=True)4445logging.info(f"FIXING {t} set ...\n")46for i in tqdm(os.listdir(os.path.join(pre_path, "ids"))):47if use_norm == "t":48mel = np.load(49os.path.join(50pre_path, "norm-feats", f"{i.split('-')[0]}-norm-feats.npy"51)52)53else:54mel = np.load(55os.path.join(56pre_path, "raw-feats", f"{i.split('-')[0]}-raw-feats.npy"57)58)5960try:61dur = np.load(62os.path.join(trimmed_dur_path, f"{i.split('-')[0]}-durations.npy")63)64except:65dur = np.load(66os.path.join(dur_path, f"{i.split('-')[0]}-durations.npy")67)6869l_mel = len(mel)70dur_s = np.sum(dur)71cloned = np.array(dur, copy=True)72diff = abs(l_mel - dur_s)7374if abs(l_mel - dur_s) > 30: # more then 300 ms75big_diff.append([i, abs(l_mel - dur_s)])7677if dur_s > l_mel:78for j in range(1, len(dur) - 1):79if diff == 0:80break81dur_val = cloned[-j]8283if dur_val >= diff:84cloned[-j] -= diff85diff -= dur_val86break87else:88cloned[-j] = 089diff -= dur_val9091if j == len(dur) - 2:92not_fixed.append(i)9394mfa_longer.append(abs(l_mel - dur_s))95elif dur_s < l_mel:96cloned[-1] += diff97mfa_shorter.append(abs(l_mel - dur_s))9899np.save(100os.path.join(pre_path, "fix_dur", f"{i.split('-')[0]}-durations.npy"),101cloned.astype(np.int32),102allow_pickle=False,103)104105logging.info(106f"{t} stats: number of mfa with longer duration: {len(mfa_longer)}, total diff: {sum(mfa_longer)}"107f", mean diff: {sum(mfa_longer)/len(mfa_longer) if len(mfa_longer) > 0 else 0}"108)109logging.info(110f"{t} stats: number of mfa with shorter duration: {len(mfa_shorter)}, total diff: {sum(mfa_shorter)}"111f", mean diff: {sum(mfa_shorter)/len(mfa_shorter) if len(mfa_shorter) > 0 else 0}"112)113logging.info(114f"{t} stats: number of files with a ''big'' duration diff: {len(big_diff)} if number>1 you should check it"115)116logging.info(f"{t} stats: not fixed len: {len(not_fixed)}\n")117118119if __name__ == "__main__":120fix()121122123