Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/ctc_asr.py
8146 views
1
"""
2
Title: Automatic Speech Recognition using CTC
3
Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
4
Date created: 2021/09/26
5
Last modified: 2026/01/27
6
Description: Training a CTC-based model for automatic speech recognition.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Harshith K](https://github.com/kharshith-k/)
9
"""
10
11
"""
12
## Introduction
13
14
Speech recognition is an interdisciplinary subfield of computer science
15
and computational linguistics that develops methodologies and technologies
16
that enable the recognition and translation of spoken language into text
17
by computers. It is also known as automatic speech recognition (ASR),
18
computer speech recognition or speech to text (STT). It incorporates
19
knowledge and research in the computer science, linguistics and computer
20
engineering fields.
21
22
This demonstration shows how to combine a 2D CNN, RNN and a Connectionist
23
Temporal Classification (CTC) loss to build an ASR. CTC is an algorithm
24
used to train deep neural networks in speech recognition, handwriting
25
recognition and other sequence problems. CTC is used when we don’t know
26
how the input aligns with the output (how the characters in the transcript
27
align to the audio). The model we create is similar to
28
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
29
30
We will use the LJSpeech dataset from the
31
[LibriVox](https://librivox.org/) project. It consists of short
32
audio clips of a single speaker reading passages from 7 non-fiction books.
33
34
We will evaluate the quality of the model using
35
[Word Error Rate (WER)](https://en.wikipedia.org/wiki/Word_error_rate).
36
WER is obtained by adding up
37
the substitutions, insertions, and deletions that occur in a sequence of
38
recognized words. Divide that number by the total number of words originally
39
spoken. The result is the WER. To get the WER score you need to install the
40
[jiwer](https://pypi.org/project/jiwer/) package. You can use the following command line:
41
42
```
43
pip install jiwer
44
```
45
46
**References:**
47
48
- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
49
- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
50
- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
51
- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
52
"""
53
54
"""
55
## Setup
56
"""
57
58
import pandas as pd
59
import numpy as np
60
import tensorflow as tf
61
import keras
62
from keras import layers
63
from keras import ops
64
import matplotlib.pyplot as plt
65
from IPython import display
66
from jiwer import wer
67
68
"""
69
## Load the LJSpeech Dataset
70
71
Let's download the [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/).
72
The dataset contains 13,100 audio files as `wav` files in the `/wavs/` folder.
73
The label (transcript) for each audio file is a string
74
given in the `metadata.csv` file. The fields are:
75
76
- **ID**: this is the name of the corresponding .wav file
77
- **Transcription**: words spoken by the reader (UTF-8)
78
- **Normalized transcription**: transcription with numbers,
79
ordinals, and monetary units expanded into full words (UTF-8).
80
81
For this demo we will use on the "Normalized transcription" field.
82
83
Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 22,050 Hz.
84
"""
85
86
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
87
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
88
wavs_path = data_path + "/LJSpeech-1.1/wavs/"
89
metadata_path = data_path + "/LJSpeech-1.1" + "/metadata.csv"
90
91
92
# Read metadata file and parse it
93
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
94
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
95
metadata_df = metadata_df[["file_name", "normalized_transcription"]]
96
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
97
metadata_df.head(3)
98
99
"""
100
We now split the data into training and validation set.
101
"""
102
103
split = int(len(metadata_df) * 0.90)
104
df_train = metadata_df[:split]
105
df_val = metadata_df[split:]
106
107
print(f"Size of the training set: {len(df_train)}")
108
print(f"Size of the training set: {len(df_val)}")
109
110
"""
111
## Preprocessing
112
113
We first prepare the vocabulary to be used.
114
"""
115
116
# The set of characters accepted in the transcription.
117
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
118
# Mapping characters to integers
119
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
120
# Mapping integers back to original characters
121
num_to_char = keras.layers.StringLookup(
122
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
123
)
124
125
print(
126
f"The vocabulary is: {char_to_num.get_vocabulary()} "
127
f"(size ={char_to_num.vocabulary_size()})"
128
)
129
130
"""
131
Next, we create the function that describes the transformation that we apply to each
132
element of our dataset.
133
"""
134
135
# An integer scalar Tensor. The window length in samples.
136
frame_length = 256
137
# An integer scalar Tensor. The number of samples to step.
138
frame_step = 160
139
# An integer scalar Tensor. The size of the FFT to apply.
140
# If not provided, uses the smallest power of 2 enclosing frame_length.
141
fft_length = 384
142
143
144
def encode_single_sample(wav_file, label):
145
###########################################
146
## Process the Audio
147
##########################################
148
# 1. Read wav file
149
file = tf.io.read_file(wavs_path + wav_file + ".wav")
150
# 2. Decode the wav file
151
audio, _ = tf.audio.decode_wav(file)
152
audio = ops.squeeze(audio)
153
# 3. Change type to float
154
audio = ops.cast(audio, "float32")
155
# 4. Get the spectrogram
156
stft_output = ops.stft(
157
audio,
158
sequence_length=frame_length,
159
sequence_stride=frame_step,
160
fft_length=fft_length,
161
center=False,
162
)
163
# 5. We only need the magnitude, which can be computed from real and imaginary parts
164
# stft returns (real, imag) tuple - compute magnitude as sqrt(real^2 + imag^2)
165
spectrogram = ops.sqrt(ops.square(stft_output[0]) + ops.square(stft_output[1]))
166
spectrogram = ops.power(spectrogram, 0.5)
167
# 6. normalisation
168
means = ops.mean(spectrogram, axis=1, keepdims=True)
169
stddevs = ops.std(spectrogram, axis=1, keepdims=True)
170
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
171
###########################################
172
## Process the label
173
##########################################
174
# 7. Convert label to Lower case
175
label = tf.strings.lower(label)
176
# 8. Split the label
177
label = tf.strings.unicode_split(label, input_encoding="UTF-8")
178
# 9. Map the characters in label to numbers
179
label = char_to_num(label)
180
# 10. Return a dict as our model is expecting two inputs
181
return spectrogram, label
182
183
184
"""
185
## Creating `Dataset` objects
186
187
We create a `tf.data.Dataset` object that yields
188
the transformed elements, in the same order as they
189
appeared in the input.
190
"""
191
192
batch_size = 32
193
# Define the training dataset
194
train_dataset = tf.data.Dataset.from_tensor_slices(
195
(list(df_train["file_name"]), list(df_train["normalized_transcription"]))
196
)
197
train_dataset = (
198
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
199
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
200
.prefetch(buffer_size=tf.data.AUTOTUNE)
201
)
202
203
# Define the validation dataset
204
validation_dataset = tf.data.Dataset.from_tensor_slices(
205
(list(df_val["file_name"]), list(df_val["normalized_transcription"]))
206
)
207
validation_dataset = (
208
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
209
.padded_batch(batch_size, padded_shapes=([None, fft_length // 2 + 1], [None]))
210
.prefetch(buffer_size=tf.data.AUTOTUNE)
211
)
212
213
"""
214
## Visualize the data
215
216
Let's visualize an example in our dataset, including the
217
audio clip, the spectrogram and the corresponding label.
218
"""
219
220
fig = plt.figure(figsize=(8, 5))
221
for batch in train_dataset.take(1):
222
spectrogram = batch[0][0].numpy()
223
spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)])
224
label = batch[1][0]
225
# Spectrogram
226
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
227
ax = plt.subplot(2, 1, 1)
228
ax.imshow(spectrogram, vmax=1)
229
ax.set_title(label)
230
ax.axis("off")
231
# Wav
232
file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
233
audio, _ = tf.audio.decode_wav(file)
234
audio = audio.numpy()
235
ax = plt.subplot(2, 1, 2)
236
plt.plot(audio)
237
ax.set_title("Signal Wave")
238
ax.set_xlim(0, len(audio))
239
display.display(display.Audio(np.transpose(audio), rate=16000))
240
plt.show()
241
242
"""
243
## Model
244
245
We first define the CTC Loss function.
246
"""
247
248
249
def CTCLoss(y_true, y_pred):
250
# Compute the training-time loss value
251
batch_len = ops.shape(y_true)[0]
252
input_length = ops.shape(y_pred)[1]
253
label_length = ops.shape(y_true)[1]
254
255
# Create length tensors - CTC needs to know the actual sequence lengths
256
input_length = input_length * ops.ones(shape=(batch_len,), dtype="int32")
257
label_length = label_length * ops.ones(shape=(batch_len,), dtype="int32")
258
259
# Use Keras ops CTC loss (backend-agnostic)
260
# Note: mask_index should match the blank token index
261
# With StringLookup(oov_token=""), index 0 is reserved, so we use 0 as mask
262
loss = ops.nn.ctc_loss(
263
target=ops.cast(y_true, "int32"),
264
output=y_pred,
265
target_length=label_length,
266
output_length=input_length,
267
mask_index=0,
268
)
269
return loss
270
271
272
"""
273
We now define our model. We will define a model similar to
274
[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
275
"""
276
277
278
def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
279
"""Model similar to DeepSpeech2."""
280
# Model's input
281
input_spectrogram = layers.Input((None, input_dim), name="input")
282
# Expand the dimension to use 2D CNN.
283
x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram)
284
# Convolution layer 1
285
x = layers.Conv2D(
286
filters=32,
287
kernel_size=[11, 41],
288
strides=[2, 2],
289
padding="same",
290
use_bias=False,
291
name="conv_1",
292
)(x)
293
x = layers.BatchNormalization(name="conv_1_bn")(x)
294
x = layers.ReLU(name="conv_1_relu")(x)
295
# Convolution layer 2
296
x = layers.Conv2D(
297
filters=32,
298
kernel_size=[11, 21],
299
strides=[1, 2],
300
padding="same",
301
use_bias=False,
302
name="conv_2",
303
)(x)
304
x = layers.BatchNormalization(name="conv_2_bn")(x)
305
x = layers.ReLU(name="conv_2_relu")(x)
306
# Reshape the resulted volume to feed the RNNs layers
307
x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
308
# RNN layers
309
for i in range(1, rnn_layers + 1):
310
recurrent = layers.GRU(
311
units=rnn_units,
312
activation="tanh",
313
recurrent_activation="sigmoid",
314
use_bias=True,
315
return_sequences=True,
316
reset_after=True,
317
name=f"gru_{i}",
318
)
319
x = layers.Bidirectional(
320
recurrent, name=f"bidirectional_{i}", merge_mode="concat"
321
)(x)
322
if i < rnn_layers:
323
x = layers.Dropout(rate=0.5)(x)
324
# Dense layer
325
x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
326
x = layers.ReLU(name="dense_1_relu")(x)
327
x = layers.Dropout(rate=0.5)(x)
328
# Classification layer
329
output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
330
# Model
331
model = keras.Model(input_spectrogram, output, name="DeepSpeech_2")
332
# Optimizer
333
opt = keras.optimizers.Adam(learning_rate=1e-4)
334
# Compile the model and return
335
model.compile(optimizer=opt, loss=CTCLoss)
336
return model
337
338
339
# Get the model
340
model = build_model(
341
input_dim=fft_length // 2 + 1,
342
output_dim=char_to_num.vocabulary_size(),
343
rnn_units=512,
344
)
345
model.summary(line_length=110)
346
347
"""
348
## Training and Evaluating
349
"""
350
351
352
# A utility function to decode the output of the network
353
def decode_batch_predictions(pred):
354
input_len = np.ones(pred.shape[0]) * pred.shape[1]
355
356
# Use Keras ops CTC decoder with greedy strategy (backend-agnostic)
357
decoded = ops.nn.ctc_decode(
358
inputs=pred,
359
sequence_lengths=ops.cast(input_len, "int32"),
360
strategy="greedy",
361
mask_index=0,
362
)
363
364
# ctc_decode returns a tuple of (decoded_sequences, log_probabilities)
365
# For greedy strategy, decoded_sequences has shape: (1, batch_size, max_length)
366
# So we need decoded[0][0] to get the batch with shape (batch_size, max_length)
367
decoded_sequences = decoded[0][0]
368
369
# Convert to numpy once for the whole batch
370
decoded_sequences = ops.convert_to_numpy(decoded_sequences)
371
372
# Iterate over the results and get back the text
373
output_text = []
374
for sequence in decoded_sequences:
375
# Remove padding/mask values (0 is the mask index)
376
sequence = sequence[sequence > 0]
377
# Convert indices to characters
378
text = tf.strings.reduce_join(num_to_char(sequence)).numpy().decode("utf-8")
379
output_text.append(text)
380
return output_text
381
382
383
# A callback class to output a few transcriptions during training
384
class CallbackEval(keras.callbacks.Callback):
385
"""Displays a batch of outputs after every epoch."""
386
387
def __init__(self, dataset):
388
super().__init__()
389
self.dataset = dataset
390
391
def on_epoch_end(self, epoch: int, logs=None):
392
predictions = []
393
targets = []
394
# Limit to 10 batches to avoid long evaluation times
395
for i, batch in enumerate(self.dataset):
396
if i >= 10:
397
break
398
X, y = batch
399
print(f"Batch {i}: X shape = {X.shape}, y shape = {y.shape}")
400
batch_predictions = model.predict(X, verbose=0)
401
print(f"Batch {i}: predictions shape = {batch_predictions.shape}")
402
batch_predictions = decode_batch_predictions(batch_predictions)
403
print(f"Batch {i}: decoded {len(batch_predictions)} predictions")
404
predictions.extend(batch_predictions)
405
for label in y:
406
label = (
407
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
408
)
409
targets.append(label)
410
print(f"\nTotal: {len(predictions)} predictions, {len(targets)} targets")
411
wer_score = wer(targets, predictions)
412
print("-" * 100)
413
print(f"Word Error Rate: {wer_score:.4f}")
414
print("-" * 100)
415
for i in np.random.randint(0, len(predictions), 2):
416
print(f"Target : {targets[i]}")
417
print(f"Prediction: {predictions[i]}")
418
print("-" * 100)
419
420
421
"""
422
Let's start the training process.
423
"""
424
425
# Define the number of epochs.
426
epochs = 1
427
# Callback function to check transcription on the val set.
428
validation_callback = CallbackEval(validation_dataset)
429
# Train the model
430
history = model.fit(
431
train_dataset,
432
validation_data=validation_dataset,
433
epochs=epochs,
434
callbacks=[validation_callback],
435
)
436
437
"""
438
## Inference
439
"""
440
441
# Let's check results on more validation samples
442
predictions = []
443
targets = []
444
for batch in validation_dataset:
445
X, y = batch
446
batch_predictions = model.predict(X)
447
batch_predictions = decode_batch_predictions(batch_predictions)
448
predictions.extend(batch_predictions)
449
for label in y:
450
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
451
targets.append(label)
452
wer_score = wer(targets, predictions)
453
print("-" * 100)
454
print(f"Word Error Rate: {wer_score:.4f}")
455
print("-" * 100)
456
for i in np.random.randint(0, len(predictions), 5):
457
print(f"Target : {targets[i]}")
458
print(f"Prediction: {predictions[i]}")
459
print("-" * 100)
460
461
"""
462
## Conclusion
463
464
In practice, you should train for around 50 epochs or more. Each epoch
465
takes approximately 8-10 minutes using a `Colab A100` GPU.
466
The model we trained at 50 epochs has a `Word Error Rate (WER) β‰ˆ 16% to 17%`.
467
468
Some of the transcriptions around epoch 50:
469
470
**Audio file: LJ017-0009.wav**
471
```
472
- Target : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign
473
of james the first
474
- Prediction: cer thomas overbery was undoubtedly poisoned by lordrochester in the reign
475
of james the first
476
```
477
478
**Audio file: LJ003-0340.wav**
479
```
480
- Target : the committee does not seem to have yet understood that newgate could be
481
only and properly replaced
482
- Prediction: the committee does not seem to have yet understood that newgate could be
483
only and proberly replace
484
```
485
486
**Audio file: LJ011-0136.wav**
487
```
488
- Target : still no sentence of death was carried out for the offense and in eighteen
489
thirtytwo
490
- Prediction: still no sentence of death was carried out for the offense and in eighteen
491
thirtytwo
492
```
493
494
Example available on HuggingFace.
495
| Trained Model | Demo |
496
| :--: | :--: |
497
| [![Generic
498
badge](https://img.shields.io/badge/πŸ€—%20Model-CTC%20ASR-black.svg)](https://huggingface.co
499
/keras-io/ctc_asr)
500
"""
501
502