Path: blob/master/tensorflow_tts/inference/savable_models.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tensorflow Savable Model modules."""1516import numpy as np17import tensorflow as tf1819from tensorflow_tts.models import (20TFFastSpeech,21TFFastSpeech2,22TFMelGANGenerator,23TFMBMelGANGenerator,24TFHifiGANGenerator,25TFTacotron2,26TFParallelWaveGANGenerator,27)282930class SavableTFTacotron2(TFTacotron2):31def __init__(self, config, **kwargs):32super().__init__(config, **kwargs)3334def call(self, inputs, training=False):35input_ids, input_lengths, speaker_ids = inputs36return super().inference(input_ids, input_lengths, speaker_ids)3738def _build(self):39input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=tf.int32)40input_lengths = tf.convert_to_tensor([9], dtype=tf.int32)41speaker_ids = tf.convert_to_tensor([0], dtype=tf.int32)42self([input_ids, input_lengths, speaker_ids])434445class SavableTFFastSpeech(TFFastSpeech):46def __init__(self, config, **kwargs):47super().__init__(config, **kwargs)4849def call(self, inputs, training=False):50input_ids, speaker_ids, speed_ratios = inputs51return super()._inference(input_ids, speaker_ids, speed_ratios)5253def _build(self):54input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)55speaker_ids = tf.convert_to_tensor([0], tf.int32)56speed_ratios = tf.convert_to_tensor([1.0], tf.float32)57self([input_ids, speaker_ids, speed_ratios])585960class SavableTFFastSpeech2(TFFastSpeech2):61def __init__(self, config, **kwargs):62super().__init__(config, **kwargs)6364def call(self, inputs, training=False):65input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios = inputs66return super()._inference(67input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios68)6970def _build(self):71input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)72speaker_ids = tf.convert_to_tensor([0], tf.int32)73speed_ratios = tf.convert_to_tensor([1.0], tf.float32)74f0_ratios = tf.convert_to_tensor([1.0], tf.float32)75energy_ratios = tf.convert_to_tensor([1.0], tf.float32)76self([input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios])777879