Path: blob/master/examples/audio/uk_ireland_accent_recognition.py
8338 views
"""1Title: English speaker accent recognition using Transfer Learning2Author: [Fadi Badine](https://twitter.com/fadibadine)3Date created: 2022/04/164Last modified: 2022/04/165Description: Training a model to classify UK & Ireland accents using feature extraction from Yamnet.6Accelerator: GPU7"""89"""10## Introduction1112The following example shows how to use feature extraction in order to13train a model to classify the English accent spoken in an audio wave.1415Instead of training a model from scratch, transfer learning enables us to16take advantage of existing state-of-the-art deep learning models and use them as feature extractors.1718Our process:1920* Use a TF Hub pre-trained model (Yamnet) and apply it as part of the tf.data pipeline which transforms21the audio files into feature vectors.22* Train a dense model on the feature vectors.23* Use the trained model for inference on a new audio file.2425Note:2627* We need to install TensorFlow IO in order to resample audio files to 16 kHz as required by Yamnet model.28* In the test section, ffmpeg is used to convert the mp3 file to wav.2930You can install TensorFlow IO with the following command:31"""3233"""shell34pip install -U -q tensorflow_io35"""3637"""38## Configuration39"""4041SEED = 133742EPOCHS = 10043BATCH_SIZE = 6444VALIDATION_RATIO = 0.145MODEL_NAME = "uk_irish_accent_recognition"4647# Location where the dataset will be downloaded.48# By default (None), keras.utils.get_file will use ~/.keras/ as the CACHE_DIR49CACHE_DIR = None5051# The location of the dataset52URL_PATH = "https://www.openslr.org/resources/83/"5354# List of datasets compressed files that contain the audio files55zip_files = {560: "irish_english_male.zip",571: "midlands_english_female.zip",582: "midlands_english_male.zip",593: "northern_english_female.zip",604: "northern_english_male.zip",615: "scottish_english_female.zip",626: "scottish_english_male.zip",637: "southern_english_female.zip",648: "southern_english_male.zip",659: "welsh_english_female.zip",6610: "welsh_english_male.zip",67}6869# We see that there are 2 compressed files for each accent (except Irish):70# - One for male speakers71# - One for female speakers72# However, we will be using a gender agnostic dataset.7374# List of gender agnostic categories75gender_agnostic_categories = [76"ir", # Irish77"mi", # Midlands78"no", # Northern79"sc", # Scottish80"so", # Southern81"we", # Welsh82]8384class_names = [85"Irish",86"Midlands",87"Northern",88"Scottish",89"Southern",90"Welsh",91"Not a speech",92]9394"""95## Imports96"""9798import os99import io100import csv101import numpy as np102import pandas as pd103import tensorflow as tf104import tensorflow_hub as hub105import tensorflow_io as tfio106from tensorflow import keras107import matplotlib.pyplot as plt108import seaborn as sns109from scipy import stats110from IPython.display import Audio111112# Set all random seeds in order to get reproducible results113keras.utils.set_random_seed(SEED)114115# Where to download the dataset116DATASET_DESTINATION = os.path.join(CACHE_DIR if CACHE_DIR else "~/.keras/", "datasets")117118"""119## Yamnet Model120121Yamnet is an audio event classifier trained on the AudioSet dataset to predict audio122events from the AudioSet ontology. It is available on TensorFlow Hub.123124Yamnet accepts a 1-D tensor of audio samples with a sample rate of 16 kHz.125As output, the model returns a 3-tuple:126127* Scores of shape `(N, 521)` representing the scores of the 521 classes.128* Embeddings of shape `(N, 1024)`.129* The log-mel spectrogram of the entire audio frame.130131We will use the embeddings, which are the features extracted from the audio samples, as the input to our dense model.132133For more detailed information about Yamnet, please refer to its [TensorFlow Hub](https://tfhub.dev/google/yamnet/1) page.134"""135136yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")137138"""139## Dataset140141The dataset used is the142[Crowdsourced high-quality UK and Ireland English Dialect speech data set](https://openslr.org/83/)143which consists of a total of 17,877 high-quality audio wav files.144145This dataset includes over 31 hours of recording from 120 volunteers who self-identify as146native speakers of Southern England, Midlands, Northern England, Wales, Scotland and Ireland.147148For more info, please refer to the above link or to the following paper:149[Open-source Multi-speaker Corpora of the English Accents in the British Isles](https://aclanthology.org/2020.lrec-1.804.pdf)150"""151152"""153## Download the data154"""155156# CSV file that contains information about the dataset. For each entry, we have:157# - ID158# - wav file name159# - transcript160line_index_file = keras.utils.get_file(161fname="line_index_file", origin=URL_PATH + "line_index_all.csv"162)163164# Download the list of compressed files that contain the audio wav files165for i in zip_files:166fname = zip_files[i].split(".")[0]167url = URL_PATH + zip_files[i]168169zip_file = keras.utils.get_file(fname=fname, origin=url, extract=True)170os.remove(zip_file)171172"""173## Load the data in a Dataframe174175Of the 3 columns (ID, filename and transcript), we are only interested in the filename column in order to read the audio file.176We will ignore the other two.177"""178179dataframe = pd.read_csv(180line_index_file, names=["id", "filename", "transcript"], usecols=["filename"]181)182dataframe.head()183184"""185Let's now preprocess the dataset by:186187* Adjusting the filename (removing a leading space & adding ".wav" extension to the188filename).189* Creating a label using the first 2 characters of the filename which indicate the190accent.191* Shuffling the samples.192"""193194195# The purpose of this function is to preprocess the dataframe by applying the following:196# - Cleaning the filename from a leading space197# - Generating a label column that is gender agnostic i.e.198# welsh english male and welsh english female for example are both labeled as199# welsh english200# - Add extension .wav to the filename201# - Shuffle samples202def preprocess_dataframe(dataframe):203# Remove leading space in filename column204dataframe["filename"] = dataframe.apply(lambda row: row["filename"].strip(), axis=1)205206# Create gender agnostic labels based on the filename first 2 letters207dataframe["label"] = dataframe.apply(208lambda row: gender_agnostic_categories.index(row["filename"][:2]), axis=1209)210211# Add the file path to the name212dataframe["filename"] = dataframe.apply(213lambda row: os.path.join(DATASET_DESTINATION, row["filename"] + ".wav"), axis=1214)215216# Shuffle the samples217dataframe = dataframe.sample(frac=1, random_state=SEED).reset_index(drop=True)218219return dataframe220221222dataframe = preprocess_dataframe(dataframe)223dataframe.head()224225"""226## Prepare training & validation sets227228Let's split the samples creating training and validation sets.229"""230231split = int(len(dataframe) * (1 - VALIDATION_RATIO))232train_df = dataframe[:split]233valid_df = dataframe[split:]234235print(236f"We have {train_df.shape[0]} training samples & {valid_df.shape[0]} validation ones"237)238239"""240## Prepare a TensorFlow Dataset241242Next, we need to create a `tf.data.Dataset`.243This is done by creating a `dataframe_to_dataset` function that does the following:244245* Create a dataset using filenames and labels.246* Get the Yamnet embeddings by calling another function `filepath_to_embeddings`.247* Apply caching, reshuffling and setting batch size.248249The `filepath_to_embeddings` does the following:250251* Load audio file.252* Resample audio to 16 kHz.253* Generate scores and embeddings from Yamnet model.254* Since Yamnet generates multiple samples for each audio file,255this function also duplicates the label for all the generated samples256that have `score=0` (speech) whereas sets the label for the others as257'other' indicating that this audio segment is not a speech and we won't label it as one of the accents.258259The below `load_16k_audio_file` is copied from the following tutorial260[Transfer learning with YAMNet for environmental sound classification](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio)261"""262263264@tf.function265def load_16k_audio_wav(filename):266# Read file content267file_content = tf.io.read_file(filename)268269# Decode audio wave270audio_wav, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)271audio_wav = tf.squeeze(audio_wav, axis=-1)272sample_rate = tf.cast(sample_rate, dtype=tf.int64)273274# Resample to 16k275audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000)276277return audio_wav278279280def filepath_to_embeddings(filename, label):281# Load 16k audio wave282audio_wav = load_16k_audio_wav(filename)283284# Get audio embeddings & scores.285# The embeddings are the audio features extracted using transfer learning286# while scores will be used to identify time slots that are not speech287# which will then be gathered into a specific new category 'other'288scores, embeddings, _ = yamnet_model(audio_wav)289290# Number of embeddings in order to know how many times to repeat the label291embeddings_num = tf.shape(embeddings)[0]292labels = tf.repeat(label, embeddings_num)293294# Change labels for time-slots that are not speech into a new category 'other'295labels = tf.where(tf.argmax(scores, axis=1) == 0, label, len(class_names) - 1)296297# Using one-hot in order to use AUC298return (embeddings, tf.one_hot(labels, len(class_names)))299300301def dataframe_to_dataset(dataframe, batch_size=64):302dataset = tf.data.Dataset.from_tensor_slices(303(dataframe["filename"], dataframe["label"])304)305306dataset = dataset.map(307lambda x, y: filepath_to_embeddings(x, y),308num_parallel_calls=tf.data.experimental.AUTOTUNE,309).unbatch()310311return dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)312313314train_ds = dataframe_to_dataset(train_df)315valid_ds = dataframe_to_dataset(valid_df)316317"""318## Build the model319320The model that we use consists of:321322* An input layer which is the embedding output of the Yamnet classifier.323* 4 dense hidden layers and 4 dropout layers.324* An output dense layer.325326The model's hyperparameters were selected using327[KerasTuner](https://keras.io/keras_tuner/).328"""329330keras.backend.clear_session()331332333def build_and_compile_model():334inputs = keras.layers.Input(shape=(1024), name="embedding")335336x = keras.layers.Dense(256, activation="relu", name="dense_1")(inputs)337x = keras.layers.Dropout(0.15, name="dropout_1")(x)338339x = keras.layers.Dense(384, activation="relu", name="dense_2")(x)340x = keras.layers.Dropout(0.2, name="dropout_2")(x)341342x = keras.layers.Dense(192, activation="relu", name="dense_3")(x)343x = keras.layers.Dropout(0.25, name="dropout_3")(x)344345x = keras.layers.Dense(384, activation="relu", name="dense_4")(x)346x = keras.layers.Dropout(0.2, name="dropout_4")(x)347348outputs = keras.layers.Dense(len(class_names), activation="softmax", name="ouput")(349x350)351352model = keras.Model(inputs=inputs, outputs=outputs, name="accent_recognition")353354model.compile(355optimizer=keras.optimizers.Adam(learning_rate=1.9644e-5),356loss=keras.losses.CategoricalCrossentropy(),357metrics=["accuracy", keras.metrics.AUC(name="auc")],358)359360return model361362363model = build_and_compile_model()364model.summary()365366"""367## Class weights calculation368369Since the dataset is quite unbalanced, we will use `class_weight` argument during training.370371Getting the class weights is a little tricky because even though we know the number of372audio files for each class, it does not represent the number of samples for that class373since Yamnet transforms each audio file into multiple audio samples of 0.96 seconds each.374So every audio file will be split into a number of samples that is proportional to its length.375376Therefore, to get those weights, we have to calculate the number of samples for each class377after preprocessing through Yamnet.378"""379380class_counts = tf.zeros(shape=(len(class_names),), dtype=tf.int32)381382for x, y in iter(train_ds):383class_counts = class_counts + tf.math.bincount(384tf.cast(tf.math.argmax(y, axis=1), tf.int32), minlength=len(class_names)385)386387class_weight = {388i: tf.math.reduce_sum(class_counts).numpy() / class_counts[i].numpy()389for i in range(len(class_counts))390}391392print(class_weight)393394"""395## Callbacks396397We use Keras callbacks in order to:398399* Stop whenever the validation AUC stops improving.400* Save the best model.401* Call TensorBoard in order to later view the training and validation logs.402"""403404early_stopping_cb = keras.callbacks.EarlyStopping(405monitor="val_auc", patience=10, restore_best_weights=True406)407408model_checkpoint_cb = keras.callbacks.ModelCheckpoint(409MODEL_NAME + ".h5", monitor="val_auc", save_best_only=True410)411412tensorboard_cb = keras.callbacks.TensorBoard(413os.path.join(os.curdir, "logs", model.name)414)415416callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]417418"""419## Training420"""421422history = model.fit(423train_ds,424epochs=EPOCHS,425validation_data=valid_ds,426class_weight=class_weight,427callbacks=callbacks,428verbose=2,429)430431"""432## Results433434Let's plot the training and validation AUC and accuracy.435"""436437fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 5))438439axs[0].plot(range(EPOCHS), history.history["accuracy"], label="Training")440axs[0].plot(range(EPOCHS), history.history["val_accuracy"], label="Validation")441axs[0].set_xlabel("Epochs")442axs[0].set_title("Training & Validation Accuracy")443axs[0].legend()444axs[0].grid(True)445446axs[1].plot(range(EPOCHS), history.history["auc"], label="Training")447axs[1].plot(range(EPOCHS), history.history["val_auc"], label="Validation")448axs[1].set_xlabel("Epochs")449axs[1].set_title("Training & Validation AUC")450axs[1].legend()451axs[1].grid(True)452453plt.show()454455"""456## Evaluation457"""458459train_loss, train_acc, train_auc = model.evaluate(train_ds)460valid_loss, valid_acc, valid_auc = model.evaluate(valid_ds)461462"""463Let's try to compare our model's performance to Yamnet's using one of Yamnet metrics (d-prime)464Yamnet achieved a d-prime value of 2.318.465Let's check our model's performance.466"""467468469# The following function calculates the d-prime score from the AUC470def d_prime(auc):471standard_normal = stats.norm()472d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)473return d_prime474475476print(477"train d-prime: {0:.3f}, validation d-prime: {1:.3f}".format(478d_prime(train_auc), d_prime(valid_auc)479)480)481482"""483We can see that the model achieves the following results:484485Results | Training | Validation486-----------|-----------|------------487Accuracy | 54% | 51%488AUC | 0.91 | 0.89489d-prime | 1.882 | 1.740490491"""492493"""494## Confusion Matrix495496Let's now plot the confusion matrix for the validation dataset.497498The confusion matrix lets us see, for every class, not only how many samples were correctly classified,499but also which other classes were the samples confused with.500501It allows us to calculate the precision and recall for every class.502"""503504# Create x and y tensors505x_valid = None506y_valid = None507508for x, y in iter(valid_ds):509if x_valid is None:510x_valid = x.numpy()511y_valid = y.numpy()512else:513x_valid = np.concatenate((x_valid, x.numpy()), axis=0)514y_valid = np.concatenate((y_valid, y.numpy()), axis=0)515516# Generate predictions517y_pred = model.predict(x_valid)518519# Calculate confusion matrix520confusion_mtx = tf.math.confusion_matrix(521np.argmax(y_valid, axis=1), np.argmax(y_pred, axis=1)522)523524# Plot the confusion matrix525plt.figure(figsize=(10, 8))526sns.heatmap(527confusion_mtx, xticklabels=class_names, yticklabels=class_names, annot=True, fmt="g"528)529plt.xlabel("Prediction")530plt.ylabel("Label")531plt.title("Validation Confusion Matrix")532plt.show()533534"""535## Precision & recall536537For every class:538539* Recall is the ratio of correctly classified samples i.e. it shows how many samples540of this specific class, the model is able to detect.541It is the ratio of diagonal elements to the sum of all elements in the row.542* Precision shows the accuracy of the classifier. It is the ratio of correctly predicted543samples among the ones classified as belonging to this class.544It is the ratio of diagonal elements to the sum of all elements in the column.545"""546547for i, label in enumerate(class_names):548precision = confusion_mtx[i, i] / np.sum(confusion_mtx[:, i])549recall = confusion_mtx[i, i] / np.sum(confusion_mtx[i, :])550print(551"{0:15} Precision:{1:.2f}%; Recall:{2:.2f}%".format(552label, precision * 100, recall * 100553)554)555556"""557## Run inference on test data558559Let's now run a test on a single audio file.560Let's check this example from [The Scottish Voice](https://www.thescottishvoice.org.uk/home/)561562We will:563564* Download the mp3 file.565* Convert it to a 16k wav file.566* Run the model on the wav file.567* Plot the results.568"""569570filename = "audio-sample-Stuart"571url = "https://www.thescottishvoice.org.uk/files/cm/files/"572573if os.path.exists(filename + ".wav") == False:574print(f"Downloading {filename}.mp3 from {url}")575command = f"wget {url}{filename}.mp3"576os.system(command)577578print(f"Converting mp3 to wav and resampling to 16 kHZ")579command = (580f"ffmpeg -hide_banner -loglevel panic -y -i {filename}.mp3 -acodec "581f"pcm_s16le -ac 1 -ar 16000 {filename}.wav"582)583os.system(command)584585filename = filename + ".wav"586587588"""589The below function `yamnet_class_names_from_csv` was copied and very slightly changed590from this [Yamnet Notebook](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/yamnet.ipynb).591"""592593594def yamnet_class_names_from_csv(yamnet_class_map_csv_text):595"""Returns list of class names corresponding to score vector."""596yamnet_class_map_csv = io.StringIO(yamnet_class_map_csv_text)597yamnet_class_names = [598name for (class_index, mid, name) in csv.reader(yamnet_class_map_csv)599]600yamnet_class_names = yamnet_class_names[1:] # Skip CSV header601return yamnet_class_names602603604yamnet_class_map_path = yamnet_model.class_map_path().numpy()605yamnet_class_names = yamnet_class_names_from_csv(606tf.io.read_file(yamnet_class_map_path).numpy().decode("utf-8")607)608609610def calculate_number_of_non_speech(scores):611number_of_non_speech = tf.math.reduce_sum(612tf.where(tf.math.argmax(scores, axis=1, output_type=tf.int32) != 0, 1, 0)613)614615return number_of_non_speech616617618def filename_to_predictions(filename):619# Load 16k audio wave620audio_wav = load_16k_audio_wav(filename)621622# Get audio embeddings & scores.623scores, embeddings, mel_spectrogram = yamnet_model(audio_wav)624625print(626"Out of {} samples, {} are not speech".format(627scores.shape[0], calculate_number_of_non_speech(scores)628)629)630631# Predict the output of the accent recognition model with embeddings as input632predictions = model.predict(embeddings)633634return audio_wav, predictions, mel_spectrogram635636637"""638Let's run the model on the audio file:639"""640641audio_wav, predictions, mel_spectrogram = filename_to_predictions(filename)642643infered_class = class_names[predictions.mean(axis=0).argmax()]644print(f"The main accent is: {infered_class} English")645646"""647Listen to the audio648"""649650Audio(audio_wav, rate=16000)651652"""653The below function was copied from this [Yamnet notebook](tinyurl.com/4a8xn7at) and adjusted to our need.654655This function plots the following:656657* Audio waveform658* Mel spectrogram659* Predictions for every time step660"""661662plt.figure(figsize=(10, 6))663664# Plot the waveform.665plt.subplot(3, 1, 1)666plt.plot(audio_wav)667plt.xlim([0, len(audio_wav)])668669# Plot the log-mel spectrogram (returned by the model).670plt.subplot(3, 1, 2)671plt.imshow(672mel_spectrogram.numpy().T, aspect="auto", interpolation="nearest", origin="lower"673)674675# Plot and label the model output scores for the top-scoring classes.676mean_predictions = np.mean(predictions, axis=0)677678top_class_indices = np.argsort(mean_predictions)[::-1]679plt.subplot(3, 1, 3)680plt.imshow(681predictions[:, top_class_indices].T,682aspect="auto",683interpolation="nearest",684cmap="gray_r",685)686687# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS688# values from the model documentation689patch_padding = (0.025 / 2) / 0.01690plt.xlim([-patch_padding - 0.5, predictions.shape[0] + patch_padding - 0.5])691# Label the top_N classes.692yticks = range(0, len(class_names), 1)693plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])694_ = plt.ylim(-0.5 + np.array([len(class_names), 0]))695696697