Path: blob/master/examples/audio/transformer_asr.py
8174 views
"""1Title: Automatic Speech Recognition with Transformer2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2021/01/134Last modified: 2021/01/135Description: Training a sequence-to-sequence Transformer for automatic speech recognition.6Accelerator: GPU7"""89"""10## Introduction1112Automatic speech recognition (ASR) consists of transcribing audio speech segments into text.13ASR can be treated as a sequence-to-sequence problem, where the14audio can be represented as a sequence of feature vectors15and the text as a sequence of characters, words, or subword tokens.1617For this demonstration, we will use the LJSpeech dataset from the18[LibriVox](https://librivox.org/) project. It consists of short19audio clips of a single speaker reading passages from 7 non-fiction books.20Our model will be similar to the original Transformer (both encoder and decoder)21as proposed in the paper, "Attention is All You Need".222324**References:**2526- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)27- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377)28- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506)29- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)30"""3132import re33import os3435os.environ["KERAS_BACKEND"] = "tensorflow"3637from glob import glob38import tensorflow as tf39import keras40from keras import layers4142"""43## Define the Transformer Input Layer4445When processing past target tokens for the decoder, we compute the sum of46position embeddings and token embeddings.4748When processing audio features, we apply convolutional layers to downsample49them (via convolution strides) and process local relationships.50"""515253class TokenEmbedding(layers.Layer):54def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):55super().__init__()56self.emb = keras.layers.Embedding(num_vocab, num_hid)57self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)5859def call(self, x):60maxlen = tf.shape(x)[-1]61x = self.emb(x)62positions = tf.range(start=0, limit=maxlen, delta=1)63positions = self.pos_emb(positions)64return x + positions656667class SpeechFeatureEmbedding(layers.Layer):68def __init__(self, num_hid=64, maxlen=100):69super().__init__()70self.conv1 = keras.layers.Conv1D(71num_hid, 11, strides=2, padding="same", activation="relu"72)73self.conv2 = keras.layers.Conv1D(74num_hid, 11, strides=2, padding="same", activation="relu"75)76self.conv3 = keras.layers.Conv1D(77num_hid, 11, strides=2, padding="same", activation="relu"78)7980def call(self, x):81x = self.conv1(x)82x = self.conv2(x)83return self.conv3(x)848586"""87## Transformer Encoder Layer88"""899091class TransformerEncoder(layers.Layer):92def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):93super().__init__()94self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)95self.ffn = keras.Sequential(96[97layers.Dense(feed_forward_dim, activation="relu"),98layers.Dense(embed_dim),99]100)101self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)102self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)103self.dropout1 = layers.Dropout(rate)104self.dropout2 = layers.Dropout(rate)105106def call(self, inputs, training=False):107attn_output = self.att(inputs, inputs)108attn_output = self.dropout1(attn_output, training=training)109out1 = self.layernorm1(inputs + attn_output)110ffn_output = self.ffn(out1)111ffn_output = self.dropout2(ffn_output, training=training)112return self.layernorm2(out1 + ffn_output)113114115"""116## Transformer Decoder Layer117"""118119120class TransformerDecoder(layers.Layer):121def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):122super().__init__()123self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)124self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)125self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)126self.self_att = layers.MultiHeadAttention(127num_heads=num_heads, key_dim=embed_dim128)129self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)130self.self_dropout = layers.Dropout(0.5)131self.enc_dropout = layers.Dropout(0.1)132self.ffn_dropout = layers.Dropout(0.1)133self.ffn = keras.Sequential(134[135layers.Dense(feed_forward_dim, activation="relu"),136layers.Dense(embed_dim),137]138)139140def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):141"""Masks the upper half of the dot product matrix in self attention.142143This prevents flow of information from future tokens to current token.1441's in the lower triangle, counting from the lower right corner.145"""146i = tf.range(n_dest)[:, None]147j = tf.range(n_src)148m = i >= j - n_src + n_dest149mask = tf.cast(m, dtype)150mask = tf.reshape(mask, [1, n_dest, n_src])151mult = tf.concat(152[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0153)154return tf.tile(mask, mult)155156def call(self, enc_out, target):157input_shape = tf.shape(target)158batch_size = input_shape[0]159seq_len = input_shape[1]160causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)161target_att = self.self_att(target, target, attention_mask=causal_mask)162target_norm = self.layernorm1(target + self.self_dropout(target_att))163enc_out = self.enc_att(target_norm, enc_out)164enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)165ffn_out = self.ffn(enc_out_norm)166ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))167return ffn_out_norm168169170"""171## Complete the Transformer model172173Our model takes audio spectrograms as inputs and predicts a sequence of characters.174During training, we give the decoder the target character sequence shifted to the left175as input. During inference, the decoder uses its own past predictions to predict the176next token.177"""178179180class Transformer(keras.Model):181def __init__(182self,183num_hid=64,184num_head=2,185num_feed_forward=128,186source_maxlen=100,187target_maxlen=100,188num_layers_enc=4,189num_layers_dec=1,190num_classes=10,191):192super().__init__()193self.loss_metric = keras.metrics.Mean(name="loss")194self.num_layers_enc = num_layers_enc195self.num_layers_dec = num_layers_dec196self.target_maxlen = target_maxlen197self.num_classes = num_classes198199self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)200self.dec_input = TokenEmbedding(201num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid202)203204self.encoder = keras.Sequential(205[self.enc_input]206+ [207TransformerEncoder(num_hid, num_head, num_feed_forward)208for _ in range(num_layers_enc)209]210)211212for i in range(num_layers_dec):213setattr(214self,215f"dec_layer_{i}",216TransformerDecoder(num_hid, num_head, num_feed_forward),217)218219self.classifier = layers.Dense(num_classes)220221def decode(self, enc_out, target):222y = self.dec_input(target)223for i in range(self.num_layers_dec):224y = getattr(self, f"dec_layer_{i}")(enc_out, y)225return y226227def call(self, inputs):228source = inputs[0]229target = inputs[1]230x = self.encoder(source)231y = self.decode(x, target)232return self.classifier(y)233234@property235def metrics(self):236return [self.loss_metric]237238def train_step(self, batch):239"""Processes one batch inside model.fit()."""240source = batch["source"]241target = batch["target"]242dec_input = target[:, :-1]243dec_target = target[:, 1:]244with tf.GradientTape() as tape:245preds = self([source, dec_input])246one_hot = tf.one_hot(dec_target, depth=self.num_classes)247mask = tf.math.logical_not(tf.math.equal(dec_target, 0))248loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)249trainable_vars = self.trainable_variables250gradients = tape.gradient(loss, trainable_vars)251self.optimizer.apply_gradients(zip(gradients, trainable_vars))252self.loss_metric.update_state(loss)253return {"loss": self.loss_metric.result()}254255def test_step(self, batch):256source = batch["source"]257target = batch["target"]258dec_input = target[:, :-1]259dec_target = target[:, 1:]260preds = self([source, dec_input])261one_hot = tf.one_hot(dec_target, depth=self.num_classes)262mask = tf.math.logical_not(tf.math.equal(dec_target, 0))263loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)264self.loss_metric.update_state(loss)265return {"loss": self.loss_metric.result()}266267def generate(self, source, target_start_token_idx):268"""Performs inference over one batch of inputs using greedy decoding."""269bs = tf.shape(source)[0]270enc = self.encoder(source)271dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx272dec_logits = []273for i in range(self.target_maxlen - 1):274dec_out = self.decode(enc, dec_input)275logits = self.classifier(dec_out)276logits = tf.argmax(logits, axis=-1, output_type=tf.int32)277last_logit = tf.expand_dims(logits[:, -1], axis=-1)278dec_logits.append(last_logit)279dec_input = tf.concat([dec_input, last_logit], axis=-1)280return dec_input281282283"""284## Download the dataset285286Note: This requires ~3.6 GB of disk space and287takes ~5 minutes for the extraction of files.288"""289290pattern_wav_name = re.compile(r"([^/\\\.]+)")291292keras.utils.get_file(293os.path.join(os.getcwd(), "data.tar.gz"),294"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",295extract=True,296archive_format="tar",297cache_dir=".",298)299300301saveto = "./datasets/LJSpeech-1.1"302wavs = glob("{}/**/*.wav".format(saveto), recursive=True)303304id_to_text = {}305with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f:306for line in f:307id = line.strip().split("|")[0]308text = line.strip().split("|")[2]309id_to_text[id] = text310311312def get_data(wavs, id_to_text, maxlen=50):313"""returns mapping of audio paths and transcription texts"""314data = []315for w in wavs:316id = pattern_wav_name.split(w)[-4]317if len(id_to_text[id]) < maxlen:318data.append({"audio": w, "text": id_to_text[id]})319return data320321322"""323## Preprocess the dataset324"""325326327class VectorizeChar:328def __init__(self, max_len=50):329self.vocab = (330["-", "#", "<", ">"]331+ [chr(i + 96) for i in range(1, 27)]332+ [" ", ".", ",", "?"]333)334self.max_len = max_len335self.char_to_idx = {}336for i, ch in enumerate(self.vocab):337self.char_to_idx[ch] = i338339def __call__(self, text):340text = text.lower()341text = text[: self.max_len - 2]342text = "<" + text + ">"343pad_len = self.max_len - len(text)344return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len345346def get_vocabulary(self):347return self.vocab348349350max_target_len = 200 # all transcripts in out data are < 200 characters351data = get_data(wavs, id_to_text, max_target_len)352vectorizer = VectorizeChar(max_target_len)353print("vocab size", len(vectorizer.get_vocabulary()))354355356def create_text_ds(data):357texts = [_["text"] for _ in data]358text_ds = [vectorizer(t) for t in texts]359text_ds = tf.data.Dataset.from_tensor_slices(text_ds)360return text_ds361362363def path_to_audio(path):364# spectrogram using stft365audio = tf.io.read_file(path)366audio, _ = tf.audio.decode_wav(audio, 1)367audio = tf.squeeze(audio, axis=-1)368stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256)369x = tf.math.pow(tf.abs(stfts), 0.5)370# normalisation371means = tf.math.reduce_mean(x, 1, keepdims=True)372stddevs = tf.math.reduce_std(x, 1, keepdims=True)373x = (x - means) / stddevs374audio_len = tf.shape(x)[0]375# padding to 10 seconds376pad_len = 2754377paddings = tf.constant([[0, pad_len], [0, 0]])378x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :]379return x380381382def create_audio_ds(data):383flist = [_["audio"] for _ in data]384audio_ds = tf.data.Dataset.from_tensor_slices(flist)385audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE)386return audio_ds387388389def create_tf_dataset(data, bs=4):390audio_ds = create_audio_ds(data)391text_ds = create_text_ds(data)392ds = tf.data.Dataset.zip((audio_ds, text_ds))393ds = ds.map(lambda x, y: {"source": x, "target": y})394ds = ds.batch(bs)395ds = ds.prefetch(tf.data.AUTOTUNE)396return ds397398399split = int(len(data) * 0.99)400train_data = data[:split]401test_data = data[split:]402ds = create_tf_dataset(train_data, bs=64)403val_ds = create_tf_dataset(test_data, bs=4)404405"""406## Callbacks to display predictions407"""408409410class DisplayOutputs(keras.callbacks.Callback):411def __init__(412self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28413):414"""Displays a batch of outputs after every epoch415416Args:417batch: A test batch containing the keys "source" and "target"418idx_to_token: A List containing the vocabulary tokens corresponding to their indices419target_start_token_idx: A start token index in the target vocabulary420target_end_token_idx: An end token index in the target vocabulary421"""422self.batch = batch423self.target_start_token_idx = target_start_token_idx424self.target_end_token_idx = target_end_token_idx425self.idx_to_char = idx_to_token426427def on_epoch_end(self, epoch, logs=None):428if epoch % 5 != 0:429return430source = self.batch["source"]431target = self.batch["target"].numpy()432bs = tf.shape(source)[0]433preds = self.model.generate(source, self.target_start_token_idx)434preds = preds.numpy()435for i in range(bs):436target_text = "".join([self.idx_to_char[_] for _ in target[i, :]])437prediction = ""438for idx in preds[i, :]:439prediction += self.idx_to_char[idx]440if idx == self.target_end_token_idx:441break442print(f"target: {target_text.replace('-','')}")443print(f"prediction: {prediction}\n")444445446"""447## Learning rate schedule448"""449450451class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):452def __init__(453self,454init_lr=0.00001,455lr_after_warmup=0.001,456final_lr=0.00001,457warmup_epochs=15,458decay_epochs=85,459steps_per_epoch=203,460):461super().__init__()462self.init_lr = init_lr463self.lr_after_warmup = lr_after_warmup464self.final_lr = final_lr465self.warmup_epochs = warmup_epochs466self.decay_epochs = decay_epochs467self.steps_per_epoch = steps_per_epoch468469def calculate_lr(self, epoch):470"""linear warm up - linear decay"""471warmup_lr = (472self.init_lr473+ ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch474)475decay_lr = tf.math.maximum(476self.final_lr,477self.lr_after_warmup478- (epoch - self.warmup_epochs)479* (self.lr_after_warmup - self.final_lr)480/ self.decay_epochs,481)482return tf.math.minimum(warmup_lr, decay_lr)483484def __call__(self, step):485epoch = step // self.steps_per_epoch486epoch = tf.cast(epoch, "float32")487return self.calculate_lr(epoch)488489490"""491## Create & train the end-to-end model492"""493494batch = next(iter(val_ds))495496# The vocabulary to convert predicted indices into characters497idx_to_char = vectorizer.get_vocabulary()498display_cb = DisplayOutputs(499batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3500) # set the arguments as per vocabulary index for '<' and '>'501502model = Transformer(503num_hid=200,504num_head=2,505num_feed_forward=400,506target_maxlen=max_target_len,507num_layers_enc=4,508num_layers_dec=1,509num_classes=34,510)511loss_fn = keras.losses.CategoricalCrossentropy(512from_logits=True,513label_smoothing=0.1,514)515516learning_rate = CustomSchedule(517init_lr=0.00001,518lr_after_warmup=0.001,519final_lr=0.00001,520warmup_epochs=15,521decay_epochs=85,522steps_per_epoch=len(ds),523)524optimizer = keras.optimizers.Adam(learning_rate)525model.compile(optimizer=optimizer, loss=loss_fn)526527history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)528529"""530In practice, you should train for around 100 epochs or more.531532Some of the predicted text at or around epoch 35 may look as follows:533```534target: <as they sat in the car, frazier asked oswald where his lunch was>535prediction: <as they sat in the car frazier his lunch ware mis lunch was>536537target: <under the entry for may one, nineteen sixty,>538prediction: <under the introus for may monee, nin the sixty,>539```540"""541542543