Path: blob/master/tensorflow_tts/bin/preprocess.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing, with raw feature extraction and normalization of train/valid split."""1516import argparse17import glob18import logging19import os20import yaml2122import librosa23import numpy as np24import pyworld as pw2526from functools import partial27from multiprocessing import Pool28from sklearn.model_selection import train_test_split29from sklearn.preprocessing import StandardScaler30from tqdm import tqdm3132from tensorflow_tts.processor import LJSpeechProcessor33from tensorflow_tts.processor import BakerProcessor34from tensorflow_tts.processor import KSSProcessor35from tensorflow_tts.processor import LibriTTSProcessor36from tensorflow_tts.processor import ThorstenProcessor37from tensorflow_tts.processor import LJSpeechUltimateProcessor38from tensorflow_tts.processor import SynpaflexProcessor39from tensorflow_tts.processor import JSUTProcessor40from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS41from tensorflow_tts.processor.baker import BAKER_SYMBOLS42from tensorflow_tts.processor.kss import KSS_SYMBOLS43from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS44from tensorflow_tts.processor.thorsten import THORSTEN_SYMBOLS45from tensorflow_tts.processor.ljspeechu import LJSPEECH_U_SYMBOLS46from tensorflow_tts.processor.synpaflex import SYNPAFLEX_SYMBOLS47from tensorflow_tts.processor.jsut import JSUT_SYMBOLS4849from tensorflow_tts.utils import remove_outlier5051os.environ["CUDA_VISIBLE_DEVICES"] = ""525354def parse_and_config():55"""Parse arguments and set configuration parameters."""56parser = argparse.ArgumentParser(57description="Preprocess audio and text features "58"(See detail in tensorflow_tts/bin/preprocess_dataset.py)."59)60parser.add_argument(61"--rootdir",62default=None,63type=str,64required=True,65help="Directory containing the dataset files.",66)67parser.add_argument(68"--outdir",69default=None,70type=str,71required=True,72help="Output directory where features will be saved.",73)74parser.add_argument(75"--dataset",76type=str,77default="ljspeech",78choices=["ljspeech", "kss", "libritts", "baker", "thorsten", "ljspeechu", "synpaflex", "jsut"],79help="Dataset to preprocess.",80)81parser.add_argument(82"--config", type=str, required=True, help="YAML format configuration file."83)84parser.add_argument(85"--n_cpus",86type=int,87default=4,88required=False,89help="Number of CPUs to use in parallel.",90)91parser.add_argument(92"--test_size",93type=float,94default=0.05,95required=False,96help="Proportion of files to use as test dataset.",97)98parser.add_argument(99"--verbose",100type=int,101default=0,102choices=[0, 1, 2],103help="Logging level. 0: DEBUG, 1: INFO and WARNING, 2: INFO, WARNING, and ERROR",104)105args = parser.parse_args()106107# set logger108FORMAT = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"109log_level = {0: logging.DEBUG, 1: logging.WARNING, 2: logging.ERROR}110logging.basicConfig(level=log_level[args.verbose], format=FORMAT)111112# load config113config = yaml.load(open(args.config), Loader=yaml.SafeLoader)114config.update(vars(args))115# config checks116assert config["format"] == "npy", "'npy' is the only supported format."117return config118119120def ph_based_trim(121config,122utt_id: str,123text_ids: np.array,124raw_text: str,125audio: np.array,126hop_size: int,127) -> (bool, np.array, np.array):128"""129Args:130config: Parsed yaml config131utt_id: file name132text_ids: array with text ids133raw_text: raw text of file134audio: parsed wav file135hop_size: Hop size136Returns: (bool, np.array, np.array) => if trimmed return True, new text_ids, new audio_array137"""138139os.makedirs(os.path.join(config["rootdir"], "trimmed-durations"), exist_ok=True)140duration_path = config.get(141"duration_path", os.path.join(config["rootdir"], "durations")142)143duration_fixed_path = config.get(144"duration_fixed_path", os.path.join(config["rootdir"], "trimmed-durations")145)146sil_ph = ["SIL", "END"] # TODO FIX hardcoded values147text = raw_text.split(" ")148149trim_start, trim_end = False, False150151if text[0] in sil_ph:152trim_start = True153154if text[-1] in sil_ph:155trim_end = True156157if not trim_start and not trim_end:158return False, text_ids, audio159160idx_start, idx_end = (1610 if not trim_start else 1,162text_ids.__len__() if not trim_end else -1,163)164text_ids = text_ids[idx_start:idx_end]165durations = np.load(os.path.join(duration_path, f"{utt_id}-durations.npy"))166if trim_start:167s_trim = int(durations[0] * hop_size)168audio = audio[s_trim:]169if trim_end:170e_trim = int(durations[-1] * hop_size)171audio = audio[:-e_trim]172173durations = durations[idx_start:idx_end]174np.save(os.path.join(duration_fixed_path, f"{utt_id}-durations.npy"), durations)175return True, text_ids, audio176177178def gen_audio_features(item, config):179"""Generate audio features and transformations180Args:181item (Dict): dictionary containing the attributes to encode.182config (Dict): configuration dictionary.183Returns:184(bool): keep this sample or not.185mel (ndarray): mel matrix in np.float32.186energy (ndarray): energy audio profile.187f0 (ndarray): fundamental frequency.188item (Dict): dictionary containing the updated attributes.189"""190# get info from sample.191audio = item["audio"]192utt_id = item["utt_id"]193rate = item["rate"]194195# check audio properties196assert len(audio.shape) == 1, f"{utt_id} seems to be multi-channel signal."197assert np.abs(audio).max() <= 1.0, f"{utt_id} is different from 16 bit PCM."198199# check sample rate200if rate != config["sampling_rate"]:201audio = librosa.resample(audio, rate, config["sampling_rate"])202logging.info(f"{utt_id} sampling rate is {rate}, not {config['sampling_rate']}, we resample it.")203204# trim silence205if config["trim_silence"]:206if "trim_mfa" in config and config["trim_mfa"]:207_, item["text_ids"], audio = ph_based_trim(208config,209utt_id,210item["text_ids"],211item["raw_text"],212audio,213config["hop_size"],214)215if (216audio.__len__() < 1217): # very short files can get trimmed fully if mfa didnt extract any tokens LibriTTS maybe take only longer files?218logging.warning(219f"File have only silence or MFA didnt extract any token {utt_id}"220)221return False, None, None, None, item222else:223audio, _ = librosa.effects.trim(224audio,225top_db=config["trim_threshold_in_db"],226frame_length=config["trim_frame_size"],227hop_length=config["trim_hop_size"],228)229230# resample audio if necessary231if "sampling_rate_for_feats" in config:232audio = librosa.resample(audio, rate, config["sampling_rate_for_feats"])233sampling_rate = config["sampling_rate_for_feats"]234assert (235config["hop_size"] * config["sampling_rate_for_feats"] % rate == 0236), "'hop_size' must be 'int' value. Please check if 'sampling_rate_for_feats' is correct."237hop_size = config["hop_size"] * config["sampling_rate_for_feats"] // rate238else:239sampling_rate = config["sampling_rate"]240hop_size = config["hop_size"]241242# get spectrogram243D = librosa.stft(244audio,245n_fft=config["fft_size"],246hop_length=hop_size,247win_length=config["win_length"],248window=config["window"],249pad_mode="reflect",250)251S, _ = librosa.magphase(D) # (#bins, #frames)252253# get mel basis254fmin = 0 if config["fmin"] is None else config["fmin"]255fmax = sampling_rate // 2 if config["fmax"] is None else config["fmax"]256mel_basis = librosa.filters.mel(257sr=sampling_rate,258n_fft=config["fft_size"],259n_mels=config["num_mels"],260fmin=fmin,261fmax=fmax,262)263mel = np.log10(np.maximum(np.dot(mel_basis, S), 1e-10)).T # (#frames, #bins)264265# check audio and feature length266audio = np.pad(audio, (0, config["fft_size"]), mode="edge")267audio = audio[: len(mel) * hop_size]268assert len(mel) * hop_size == len(audio)269270# extract raw pitch271_f0, t = pw.dio(272audio.astype(np.double),273fs=sampling_rate,274f0_ceil=fmax,275frame_period=1000 * hop_size / sampling_rate,276)277f0 = pw.stonemask(audio.astype(np.double), _f0, t, sampling_rate)278if len(f0) >= len(mel):279f0 = f0[: len(mel)]280else:281f0 = np.pad(f0, (0, len(mel) - len(f0)))282283# extract energy284energy = np.sqrt(np.sum(S ** 2, axis=0))285assert len(mel) == len(f0) == len(energy)286287# remove outlier f0/energy288f0 = remove_outlier(f0)289energy = remove_outlier(energy)290291# apply global gain292if config["global_gain_scale"] > 0.0:293audio *= config["global_gain_scale"]294if np.abs(audio).max() >= 1.0:295logging.warn(296f"{utt_id} causes clipping. It is better to reconsider global gain scale value."297)298item["audio"] = audio299item["mel"] = mel300item["f0"] = f0301item["energy"] = energy302return True, mel, energy, f0, item303304305def save_statistics_to_file(scaler_list, config):306"""Save computed statistics to disk.307Args:308scaler_list (List): List of scalers containing statistics to save.309config (Dict): configuration dictionary.310"""311for scaler, name in scaler_list:312stats = np.stack((scaler.mean_, scaler.scale_))313np.save(314os.path.join(config["outdir"], f"stats{name}.npy"),315stats.astype(np.float32),316allow_pickle=False,317)318319320def save_features_to_file(features, subdir, config):321"""Save transformed dataset features in disk.322Args:323features (Dict): dictionary containing the attributes to save.324subdir (str): data split folder where features will be saved.325config (Dict): configuration dictionary.326"""327utt_id = features["utt_id"]328329if config["format"] == "npy":330save_list = [331(features["audio"], "wavs", "wave", np.float32),332(features["mel"], "raw-feats", "raw-feats", np.float32),333(features["text_ids"], "ids", "ids", np.int32),334(features["f0"], "raw-f0", "raw-f0", np.float32),335(features["energy"], "raw-energies", "raw-energy", np.float32),336]337for item, name_dir, name_file, fmt in save_list:338np.save(339os.path.join(340config["outdir"], subdir, name_dir, f"{utt_id}-{name_file}.npy"341),342item.astype(fmt),343allow_pickle=False,344)345else:346raise ValueError("'npy' is the only supported format.")347348349def preprocess():350"""Run preprocessing process and compute statistics for normalizing."""351config = parse_and_config()352353dataset_processor = {354"ljspeech": LJSpeechProcessor,355"kss": KSSProcessor,356"libritts": LibriTTSProcessor,357"baker": BakerProcessor,358"thorsten": ThorstenProcessor,359"ljspeechu": LJSpeechUltimateProcessor,360"synpaflex": SynpaflexProcessor,361"jsut": JSUTProcessor,362}363364dataset_symbol = {365"ljspeech": LJSPEECH_SYMBOLS,366"kss": KSS_SYMBOLS,367"libritts": LIBRITTS_SYMBOLS,368"baker": BAKER_SYMBOLS,369"thorsten": THORSTEN_SYMBOLS,370"ljspeechu": LJSPEECH_U_SYMBOLS,371"synpaflex": SYNPAFLEX_SYMBOLS,372"jsut": JSUT_SYMBOLS,373}374375dataset_cleaner = {376"ljspeech": "english_cleaners",377"kss": "korean_cleaners",378"libritts": None,379"baker": None,380"thorsten": "german_cleaners",381"ljspeechu": "english_cleaners",382"synpaflex": "basic_cleaners",383"jsut": None,384}385386logging.info(f"Selected '{config['dataset']}' processor.")387processor = dataset_processor[config["dataset"]](388config["rootdir"],389symbols=dataset_symbol[config["dataset"]],390cleaner_names=dataset_cleaner[config["dataset"]],391)392393# check output directories394build_dir = lambda x: [395os.makedirs(os.path.join(config["outdir"], x, y), exist_ok=True)396for y in ["raw-feats", "wavs", "ids", "raw-f0", "raw-energies"]397]398build_dir("train")399build_dir("valid")400401# save pretrained-processor to feature dir402processor._save_mapper(403os.path.join(config["outdir"], f"{config['dataset']}_mapper.json"),404extra_attrs_to_save={"pinyin_dict": processor.pinyin_dict}405if config["dataset"] == "baker"406else {},407)408409# build train test split410if config["dataset"] == "libritts":411train_split, valid_split, _, _ = train_test_split(412processor.items,413[i[-1] for i in processor.items],414test_size=config["test_size"],415random_state=42,416shuffle=True,417)418else:419train_split, valid_split = train_test_split(420processor.items,421test_size=config["test_size"],422random_state=42,423shuffle=True,424)425logging.info(f"Training items: {len(train_split)}")426logging.info(f"Validation items: {len(valid_split)}")427428get_utt_id = lambda x: os.path.split(x[1])[-1].split(".")[0]429train_utt_ids = [get_utt_id(x) for x in train_split]430valid_utt_ids = [get_utt_id(x) for x in valid_split]431432# save train and valid utt_ids to track later433np.save(os.path.join(config["outdir"], "train_utt_ids.npy"), train_utt_ids)434np.save(os.path.join(config["outdir"], "valid_utt_ids.npy"), valid_utt_ids)435436# define map iterator437def iterator_data(items_list):438for item in items_list:439yield processor.get_one_sample(item)440441train_iterator_data = iterator_data(train_split)442valid_iterator_data = iterator_data(valid_split)443444p = Pool(config["n_cpus"])445446# preprocess train files and get statistics for normalizing447partial_fn = partial(gen_audio_features, config=config)448train_map = p.imap_unordered(449partial_fn,450tqdm(train_iterator_data, total=len(train_split), desc="[Preprocessing train]"),451chunksize=10,452)453# init scaler for multiple features454scaler_mel = StandardScaler(copy=False)455scaler_energy = StandardScaler(copy=False)456scaler_f0 = StandardScaler(copy=False)457458id_to_remove = []459for result, mel, energy, f0, features in train_map:460if not result:461id_to_remove.append(features["utt_id"])462continue463save_features_to_file(features, "train", config)464# partial fitting of scalers465if len(energy[energy != 0]) == 0 or len(f0[f0 != 0]) == 0:466id_to_remove.append(features["utt_id"])467continue468# partial fitting of scalers469if len(energy[energy != 0]) == 0 or len(f0[f0 != 0]) == 0:470id_to_remove.append(features["utt_id"])471continue472scaler_mel.partial_fit(mel)473scaler_energy.partial_fit(energy[energy != 0].reshape(-1, 1))474scaler_f0.partial_fit(f0[f0 != 0].reshape(-1, 1))475476if len(id_to_remove) > 0:477np.save(478os.path.join(config["outdir"], "train_utt_ids.npy"),479[i for i in train_utt_ids if i not in id_to_remove],480)481logging.info(482f"removed {len(id_to_remove)} cause of too many outliers or bad mfa extraction"483)484485# save statistics to file486logging.info("Saving computed statistics.")487scaler_list = [(scaler_mel, ""), (scaler_energy, "_energy"), (scaler_f0, "_f0")]488save_statistics_to_file(scaler_list, config)489490# preprocess valid files491partial_fn = partial(gen_audio_features, config=config)492valid_map = p.imap_unordered(493partial_fn,494tqdm(valid_iterator_data, total=len(valid_split), desc="[Preprocessing valid]"),495chunksize=10,496)497for *_, features in valid_map:498save_features_to_file(features, "valid", config)499500501def gen_normal_mel(mel_path, scaler, config):502"""Normalize the mel spectrogram and save it to the corresponding path.503Args:504mel_path (string): path of the mel spectrogram to normalize.505scaler (sklearn.base.BaseEstimator): scaling function to use for normalize.506config (Dict): configuration dictionary.507"""508mel = np.load(mel_path)509mel_norm = scaler.transform(mel)510path, file_name = os.path.split(mel_path)511*_, subdir, suffix = path.split(os.sep)512513utt_id = file_name.split(f"-{suffix}.npy")[0]514np.save(515os.path.join(516config["outdir"], subdir, "norm-feats", f"{utt_id}-norm-feats.npy"517),518mel_norm.astype(np.float32),519allow_pickle=False,520)521522523def normalize():524"""Normalize mel spectrogram with pre-computed statistics."""525config = parse_and_config()526if config["format"] == "npy":527# init scaler with saved values528scaler = StandardScaler()529scaler.mean_, scaler.scale_ = np.load(530os.path.join(config["outdir"], "stats.npy")531)532scaler.n_features_in_ = config["num_mels"]533else:534raise ValueError("'npy' is the only supported format.")535536# find all "raw-feats" files in both train and valid folders537glob_path = os.path.join(config["rootdir"], "**", "raw-feats", "*.npy")538mel_raw_feats = glob.glob(glob_path, recursive=True)539logging.info(f"Files to normalize: {len(mel_raw_feats)}")540541# check for output directories542os.makedirs(os.path.join(config["outdir"], "train", "norm-feats"), exist_ok=True)543os.makedirs(os.path.join(config["outdir"], "valid", "norm-feats"), exist_ok=True)544545p = Pool(config["n_cpus"])546partial_fn = partial(gen_normal_mel, scaler=scaler, config=config)547list(p.map(partial_fn, tqdm(mel_raw_feats, desc="[Normalizing]")))548549550def compute_statistics():551"""Compute mean / std statistics of some features for later normalization."""552config = parse_and_config()553554# find features files for the train split555glob_fn = lambda x: glob.glob(os.path.join(config["rootdir"], "train", x, "*.npy"))556glob_mel = glob_fn("raw-feats")557glob_f0 = glob_fn("raw-f0")558glob_energy = glob_fn("raw-energies")559assert (560len(glob_mel) == len(glob_f0) == len(glob_energy)561), "Features, f0 and energies have different files in training split."562563logging.info(f"Computing statistics for {len(glob_mel)} files.")564# init scaler for multiple features565scaler_mel = StandardScaler(copy=False)566scaler_energy = StandardScaler(copy=False)567scaler_f0 = StandardScaler(copy=False)568569for mel, f0, energy in tqdm(570zip(glob_mel, glob_f0, glob_energy), total=len(glob_mel)571):572# remove outliers573energy = np.load(energy)574f0 = np.load(f0)575# partial fitting of scalers576scaler_mel.partial_fit(np.load(mel))577scaler_energy.partial_fit(energy[energy != 0].reshape(-1, 1))578scaler_f0.partial_fit(f0[f0 != 0].reshape(-1, 1))579580# save statistics to file581logging.info("Saving computed statistics.")582scaler_list = [(scaler_mel, ""), (scaler_energy, "_energy"), (scaler_f0, "_f0")]583save_statistics_to_file(scaler_list, config)584585586if __name__ == "__main__":587preprocess()588589590