Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/models/tacotron2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""Tacotron-2 Modules."""
17
18
import collections
19
20
import numpy as np
21
import tensorflow as tf
22
23
# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,
24
# uncomment this line.
25
# from tensorflow_addons.seq2seq import dynamic_decode
26
from tensorflow_addons.seq2seq import BahdanauAttention, Decoder, Sampler
27
28
from tensorflow_tts.utils import dynamic_decode
29
30
from tensorflow_tts.models import BaseModel
31
32
33
def get_initializer(initializer_range=0.02):
34
"""Creates a `tf.initializers.truncated_normal` with the given range.
35
Args:
36
initializer_range: float, initializer range for stddev.
37
Returns:
38
TruncatedNormal initializer with stddev = `initializer_range`.
39
"""
40
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
41
42
43
def gelu(x):
44
"""Gaussian Error Linear unit."""
45
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
46
return x * cdf
47
48
49
def gelu_new(x):
50
"""Smoother gaussian Error Linear Unit."""
51
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
52
return x * cdf
53
54
55
def swish(x):
56
"""Swish activation function."""
57
return tf.nn.swish(x)
58
59
60
def mish(x):
61
return x * tf.math.tanh(tf.math.softplus(x))
62
63
64
ACT2FN = {
65
"identity": tf.keras.layers.Activation("linear"),
66
"tanh": tf.keras.layers.Activation("tanh"),
67
"gelu": tf.keras.layers.Activation(gelu),
68
"relu": tf.keras.activations.relu,
69
"swish": tf.keras.layers.Activation(swish),
70
"gelu_new": tf.keras.layers.Activation(gelu_new),
71
"mish": tf.keras.layers.Activation(mish),
72
}
73
74
75
class TFEmbedding(tf.keras.layers.Embedding):
76
"""Faster version of embedding."""
77
78
def __init__(self, *args, **kwargs):
79
super().__init__(*args, **kwargs)
80
81
def call(self, inputs):
82
inputs = tf.cast(tf.expand_dims(inputs, -1), tf.int32)
83
outputs = tf.gather_nd(self.embeddings, inputs)
84
return outputs
85
86
87
class TFTacotronConvBatchNorm(tf.keras.layers.Layer):
88
"""Tacotron-2 Convolutional Batchnorm module."""
89
90
def __init__(
91
self, filters, kernel_size, dropout_rate, activation=None, name_idx=None
92
):
93
super().__init__()
94
self.conv1d = tf.keras.layers.Conv1D(
95
filters,
96
kernel_size,
97
kernel_initializer=get_initializer(0.02),
98
padding="same",
99
name="conv_._{}".format(name_idx),
100
)
101
self.norm = tf.keras.layers.experimental.SyncBatchNormalization(
102
axis=-1, name="batch_norm_._{}".format(name_idx)
103
)
104
self.dropout = tf.keras.layers.Dropout(
105
rate=dropout_rate, name="dropout_._{}".format(name_idx)
106
)
107
self.act = ACT2FN[activation]
108
109
def call(self, inputs, training=False):
110
outputs = self.conv1d(inputs)
111
outputs = self.norm(outputs, training=training)
112
outputs = self.act(outputs)
113
outputs = self.dropout(outputs, training=training)
114
return outputs
115
116
117
class TFTacotronEmbeddings(tf.keras.layers.Layer):
118
"""Construct character/phoneme/positional/speaker embeddings."""
119
120
def __init__(self, config, **kwargs):
121
"""Init variables."""
122
super().__init__(**kwargs)
123
self.vocab_size = config.vocab_size
124
self.embedding_hidden_size = config.embedding_hidden_size
125
self.initializer_range = config.initializer_range
126
self.config = config
127
128
if config.n_speakers > 1:
129
self.speaker_embeddings = TFEmbedding(
130
config.n_speakers,
131
config.embedding_hidden_size,
132
embeddings_initializer=get_initializer(self.initializer_range),
133
name="speaker_embeddings",
134
)
135
self.speaker_fc = tf.keras.layers.Dense(
136
units=config.embedding_hidden_size, name="speaker_fc"
137
)
138
139
self.LayerNorm = tf.keras.layers.LayerNormalization(
140
epsilon=config.layer_norm_eps, name="LayerNorm"
141
)
142
self.dropout = tf.keras.layers.Dropout(config.embedding_dropout_prob)
143
144
def build(self, input_shape):
145
"""Build shared character/phoneme embedding layers."""
146
with tf.name_scope("character_embeddings"):
147
self.character_embeddings = self.add_weight(
148
"weight",
149
shape=[self.vocab_size, self.embedding_hidden_size],
150
initializer=get_initializer(self.initializer_range),
151
)
152
super().build(input_shape)
153
154
def call(self, inputs, training=False):
155
"""Get character embeddings of inputs.
156
Args:
157
1. character, Tensor (int32) shape [batch_size, length].
158
2. speaker_id, Tensor (int32) shape [batch_size]
159
Returns:
160
Tensor (float32) shape [batch_size, length, embedding_size].
161
"""
162
return self._embedding(inputs, training=training)
163
164
def _embedding(self, inputs, training=False):
165
"""Applies embedding based on inputs tensor."""
166
input_ids, speaker_ids = inputs
167
168
# create embeddings
169
inputs_embeds = tf.gather(self.character_embeddings, input_ids)
170
embeddings = inputs_embeds
171
172
if self.config.n_speakers > 1:
173
speaker_embeddings = self.speaker_embeddings(speaker_ids)
174
speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))
175
# extended speaker embeddings
176
extended_speaker_features = speaker_features[:, tf.newaxis, :]
177
# sum all embedding
178
embeddings += extended_speaker_features
179
180
# apply layer-norm and dropout for embeddings.
181
embeddings = self.LayerNorm(embeddings)
182
embeddings = self.dropout(embeddings, training=training)
183
184
return embeddings
185
186
187
class TFTacotronEncoderConvs(tf.keras.layers.Layer):
188
"""Tacotron-2 Encoder Convolutional Batchnorm module."""
189
190
def __init__(self, config, **kwargs):
191
"""Init variables."""
192
super().__init__(**kwargs)
193
self.conv_batch_norm = []
194
for i in range(config.n_conv_encoder):
195
conv = TFTacotronConvBatchNorm(
196
filters=config.encoder_conv_filters,
197
kernel_size=config.encoder_conv_kernel_sizes,
198
activation=config.encoder_conv_activation,
199
dropout_rate=config.encoder_conv_dropout_rate,
200
name_idx=i,
201
)
202
self.conv_batch_norm.append(conv)
203
204
def call(self, inputs, training=False):
205
"""Call logic."""
206
outputs = inputs
207
for conv in self.conv_batch_norm:
208
outputs = conv(outputs, training=training)
209
return outputs
210
211
212
class TFTacotronEncoder(tf.keras.layers.Layer):
213
"""Tacotron-2 Encoder."""
214
215
def __init__(self, config, **kwargs):
216
"""Init variables."""
217
super().__init__(**kwargs)
218
self.embeddings = TFTacotronEmbeddings(config, name="embeddings")
219
self.convbn = TFTacotronEncoderConvs(config, name="conv_batch_norm")
220
self.bilstm = tf.keras.layers.Bidirectional(
221
tf.keras.layers.LSTM(
222
units=config.encoder_lstm_units, return_sequences=True
223
),
224
name="bilstm",
225
)
226
227
if config.n_speakers > 1:
228
self.encoder_speaker_embeddings = TFEmbedding(
229
config.n_speakers,
230
config.embedding_hidden_size,
231
embeddings_initializer=get_initializer(config.initializer_range),
232
name="encoder_speaker_embeddings",
233
)
234
self.encoder_speaker_fc = tf.keras.layers.Dense(
235
units=config.encoder_lstm_units * 2, name="encoder_speaker_fc"
236
)
237
238
self.config = config
239
240
def call(self, inputs, training=False):
241
"""Call logic."""
242
input_ids, speaker_ids, input_mask = inputs
243
244
# create embedding and mask them since we sum
245
# speaker embedding to all character embedding.
246
input_embeddings = self.embeddings([input_ids, speaker_ids], training=training)
247
248
# pass embeddings to convolution batch norm
249
conv_outputs = self.convbn(input_embeddings, training=training)
250
251
# bi-lstm.
252
outputs = self.bilstm(conv_outputs, mask=input_mask)
253
254
if self.config.n_speakers > 1:
255
encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)
256
encoder_speaker_features = tf.math.softplus(
257
self.encoder_speaker_fc(encoder_speaker_embeddings)
258
)
259
# extended encoderspeaker embeddings
260
extended_encoder_speaker_features = encoder_speaker_features[
261
:, tf.newaxis, :
262
]
263
# sum to encoder outputs
264
outputs += extended_encoder_speaker_features
265
266
return outputs
267
268
269
class Tacotron2Sampler(Sampler):
270
"""Tacotron2 sampler for Seq2Seq training."""
271
272
def __init__(
273
self, config,
274
):
275
super().__init__()
276
self.config = config
277
# create schedule factor.
278
# the input of a next decoder cell is calculated by formular:
279
# next_inputs = ratio * prev_groundtruth_outputs + (1.0 - ratio) * prev_predicted_outputs.
280
self._ratio = tf.constant(1.0, dtype=tf.float32)
281
self._reduction_factor = self.config.reduction_factor
282
283
def setup_target(self, targets, mel_lengths):
284
"""Setup ground-truth mel outputs for decoder."""
285
self.mel_lengths = mel_lengths
286
self.set_batch_size(tf.shape(targets)[0])
287
self.targets = targets[
288
:, self._reduction_factor - 1 :: self._reduction_factor, :
289
]
290
self.max_lengths = tf.tile([tf.shape(self.targets)[1]], [self._batch_size])
291
292
@property
293
def batch_size(self):
294
return self._batch_size
295
296
@property
297
def sample_ids_shape(self):
298
return tf.TensorShape([])
299
300
@property
301
def sample_ids_dtype(self):
302
return tf.int32
303
304
@property
305
def reduction_factor(self):
306
return self._reduction_factor
307
308
def initialize(self):
309
"""Return (Finished, next_inputs)."""
310
return (
311
tf.tile([False], [self._batch_size]),
312
tf.tile([[0.0]], [self._batch_size, self.config.n_mels]),
313
)
314
315
def sample(self, time, outputs, state):
316
return tf.tile([0], [self._batch_size])
317
318
def next_inputs(
319
self,
320
time,
321
outputs,
322
state,
323
sample_ids,
324
stop_token_prediction,
325
training=False,
326
**kwargs,
327
):
328
if training:
329
finished = time + 1 >= self.max_lengths
330
next_inputs = (
331
self._ratio * self.targets[:, time, :]
332
+ (1.0 - self._ratio) * outputs[:, -self.config.n_mels :]
333
)
334
next_state = state
335
return (finished, next_inputs, next_state)
336
else:
337
stop_token_prediction = tf.nn.sigmoid(stop_token_prediction)
338
finished = tf.cast(tf.round(stop_token_prediction), tf.bool)
339
finished = tf.reduce_all(finished)
340
next_inputs = outputs[:, -self.config.n_mels :]
341
next_state = state
342
return (finished, next_inputs, next_state)
343
344
def set_batch_size(self, batch_size):
345
self._batch_size = batch_size
346
347
348
class TFTacotronLocationSensitiveAttention(BahdanauAttention):
349
"""Tacotron-2 Location Sensitive Attention module."""
350
351
def __init__(
352
self,
353
config,
354
memory,
355
mask_encoder=True,
356
memory_sequence_length=None,
357
is_cumulate=True,
358
):
359
"""Init variables."""
360
memory_length = memory_sequence_length if (mask_encoder is True) else None
361
super().__init__(
362
units=config.attention_dim,
363
memory=memory,
364
memory_sequence_length=memory_length,
365
probability_fn="softmax",
366
name="LocationSensitiveAttention",
367
)
368
self.location_convolution = tf.keras.layers.Conv1D(
369
filters=config.attention_filters,
370
kernel_size=config.attention_kernel,
371
padding="same",
372
use_bias=False,
373
name="location_conv",
374
)
375
self.location_layer = tf.keras.layers.Dense(
376
units=config.attention_dim, use_bias=False, name="location_layer"
377
)
378
379
self.v = tf.keras.layers.Dense(1, use_bias=True, name="scores_attention")
380
self.config = config
381
self.is_cumulate = is_cumulate
382
self.use_window = False
383
384
def setup_window(self, win_front=2, win_back=4):
385
self.win_front = tf.constant(win_front, tf.int32)
386
self.win_back = tf.constant(win_back, tf.int32)
387
388
self._indices = tf.expand_dims(tf.range(tf.shape(self.keys)[1]), 0)
389
self._indices = tf.tile(
390
self._indices, [tf.shape(self.keys)[0], 1]
391
) # [batch_size, max_time]
392
393
self.use_window = True
394
395
def _compute_window_mask(self, max_alignments):
396
"""Compute window mask for inference.
397
Args:
398
max_alignments (int): [batch_size]
399
"""
400
expanded_max_alignments = tf.expand_dims(max_alignments, 1) # [batch_size, 1]
401
low = expanded_max_alignments - self.win_front
402
high = expanded_max_alignments + self.win_back
403
mlow = tf.cast((self._indices < low), tf.float32)
404
mhigh = tf.cast((self._indices > high), tf.float32)
405
mask = mlow + mhigh
406
return mask # [batch_size, max_length]
407
408
def __call__(self, inputs, training=False):
409
query, state, prev_max_alignments = inputs
410
411
processed_query = self.query_layer(query) if self.query_layer else query
412
processed_query = tf.expand_dims(processed_query, 1)
413
414
expanded_alignments = tf.expand_dims(state, axis=2)
415
f = self.location_convolution(expanded_alignments)
416
processed_location_features = self.location_layer(f)
417
418
energy = self._location_sensitive_score(
419
processed_query, processed_location_features, self.keys
420
)
421
422
# mask energy on inference steps.
423
if self.use_window is True:
424
window_mask = self._compute_window_mask(prev_max_alignments)
425
energy = energy + window_mask * -1e20
426
427
alignments = self.probability_fn(energy, state)
428
429
if self.is_cumulate:
430
state = alignments + state
431
else:
432
state = alignments
433
434
expanded_alignments = tf.expand_dims(alignments, 2)
435
context = tf.reduce_sum(expanded_alignments * self.values, 1)
436
437
return context, alignments, state
438
439
def _location_sensitive_score(self, W_query, W_fil, W_keys):
440
"""Calculate location sensitive energy."""
441
return tf.squeeze(self.v(tf.nn.tanh(W_keys + W_query + W_fil)), -1)
442
443
def get_initial_state(self, batch_size, size):
444
"""Get initial alignments."""
445
return tf.zeros(shape=[batch_size, size], dtype=tf.float32)
446
447
def get_initial_context(self, batch_size):
448
"""Get initial attention."""
449
return tf.zeros(
450
shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32
451
)
452
453
454
class TFTacotronPrenet(tf.keras.layers.Layer):
455
"""Tacotron-2 prenet."""
456
457
def __init__(self, config, **kwargs):
458
"""Init variables."""
459
super().__init__(**kwargs)
460
self.prenet_dense = [
461
tf.keras.layers.Dense(
462
units=config.prenet_units,
463
activation=ACT2FN[config.prenet_activation],
464
name="dense_._{}".format(i),
465
)
466
for i in range(config.n_prenet_layers)
467
]
468
self.dropout = tf.keras.layers.Dropout(
469
rate=config.prenet_dropout_rate, name="dropout"
470
)
471
472
def call(self, inputs, training=False):
473
"""Call logic."""
474
outputs = inputs
475
for layer in self.prenet_dense:
476
outputs = layer(outputs)
477
outputs = self.dropout(outputs, training=True)
478
return outputs
479
480
481
class TFTacotronPostnet(tf.keras.layers.Layer):
482
"""Tacotron-2 postnet."""
483
484
def __init__(self, config, **kwargs):
485
"""Init variables."""
486
super().__init__(**kwargs)
487
self.conv_batch_norm = []
488
for i in range(config.n_conv_postnet):
489
conv = TFTacotronConvBatchNorm(
490
filters=config.postnet_conv_filters,
491
kernel_size=config.postnet_conv_kernel_sizes,
492
dropout_rate=config.postnet_dropout_rate,
493
activation="identity" if i + 1 == config.n_conv_postnet else "tanh",
494
name_idx=i,
495
)
496
self.conv_batch_norm.append(conv)
497
498
def call(self, inputs, training=False):
499
"""Call logic."""
500
outputs = inputs
501
for _, conv in enumerate(self.conv_batch_norm):
502
outputs = conv(outputs, training=training)
503
return outputs
504
505
506
TFTacotronDecoderCellState = collections.namedtuple(
507
"TFTacotronDecoderCellState",
508
[
509
"attention_lstm_state",
510
"decoder_lstms_state",
511
"context",
512
"time",
513
"state",
514
"alignment_history",
515
"max_alignments",
516
],
517
)
518
519
TFDecoderOutput = collections.namedtuple(
520
"TFDecoderOutput", ("mel_output", "token_output", "sample_id")
521
)
522
523
524
class TFTacotronDecoderCell(tf.keras.layers.AbstractRNNCell):
525
"""Tacotron-2 custom decoder cell."""
526
527
def __init__(self, config, enable_tflite_convertible=False, **kwargs):
528
"""Init variables."""
529
super().__init__(**kwargs)
530
self.enable_tflite_convertible = enable_tflite_convertible
531
self.prenet = TFTacotronPrenet(config, name="prenet")
532
533
# define lstm cell on decoder.
534
# TODO(@dathudeptrai) switch to zone-out lstm.
535
self.attention_lstm = tf.keras.layers.LSTMCell(
536
units=config.decoder_lstm_units, name="attention_lstm_cell"
537
)
538
lstm_cells = []
539
for i in range(config.n_lstm_decoder):
540
lstm_cell = tf.keras.layers.LSTMCell(
541
units=config.decoder_lstm_units, name="lstm_cell_._{}".format(i)
542
)
543
lstm_cells.append(lstm_cell)
544
self.decoder_lstms = tf.keras.layers.StackedRNNCells(
545
lstm_cells, name="decoder_lstms"
546
)
547
548
# define attention layer.
549
if config.attention_type == "lsa":
550
# create location-sensitive attention.
551
self.attention_layer = TFTacotronLocationSensitiveAttention(
552
config,
553
memory=None,
554
mask_encoder=True,
555
memory_sequence_length=None,
556
is_cumulate=True,
557
)
558
else:
559
raise ValueError("Only lsa (location-sensitive attention) is supported")
560
561
# frame, stop projection layer.
562
self.frame_projection = tf.keras.layers.Dense(
563
units=config.n_mels * config.reduction_factor, name="frame_projection"
564
)
565
self.stop_projection = tf.keras.layers.Dense(
566
units=config.reduction_factor, name="stop_projection"
567
)
568
569
self.config = config
570
571
def set_alignment_size(self, alignment_size):
572
self.alignment_size = alignment_size
573
574
@property
575
def output_size(self):
576
"""Return output (mel) size."""
577
return self.frame_projection.units
578
579
@property
580
def state_size(self):
581
"""Return hidden state size."""
582
return TFTacotronDecoderCellState(
583
attention_lstm_state=self.attention_lstm.state_size,
584
decoder_lstms_state=self.decoder_lstms.state_size,
585
time=tf.TensorShape([]),
586
attention=self.config.attention_dim,
587
state=self.alignment_size,
588
alignment_history=(),
589
max_alignments=tf.TensorShape([1]),
590
)
591
592
def get_initial_state(self, batch_size):
593
"""Get initial states."""
594
initial_attention_lstm_cell_states = self.attention_lstm.get_initial_state(
595
None, batch_size, dtype=tf.float32
596
)
597
initial_decoder_lstms_cell_states = self.decoder_lstms.get_initial_state(
598
None, batch_size, dtype=tf.float32
599
)
600
initial_context = tf.zeros(
601
shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32
602
)
603
initial_state = self.attention_layer.get_initial_state(
604
batch_size, size=self.alignment_size
605
)
606
if self.enable_tflite_convertible:
607
initial_alignment_history = ()
608
else:
609
initial_alignment_history = tf.TensorArray(
610
dtype=tf.float32, size=0, dynamic_size=True
611
)
612
return TFTacotronDecoderCellState(
613
attention_lstm_state=initial_attention_lstm_cell_states,
614
decoder_lstms_state=initial_decoder_lstms_cell_states,
615
time=tf.zeros([], dtype=tf.int32),
616
context=initial_context,
617
state=initial_state,
618
alignment_history=initial_alignment_history,
619
max_alignments=tf.zeros([batch_size], dtype=tf.int32),
620
)
621
622
def call(self, inputs, states, training=False):
623
"""Call logic."""
624
decoder_input = inputs
625
626
# 1. apply prenet for decoder_input.
627
prenet_out = self.prenet(decoder_input, training=training) # [batch_size, dim]
628
629
# 2. concat prenet_out and prev context vector
630
# then use it as input of attention lstm layer.
631
attention_lstm_input = tf.concat([prenet_out, states.context], axis=-1)
632
attention_lstm_output, next_attention_lstm_state = self.attention_lstm(
633
attention_lstm_input, states.attention_lstm_state
634
)
635
636
# 3. compute context, alignment and cumulative alignment.
637
prev_state = states.state
638
if not self.enable_tflite_convertible:
639
prev_alignment_history = states.alignment_history
640
prev_max_alignments = states.max_alignments
641
context, alignments, state = self.attention_layer(
642
[attention_lstm_output, prev_state, prev_max_alignments], training=training,
643
)
644
645
# 4. run decoder lstm(s)
646
decoder_lstms_input = tf.concat([attention_lstm_output, context], axis=-1)
647
decoder_lstms_output, next_decoder_lstms_state = self.decoder_lstms(
648
decoder_lstms_input, states.decoder_lstms_state
649
)
650
651
# 5. compute frame feature and stop token.
652
projection_inputs = tf.concat([decoder_lstms_output, context], axis=-1)
653
decoder_outputs = self.frame_projection(projection_inputs)
654
655
stop_inputs = tf.concat([decoder_lstms_output, decoder_outputs], axis=-1)
656
stop_tokens = self.stop_projection(stop_inputs)
657
658
# 6. save alignment history to visualize.
659
if self.enable_tflite_convertible:
660
alignment_history = ()
661
else:
662
alignment_history = prev_alignment_history.write(states.time, alignments)
663
664
# 7. return new states.
665
new_states = TFTacotronDecoderCellState(
666
attention_lstm_state=next_attention_lstm_state,
667
decoder_lstms_state=next_decoder_lstms_state,
668
time=states.time + 1,
669
context=context,
670
state=state,
671
alignment_history=alignment_history,
672
max_alignments=tf.argmax(alignments, -1, output_type=tf.int32),
673
)
674
675
return (decoder_outputs, stop_tokens), new_states
676
677
678
class TFTacotronDecoder(Decoder):
679
"""Tacotron-2 Decoder."""
680
681
def __init__(
682
self,
683
decoder_cell,
684
decoder_sampler,
685
output_layer=None,
686
enable_tflite_convertible=False,
687
):
688
"""Initial variables."""
689
self.cell = decoder_cell
690
self.sampler = decoder_sampler
691
self.output_layer = output_layer
692
self.enable_tflite_convertible = enable_tflite_convertible
693
694
def setup_decoder_init_state(self, decoder_init_state):
695
self.initial_state = decoder_init_state
696
697
def initialize(self, **kwargs):
698
return self.sampler.initialize() + (self.initial_state,)
699
700
@property
701
def output_size(self):
702
return TFDecoderOutput(
703
mel_output=tf.nest.map_structure(
704
lambda shape: tf.TensorShape(shape), self.cell.output_size
705
),
706
token_output=tf.TensorShape(self.sampler.reduction_factor),
707
sample_id=tf.TensorShape([1])
708
if self.enable_tflite_convertible
709
else self.sampler.sample_ids_shape, # tf.TensorShape([])
710
)
711
712
@property
713
def output_dtype(self):
714
return TFDecoderOutput(tf.float32, tf.float32, self.sampler.sample_ids_dtype)
715
716
@property
717
def batch_size(self):
718
return self.sampler._batch_size
719
720
def step(self, time, inputs, state, training=False):
721
(mel_outputs, stop_tokens), cell_state = self.cell(
722
inputs, state, training=training
723
)
724
if self.output_layer is not None:
725
mel_outputs = self.output_layer(mel_outputs)
726
sample_ids = self.sampler.sample(
727
time=time, outputs=mel_outputs, state=cell_state
728
)
729
(finished, next_inputs, next_state) = self.sampler.next_inputs(
730
time=time,
731
outputs=mel_outputs,
732
state=cell_state,
733
sample_ids=sample_ids,
734
stop_token_prediction=stop_tokens,
735
training=training,
736
)
737
738
outputs = TFDecoderOutput(mel_outputs, stop_tokens, sample_ids)
739
return (outputs, next_state, next_inputs, finished)
740
741
742
class TFTacotron2(BaseModel):
743
"""Tensorflow tacotron-2 model."""
744
745
def __init__(self, config, **kwargs):
746
"""Initalize tacotron-2 layers."""
747
enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False)
748
super().__init__(self, **kwargs)
749
self.encoder = TFTacotronEncoder(config, name="encoder")
750
self.decoder_cell = TFTacotronDecoderCell(
751
config,
752
name="decoder_cell",
753
enable_tflite_convertible=enable_tflite_convertible,
754
)
755
self.decoder = TFTacotronDecoder(
756
self.decoder_cell,
757
Tacotron2Sampler(config),
758
enable_tflite_convertible=enable_tflite_convertible,
759
)
760
self.postnet = TFTacotronPostnet(config, name="post_net")
761
self.post_projection = tf.keras.layers.Dense(
762
units=config.n_mels, name="residual_projection"
763
)
764
765
self.use_window_mask = False
766
self.maximum_iterations = 4000
767
self.enable_tflite_convertible = enable_tflite_convertible
768
self.config = config
769
770
def setup_window(self, win_front, win_back):
771
"""Call only for inference."""
772
self.use_window_mask = True
773
self.win_front = win_front
774
self.win_back = win_back
775
776
def setup_maximum_iterations(self, maximum_iterations):
777
"""Call only for inference."""
778
self.maximum_iterations = maximum_iterations
779
780
def _build(self):
781
input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]])
782
input_lengths = np.array([9])
783
speaker_ids = np.array([0])
784
mel_outputs = np.random.normal(size=(1, 50, 80)).astype(np.float32)
785
mel_lengths = np.array([50])
786
self(
787
input_ids,
788
input_lengths,
789
speaker_ids,
790
mel_outputs,
791
mel_lengths,
792
10,
793
training=True,
794
)
795
796
def call(
797
self,
798
input_ids,
799
input_lengths,
800
speaker_ids,
801
mel_gts,
802
mel_lengths,
803
maximum_iterations=None,
804
use_window_mask=False,
805
win_front=2,
806
win_back=3,
807
training=False,
808
**kwargs,
809
):
810
"""Call logic."""
811
# create input-mask based on input_lengths
812
input_mask = tf.sequence_mask(
813
input_lengths,
814
maxlen=tf.reduce_max(input_lengths),
815
name="input_sequence_masks",
816
)
817
818
# Encoder Step.
819
encoder_hidden_states = self.encoder(
820
[input_ids, speaker_ids, input_mask], training=training
821
)
822
823
batch_size = tf.shape(encoder_hidden_states)[0]
824
alignment_size = tf.shape(encoder_hidden_states)[1]
825
826
# Setup some initial placeholders for decoder step. Include:
827
# 1. mel_gts, mel_lengths for teacher forcing mode.
828
# 2. alignment_size for attention size.
829
# 3. initial state for decoder cell.
830
# 4. memory (encoder hidden state) for attention mechanism.
831
self.decoder.sampler.setup_target(targets=mel_gts, mel_lengths=mel_lengths)
832
self.decoder.cell.set_alignment_size(alignment_size)
833
self.decoder.setup_decoder_init_state(
834
self.decoder.cell.get_initial_state(batch_size)
835
)
836
self.decoder.cell.attention_layer.setup_memory(
837
memory=encoder_hidden_states,
838
memory_sequence_length=input_lengths, # use for mask attention.
839
)
840
if use_window_mask:
841
self.decoder.cell.attention_layer.setup_window(
842
win_front=win_front, win_back=win_back
843
)
844
845
# run decode step.
846
(
847
(frames_prediction, stop_token_prediction, _),
848
final_decoder_state,
849
_,
850
) = dynamic_decode(
851
self.decoder,
852
maximum_iterations=maximum_iterations,
853
enable_tflite_convertible=self.enable_tflite_convertible,
854
training=training,
855
)
856
857
decoder_outputs = tf.reshape(
858
frames_prediction, [batch_size, -1, self.config.n_mels]
859
)
860
stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])
861
862
residual = self.postnet(decoder_outputs, training=training)
863
residual_projection = self.post_projection(residual)
864
865
mel_outputs = decoder_outputs + residual_projection
866
867
if self.enable_tflite_convertible:
868
mask = tf.math.not_equal(
869
tf.cast(
870
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32
871
),
872
0,
873
)
874
decoder_outputs = tf.expand_dims(
875
tf.boolean_mask(decoder_outputs, mask), axis=0
876
)
877
mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0)
878
alignment_history = ()
879
else:
880
alignment_history = tf.transpose(
881
final_decoder_state.alignment_history.stack(), [1, 2, 0]
882
)
883
884
return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history
885
886
@tf.function(
887
experimental_relax_shapes=True,
888
input_signature=[
889
tf.TensorSpec([None, None], dtype=tf.int32, name="input_ids"),
890
tf.TensorSpec([None,], dtype=tf.int32, name="input_lengths"),
891
tf.TensorSpec([None,], dtype=tf.int32, name="speaker_ids"),
892
],
893
)
894
def inference(self, input_ids, input_lengths, speaker_ids, **kwargs):
895
"""Call logic."""
896
# create input-mask based on input_lengths
897
input_mask = tf.sequence_mask(
898
input_lengths,
899
maxlen=tf.reduce_max(input_lengths),
900
name="input_sequence_masks",
901
)
902
903
# Encoder Step.
904
encoder_hidden_states = self.encoder(
905
[input_ids, speaker_ids, input_mask], training=False
906
)
907
908
batch_size = tf.shape(encoder_hidden_states)[0]
909
alignment_size = tf.shape(encoder_hidden_states)[1]
910
911
# Setup some initial placeholders for decoder step. Include:
912
# 1. batch_size for inference.
913
# 2. alignment_size for attention size.
914
# 3. initial state for decoder cell.
915
# 4. memory (encoder hidden state) for attention mechanism.
916
# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
917
self.decoder.sampler.set_batch_size(batch_size)
918
self.decoder.cell.set_alignment_size(alignment_size)
919
self.decoder.setup_decoder_init_state(
920
self.decoder.cell.get_initial_state(batch_size)
921
)
922
self.decoder.cell.attention_layer.setup_memory(
923
memory=encoder_hidden_states,
924
memory_sequence_length=input_lengths, # use for mask attention.
925
)
926
if self.use_window_mask:
927
self.decoder.cell.attention_layer.setup_window(
928
win_front=self.win_front, win_back=self.win_back
929
)
930
931
# run decode step.
932
(
933
(frames_prediction, stop_token_prediction, _),
934
final_decoder_state,
935
_,
936
) = dynamic_decode(
937
self.decoder, maximum_iterations=self.maximum_iterations, training=False
938
)
939
940
decoder_outputs = tf.reshape(
941
frames_prediction, [batch_size, -1, self.config.n_mels]
942
)
943
stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1])
944
945
residual = self.postnet(decoder_outputs, training=False)
946
residual_projection = self.post_projection(residual)
947
948
mel_outputs = decoder_outputs + residual_projection
949
950
alignment_historys = tf.transpose(
951
final_decoder_state.alignment_history.stack(), [1, 2, 0]
952
)
953
954
return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys
955
956
@tf.function(
957
experimental_relax_shapes=True,
958
input_signature=[
959
tf.TensorSpec([1, None], dtype=tf.int32, name="input_ids"),
960
tf.TensorSpec([1,], dtype=tf.int32, name="input_lengths"),
961
tf.TensorSpec([1,], dtype=tf.int32, name="speaker_ids"),
962
],
963
)
964
def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs):
965
"""Call logic."""
966
# create input-mask based on input_lengths
967
input_mask = tf.sequence_mask(
968
input_lengths,
969
maxlen=tf.reduce_max(input_lengths),
970
name="input_sequence_masks",
971
)
972
973
# Encoder Step.
974
encoder_hidden_states = self.encoder(
975
[input_ids, speaker_ids, input_mask], training=False
976
)
977
978
batch_size = tf.shape(encoder_hidden_states)[0]
979
alignment_size = tf.shape(encoder_hidden_states)[1]
980
981
# Setup some initial placeholders for decoder step. Include:
982
# 1. batch_size for inference.
983
# 2. alignment_size for attention size.
984
# 3. initial state for decoder cell.
985
# 4. memory (encoder hidden state) for attention mechanism.
986
# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
987
self.decoder.sampler.set_batch_size(batch_size)
988
self.decoder.cell.set_alignment_size(alignment_size)
989
self.decoder.setup_decoder_init_state(
990
self.decoder.cell.get_initial_state(batch_size)
991
)
992
self.decoder.cell.attention_layer.setup_memory(
993
memory=encoder_hidden_states,
994
memory_sequence_length=input_lengths, # use for mask attention.
995
)
996
if self.use_window_mask:
997
self.decoder.cell.attention_layer.setup_window(
998
win_front=self.win_front, win_back=self.win_back
999
)
1000
1001
# run decode step.
1002
(
1003
(frames_prediction, stop_token_prediction, _),
1004
final_decoder_state,
1005
_,
1006
) = dynamic_decode(
1007
self.decoder,
1008
maximum_iterations=self.maximum_iterations,
1009
enable_tflite_convertible=self.enable_tflite_convertible,
1010
training=False,
1011
)
1012
1013
decoder_outputs = tf.reshape(
1014
frames_prediction, [batch_size, -1, self.config.n_mels]
1015
)
1016
stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1])
1017
1018
residual = self.postnet(decoder_outputs, training=False)
1019
residual_projection = self.post_projection(residual)
1020
1021
mel_outputs = decoder_outputs + residual_projection
1022
1023
if self.enable_tflite_convertible:
1024
mask = tf.math.not_equal(
1025
tf.cast(
1026
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32
1027
),
1028
0,
1029
)
1030
decoder_outputs = tf.expand_dims(
1031
tf.boolean_mask(decoder_outputs, mask), axis=0
1032
)
1033
mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0)
1034
alignment_historys = ()
1035
else:
1036
alignment_historys = tf.transpose(
1037
final_decoder_state.alignment_history.stack(), [1, 2, 0]
1038
)
1039
1040
return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys
1041
1042