""" .. _tut-sleep-stage-classif: Sleep stage classification from polysomnography (PSG) data ========================================================== .. note:: This code is taken from the analysis code used in :footcite:`ChambonEtAl2018`. If you reuse this code please consider citing this work. This tutorial explains how to perform a toy polysomnography analysis that answers the following question: .. important:: Given two subjects from the Sleep Physionet dataset :footcite:`KempEtAl2000,GoldbergerEtAl2000`, namely *Alice* and *Bob*, how well can we predict the sleep stages of *Bob* from *Alice's* data? This problem is tackled as supervised multiclass classification task. The aim is to predict the sleep stage from 5 possible stages for each chunk of 30 seconds of data. .. _Pipeline: https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html .. _FunctionTransformer: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html .. _physionet_labels: https://physionet.org/physiobank/database/sleep-edfx/#sleep-cassette-study-and-data """ # noqa: E501 # Authors: Alexandre Gramfort # Stanislas Chambon # Joan Massich # # License: BSD-3-Clause # Copyright the MNE-Python contributors. # %% import matplotlib.pyplot as plt import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score, classification_report, confusion_matrix from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer import mne from mne.datasets.sleep_physionet.age import fetch_data ############################################################################## # Load the data # ------------- # # Here we download the data of two subjects. The end goal is to obtain # :term:`epochs` and the associated ground truth. # # MNE-Python provides us with # :func:`mne.datasets.sleep_physionet.age.fetch_data` to conveniently download # data from the Sleep Physionet dataset # :footcite:`KempEtAl2000,GoldbergerEtAl2000`. # Given a list of subjects and records, the fetcher downloads the data and # provides us with a pair of files for each subject: # # * ``-PSG.edf`` containing the polysomnography. The :term:`raw` data from the # EEG helmet, # * ``-Hypnogram.edf`` containing the :term:`annotations` recorded by an # expert. # # Combining these two in a :class:`mne.io.Raw` object will allow us to extract # :term:`events` based on the descriptions of the annotations to obtain the # :term:`epochs`. # # Read the PSG data and Hypnograms to create a raw object # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ALICE, BOB = 0, 1 [alice_files, bob_files] = fetch_data(subjects=[ALICE, BOB], recording=[1]) raw_train = mne.io.read_raw_edf( alice_files[0], stim_channel="Event marker", infer_types=True, preload=True, verbose="error", # ignore issues with stored filter settings ) annot_train = mne.read_annotations(alice_files[1]) raw_train.set_annotations(annot_train, emit_warning=False) # plot some data # scalings were chosen manually to allow for simultaneous visualization of # different channel types in this specific dataset raw_train.plot( start=60, duration=60, scalings=dict(eeg=1e-4, resp=1e3, eog=1e-4, emg=1e-7, misc=1e-1), ) ############################################################################## # Extract 30s events from annotations # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The Sleep Physionet dataset is annotated using # `8 labels `_: # Wake (W), Stage 1, Stage 2, Stage 3, Stage 4 corresponding to the range from # light sleep to deep sleep, REM sleep (R) where REM is the abbreviation for # Rapid Eye Movement sleep, movement (M), and Stage (?) for any none scored # segment. # # We will work only with 5 stages: Wake (W), Stage 1, Stage 2, Stage 3/4, and # REM sleep (R). To do so, we use the ``event_id`` parameter in # :func:`mne.events_from_annotations` to select which events are we # interested in and we associate an event identifier to each of them. # # Moreover, the recordings contain long awake (W) regions before and after each # night. To limit the impact of class imbalance, we trim each recording by only # keeping 30 minutes of wake time before the first occurrence and 30 minutes # after the last occurrence of sleep stages. annotation_desc_2_event_id = { "Sleep stage W": 1, "Sleep stage 1": 2, "Sleep stage 2": 3, "Sleep stage 3": 4, "Sleep stage 4": 4, "Sleep stage R": 5, } # keep last 30-min wake events before sleep and first 30-min wake events after # sleep and redefine annotations on raw data annot_train.crop(annot_train[1]["onset"] - 30 * 60, annot_train[-2]["onset"] + 30 * 60) raw_train.set_annotations(annot_train, emit_warning=False) events_train, _ = mne.events_from_annotations( raw_train, event_id=annotation_desc_2_event_id, chunk_duration=30.0 ) # create a new event_id that unifies stages 3 and 4 event_id = { "Sleep stage W": 1, "Sleep stage 1": 2, "Sleep stage 2": 3, "Sleep stage 3/4": 4, "Sleep stage R": 5, } # plot events fig = mne.viz.plot_events( events_train, event_id=event_id, sfreq=raw_train.info["sfreq"], first_samp=events_train[0, 0], ) # keep the color-code for further plotting stage_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] ############################################################################## # Create Epochs from the data based on the events found in the annotations # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ tmax = 30.0 - 1.0 / raw_train.info["sfreq"] # tmax in included epochs_train = mne.Epochs( raw=raw_train, events=events_train, event_id=event_id, tmin=0.0, tmax=tmax, baseline=None, ) del raw_train print(epochs_train) ############################################################################## # Applying the same steps to the test data from Bob # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ raw_test = mne.io.read_raw_edf( bob_files[0], stim_channel="Event marker", infer_types=True, preload=True, verbose="error", ) annot_test = mne.read_annotations(bob_files[1]) annot_test.crop(annot_test[1]["onset"] - 30 * 60, annot_test[-2]["onset"] + 30 * 60) raw_test.set_annotations(annot_test, emit_warning=False) events_test, _ = mne.events_from_annotations( raw_test, event_id=annotation_desc_2_event_id, chunk_duration=30.0 ) epochs_test = mne.Epochs( raw=raw_test, events=events_test, event_id=event_id, tmin=0.0, tmax=tmax, baseline=None, ) del raw_test print(epochs_test) ############################################################################## # Feature Engineering # ------------------- # # Observing the power spectral density (PSD) plot of the :term:`epochs` grouped # by sleeping stage we can see that different sleep stages have different # signatures. These signatures remain similar between Alice and Bob's data. # # The rest of this section we will create EEG features based on relative power # in specific frequency bands to capture this difference between the sleep # stages in our data. # visualize Alice vs. Bob PSD by sleep stage. fig, (ax1, ax2) = plt.subplots(ncols=2) # iterate over the subjects stages = sorted(event_id.keys()) for ax, title, epochs in zip([ax1, ax2], ["Alice", "Bob"], [epochs_train, epochs_test]): for stage, color in zip(stages, stage_colors): spectrum = epochs[stage].compute_psd(fmin=0.1, fmax=20.0) spectrum.plot( ci=None, color=color, axes=ax, show=False, average=True, amplitude=False, spatial_colors=False, picks="data", exclude="bads", ) ax.set(title=title, xlabel="Frequency (Hz)") ax1.set(ylabel="µV²/Hz (dB)") ax2.legend(ax2.lines[2::3], stages) ############################################################################## # Design a scikit-learn transformer from a Python function # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We will now create a function to extract EEG features based on relative power # in specific frequency bands to be able to predict sleep stages from EEG # signals. def eeg_power_band(epochs): """EEG relative power band feature extraction. This function takes an ``mne.Epochs`` object and creates EEG features based on relative power in specific frequency bands that are compatible with scikit-learn. Parameters ---------- epochs : Epochs The data. Returns ------- X : numpy array of shape [n_samples, 5 * n_channels] Transformed data. """ # specific frequency bands FREQ_BANDS = { "delta": [0.5, 4.5], "theta": [4.5, 8.5], "alpha": [8.5, 11.5], "sigma": [11.5, 15.5], "beta": [15.5, 30], } spectrum = epochs.compute_psd(picks="eeg", fmin=0.5, fmax=30.0) psds, freqs = spectrum.get_data(return_freqs=True) # Normalize the PSDs psds /= np.sum(psds, axis=-1, keepdims=True) X = [] for fmin, fmax in FREQ_BANDS.values(): psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1) X.append(psds_band.reshape(len(psds), -1)) return np.concatenate(X, axis=1) ############################################################################## # Multiclass classification workflow using scikit-learn # ----------------------------------------------------- # # To answer the question of how well can we predict the sleep stages of Bob # from Alice's data and avoid as much boilerplate code as possible, we will # take advantage of two key features of sckit-learn: `Pipeline`_ , and # `FunctionTransformer`_. # # Scikit-learn pipeline composes an estimator as a sequence of transforms # and a final estimator, while the FunctionTransformer converts a python # function in an estimator compatible object. In this manner we can create # scikit-learn estimator that takes :class:`mne.Epochs` thanks to # ``eeg_power_band`` function we just created. pipe = make_pipeline( FunctionTransformer(eeg_power_band, validate=False), RandomForestClassifier(n_estimators=100, random_state=42), ) # Train y_train = epochs_train.events[:, 2] pipe.fit(epochs_train, y_train) # Test y_pred = pipe.predict(epochs_test) # Assess the results y_test = epochs_test.events[:, 2] acc = accuracy_score(y_test, y_pred) print(f"Accuracy score: {acc}") ############################################################################## # In short, yes. We can predict Bob's sleeping stages based on Alice's data. # # Further analysis of the data # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can check the confusion matrix or the classification report. print(confusion_matrix(y_test, y_pred)) ############################################################################## # print(classification_report(y_test, y_pred, target_names=event_id.keys())) ############################################################################## # Exercise # -------- # # Fetch 50 subjects from the Physionet database and run a 5-fold # cross-validation leaving each time 10 subjects out in the test set. # # References # ---------- # .. footbibliography::