Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/inference/savable_models.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tensorflow Savable Model modules."""
16
17
import numpy as np
18
import tensorflow as tf
19
20
from tensorflow_tts.models import (
21
TFFastSpeech,
22
TFFastSpeech2,
23
TFMelGANGenerator,
24
TFMBMelGANGenerator,
25
TFHifiGANGenerator,
26
TFTacotron2,
27
TFParallelWaveGANGenerator,
28
)
29
30
31
class SavableTFTacotron2(TFTacotron2):
32
def __init__(self, config, **kwargs):
33
super().__init__(config, **kwargs)
34
35
def call(self, inputs, training=False):
36
input_ids, input_lengths, speaker_ids = inputs
37
return super().inference(input_ids, input_lengths, speaker_ids)
38
39
def _build(self):
40
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=tf.int32)
41
input_lengths = tf.convert_to_tensor([9], dtype=tf.int32)
42
speaker_ids = tf.convert_to_tensor([0], dtype=tf.int32)
43
self([input_ids, input_lengths, speaker_ids])
44
45
46
class SavableTFFastSpeech(TFFastSpeech):
47
def __init__(self, config, **kwargs):
48
super().__init__(config, **kwargs)
49
50
def call(self, inputs, training=False):
51
input_ids, speaker_ids, speed_ratios = inputs
52
return super()._inference(input_ids, speaker_ids, speed_ratios)
53
54
def _build(self):
55
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
56
speaker_ids = tf.convert_to_tensor([0], tf.int32)
57
speed_ratios = tf.convert_to_tensor([1.0], tf.float32)
58
self([input_ids, speaker_ids, speed_ratios])
59
60
61
class SavableTFFastSpeech2(TFFastSpeech2):
62
def __init__(self, config, **kwargs):
63
super().__init__(config, **kwargs)
64
65
def call(self, inputs, training=False):
66
input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios = inputs
67
return super()._inference(
68
input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios
69
)
70
71
def _build(self):
72
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
73
speaker_ids = tf.convert_to_tensor([0], tf.int32)
74
speed_ratios = tf.convert_to_tensor([1.0], tf.float32)
75
f0_ratios = tf.convert_to_tensor([1.0], tf.float32)
76
energy_ratios = tf.convert_to_tensor([1.0], tf.float32)
77
self([input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios])
78
79