Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/transformer_asr.py
8174 views
1
"""
2
Title: Automatic Speech Recognition with Transformer
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2021/01/13
5
Last modified: 2021/01/13
6
Description: Training a sequence-to-sequence Transformer for automatic speech recognition.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Automatic speech recognition (ASR) consists of transcribing audio speech segments into text.
14
ASR can be treated as a sequence-to-sequence problem, where the
15
audio can be represented as a sequence of feature vectors
16
and the text as a sequence of characters, words, or subword tokens.
17
18
For this demonstration, we will use the LJSpeech dataset from the
19
[LibriVox](https://librivox.org/) project. It consists of short
20
audio clips of a single speaker reading passages from 7 non-fiction books.
21
Our model will be similar to the original Transformer (both encoder and decoder)
22
as proposed in the paper, "Attention is All You Need".
23
24
25
**References:**
26
27
- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
28
- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377)
29
- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506)
30
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
31
"""
32
33
import re
34
import os
35
36
os.environ["KERAS_BACKEND"] = "tensorflow"
37
38
from glob import glob
39
import tensorflow as tf
40
import keras
41
from keras import layers
42
43
"""
44
## Define the Transformer Input Layer
45
46
When processing past target tokens for the decoder, we compute the sum of
47
position embeddings and token embeddings.
48
49
When processing audio features, we apply convolutional layers to downsample
50
them (via convolution strides) and process local relationships.
51
"""
52
53
54
class TokenEmbedding(layers.Layer):
55
def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
56
super().__init__()
57
self.emb = keras.layers.Embedding(num_vocab, num_hid)
58
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)
59
60
def call(self, x):
61
maxlen = tf.shape(x)[-1]
62
x = self.emb(x)
63
positions = tf.range(start=0, limit=maxlen, delta=1)
64
positions = self.pos_emb(positions)
65
return x + positions
66
67
68
class SpeechFeatureEmbedding(layers.Layer):
69
def __init__(self, num_hid=64, maxlen=100):
70
super().__init__()
71
self.conv1 = keras.layers.Conv1D(
72
num_hid, 11, strides=2, padding="same", activation="relu"
73
)
74
self.conv2 = keras.layers.Conv1D(
75
num_hid, 11, strides=2, padding="same", activation="relu"
76
)
77
self.conv3 = keras.layers.Conv1D(
78
num_hid, 11, strides=2, padding="same", activation="relu"
79
)
80
81
def call(self, x):
82
x = self.conv1(x)
83
x = self.conv2(x)
84
return self.conv3(x)
85
86
87
"""
88
## Transformer Encoder Layer
89
"""
90
91
92
class TransformerEncoder(layers.Layer):
93
def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
94
super().__init__()
95
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
96
self.ffn = keras.Sequential(
97
[
98
layers.Dense(feed_forward_dim, activation="relu"),
99
layers.Dense(embed_dim),
100
]
101
)
102
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
103
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
104
self.dropout1 = layers.Dropout(rate)
105
self.dropout2 = layers.Dropout(rate)
106
107
def call(self, inputs, training=False):
108
attn_output = self.att(inputs, inputs)
109
attn_output = self.dropout1(attn_output, training=training)
110
out1 = self.layernorm1(inputs + attn_output)
111
ffn_output = self.ffn(out1)
112
ffn_output = self.dropout2(ffn_output, training=training)
113
return self.layernorm2(out1 + ffn_output)
114
115
116
"""
117
## Transformer Decoder Layer
118
"""
119
120
121
class TransformerDecoder(layers.Layer):
122
def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
123
super().__init__()
124
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
125
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
126
self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
127
self.self_att = layers.MultiHeadAttention(
128
num_heads=num_heads, key_dim=embed_dim
129
)
130
self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
131
self.self_dropout = layers.Dropout(0.5)
132
self.enc_dropout = layers.Dropout(0.1)
133
self.ffn_dropout = layers.Dropout(0.1)
134
self.ffn = keras.Sequential(
135
[
136
layers.Dense(feed_forward_dim, activation="relu"),
137
layers.Dense(embed_dim),
138
]
139
)
140
141
def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
142
"""Masks the upper half of the dot product matrix in self attention.
143
144
This prevents flow of information from future tokens to current token.
145
1's in the lower triangle, counting from the lower right corner.
146
"""
147
i = tf.range(n_dest)[:, None]
148
j = tf.range(n_src)
149
m = i >= j - n_src + n_dest
150
mask = tf.cast(m, dtype)
151
mask = tf.reshape(mask, [1, n_dest, n_src])
152
mult = tf.concat(
153
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
154
)
155
return tf.tile(mask, mult)
156
157
def call(self, enc_out, target):
158
input_shape = tf.shape(target)
159
batch_size = input_shape[0]
160
seq_len = input_shape[1]
161
causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
162
target_att = self.self_att(target, target, attention_mask=causal_mask)
163
target_norm = self.layernorm1(target + self.self_dropout(target_att))
164
enc_out = self.enc_att(target_norm, enc_out)
165
enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
166
ffn_out = self.ffn(enc_out_norm)
167
ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
168
return ffn_out_norm
169
170
171
"""
172
## Complete the Transformer model
173
174
Our model takes audio spectrograms as inputs and predicts a sequence of characters.
175
During training, we give the decoder the target character sequence shifted to the left
176
as input. During inference, the decoder uses its own past predictions to predict the
177
next token.
178
"""
179
180
181
class Transformer(keras.Model):
182
def __init__(
183
self,
184
num_hid=64,
185
num_head=2,
186
num_feed_forward=128,
187
source_maxlen=100,
188
target_maxlen=100,
189
num_layers_enc=4,
190
num_layers_dec=1,
191
num_classes=10,
192
):
193
super().__init__()
194
self.loss_metric = keras.metrics.Mean(name="loss")
195
self.num_layers_enc = num_layers_enc
196
self.num_layers_dec = num_layers_dec
197
self.target_maxlen = target_maxlen
198
self.num_classes = num_classes
199
200
self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)
201
self.dec_input = TokenEmbedding(
202
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
203
)
204
205
self.encoder = keras.Sequential(
206
[self.enc_input]
207
+ [
208
TransformerEncoder(num_hid, num_head, num_feed_forward)
209
for _ in range(num_layers_enc)
210
]
211
)
212
213
for i in range(num_layers_dec):
214
setattr(
215
self,
216
f"dec_layer_{i}",
217
TransformerDecoder(num_hid, num_head, num_feed_forward),
218
)
219
220
self.classifier = layers.Dense(num_classes)
221
222
def decode(self, enc_out, target):
223
y = self.dec_input(target)
224
for i in range(self.num_layers_dec):
225
y = getattr(self, f"dec_layer_{i}")(enc_out, y)
226
return y
227
228
def call(self, inputs):
229
source = inputs[0]
230
target = inputs[1]
231
x = self.encoder(source)
232
y = self.decode(x, target)
233
return self.classifier(y)
234
235
@property
236
def metrics(self):
237
return [self.loss_metric]
238
239
def train_step(self, batch):
240
"""Processes one batch inside model.fit()."""
241
source = batch["source"]
242
target = batch["target"]
243
dec_input = target[:, :-1]
244
dec_target = target[:, 1:]
245
with tf.GradientTape() as tape:
246
preds = self([source, dec_input])
247
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
248
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
249
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
250
trainable_vars = self.trainable_variables
251
gradients = tape.gradient(loss, trainable_vars)
252
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
253
self.loss_metric.update_state(loss)
254
return {"loss": self.loss_metric.result()}
255
256
def test_step(self, batch):
257
source = batch["source"]
258
target = batch["target"]
259
dec_input = target[:, :-1]
260
dec_target = target[:, 1:]
261
preds = self([source, dec_input])
262
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
263
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
264
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
265
self.loss_metric.update_state(loss)
266
return {"loss": self.loss_metric.result()}
267
268
def generate(self, source, target_start_token_idx):
269
"""Performs inference over one batch of inputs using greedy decoding."""
270
bs = tf.shape(source)[0]
271
enc = self.encoder(source)
272
dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
273
dec_logits = []
274
for i in range(self.target_maxlen - 1):
275
dec_out = self.decode(enc, dec_input)
276
logits = self.classifier(dec_out)
277
logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
278
last_logit = tf.expand_dims(logits[:, -1], axis=-1)
279
dec_logits.append(last_logit)
280
dec_input = tf.concat([dec_input, last_logit], axis=-1)
281
return dec_input
282
283
284
"""
285
## Download the dataset
286
287
Note: This requires ~3.6 GB of disk space and
288
takes ~5 minutes for the extraction of files.
289
"""
290
291
pattern_wav_name = re.compile(r"([^/\\\.]+)")
292
293
keras.utils.get_file(
294
os.path.join(os.getcwd(), "data.tar.gz"),
295
"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
296
extract=True,
297
archive_format="tar",
298
cache_dir=".",
299
)
300
301
302
saveto = "./datasets/LJSpeech-1.1"
303
wavs = glob("{}/**/*.wav".format(saveto), recursive=True)
304
305
id_to_text = {}
306
with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f:
307
for line in f:
308
id = line.strip().split("|")[0]
309
text = line.strip().split("|")[2]
310
id_to_text[id] = text
311
312
313
def get_data(wavs, id_to_text, maxlen=50):
314
"""returns mapping of audio paths and transcription texts"""
315
data = []
316
for w in wavs:
317
id = pattern_wav_name.split(w)[-4]
318
if len(id_to_text[id]) < maxlen:
319
data.append({"audio": w, "text": id_to_text[id]})
320
return data
321
322
323
"""
324
## Preprocess the dataset
325
"""
326
327
328
class VectorizeChar:
329
def __init__(self, max_len=50):
330
self.vocab = (
331
["-", "#", "<", ">"]
332
+ [chr(i + 96) for i in range(1, 27)]
333
+ [" ", ".", ",", "?"]
334
)
335
self.max_len = max_len
336
self.char_to_idx = {}
337
for i, ch in enumerate(self.vocab):
338
self.char_to_idx[ch] = i
339
340
def __call__(self, text):
341
text = text.lower()
342
text = text[: self.max_len - 2]
343
text = "<" + text + ">"
344
pad_len = self.max_len - len(text)
345
return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len
346
347
def get_vocabulary(self):
348
return self.vocab
349
350
351
max_target_len = 200 # all transcripts in out data are < 200 characters
352
data = get_data(wavs, id_to_text, max_target_len)
353
vectorizer = VectorizeChar(max_target_len)
354
print("vocab size", len(vectorizer.get_vocabulary()))
355
356
357
def create_text_ds(data):
358
texts = [_["text"] for _ in data]
359
text_ds = [vectorizer(t) for t in texts]
360
text_ds = tf.data.Dataset.from_tensor_slices(text_ds)
361
return text_ds
362
363
364
def path_to_audio(path):
365
# spectrogram using stft
366
audio = tf.io.read_file(path)
367
audio, _ = tf.audio.decode_wav(audio, 1)
368
audio = tf.squeeze(audio, axis=-1)
369
stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256)
370
x = tf.math.pow(tf.abs(stfts), 0.5)
371
# normalisation
372
means = tf.math.reduce_mean(x, 1, keepdims=True)
373
stddevs = tf.math.reduce_std(x, 1, keepdims=True)
374
x = (x - means) / stddevs
375
audio_len = tf.shape(x)[0]
376
# padding to 10 seconds
377
pad_len = 2754
378
paddings = tf.constant([[0, pad_len], [0, 0]])
379
x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :]
380
return x
381
382
383
def create_audio_ds(data):
384
flist = [_["audio"] for _ in data]
385
audio_ds = tf.data.Dataset.from_tensor_slices(flist)
386
audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE)
387
return audio_ds
388
389
390
def create_tf_dataset(data, bs=4):
391
audio_ds = create_audio_ds(data)
392
text_ds = create_text_ds(data)
393
ds = tf.data.Dataset.zip((audio_ds, text_ds))
394
ds = ds.map(lambda x, y: {"source": x, "target": y})
395
ds = ds.batch(bs)
396
ds = ds.prefetch(tf.data.AUTOTUNE)
397
return ds
398
399
400
split = int(len(data) * 0.99)
401
train_data = data[:split]
402
test_data = data[split:]
403
ds = create_tf_dataset(train_data, bs=64)
404
val_ds = create_tf_dataset(test_data, bs=4)
405
406
"""
407
## Callbacks to display predictions
408
"""
409
410
411
class DisplayOutputs(keras.callbacks.Callback):
412
def __init__(
413
self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28
414
):
415
"""Displays a batch of outputs after every epoch
416
417
Args:
418
batch: A test batch containing the keys "source" and "target"
419
idx_to_token: A List containing the vocabulary tokens corresponding to their indices
420
target_start_token_idx: A start token index in the target vocabulary
421
target_end_token_idx: An end token index in the target vocabulary
422
"""
423
self.batch = batch
424
self.target_start_token_idx = target_start_token_idx
425
self.target_end_token_idx = target_end_token_idx
426
self.idx_to_char = idx_to_token
427
428
def on_epoch_end(self, epoch, logs=None):
429
if epoch % 5 != 0:
430
return
431
source = self.batch["source"]
432
target = self.batch["target"].numpy()
433
bs = tf.shape(source)[0]
434
preds = self.model.generate(source, self.target_start_token_idx)
435
preds = preds.numpy()
436
for i in range(bs):
437
target_text = "".join([self.idx_to_char[_] for _ in target[i, :]])
438
prediction = ""
439
for idx in preds[i, :]:
440
prediction += self.idx_to_char[idx]
441
if idx == self.target_end_token_idx:
442
break
443
print(f"target: {target_text.replace('-','')}")
444
print(f"prediction: {prediction}\n")
445
446
447
"""
448
## Learning rate schedule
449
"""
450
451
452
class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
453
def __init__(
454
self,
455
init_lr=0.00001,
456
lr_after_warmup=0.001,
457
final_lr=0.00001,
458
warmup_epochs=15,
459
decay_epochs=85,
460
steps_per_epoch=203,
461
):
462
super().__init__()
463
self.init_lr = init_lr
464
self.lr_after_warmup = lr_after_warmup
465
self.final_lr = final_lr
466
self.warmup_epochs = warmup_epochs
467
self.decay_epochs = decay_epochs
468
self.steps_per_epoch = steps_per_epoch
469
470
def calculate_lr(self, epoch):
471
"""linear warm up - linear decay"""
472
warmup_lr = (
473
self.init_lr
474
+ ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch
475
)
476
decay_lr = tf.math.maximum(
477
self.final_lr,
478
self.lr_after_warmup
479
- (epoch - self.warmup_epochs)
480
* (self.lr_after_warmup - self.final_lr)
481
/ self.decay_epochs,
482
)
483
return tf.math.minimum(warmup_lr, decay_lr)
484
485
def __call__(self, step):
486
epoch = step // self.steps_per_epoch
487
epoch = tf.cast(epoch, "float32")
488
return self.calculate_lr(epoch)
489
490
491
"""
492
## Create & train the end-to-end model
493
"""
494
495
batch = next(iter(val_ds))
496
497
# The vocabulary to convert predicted indices into characters
498
idx_to_char = vectorizer.get_vocabulary()
499
display_cb = DisplayOutputs(
500
batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3
501
) # set the arguments as per vocabulary index for '<' and '>'
502
503
model = Transformer(
504
num_hid=200,
505
num_head=2,
506
num_feed_forward=400,
507
target_maxlen=max_target_len,
508
num_layers_enc=4,
509
num_layers_dec=1,
510
num_classes=34,
511
)
512
loss_fn = keras.losses.CategoricalCrossentropy(
513
from_logits=True,
514
label_smoothing=0.1,
515
)
516
517
learning_rate = CustomSchedule(
518
init_lr=0.00001,
519
lr_after_warmup=0.001,
520
final_lr=0.00001,
521
warmup_epochs=15,
522
decay_epochs=85,
523
steps_per_epoch=len(ds),
524
)
525
optimizer = keras.optimizers.Adam(learning_rate)
526
model.compile(optimizer=optimizer, loss=loss_fn)
527
528
history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
529
530
"""
531
In practice, you should train for around 100 epochs or more.
532
533
Some of the predicted text at or around epoch 35 may look as follows:
534
```
535
target: <as they sat in the car, frazier asked oswald where his lunch was>
536
prediction: <as they sat in the car frazier his lunch ware mis lunch was>
537
538
target: <under the entry for may one, nineteen sixty,>
539
prediction: <under the introus for may monee, nin the sixty,>
540
```
541
"""
542
543