Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/models/fastspeech.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The FastSpeech Authors, The HuggingFace Inc. team and Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tensorflow Model modules for FastSpeech."""
16
17
import numpy as np
18
import tensorflow as tf
19
20
from tensorflow_tts.models import BaseModel
21
22
23
def get_initializer(initializer_range=0.02):
24
"""Creates a `tf.initializers.truncated_normal` with the given range.
25
26
Args:
27
initializer_range: float, initializer range for stddev.
28
29
Returns:
30
TruncatedNormal initializer with stddev = `initializer_range`.
31
32
"""
33
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
34
35
36
def gelu(x):
37
"""Gaussian Error Linear unit."""
38
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
39
return x * cdf
40
41
42
def gelu_new(x):
43
"""Smoother gaussian Error Linear Unit."""
44
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
45
return x * cdf
46
47
48
def swish(x):
49
"""Swish activation function."""
50
return tf.nn.swish(x)
51
52
53
def mish(x):
54
return x * tf.math.tanh(tf.math.softplus(x))
55
56
57
ACT2FN = {
58
"identity": tf.keras.layers.Activation("linear"),
59
"tanh": tf.keras.layers.Activation("tanh"),
60
"gelu": tf.keras.layers.Activation(gelu),
61
"relu": tf.keras.activations.relu,
62
"swish": tf.keras.layers.Activation(swish),
63
"gelu_new": tf.keras.layers.Activation(gelu_new),
64
"mish": tf.keras.layers.Activation(mish),
65
}
66
67
68
class TFEmbedding(tf.keras.layers.Embedding):
69
"""Faster version of embedding."""
70
71
def __init__(self, *args, **kwargs):
72
super().__init__(*args, **kwargs)
73
74
def call(self, inputs):
75
inputs = tf.cast(inputs, tf.int32)
76
outputs = tf.gather(self.embeddings, inputs)
77
return outputs
78
79
80
class TFFastSpeechEmbeddings(tf.keras.layers.Layer):
81
"""Construct charactor/phoneme/positional/speaker embeddings."""
82
83
def __init__(self, config, **kwargs):
84
"""Init variables."""
85
super().__init__(**kwargs)
86
self.vocab_size = config.vocab_size
87
self.hidden_size = config.encoder_self_attention_params.hidden_size
88
self.initializer_range = config.initializer_range
89
self.config = config
90
91
self.position_embeddings = TFEmbedding(
92
config.max_position_embeddings + 1,
93
self.hidden_size,
94
weights=[
95
self._sincos_embedding(
96
self.hidden_size, self.config.max_position_embeddings
97
)
98
],
99
name="position_embeddings",
100
trainable=False,
101
)
102
103
if config.n_speakers > 1:
104
self.encoder_speaker_embeddings = TFEmbedding(
105
config.n_speakers,
106
self.hidden_size,
107
embeddings_initializer=get_initializer(self.initializer_range),
108
name="speaker_embeddings",
109
)
110
self.speaker_fc = tf.keras.layers.Dense(
111
units=self.hidden_size, name="speaker_fc"
112
)
113
114
def build(self, input_shape):
115
"""Build shared charactor/phoneme embedding layers."""
116
with tf.name_scope("charactor_embeddings"):
117
self.charactor_embeddings = self.add_weight(
118
"weight",
119
shape=[self.vocab_size, self.hidden_size],
120
initializer=get_initializer(self.initializer_range),
121
)
122
super().build(input_shape)
123
124
def call(self, inputs, training=False):
125
"""Get charactor embeddings of inputs.
126
127
Args:
128
1. charactor, Tensor (int32) shape [batch_size, length].
129
2. speaker_id, Tensor (int32) shape [batch_size]
130
Returns:
131
Tensor (float32) shape [batch_size, length, embedding_size].
132
133
"""
134
return self._embedding(inputs, training=training)
135
136
def _embedding(self, inputs, training=False):
137
"""Applies embedding based on inputs tensor."""
138
input_ids, speaker_ids = inputs
139
140
input_shape = tf.shape(input_ids)
141
seq_length = input_shape[1]
142
143
position_ids = tf.range(1, seq_length + 1, dtype=tf.int32)[tf.newaxis, :]
144
145
# create embeddings
146
inputs_embeds = tf.gather(self.charactor_embeddings, input_ids)
147
position_embeddings = self.position_embeddings(position_ids)
148
149
# sum embedding
150
embeddings = inputs_embeds + tf.cast(position_embeddings, inputs_embeds.dtype)
151
if self.config.n_speakers > 1:
152
speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)
153
speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))
154
# extended speaker embeddings
155
extended_speaker_features = speaker_features[:, tf.newaxis, :]
156
embeddings += extended_speaker_features
157
158
return embeddings
159
160
def _sincos_embedding(
161
self, hidden_size, max_positional_embedding,
162
):
163
position_enc = np.array(
164
[
165
[
166
pos / np.power(10000, 2.0 * (i // 2) / hidden_size)
167
for i in range(hidden_size)
168
]
169
for pos in range(max_positional_embedding + 1)
170
]
171
)
172
173
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])
174
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])
175
176
# pad embedding.
177
position_enc[0] = 0.0
178
179
return position_enc
180
181
def resize_positional_embeddings(self, new_size):
182
self.position_embeddings = TFEmbedding(
183
new_size + 1,
184
self.hidden_size,
185
weights=[self._sincos_embedding(self.hidden_size, new_size)],
186
name="position_embeddings",
187
trainable=False,
188
)
189
190
191
class TFFastSpeechSelfAttention(tf.keras.layers.Layer):
192
"""Self attention module for fastspeech."""
193
194
def __init__(self, config, **kwargs):
195
"""Init variables."""
196
super().__init__(**kwargs)
197
if config.hidden_size % config.num_attention_heads != 0:
198
raise ValueError(
199
"The hidden size (%d) is not a multiple of the number of attention "
200
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
201
)
202
self.output_attentions = config.output_attentions
203
self.num_attention_heads = config.num_attention_heads
204
self.all_head_size = self.num_attention_heads * config.attention_head_size
205
206
self.query = tf.keras.layers.Dense(
207
self.all_head_size,
208
kernel_initializer=get_initializer(config.initializer_range),
209
name="query",
210
)
211
self.key = tf.keras.layers.Dense(
212
self.all_head_size,
213
kernel_initializer=get_initializer(config.initializer_range),
214
name="key",
215
)
216
self.value = tf.keras.layers.Dense(
217
self.all_head_size,
218
kernel_initializer=get_initializer(config.initializer_range),
219
name="value",
220
)
221
222
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
223
self.config = config
224
225
def transpose_for_scores(self, x, batch_size):
226
"""Transpose to calculate attention scores."""
227
x = tf.reshape(
228
x,
229
(batch_size, -1, self.num_attention_heads, self.config.attention_head_size),
230
)
231
return tf.transpose(x, perm=[0, 2, 1, 3])
232
233
def call(self, inputs, training=False):
234
"""Call logic."""
235
hidden_states, attention_mask = inputs
236
237
batch_size = tf.shape(hidden_states)[0]
238
mixed_query_layer = self.query(hidden_states)
239
mixed_key_layer = self.key(hidden_states)
240
mixed_value_layer = self.value(hidden_states)
241
242
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
243
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
244
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
245
246
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
247
dk = tf.cast(
248
tf.shape(key_layer)[-1], attention_scores.dtype
249
) # scale attention_scores
250
attention_scores = attention_scores / tf.math.sqrt(dk)
251
252
if attention_mask is not None:
253
# extended_attention_masks for self attention encoder.
254
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
255
extended_attention_mask = tf.cast(
256
extended_attention_mask, attention_scores.dtype
257
)
258
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
259
attention_scores = attention_scores + extended_attention_mask
260
261
# Normalize the attention scores to probabilities.
262
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
263
attention_probs = self.dropout(attention_probs, training=training)
264
265
context_layer = tf.matmul(attention_probs, value_layer)
266
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
267
context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
268
269
outputs = (
270
(context_layer, attention_probs)
271
if self.output_attentions
272
else (context_layer,)
273
)
274
return outputs
275
276
277
class TFFastSpeechSelfOutput(tf.keras.layers.Layer):
278
"""Fastspeech output of self attention module."""
279
280
def __init__(self, config, **kwargs):
281
"""Init variables."""
282
super().__init__(**kwargs)
283
self.dense = tf.keras.layers.Dense(
284
config.hidden_size,
285
kernel_initializer=get_initializer(config.initializer_range),
286
name="dense",
287
)
288
self.LayerNorm = tf.keras.layers.LayerNormalization(
289
epsilon=config.layer_norm_eps, name="LayerNorm"
290
)
291
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
292
293
def call(self, inputs, training=False):
294
"""Call logic."""
295
hidden_states, input_tensor = inputs
296
297
hidden_states = self.dense(hidden_states)
298
hidden_states = self.dropout(hidden_states, training=training)
299
hidden_states = self.LayerNorm(hidden_states + input_tensor)
300
return hidden_states
301
302
303
class TFFastSpeechAttention(tf.keras.layers.Layer):
304
"""Fastspeech attention module."""
305
306
def __init__(self, config, **kwargs):
307
"""Init variables."""
308
super().__init__(**kwargs)
309
self.self_attention = TFFastSpeechSelfAttention(config, name="self")
310
self.dense_output = TFFastSpeechSelfOutput(config, name="output")
311
312
def call(self, inputs, training=False):
313
input_tensor, attention_mask = inputs
314
315
self_outputs = self.self_attention(
316
[input_tensor, attention_mask], training=training
317
)
318
attention_output = self.dense_output(
319
[self_outputs[0], input_tensor], training=training
320
)
321
masked_attention_output = attention_output * tf.cast(
322
tf.expand_dims(attention_mask, 2), dtype=attention_output.dtype
323
)
324
outputs = (masked_attention_output,) + self_outputs[
325
1:
326
] # add attentions if we output them
327
return outputs
328
329
330
class TFFastSpeechIntermediate(tf.keras.layers.Layer):
331
"""Intermediate representation module."""
332
333
def __init__(self, config, **kwargs):
334
"""Init variables."""
335
super().__init__(**kwargs)
336
self.conv1d_1 = tf.keras.layers.Conv1D(
337
config.intermediate_size,
338
kernel_size=config.intermediate_kernel_size,
339
kernel_initializer=get_initializer(config.initializer_range),
340
padding="same",
341
name="conv1d_1",
342
)
343
self.conv1d_2 = tf.keras.layers.Conv1D(
344
config.hidden_size,
345
kernel_size=config.intermediate_kernel_size,
346
kernel_initializer=get_initializer(config.initializer_range),
347
padding="same",
348
name="conv1d_2",
349
)
350
if isinstance(config.hidden_act, str):
351
self.intermediate_act_fn = ACT2FN[config.hidden_act]
352
else:
353
self.intermediate_act_fn = config.hidden_act
354
355
def call(self, inputs):
356
"""Call logic."""
357
hidden_states, attention_mask = inputs
358
359
hidden_states = self.conv1d_1(hidden_states)
360
hidden_states = self.intermediate_act_fn(hidden_states)
361
hidden_states = self.conv1d_2(hidden_states)
362
363
masked_hidden_states = hidden_states * tf.cast(
364
tf.expand_dims(attention_mask, 2), dtype=hidden_states.dtype
365
)
366
return masked_hidden_states
367
368
369
class TFFastSpeechOutput(tf.keras.layers.Layer):
370
"""Output module."""
371
372
def __init__(self, config, **kwargs):
373
"""Init variables."""
374
super().__init__(**kwargs)
375
self.LayerNorm = tf.keras.layers.LayerNormalization(
376
epsilon=config.layer_norm_eps, name="LayerNorm"
377
)
378
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
379
380
def call(self, inputs, training=False):
381
"""Call logic."""
382
hidden_states, input_tensor = inputs
383
384
hidden_states = self.dropout(hidden_states, training=training)
385
hidden_states = self.LayerNorm(hidden_states + input_tensor)
386
return hidden_states
387
388
389
class TFFastSpeechLayer(tf.keras.layers.Layer):
390
"""Fastspeech module (FFT module on the paper)."""
391
392
def __init__(self, config, **kwargs):
393
"""Init variables."""
394
super().__init__(**kwargs)
395
self.attention = TFFastSpeechAttention(config, name="attention")
396
self.intermediate = TFFastSpeechIntermediate(config, name="intermediate")
397
self.bert_output = TFFastSpeechOutput(config, name="output")
398
399
def call(self, inputs, training=False):
400
"""Call logic."""
401
hidden_states, attention_mask = inputs
402
403
attention_outputs = self.attention(
404
[hidden_states, attention_mask], training=training
405
)
406
attention_output = attention_outputs[0]
407
intermediate_output = self.intermediate(
408
[attention_output, attention_mask], training=training
409
)
410
layer_output = self.bert_output(
411
[intermediate_output, attention_output], training=training
412
)
413
masked_layer_output = layer_output * tf.cast(
414
tf.expand_dims(attention_mask, 2), dtype=layer_output.dtype
415
)
416
outputs = (masked_layer_output,) + attention_outputs[
417
1:
418
] # add attentions if we output them
419
return outputs
420
421
422
class TFFastSpeechEncoder(tf.keras.layers.Layer):
423
"""Fast Speech encoder module."""
424
425
def __init__(self, config, **kwargs):
426
"""Init variables."""
427
super().__init__(**kwargs)
428
self.output_attentions = config.output_attentions
429
self.output_hidden_states = config.output_hidden_states
430
self.layer = [
431
TFFastSpeechLayer(config, name="layer_._{}".format(i))
432
for i in range(config.num_hidden_layers)
433
]
434
435
def call(self, inputs, training=False):
436
"""Call logic."""
437
hidden_states, attention_mask = inputs
438
439
all_hidden_states = ()
440
all_attentions = ()
441
for _, layer_module in enumerate(self.layer):
442
if self.output_hidden_states:
443
all_hidden_states = all_hidden_states + (hidden_states,)
444
445
layer_outputs = layer_module(
446
[hidden_states, attention_mask], training=training
447
)
448
hidden_states = layer_outputs[0]
449
450
if self.output_attentions:
451
all_attentions = all_attentions + (layer_outputs[1],)
452
453
# Add last layer
454
if self.output_hidden_states:
455
all_hidden_states = all_hidden_states + (hidden_states,)
456
457
outputs = (hidden_states,)
458
if self.output_hidden_states:
459
outputs = outputs + (all_hidden_states,)
460
if self.output_attentions:
461
outputs = outputs + (all_attentions,)
462
return outputs # outputs, (hidden states), (attentions)
463
464
465
class TFFastSpeechDecoder(TFFastSpeechEncoder):
466
"""Fast Speech decoder module."""
467
468
def __init__(self, config, **kwargs):
469
self.is_compatible_encoder = kwargs.pop("is_compatible_encoder", True)
470
471
super().__init__(config, **kwargs)
472
self.config = config
473
474
# create decoder positional embedding
475
self.decoder_positional_embeddings = TFEmbedding(
476
config.max_position_embeddings + 1,
477
config.hidden_size,
478
weights=[self._sincos_embedding()],
479
name="position_embeddings",
480
trainable=False,
481
)
482
483
if self.is_compatible_encoder is False:
484
self.project_compatible_decoder = tf.keras.layers.Dense(
485
units=config.hidden_size, name="project_compatible_decoder"
486
)
487
488
if config.n_speakers > 1:
489
self.decoder_speaker_embeddings = TFEmbedding(
490
config.n_speakers,
491
config.hidden_size,
492
embeddings_initializer=get_initializer(config.initializer_range),
493
name="speaker_embeddings",
494
)
495
self.speaker_fc = tf.keras.layers.Dense(
496
units=config.hidden_size, name="speaker_fc"
497
)
498
499
def call(self, inputs, training=False):
500
hidden_states, speaker_ids, encoder_mask, decoder_pos = inputs
501
502
if self.is_compatible_encoder is False:
503
hidden_states = self.project_compatible_decoder(hidden_states)
504
505
# calculate new hidden states.
506
hidden_states += tf.cast(
507
self.decoder_positional_embeddings(decoder_pos), hidden_states.dtype
508
)
509
510
if self.config.n_speakers > 1:
511
speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)
512
speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))
513
# extended speaker embeddings
514
extended_speaker_features = speaker_features[:, tf.newaxis, :]
515
hidden_states += extended_speaker_features
516
517
return super().call([hidden_states, encoder_mask], training=training)
518
519
def _sincos_embedding(self):
520
position_enc = np.array(
521
[
522
[
523
pos / np.power(10000, 2.0 * (i // 2) / self.config.hidden_size)
524
for i in range(self.config.hidden_size)
525
]
526
for pos in range(self.config.max_position_embeddings + 1)
527
]
528
)
529
530
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])
531
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])
532
533
# pad embedding.
534
position_enc[0] = 0.0
535
536
return position_enc
537
538
539
class TFTacotronPostnet(tf.keras.layers.Layer):
540
"""Tacotron-2 postnet."""
541
542
def __init__(self, config, **kwargs):
543
"""Init variables."""
544
super().__init__(**kwargs)
545
self.conv_batch_norm = []
546
for i in range(config.n_conv_postnet):
547
conv = tf.keras.layers.Conv1D(
548
filters=config.postnet_conv_filters
549
if i < config.n_conv_postnet - 1
550
else config.num_mels,
551
kernel_size=config.postnet_conv_kernel_sizes,
552
padding="same",
553
name="conv_._{}".format(i),
554
)
555
batch_norm = tf.keras.layers.BatchNormalization(
556
axis=-1, name="batch_norm_._{}".format(i)
557
)
558
self.conv_batch_norm.append((conv, batch_norm))
559
self.dropout = tf.keras.layers.Dropout(
560
rate=config.postnet_dropout_rate, name="dropout"
561
)
562
self.activation = [tf.nn.tanh] * (config.n_conv_postnet - 1) + [tf.identity]
563
564
def call(self, inputs, training=False):
565
"""Call logic."""
566
outputs, mask = inputs
567
extended_mask = tf.cast(tf.expand_dims(mask, axis=2), outputs.dtype)
568
for i, (conv, bn) in enumerate(self.conv_batch_norm):
569
outputs = conv(outputs)
570
outputs = bn(outputs)
571
outputs = self.activation[i](outputs)
572
outputs = self.dropout(outputs, training=training)
573
return outputs * extended_mask
574
575
576
class TFFastSpeechDurationPredictor(tf.keras.layers.Layer):
577
"""FastSpeech duration predictor module."""
578
579
def __init__(self, config, **kwargs):
580
"""Init variables."""
581
super().__init__(**kwargs)
582
self.conv_layers = []
583
for i in range(config.num_duration_conv_layers):
584
self.conv_layers.append(
585
tf.keras.layers.Conv1D(
586
config.duration_predictor_filters,
587
config.duration_predictor_kernel_sizes,
588
padding="same",
589
name="conv_._{}".format(i),
590
)
591
)
592
self.conv_layers.append(
593
tf.keras.layers.LayerNormalization(
594
epsilon=config.layer_norm_eps, name="LayerNorm_._{}".format(i)
595
)
596
)
597
self.conv_layers.append(tf.keras.layers.Activation(tf.nn.relu6))
598
self.conv_layers.append(
599
tf.keras.layers.Dropout(config.duration_predictor_dropout_probs)
600
)
601
self.conv_layers_sequence = tf.keras.Sequential(self.conv_layers)
602
self.output_layer = tf.keras.layers.Dense(1)
603
604
def call(self, inputs, training=False):
605
"""Call logic."""
606
encoder_hidden_states, attention_mask = inputs
607
attention_mask = tf.cast(
608
tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype
609
)
610
611
# mask encoder hidden states
612
masked_encoder_hidden_states = encoder_hidden_states * attention_mask
613
614
# pass though first layer
615
outputs = self.conv_layers_sequence(masked_encoder_hidden_states)
616
outputs = self.output_layer(outputs)
617
masked_outputs = outputs * attention_mask
618
return tf.squeeze(tf.nn.relu6(masked_outputs), -1) # make sure positive value.
619
620
621
class TFFastSpeechLengthRegulator(tf.keras.layers.Layer):
622
"""FastSpeech lengthregulator module."""
623
624
def __init__(self, config, **kwargs):
625
"""Init variables."""
626
self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)
627
super().__init__(**kwargs)
628
self.config = config
629
630
def call(self, inputs, training=False):
631
"""Call logic.
632
Args:
633
1. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size]
634
2. durations_gt, Tensor (float32/int32) shape [batch_size, length]
635
"""
636
encoder_hidden_states, durations_gt = inputs
637
outputs, encoder_masks = self._length_regulator(
638
encoder_hidden_states, durations_gt
639
)
640
return outputs, encoder_masks
641
642
def _length_regulator(self, encoder_hidden_states, durations_gt):
643
"""Length regulator logic."""
644
sum_durations = tf.reduce_sum(durations_gt, axis=-1) # [batch_size]
645
max_durations = tf.reduce_max(sum_durations)
646
647
input_shape = tf.shape(encoder_hidden_states)
648
batch_size = input_shape[0]
649
hidden_size = input_shape[-1]
650
651
# initialize output hidden states and encoder masking.
652
if self.enable_tflite_convertible:
653
# There is only 1 batch in inference, so we don't have to use
654
# `tf.While` op with 3-D output tensor.
655
repeats = durations_gt[0]
656
real_length = tf.reduce_sum(repeats)
657
pad_size = max_durations - real_length
658
# masks : [max_durations]
659
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
660
repeat_encoder_hidden_states = tf.repeat(
661
encoder_hidden_states[0], repeats=repeats, axis=0
662
)
663
repeat_encoder_hidden_states = tf.expand_dims(
664
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
665
) # [1, max_durations, hidden_size]
666
667
outputs = repeat_encoder_hidden_states
668
encoder_masks = masks
669
else:
670
outputs = tf.zeros(
671
shape=[0, max_durations, hidden_size], dtype=encoder_hidden_states.dtype
672
)
673
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)
674
675
def condition(
676
i,
677
batch_size,
678
outputs,
679
encoder_masks,
680
encoder_hidden_states,
681
durations_gt,
682
max_durations,
683
):
684
return tf.less(i, batch_size)
685
686
def body(
687
i,
688
batch_size,
689
outputs,
690
encoder_masks,
691
encoder_hidden_states,
692
durations_gt,
693
max_durations,
694
):
695
repeats = durations_gt[i]
696
real_length = tf.reduce_sum(repeats)
697
pad_size = max_durations - real_length
698
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
699
repeat_encoder_hidden_states = tf.repeat(
700
encoder_hidden_states[i], repeats=repeats, axis=0
701
)
702
repeat_encoder_hidden_states = tf.expand_dims(
703
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
704
) # [1, max_durations, hidden_size]
705
outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
706
encoder_masks = tf.concat([encoder_masks, masks], axis=0)
707
return [
708
i + 1,
709
batch_size,
710
outputs,
711
encoder_masks,
712
encoder_hidden_states,
713
durations_gt,
714
max_durations,
715
]
716
717
# initialize iteration i.
718
i = tf.constant(0, dtype=tf.int32)
719
_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
720
condition,
721
body,
722
[
723
i,
724
batch_size,
725
outputs,
726
encoder_masks,
727
encoder_hidden_states,
728
durations_gt,
729
max_durations,
730
],
731
shape_invariants=[
732
i.get_shape(),
733
batch_size.get_shape(),
734
tf.TensorShape(
735
[
736
None,
737
None,
738
self.config.encoder_self_attention_params.hidden_size,
739
]
740
),
741
tf.TensorShape([None, None]),
742
encoder_hidden_states.get_shape(),
743
durations_gt.get_shape(),
744
max_durations.get_shape(),
745
],
746
)
747
748
return outputs, encoder_masks
749
750
751
class TFFastSpeech(BaseModel):
752
"""TF Fastspeech module."""
753
754
def __init__(self, config, **kwargs):
755
"""Init layers for fastspeech."""
756
self.enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)
757
super().__init__(**kwargs)
758
self.embeddings = TFFastSpeechEmbeddings(config, name="embeddings")
759
self.encoder = TFFastSpeechEncoder(
760
config.encoder_self_attention_params, name="encoder"
761
)
762
self.duration_predictor = TFFastSpeechDurationPredictor(
763
config, dtype=tf.float32, name="duration_predictor"
764
)
765
self.length_regulator = TFFastSpeechLengthRegulator(
766
config,
767
enable_tflite_convertible=self.enable_tflite_convertible,
768
name="length_regulator",
769
)
770
self.decoder = TFFastSpeechDecoder(
771
config.decoder_self_attention_params,
772
is_compatible_encoder=config.encoder_self_attention_params.hidden_size
773
== config.decoder_self_attention_params.hidden_size,
774
name="decoder",
775
)
776
self.mel_dense = tf.keras.layers.Dense(
777
units=config.num_mels, dtype=tf.float32, name="mel_before"
778
)
779
self.postnet = TFTacotronPostnet(
780
config=config, dtype=tf.float32, name="postnet"
781
)
782
783
self.setup_inference_fn()
784
785
def _build(self):
786
"""Dummy input for building model."""
787
# fake inputs
788
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
789
speaker_ids = tf.convert_to_tensor([0], tf.int32)
790
duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
791
self(input_ids, speaker_ids, duration_gts)
792
793
def resize_positional_embeddings(self, new_size):
794
self.embeddings.resize_positional_embeddings(new_size)
795
self._build()
796
797
def call(
798
self, input_ids, speaker_ids, duration_gts, training=False, **kwargs,
799
):
800
"""Call logic."""
801
attention_mask = tf.math.not_equal(input_ids, 0)
802
embedding_output = self.embeddings([input_ids, speaker_ids], training=training)
803
encoder_output = self.encoder(
804
[embedding_output, attention_mask], training=training
805
)
806
last_encoder_hidden_states = encoder_output[0]
807
808
# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
809
# rather than just use last_hidden_states of encoder for duration_predictor.
810
duration_outputs = self.duration_predictor(
811
[last_encoder_hidden_states, attention_mask]
812
) # [batch_size, length]
813
814
length_regulator_outputs, encoder_masks = self.length_regulator(
815
[last_encoder_hidden_states, duration_gts], training=training
816
)
817
818
# create decoder positional embedding
819
decoder_pos = tf.range(
820
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
821
)
822
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks
823
824
decoder_output = self.decoder(
825
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
826
training=training,
827
)
828
last_decoder_hidden_states = decoder_output[0]
829
830
# here u can use sum or concat more than 1 hidden states layers from decoder.
831
mel_before = self.mel_dense(last_decoder_hidden_states)
832
mel_after = (
833
self.postnet([mel_before, encoder_masks], training=training) + mel_before
834
)
835
836
outputs = (mel_before, mel_after, duration_outputs)
837
return outputs
838
839
def _inference(self, input_ids, speaker_ids, speed_ratios, **kwargs):
840
"""Call logic."""
841
attention_mask = tf.math.not_equal(input_ids, 0)
842
embedding_output = self.embeddings([input_ids, speaker_ids], training=False)
843
encoder_output = self.encoder(
844
[embedding_output, attention_mask], training=False
845
)
846
last_encoder_hidden_states = encoder_output[0]
847
848
# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
849
# rather than just use last_hidden_states of encoder for duration_predictor.
850
duration_outputs = self.duration_predictor(
851
[last_encoder_hidden_states, attention_mask]
852
) # [batch_size, length]
853
duration_outputs = tf.math.exp(duration_outputs) - 1.0
854
855
if speed_ratios is None:
856
speed_ratios = tf.convert_to_tensor(np.array([1.0]), dtype=tf.float32)
857
858
speed_ratios = tf.expand_dims(speed_ratios, 1)
859
860
duration_outputs = tf.cast(
861
tf.math.round(duration_outputs * speed_ratios), tf.int32
862
)
863
864
length_regulator_outputs, encoder_masks = self.length_regulator(
865
[last_encoder_hidden_states, duration_outputs], training=False
866
)
867
868
# create decoder positional embedding
869
decoder_pos = tf.range(
870
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
871
)
872
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks
873
874
decoder_output = self.decoder(
875
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
876
training=False,
877
)
878
last_decoder_hidden_states = decoder_output[0]
879
880
# here u can use sum or concat more than 1 hidden states layers from decoder.
881
mel_before = self.mel_dense(last_decoder_hidden_states)
882
mel_after = (
883
self.postnet([mel_before, encoder_masks], training=False) + mel_before
884
)
885
886
outputs = (mel_before, mel_after, duration_outputs)
887
return outputs
888
889
def setup_inference_fn(self):
890
self.inference = tf.function(
891
self._inference,
892
experimental_relax_shapes=True,
893
input_signature=[
894
tf.TensorSpec(shape=[None, None], dtype=tf.int32, name="input_ids"),
895
tf.TensorSpec(shape=[None,], dtype=tf.int32, name="speaker_ids"),
896
tf.TensorSpec(shape=[None,], dtype=tf.float32, name="speed_ratios"),
897
],
898
)
899
900
self.inference_tflite = tf.function(
901
self._inference,
902
experimental_relax_shapes=True,
903
input_signature=[
904
tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="input_ids"),
905
tf.TensorSpec(shape=[1,], dtype=tf.int32, name="speaker_ids"),
906
tf.TensorSpec(shape=[1,], dtype=tf.float32, name="speed_ratios"),
907
],
908
)
909
910