Path: blob/master/tensorflow_tts/utils/utils.py
1558 views
# -*- coding: utf-8 -*-12# Copyright 2019 Tomoki Hayashi3# MIT License (https://opensource.org/licenses/MIT)4"""Utility functions."""56import fnmatch7import os8import re9import tempfile10from pathlib import Path1112import tensorflow as tf1314MODEL_FILE_NAME = "model.h5"15CONFIG_FILE_NAME = "config.yml"16PROCESSOR_FILE_NAME = "processor.json"17LIBRARY_NAME = "tensorflow_tts"18CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)192021def find_files(root_dir, query="*.wav", include_root_dir=True):22"""Find files recursively.23Args:24root_dir (str): Root root_dir to find.25query (str): Query to find.26include_root_dir (bool): If False, root_dir name is not included.27Returns:28list: List of found filenames.29"""30files = []31for root, _, filenames in os.walk(root_dir, followlinks=True):32for filename in fnmatch.filter(filenames, query):33files.append(os.path.join(root, filename))34if not include_root_dir:35files = [file_.replace(root_dir + "/", "") for file_ in files]3637return files383940def _path_requires_gfile(filepath):41"""Checks if the given path requires use of GFile API.4243Args:44filepath (str): Path to check.45Returns:46bool: True if the given path needs GFile API to access, such as47"s3://some/path" and "gs://some/path".48"""49# If the filepath contains a protocol (e.g. "gs://"), it should be handled50# using TensorFlow GFile API.51return bool(re.match(r"^[a-z]+://", filepath))525354def save_weights(model, filepath):55"""Save model weights.5657Same as model.save_weights(filepath), but supports saving to S3 or GCS58buckets using TensorFlow GFile API.5960Args:61model (tf.keras.Model): Model to save.62filepath (str): Path to save the model weights to.63"""64if not _path_requires_gfile(filepath):65model.save_weights(filepath)66return6768# Save to a local temp file and copy to the desired path using GFile API.69_, ext = os.path.splitext(filepath)70with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:71model.save_weights(temp_file.name)72# To preserve the original semantics, we need to overwrite the target73# file.74tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)757677def load_weights(model, filepath):78"""Load model weights.7980Same as model.load_weights(filepath), but supports loading from S3 or GCS81buckets using TensorFlow GFile API.8283Args:84model (tf.keras.Model): Model to load weights to.85filepath (str): Path to the weights file.86"""87if not _path_requires_gfile(filepath):88model.load_weights(filepath)89return9091# Make a local copy and load it.92_, ext = os.path.splitext(filepath)93with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:94# The target temp_file should be created above, so we need to overwrite.95tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)96model.load_weights(temp_file.name)979899