Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/utils/utils.py
1558 views
1
# -*- coding: utf-8 -*-
2
3
# Copyright 2019 Tomoki Hayashi
4
# MIT License (https://opensource.org/licenses/MIT)
5
"""Utility functions."""
6
7
import fnmatch
8
import os
9
import re
10
import tempfile
11
from pathlib import Path
12
13
import tensorflow as tf
14
15
MODEL_FILE_NAME = "model.h5"
16
CONFIG_FILE_NAME = "config.yml"
17
PROCESSOR_FILE_NAME = "processor.json"
18
LIBRARY_NAME = "tensorflow_tts"
19
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)
20
21
22
def find_files(root_dir, query="*.wav", include_root_dir=True):
23
"""Find files recursively.
24
Args:
25
root_dir (str): Root root_dir to find.
26
query (str): Query to find.
27
include_root_dir (bool): If False, root_dir name is not included.
28
Returns:
29
list: List of found filenames.
30
"""
31
files = []
32
for root, _, filenames in os.walk(root_dir, followlinks=True):
33
for filename in fnmatch.filter(filenames, query):
34
files.append(os.path.join(root, filename))
35
if not include_root_dir:
36
files = [file_.replace(root_dir + "/", "") for file_ in files]
37
38
return files
39
40
41
def _path_requires_gfile(filepath):
42
"""Checks if the given path requires use of GFile API.
43
44
Args:
45
filepath (str): Path to check.
46
Returns:
47
bool: True if the given path needs GFile API to access, such as
48
"s3://some/path" and "gs://some/path".
49
"""
50
# If the filepath contains a protocol (e.g. "gs://"), it should be handled
51
# using TensorFlow GFile API.
52
return bool(re.match(r"^[a-z]+://", filepath))
53
54
55
def save_weights(model, filepath):
56
"""Save model weights.
57
58
Same as model.save_weights(filepath), but supports saving to S3 or GCS
59
buckets using TensorFlow GFile API.
60
61
Args:
62
model (tf.keras.Model): Model to save.
63
filepath (str): Path to save the model weights to.
64
"""
65
if not _path_requires_gfile(filepath):
66
model.save_weights(filepath)
67
return
68
69
# Save to a local temp file and copy to the desired path using GFile API.
70
_, ext = os.path.splitext(filepath)
71
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
72
model.save_weights(temp_file.name)
73
# To preserve the original semantics, we need to overwrite the target
74
# file.
75
tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)
76
77
78
def load_weights(model, filepath):
79
"""Load model weights.
80
81
Same as model.load_weights(filepath), but supports loading from S3 or GCS
82
buckets using TensorFlow GFile API.
83
84
Args:
85
model (tf.keras.Model): Model to load weights to.
86
filepath (str): Path to the weights file.
87
"""
88
if not _path_requires_gfile(filepath):
89
model.load_weights(filepath)
90
return
91
92
# Make a local copy and load it.
93
_, ext = os.path.splitext(filepath)
94
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
95
# The target temp_file should be created above, so we need to overwrite.
96
tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)
97
model.load_weights(temp_file.name)
98
99