{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Motor imagery decoding from EEG data using the Common Spatial Pattern (CSP)\n\nDecoding of motor imagery applied to EEG data decomposed using CSP. A\nclassifier is then applied to features extracted on CSP-filtered signals.\n\nSee https://en.wikipedia.org/wiki/Common_spatial_pattern and\n:footcite:`Koles1991`. The EEGBCI dataset is documented in\n:footcite:`SchalkEtAl2004` and is available at PhysioNet\n:footcite:`GoldbergerEtAl2000`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Martin Billinger \n#\n# License: BSD-3-Clause\n# Copyright the MNE-Python contributors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.discriminant_analysis import LinearDiscriminantAnalysis\nfrom sklearn.model_selection import ShuffleSplit, cross_val_score\nfrom sklearn.pipeline import Pipeline\n\nfrom mne import Epochs, pick_types\nfrom mne.channels import make_standard_montage\nfrom mne.datasets import eegbci\nfrom mne.decoding import CSP\nfrom mne.io import concatenate_raws, read_raw_edf\n\nprint(__doc__)\n\n# #############################################################################\n# # Set parameters and read data\n\n# avoid classification of evoked responses by using epochs that start 1s after\n# cue onset.\ntmin, tmax = -1.0, 4.0\nsubject = 1\nruns = [6, 10, 14] # motor imagery: hands vs feet\n\nraw_fnames = eegbci.load_data(subject, runs)\nraw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])\neegbci.standardize(raw) # set channel names\nmontage = make_standard_montage(\"standard_1005\")\nraw.set_montage(montage)\nraw.annotations.rename(dict(T1=\"hands\", T2=\"feet\"))\nraw.set_eeg_reference(projection=True)\n\n# Apply band-pass filter\nraw.filter(7.0, 30.0, fir_design=\"firwin\", skip_by_annotation=\"edge\")\n\npicks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude=\"bads\")\n\n# Read epochs (train will be done only between 1 and 2s)\n# Testing will be done with a running classifier\nepochs = Epochs(\n raw,\n event_id=[\"hands\", \"feet\"],\n tmin=tmin,\n tmax=tmax,\n proj=True,\n picks=picks,\n baseline=None,\n preload=True,\n)\nepochs_train = epochs.copy().crop(tmin=1.0, tmax=2.0)\nlabels = epochs.events[:, -1] - 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Classification with linear discrimant analysis\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Define a monte-carlo cross-validation generator (reduce variance):\nscores = []\nepochs_data = epochs.get_data(copy=False)\nepochs_data_train = epochs_train.get_data(copy=False)\ncv = ShuffleSplit(10, test_size=0.2, random_state=42)\ncv_split = cv.split(epochs_data_train)\n\n# Assemble a classifier\nlda = LinearDiscriminantAnalysis()\ncsp = CSP(n_components=4, reg=None, log=True, norm_trace=False)\n\n# Use scikit-learn Pipeline with cross_val_score function\nclf = Pipeline([(\"CSP\", csp), (\"LDA\", lda)])\nscores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None)\n\n# Printing the results\nclass_balance = np.mean(labels == labels[0])\nclass_balance = max(class_balance, 1.0 - class_balance)\nprint(f\"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}\")\n\n# plot CSP patterns estimated on full data for visualization\ncsp.fit_transform(epochs_data, labels)\n\ncsp.plot_patterns(epochs.info, ch_type=\"eeg\", units=\"Patterns (AU)\", size=1.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Look at performance over time\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sfreq = raw.info[\"sfreq\"]\nw_length = int(sfreq * 0.5) # running classifier: window length\nw_step = int(sfreq * 0.1) # running classifier: window step size\nw_start = np.arange(0, epochs_data.shape[2] - w_length, w_step)\n\nscores_windows = []\n\nfor train_idx, test_idx in cv_split:\n y_train, y_test = labels[train_idx], labels[test_idx]\n\n X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)\n X_test = csp.transform(epochs_data_train[test_idx])\n\n # fit classifier\n lda.fit(X_train, y_train)\n\n # running classifier: test classifier on sliding window\n score_this_window = []\n for n in w_start:\n X_test = csp.transform(epochs_data[test_idx][:, :, n : (n + w_length)])\n score_this_window.append(lda.score(X_test, y_test))\n scores_windows.append(score_this_window)\n\n# Plot scores over time\nw_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin\n\nplt.figure()\nplt.plot(w_times, np.mean(scores_windows, 0), label=\"Score\")\nplt.axvline(0, linestyle=\"--\", color=\"k\", label=\"Onset\")\nplt.axhline(0.5, linestyle=\"-\", color=\"k\", label=\"Chance\")\nplt.xlabel(\"time (s)\")\nplt.ylabel(\"classification accuracy\")\nplt.title(\"Classification score over time\")\nplt.legend(loc=\"lower right\")\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n.. footbibliography::\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }