Path: blob/master/tensorflow_tts/inference/auto_config.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tensorflow Auto Config modules."""1516import logging17import yaml18import os19from collections import OrderedDict2021from tensorflow_tts.configs import (22FastSpeechConfig,23FastSpeech2Config,24MelGANGeneratorConfig,25MultiBandMelGANGeneratorConfig,26HifiGANGeneratorConfig,27Tacotron2Config,28ParallelWaveGANGeneratorConfig,29)3031from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME32from tensorflow_tts import __version__ as VERSION33from huggingface_hub import hf_hub_url, cached_download3435CONFIG_MAPPING = OrderedDict(36[37("fastspeech", FastSpeechConfig),38("fastspeech2", FastSpeech2Config),39("multiband_melgan_generator", MultiBandMelGANGeneratorConfig),40("melgan_generator", MelGANGeneratorConfig),41("hifigan_generator", HifiGANGeneratorConfig),42("tacotron2", Tacotron2Config),43("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig),44]45)464748class AutoConfig:49def __init__(self):50raise EnvironmentError(51"AutoConfig is designed to be instantiated "52"using the `AutoConfig.from_pretrained(pretrained_path)` method."53)5455@classmethod56def from_pretrained(cls, pretrained_path, **kwargs):57# load weights from hf hub58if not os.path.isfile(pretrained_path):59# retrieve correct hub url60download_url = hf_hub_url(61repo_id=pretrained_path, filename=CONFIG_FILE_NAME62)6364pretrained_path = str(65cached_download(66url=download_url,67library_name=LIBRARY_NAME,68library_version=VERSION,69cache_dir=CACHE_DIRECTORY,70)71)7273with open(pretrained_path) as f:74config = yaml.load(f, Loader=yaml.Loader)7576try:77model_type = config["model_type"]78config_class = CONFIG_MAPPING[model_type]79config_class = config_class(**config[model_type + "_params"], **kwargs)80config_class.set_config_params(config)81return config_class82except Exception:83raise ValueError(84"Unrecognized config in {}. "85"Should have a `model_type` key in its config.yaml, or contain one of the following strings "86"in its name: {}".format(87pretrained_path, ", ".join(CONFIG_MAPPING.keys())88)89)909192