{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Sleep stage classification from polysomnography (PSG) data\n\n

Note

This code is taken from the analysis code used in\n :footcite:`ChambonEtAl2018`. If you reuse this code please consider\n citing this work.

\n\nThis tutorial explains how to perform a toy polysomnography analysis that\nanswers the following question:\n\n.. important:: Given two subjects from the Sleep Physionet dataset\n :footcite:`KempEtAl2000,GoldbergerEtAl2000`, namely\n *Alice* and *Bob*, how well can we predict the sleep stages of\n *Bob* from *Alice's* data?\n\nThis problem is tackled as supervised multiclass classification task. The aim\nis to predict the sleep stage from 5 possible stages for each chunk of 30\nseconds of data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Alexandre Gramfort \n# Stanislas Chambon \n# Joan Massich \n#\n# License: BSD-3-Clause\n# Copyright the MNE-Python contributors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score, classification_report, confusion_matrix\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import FunctionTransformer\n\nimport mne\nfrom mne.datasets.sleep_physionet.age import fetch_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n\nHere we download the data of two subjects. The end goal is to obtain\n:term:`epochs` and the associated ground truth.\n\nMNE-Python provides us with\n:func:`mne.datasets.sleep_physionet.age.fetch_data` to conveniently download\ndata from the Sleep Physionet dataset\n:footcite:`KempEtAl2000,GoldbergerEtAl2000`.\nGiven a list of subjects and records, the fetcher downloads the data and\nprovides us with a pair of files for each subject:\n\n* ``-PSG.edf`` containing the polysomnography. The :term:`raw` data from the\n EEG helmet,\n* ``-Hypnogram.edf`` containing the :term:`annotations` recorded by an\n expert.\n\nCombining these two in a :class:`mne.io.Raw` object will allow us to extract\n:term:`events` based on the descriptions of the annotations to obtain the\n:term:`epochs`.\n\n### Read the PSG data and Hypnograms to create a raw object\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ALICE, BOB = 0, 1\n\n[alice_files, bob_files] = fetch_data(subjects=[ALICE, BOB], recording=[1])\n\nraw_train = mne.io.read_raw_edf(\n alice_files[0],\n stim_channel=\"Event marker\",\n infer_types=True,\n preload=True,\n verbose=\"error\", # ignore issues with stored filter settings\n)\nannot_train = mne.read_annotations(alice_files[1])\n\nraw_train.set_annotations(annot_train, emit_warning=False)\n\n# plot some data\n# scalings were chosen manually to allow for simultaneous visualization of\n# different channel types in this specific dataset\nraw_train.plot(\n start=60,\n duration=60,\n scalings=dict(eeg=1e-4, resp=1e3, eog=1e-4, emg=1e-7, misc=1e-1),\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extract 30s events from annotations\n\nThe Sleep Physionet dataset is annotated using\n[8 labels](physionet_labels_):\nWake (W), Stage 1, Stage 2, Stage 3, Stage 4 corresponding to the range from\nlight sleep to deep sleep, REM sleep (R) where REM is the abbreviation for\nRapid Eye Movement sleep, movement (M), and Stage (?) for any none scored\nsegment.\n\nWe will work only with 5 stages: Wake (W), Stage 1, Stage 2, Stage 3/4, and\nREM sleep (R). To do so, we use the ``event_id`` parameter in\n:func:`mne.events_from_annotations` to select which events are we\ninterested in and we associate an event identifier to each of them.\n\nMoreover, the recordings contain long awake (W) regions before and after each\nnight. To limit the impact of class imbalance, we trim each recording by only\nkeeping 30 minutes of wake time before the first occurrence and 30 minutes\nafter the last occurrence of sleep stages.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "annotation_desc_2_event_id = {\n \"Sleep stage W\": 1,\n \"Sleep stage 1\": 2,\n \"Sleep stage 2\": 3,\n \"Sleep stage 3\": 4,\n \"Sleep stage 4\": 4,\n \"Sleep stage R\": 5,\n}\n\n# keep last 30-min wake events before sleep and first 30-min wake events after\n# sleep and redefine annotations on raw data\nannot_train.crop(annot_train[1][\"onset\"] - 30 * 60, annot_train[-2][\"onset\"] + 30 * 60)\nraw_train.set_annotations(annot_train, emit_warning=False)\n\nevents_train, _ = mne.events_from_annotations(\n raw_train, event_id=annotation_desc_2_event_id, chunk_duration=30.0\n)\n\n# create a new event_id that unifies stages 3 and 4\nevent_id = {\n \"Sleep stage W\": 1,\n \"Sleep stage 1\": 2,\n \"Sleep stage 2\": 3,\n \"Sleep stage 3/4\": 4,\n \"Sleep stage R\": 5,\n}\n\n# plot events\nfig = mne.viz.plot_events(\n events_train,\n event_id=event_id,\n sfreq=raw_train.info[\"sfreq\"],\n first_samp=events_train[0, 0],\n)\n\n# keep the color-code for further plotting\nstage_colors = plt.rcParams[\"axes.prop_cycle\"].by_key()[\"color\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Epochs from the data based on the events found in the annotations\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tmax = 30.0 - 1.0 / raw_train.info[\"sfreq\"] # tmax in included\n\nepochs_train = mne.Epochs(\n raw=raw_train,\n events=events_train,\n event_id=event_id,\n tmin=0.0,\n tmax=tmax,\n baseline=None,\n)\ndel raw_train\n\nprint(epochs_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Applying the same steps to the test data from Bob\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "raw_test = mne.io.read_raw_edf(\n bob_files[0],\n stim_channel=\"Event marker\",\n infer_types=True,\n preload=True,\n verbose=\"error\",\n)\nannot_test = mne.read_annotations(bob_files[1])\nannot_test.crop(annot_test[1][\"onset\"] - 30 * 60, annot_test[-2][\"onset\"] + 30 * 60)\nraw_test.set_annotations(annot_test, emit_warning=False)\nevents_test, _ = mne.events_from_annotations(\n raw_test, event_id=annotation_desc_2_event_id, chunk_duration=30.0\n)\nepochs_test = mne.Epochs(\n raw=raw_test,\n events=events_test,\n event_id=event_id,\n tmin=0.0,\n tmax=tmax,\n baseline=None,\n)\ndel raw_test\n\nprint(epochs_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Engineering\n\nObserving the power spectral density (PSD) plot of the :term:`epochs` grouped\nby sleeping stage we can see that different sleep stages have different\nsignatures. These signatures remain similar between Alice and Bob's data.\n\nThe rest of this section we will create EEG features based on relative power\nin specific frequency bands to capture this difference between the sleep\nstages in our data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# visualize Alice vs. Bob PSD by sleep stage.\nfig, (ax1, ax2) = plt.subplots(ncols=2)\n\n# iterate over the subjects\nstages = sorted(event_id.keys())\nfor ax, title, epochs in zip([ax1, ax2], [\"Alice\", \"Bob\"], [epochs_train, epochs_test]):\n for stage, color in zip(stages, stage_colors):\n spectrum = epochs[stage].compute_psd(fmin=0.1, fmax=20.0)\n spectrum.plot(\n ci=None,\n color=color,\n axes=ax,\n show=False,\n average=True,\n amplitude=False,\n spatial_colors=False,\n picks=\"data\",\n exclude=\"bads\",\n )\n ax.set(title=title, xlabel=\"Frequency (Hz)\")\nax1.set(ylabel=\"\u00b5V\u00b2/Hz (dB)\")\nax2.legend(ax2.lines[2::3], stages)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Design a scikit-learn transformer from a Python function\n\nWe will now create a function to extract EEG features based on relative power\nin specific frequency bands to be able to predict sleep stages from EEG\nsignals.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def eeg_power_band(epochs):\n \"\"\"EEG relative power band feature extraction.\n\n This function takes an ``mne.Epochs`` object and creates EEG features based\n on relative power in specific frequency bands that are compatible with\n scikit-learn.\n\n Parameters\n ----------\n epochs : Epochs\n The data.\n\n Returns\n -------\n X : numpy array of shape [n_samples, 5 * n_channels]\n Transformed data.\n \"\"\"\n # specific frequency bands\n FREQ_BANDS = {\n \"delta\": [0.5, 4.5],\n \"theta\": [4.5, 8.5],\n \"alpha\": [8.5, 11.5],\n \"sigma\": [11.5, 15.5],\n \"beta\": [15.5, 30],\n }\n\n spectrum = epochs.compute_psd(picks=\"eeg\", fmin=0.5, fmax=30.0)\n psds, freqs = spectrum.get_data(return_freqs=True)\n # Normalize the PSDs\n psds /= np.sum(psds, axis=-1, keepdims=True)\n\n X = []\n for fmin, fmax in FREQ_BANDS.values():\n psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)\n X.append(psds_band.reshape(len(psds), -1))\n\n return np.concatenate(X, axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiclass classification workflow using scikit-learn\n\nTo answer the question of how well can we predict the sleep stages of Bob\nfrom Alice's data and avoid as much boilerplate code as possible, we will\ntake advantage of two key features of sckit-learn: `Pipeline`_ , and\n`FunctionTransformer`_.\n\nScikit-learn pipeline composes an estimator as a sequence of transforms\nand a final estimator, while the FunctionTransformer converts a python\nfunction in an estimator compatible object. In this manner we can create\nscikit-learn estimator that takes :class:`mne.Epochs` thanks to\n``eeg_power_band`` function we just created.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipe = make_pipeline(\n FunctionTransformer(eeg_power_band, validate=False),\n RandomForestClassifier(n_estimators=100, random_state=42),\n)\n\n# Train\ny_train = epochs_train.events[:, 2]\npipe.fit(epochs_train, y_train)\n\n# Test\ny_pred = pipe.predict(epochs_test)\n\n# Assess the results\ny_test = epochs_test.events[:, 2]\nacc = accuracy_score(y_test, y_pred)\n\nprint(f\"Accuracy score: {acc}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In short, yes. We can predict Bob's sleeping stages based on Alice's data.\n\n### Further analysis of the data\n\nWe can check the confusion matrix or the classification report.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(confusion_matrix(y_test, y_pred))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(classification_report(y_test, y_pred, target_names=event_id.keys()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise\n\nFetch 50 subjects from the Physionet database and run a 5-fold\ncross-validation leaving each time 10 subjects out in the test set.\n\n## References\n.. footbibliography::\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }