""" .. _ex-decoding-csp-eeg: =========================================================================== Motor imagery decoding from EEG data using the Common Spatial Pattern (CSP) =========================================================================== Decoding of motor imagery applied to EEG data decomposed using CSP. A classifier is then applied to features extracted on CSP-filtered signals. See https://en.wikipedia.org/wiki/Common_spatial_pattern and :footcite:`Koles1991`. The EEGBCI dataset is documented in :footcite:`SchalkEtAl2004` and is available at PhysioNet :footcite:`GoldbergerEtAl2000`. """ # Authors: Martin Billinger # # License: BSD-3-Clause # Copyright the MNE-Python contributors. # %% import matplotlib.pyplot as plt import numpy as np from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.model_selection import ShuffleSplit, cross_val_score from sklearn.pipeline import Pipeline from mne import Epochs, pick_types from mne.channels import make_standard_montage from mne.datasets import eegbci from mne.decoding import CSP from mne.io import concatenate_raws, read_raw_edf print(__doc__) # ############################################################################# # # Set parameters and read data # avoid classification of evoked responses by using epochs that start 1s after # cue onset. tmin, tmax = -1.0, 4.0 subject = 1 runs = [6, 10, 14] # motor imagery: hands vs feet raw_fnames = eegbci.load_data(subject, runs) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames]) eegbci.standardize(raw) # set channel names montage = make_standard_montage("standard_1005") raw.set_montage(montage) raw.annotations.rename(dict(T1="hands", T2="feet")) raw.set_eeg_reference(projection=True) # Apply band-pass filter raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") # Read epochs (train will be done only between 1 and 2s) # Testing will be done with a running classifier epochs = Epochs( raw, event_id=["hands", "feet"], tmin=tmin, tmax=tmax, proj=True, picks=picks, baseline=None, preload=True, ) epochs_train = epochs.copy().crop(tmin=1.0, tmax=2.0) labels = epochs.events[:, -1] - 2 # %% # Classification with linear discrimant analysis # Define a monte-carlo cross-validation generator (reduce variance): scores = [] epochs_data = epochs.get_data(copy=False) epochs_data_train = epochs_train.get_data(copy=False) cv = ShuffleSplit(10, test_size=0.2, random_state=42) cv_split = cv.split(epochs_data_train) # Assemble a classifier lda = LinearDiscriminantAnalysis() csp = CSP(n_components=4, reg=None, log=True, norm_trace=False) # Use scikit-learn Pipeline with cross_val_score function clf = Pipeline([("CSP", csp), ("LDA", lda)]) scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None) # Printing the results class_balance = np.mean(labels == labels[0]) class_balance = max(class_balance, 1.0 - class_balance) print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}") # plot CSP patterns estimated on full data for visualization csp.fit_transform(epochs_data, labels) csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5) # %% # Look at performance over time sfreq = raw.info["sfreq"] w_length = int(sfreq * 0.5) # running classifier: window length w_step = int(sfreq * 0.1) # running classifier: window step size w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) scores_windows = [] for train_idx, test_idx in cv_split: y_train, y_test = labels[train_idx], labels[test_idx] X_train = csp.fit_transform(epochs_data_train[train_idx], y_train) X_test = csp.transform(epochs_data_train[test_idx]) # fit classifier lda.fit(X_train, y_train) # running classifier: test classifier on sliding window score_this_window = [] for n in w_start: X_test = csp.transform(epochs_data[test_idx][:, :, n : (n + w_length)]) score_this_window.append(lda.score(X_test, y_test)) scores_windows.append(score_this_window) # Plot scores over time w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin plt.figure() plt.plot(w_times, np.mean(scores_windows, 0), label="Score") plt.axvline(0, linestyle="--", color="k", label="Onset") plt.axhline(0.5, linestyle="-", color="k", label="Chance") plt.xlabel("time (s)") plt.ylabel("classification accuracy") plt.title("Classification score over time") plt.legend(loc="lower right") plt.show() ############################################################################## # References # ---------- # .. footbibliography::