Path: blob/master/tensorflow_tts/inference/auto_model.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tensorflow Auto Model modules."""1516import logging17import warnings18import os19import copy2021from collections import OrderedDict2223from tensorflow_tts.configs import (24FastSpeechConfig,25FastSpeech2Config,26MelGANGeneratorConfig,27MultiBandMelGANGeneratorConfig,28HifiGANGeneratorConfig,29Tacotron2Config,30ParallelWaveGANGeneratorConfig,31)3233from tensorflow_tts.models import (34TFMelGANGenerator,35TFMBMelGANGenerator,36TFHifiGANGenerator,37TFParallelWaveGANGenerator,38)3940from tensorflow_tts.inference.savable_models import (41SavableTFFastSpeech,42SavableTFFastSpeech2,43SavableTFTacotron244)45from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME46from tensorflow_tts import __version__ as VERSION47from huggingface_hub import hf_hub_url, cached_download484950TF_MODEL_MAPPING = OrderedDict(51[52(FastSpeech2Config, SavableTFFastSpeech2),53(FastSpeechConfig, SavableTFFastSpeech),54(MultiBandMelGANGeneratorConfig, TFMBMelGANGenerator),55(MelGANGeneratorConfig, TFMelGANGenerator),56(Tacotron2Config, SavableTFTacotron2),57(HifiGANGeneratorConfig, TFHifiGANGenerator),58(ParallelWaveGANGeneratorConfig, TFParallelWaveGANGenerator),59]60)616263class TFAutoModel(object):64"""General model class for inferencing."""6566def __init__(self):67raise EnvironmentError("Cannot be instantiated using `__init__()`")6869@classmethod70def from_pretrained(cls, pretrained_path=None, config=None, **kwargs):71# load weights from hf hub72if pretrained_path is not None:73if not os.path.isfile(pretrained_path):74# retrieve correct hub url75download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)7677downloaded_file = str(78cached_download(79url=download_url,80library_name=LIBRARY_NAME,81library_version=VERSION,82cache_dir=CACHE_DIRECTORY,83)84)8586# load config from repo as well87if config is None:88from tensorflow_tts.inference import AutoConfig8990config = AutoConfig.from_pretrained(pretrained_path)9192pretrained_path = downloaded_file939495assert config is not None, "Please make sure to pass a config along to load a model from a local file"9697for config_class, model_class in TF_MODEL_MAPPING.items():98if isinstance(config, config_class) and str(config_class.__name__) in str(99config100):101model = model_class(config=config, **kwargs)102model.set_config(config)103model._build()104if pretrained_path is not None and ".h5" in pretrained_path:105try:106model.load_weights(pretrained_path)107except:108model.load_weights(109pretrained_path, by_name=True, skip_mismatch=True110)111return model112113raise ValueError(114"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"115"Model type should be one of {}.".format(116config.__class__,117cls.__name__,118", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys()),119)120)121122123