Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/inference/auto_model.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tensorflow Auto Model modules."""
16
17
import logging
18
import warnings
19
import os
20
import copy
21
22
from collections import OrderedDict
23
24
from tensorflow_tts.configs import (
25
FastSpeechConfig,
26
FastSpeech2Config,
27
MelGANGeneratorConfig,
28
MultiBandMelGANGeneratorConfig,
29
HifiGANGeneratorConfig,
30
Tacotron2Config,
31
ParallelWaveGANGeneratorConfig,
32
)
33
34
from tensorflow_tts.models import (
35
TFMelGANGenerator,
36
TFMBMelGANGenerator,
37
TFHifiGANGenerator,
38
TFParallelWaveGANGenerator,
39
)
40
41
from tensorflow_tts.inference.savable_models import (
42
SavableTFFastSpeech,
43
SavableTFFastSpeech2,
44
SavableTFTacotron2
45
)
46
from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME
47
from tensorflow_tts import __version__ as VERSION
48
from huggingface_hub import hf_hub_url, cached_download
49
50
51
TF_MODEL_MAPPING = OrderedDict(
52
[
53
(FastSpeech2Config, SavableTFFastSpeech2),
54
(FastSpeechConfig, SavableTFFastSpeech),
55
(MultiBandMelGANGeneratorConfig, TFMBMelGANGenerator),
56
(MelGANGeneratorConfig, TFMelGANGenerator),
57
(Tacotron2Config, SavableTFTacotron2),
58
(HifiGANGeneratorConfig, TFHifiGANGenerator),
59
(ParallelWaveGANGeneratorConfig, TFParallelWaveGANGenerator),
60
]
61
)
62
63
64
class TFAutoModel(object):
65
"""General model class for inferencing."""
66
67
def __init__(self):
68
raise EnvironmentError("Cannot be instantiated using `__init__()`")
69
70
@classmethod
71
def from_pretrained(cls, pretrained_path=None, config=None, **kwargs):
72
# load weights from hf hub
73
if pretrained_path is not None:
74
if not os.path.isfile(pretrained_path):
75
# retrieve correct hub url
76
download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)
77
78
downloaded_file = str(
79
cached_download(
80
url=download_url,
81
library_name=LIBRARY_NAME,
82
library_version=VERSION,
83
cache_dir=CACHE_DIRECTORY,
84
)
85
)
86
87
# load config from repo as well
88
if config is None:
89
from tensorflow_tts.inference import AutoConfig
90
91
config = AutoConfig.from_pretrained(pretrained_path)
92
93
pretrained_path = downloaded_file
94
95
96
assert config is not None, "Please make sure to pass a config along to load a model from a local file"
97
98
for config_class, model_class in TF_MODEL_MAPPING.items():
99
if isinstance(config, config_class) and str(config_class.__name__) in str(
100
config
101
):
102
model = model_class(config=config, **kwargs)
103
model.set_config(config)
104
model._build()
105
if pretrained_path is not None and ".h5" in pretrained_path:
106
try:
107
model.load_weights(pretrained_path)
108
except:
109
model.load_weights(
110
pretrained_path, by_name=True, skip_mismatch=True
111
)
112
return model
113
114
raise ValueError(
115
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
116
"Model type should be one of {}.".format(
117
config.__class__,
118
cls.__name__,
119
", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys()),
120
)
121
)
122
123