{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Decoding in time-frequency space using Common Spatial Patterns (CSP)\n\nThe time-frequency decomposition is estimated by iterating over raw data that\nhas been band-passed at different frequencies. This is used to compute a\ncovariance matrix over each epoch or a rolling time-window and extract the CSP\nfiltered signals. A linear discriminant classifier is then applied to these\nsignals.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Laura Gwilliams \n# Jean-R\u00e9mi King \n# Alex Barachant \n# Alexandre Gramfort \n#\n# License: BSD-3-Clause\n# Copyright the MNE-Python contributors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.discriminant_analysis import LinearDiscriminantAnalysis\nfrom sklearn.model_selection import StratifiedKFold, cross_val_score\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import LabelEncoder\n\nfrom mne import Epochs, create_info\nfrom mne.datasets import eegbci\nfrom mne.decoding import CSP\nfrom mne.io import concatenate_raws, read_raw_edf\nfrom mne.time_frequency import AverageTFRArray" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set parameters and read data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "subject = 1\nruns = [6, 10, 14]\nraw_fnames = eegbci.load_data(subject, runs)\nraw = concatenate_raws([read_raw_edf(f) for f in raw_fnames])\nraw.annotations.rename(dict(T1=\"hands\", T2=\"feet\"))\n\n# Extract information from the raw file\nsfreq = raw.info[\"sfreq\"]\nraw.pick(picks=\"eeg\", exclude=\"bads\")\nraw.load_data()\n\n# Assemble the classifier using scikit-learn pipeline\nclf = make_pipeline(\n CSP(n_components=4, reg=None, log=True, norm_trace=False),\n LinearDiscriminantAnalysis(),\n)\nn_splits = 3 # for cross-validation, 5 is better, here we use 3 for speed\ncv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)\n\n# Classification & time-frequency parameters\ntmin, tmax = -0.200, 2.000\nn_cycles = 10.0 # how many complete cycles: used to define window size\nmin_freq = 8.0\nmax_freq = 20.0\nn_freqs = 6 # how many frequency bins to use\n\n# Assemble list of frequency range tuples\nfreqs = np.linspace(min_freq, max_freq, n_freqs) # assemble frequencies\nfreq_ranges = list(zip(freqs[:-1], freqs[1:])) # make freqs list of tuples\n\n# Infer window spacing from the max freq and number of cycles to avoid gaps\nwindow_spacing = n_cycles / np.max(freqs) / 2.0\ncentered_w_times = np.arange(tmin, tmax, window_spacing)[1:]\nn_windows = len(centered_w_times)\n\n# Instantiate label encoder\nle = LabelEncoder()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loop through frequencies, apply classifier and save scores\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# init scores\nfreq_scores = np.zeros((n_freqs - 1,))\n\n# Loop through each frequency range of interest\nfor freq, (fmin, fmax) in enumerate(freq_ranges):\n # Infer window size based on the frequency being used\n w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds\n\n # Apply band-pass filter to isolate the specified frequencies\n raw_filter = raw.copy().filter(\n fmin, fmax, fir_design=\"firwin\", skip_by_annotation=\"edge\"\n )\n\n # Extract epochs from filtered data, padded by window size\n epochs = Epochs(\n raw_filter,\n event_id=[\"hands\", \"feet\"],\n tmin=tmin - w_size,\n tmax=tmax + w_size,\n proj=False,\n baseline=None,\n preload=True,\n )\n epochs.drop_bad()\n y = le.fit_transform(epochs.events[:, 2])\n\n X = epochs.get_data(copy=False)\n\n # Save mean scores over folds for each frequency and time window\n freq_scores[freq] = np.mean(\n cross_val_score(estimator=clf, X=X, y=y, scoring=\"roc_auc\", cv=cv), axis=0\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot frequency results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.bar(\n freqs[:-1], freq_scores, width=np.diff(freqs)[0], align=\"edge\", edgecolor=\"black\"\n)\nplt.xticks(freqs)\nplt.ylim([0, 1])\nplt.axhline(\n len(epochs[\"feet\"]) / len(epochs), color=\"k\", linestyle=\"--\", label=\"chance level\"\n)\nplt.legend()\nplt.xlabel(\"Frequency (Hz)\")\nplt.ylabel(\"Decoding Scores\")\nplt.title(\"Frequency Decoding Scores\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loop through frequencies and time, apply classifier and save scores\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# init scores\ntf_scores = np.zeros((n_freqs - 1, n_windows))\n\n# Loop through each frequency range of interest\nfor freq, (fmin, fmax) in enumerate(freq_ranges):\n # Infer window size based on the frequency being used\n w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds\n\n # Apply band-pass filter to isolate the specified frequencies\n raw_filter = raw.copy().filter(\n fmin, fmax, fir_design=\"firwin\", skip_by_annotation=\"edge\"\n )\n\n # Extract epochs from filtered data, padded by window size\n epochs = Epochs(\n raw_filter,\n event_id=[\"hands\", \"feet\"],\n tmin=tmin - w_size,\n tmax=tmax + w_size,\n proj=False,\n baseline=None,\n preload=True,\n )\n epochs.drop_bad()\n y = le.fit_transform(epochs.events[:, 2])\n\n # Roll covariance, csp and lda over time\n for t, w_time in enumerate(centered_w_times):\n # Center the min and max of the window\n w_tmin = w_time - w_size / 2.0\n w_tmax = w_time + w_size / 2.0\n\n # Crop data into time-window of interest\n X = epochs.get_data(tmin=w_tmin, tmax=w_tmax, copy=False)\n\n # Save mean scores over folds for each frequency and time window\n tf_scores[freq, t] = np.mean(\n cross_val_score(estimator=clf, X=X, y=y, scoring=\"roc_auc\", cv=cv), axis=0\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot time-frequency results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Set up time frequency object\nav_tfr = AverageTFRArray(\n info=create_info([\"freq\"], sfreq),\n data=tf_scores[np.newaxis, :],\n times=centered_w_times,\n freqs=freqs[1:],\n nave=1,\n)\n\nchance = np.mean(y) # set chance level to white in the plot\nav_tfr.plot(\n [0], vlim=(chance, None), title=\"Time-Frequency Decoding Scores\", cmap=plt.cm.Reds\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }