Path: blob/master/examples/timeseries/ipynb/eeg_bci_ssvepformer.ipynb
3236 views
Electroencephalogram Signal Classification for Brain-Computer Interface
Author: Okba Bekhelifi
Date created: 2025/01/08
Last modified: 2025/01/08
Description: A Transformer based classification for EEG signal for BCI.
Introduction
This tutorial will explain how to build a Transformer based Neural Network to classify Brain-Computer Interface (BCI) Electroencephalograpy (EEG) data recorded in a Steady-State Visual Evoked Potentials (SSVEPs) experiment for the application of a brain-controlled speller.
The tutorial reproduces an experiment from the SSVEPFormer study [1] ( arXiv preprint / Peer-Reviewed paper ). This model was the first Transformer based model to be introduced for SSVEP data classification, we will test it on the Nakanishi et al. [2] public dataset as dataset 1 from the paper.
The process follows an inter-subject classification experiment. Given N subject data in the dataset, the training data partition contains data from N-1 subject and the remaining single subject data is used for testing. the training set does not contain any sample from the testing subject. This way we construct a true subject-independent model. We keep the same parameters and settings as the original paper in all processing operations from preprocessing to training.
The tutorial begins with a quick BCI and dataset description then, we go through the technicalities following these sections:
Setup, and imports.
Dataset download and extraction.
Data preprocessing: EEG data filtering, segmentation and visualization of raw and filtered data, and frequency response for a well performing participant.
Layers and model creation.
Evaluation: a single participant data classification as an example then the total participants data classification.
Visulization: we show the results of training and inference times comparison among the Keras 3 available backends (JAX, Tensorflow, and PyTorch) on three different GPUs.
Conclusion: final discussion and remarks.
Dataset description
BCI and SSVEP:
A BCI offers the ability to communicate using only brain activity, this can be achieved through exogenous stimuli that generate specific responses indicating the intent of the subject. the responses are elicited when the user focuses their attention on the target stimulus. We can use visual stimuli by presenting the subject with a set of options typically on a monitor as a grid to select one command at a time. Each stimulus will flicker following a fixed frequency and phase, the resulting EEG recorded at occipital and occipito-parietal areas of the cortex (visual cortex) will have higher power in the associated frequency with the stimulus where the subject was looking at. This type of BCI paradigm is called the Steady-State Visual Evoked Potentials (SSVEPs) and became widely used for multiple application due to its reliability and high perfromance in classification and rapidity as a 1-second of EEG is sufficient making a command. Other types of brain responses exists and do not require external stimulations, however they are less reliable. Demo video
This tutorials uses the 12 commands (class) public SSVEP dataset [2] with the following interface emulating a phone dialing numbers.
The dataset was recorded with 10 participants, each faced the above 12 SSVEP stimuli (A). The stimulation frequencies ranged from 9.25Hz to 14.75 Hz with 0.5Hz step, and phases ranged from 0 to 1.5 π with 0.5 π step for each row.(B). The EEG signal was acquired with 8 electrodes (channels) (PO7, PO3, POz, PO4, PO8, O1, Oz, O2) sampling frequency was 2048 Hz then the stored data were downsampled to 256 Hz. The subjects completed 15 blocks of recordings, each consisted of 12 random ordered stimulations (1 for each class) of 4 seconds each. In total, each subject conducted 180 trials.
Setup
Select JAX backend
Install dependencies
Imports
Download and extract dataset
Nakanishi et. al 2015 DataSet Repo
Pre-Processing
The preprocessing steps followed are first to read the EEG data for each subject, then to filter the raw data in a frequency interval where most useful information lies, then we select a fixed duration of signal starting from the onset of the stimulation (due to latency delay caused by the visual system we start we add 135 milliseconds to the stimulation onset). Lastly, all subjects data are concatenated in a single Tensor of the shape: [subjects x samples x channels x trials]. The data labels are also concatenated following the order of the trials in the experiments and will be a matrix of the shape [subjects x trials] (here by channels we mean electrodes, we use this notation throughout the tutorial).
Segment data into epochs
Visualize EEG signal
EEG in time
Raw EEG vs Filtered EEG The same 1-second recording for subject s1 at Oz (central electrode in the visual cortex, back of the head) is illustrated. left is the raw EEG as recorded and in the right is the filtered EEG on the [8, 64] Hz frequency band. we see less noise and normalized amplitude values in a natural EEG range.
EEG frequency representation
Using the welch method, we visualize the frequency power for a well performing subject for the entire 4 seconds EEG recording at Oz electrode for each stimuli. the red peaks indicate the stimuli fundamental frequency and the 2nd harmonics (double the fundamental frequency). we see clear peaks showing the high responses from that subject which means that this subject is a good candidate for SSVEP BCI control. In many cases the peaks are weak or absent, meaning that subject do not achieve the task correctly.
Create Layers and model
Create Layers in a cross-framework custom component fashion. In the SSVEPFormer, the data is first transformed to the frequency domain through Fast-Fourier transform (FFT), to construct a complex spectrum presentation consisting of the concatenation of frequency and phase information in a fixed frequency band. To keep the model in an end-to-end format, we implement the complex spectrum transformation as non-trainable layer.
The SSVEPFormer unlike the Transformer architecture does not contain positional encoding/embedding layers which replaced a channel combination block that has a layer of Conv1D layer of 1 kernel size with double input channels (double the count of electrodes) number of filters, and LayerNorm, Gelu activation and dropout. Another difference with Transformers is the absence of multi-head attention layers with attention mechanism. The model encoder contains two identical and successive blocks. Each block has two sub-blocks of CNN module and MLP module. the CNN module consists of a LayerNorm, Conv1D with the same number of filters as channel combination, LayerNorm, Gelu, Dropout and an residual connection. The MLP module consists of a LayerNorm, Dense layer, Gelu, droput and residual connection. the Dense layer is applied on each channel separately. The last block of the model is MLP head with Flatten layer, Dropout, Dense, LayerNorm, Gelu, Dropout and Dense layer with softmax acitvation. All trainable weights are initialized by a normal distribution with 0 mean and 0.01 standard deviation as state in the original paper.
Create a sequential model with the layers above
Evaluation
From the entire dataset we select folds for each subject evaluation. construct a tf dataset object for train and testing data and create the model and launch the training using SGD optimizer.
Run evaluation
Evaluation on all subjects following a leave-one-subject out data repartition scheme
and that's it! we see how some subjects with no data on the training set still can achieve almost a 100% correct commands and others show poor performance around 50%. In the original paper using PyTorch the average accuracy was 84.04% with 17.37 std. we reached the same values knowing the stochastic nature of deep learning.
Visualizations
Training and inference times comparison between the different backends (Jax, Tensorflow and PyTorch) on the three GPUs available with Colab Free/Pro/Pro+: T4, L4, A100.
Training Time
Inference Time
the Jax backend was the best on training and inference in all the GPUs, the PyTorch was exremely slow due to the jit compilation option being disable because of the complex data type calculated by FFT which is not supported by the PyTorch jit compiler.
Acknowledgment
I thank Chris Perry X @GoogleColab for supporting this work with GPU compute.
References
[1] Chen, J. et al. (2023) ‘A transformer-based deep neural network model for SSVEP classification’, Neural Networks, 164, pp. 521–534. Available at: https://doi.org/10.1016/j.neunet.2023.04.045.
[2] Nakanishi, M. et al. (2015) ‘A Comparison Study of Canonical Correlation Analysis Based Methods for Detecting Steady-State Visual Evoked Potentials’, Plos One, 10(10), p. e0140703. Available at: https://doi.org/10.1371/journal.pone.0140703