{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", " \n", "## [mlcourse.ai](mlcourse.ai) – Open Machine Learning Course \n", "###
AEGo, ODS Slack nickname: AEGo\n", " \n", "##
Individual data analysis project" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Research plan**\n", " - Dataset and features description\n", " - Exploratory data analysis\n", " - Visual analysis of the features\n", " - Patterns, insights, pecularities of data\n", " - Data preprocessing\n", " - Feature engineering and description\n", " - Cross-validation, hyperparameter tuning\n", " - Validation and learning curves\n", " - Prediction for hold-out and test samples\n", " - Model evaluation with metrics description\n", " - Conclusions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import json\n", "from sklearn.model_selection import StratifiedKFold, cross_val_score, GridSearchCV\n", "from sklearn.metrics import mean_absolute_error\n", "from sklearn.metrics import classification_report, confusion_matrix \n", "from sklearn.metrics import accuracy_score, precision_score, balanced_accuracy_score\n", "from sklearn.metrics import label_ranking_average_precision_score, label_ranking_loss, coverage_error\n", "from scipy.sparse import csr_matrix, hstack\n", "from scipy.stats import probplot\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "get_ipython().run_line_magic('matplotlib', 'inline')\n", "import seaborn as sns\n", "import gc\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "import time\n", "import random\n", "import itertools\n", "from scipy.signal import resample\n", "from sklearn.utils import shuffle\n", "from sklearn.preprocessing import OneHotEncoder,StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import train_test_split\n", "\n", "from keras.models import Model\n", "from keras.layers import Input, Dense, Conv1D, MaxPooling1D, Softmax, Add, Flatten, Activation\n", "from keras import backend as K\n", "from keras.optimizers import Adam\n", "from keras.callbacks import LearningRateScheduler, ModelCheckpoint\n", "from keras.wrappers.scikit_learn import KerasClassifier\n", "\n", "from time import gmtime, strftime\n", "from scipy.signal import butter, lfilter\n", "import pywt\n", "from scipy.signal import medfilt\n", "\n", "import math\n", "import os\n", "\n", "from numpy.random import seed\n", "from tensorflow import set_random_seed\n", "\n", "sns.set(style=\"darkgrid\")\n", "\n", "import tensorflow as tf\n", "\n", "K.clear_session()\n", "\n", "session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,\n", " inter_op_parallelism_threads=1)\n", "\n", "# The below tf.set_random_seed() will make random number generation\n", "# in the TensorFlow backend have a well-defined initial state.\n", "# For further details, see:\n", "# https://www.tensorflow.org/api_docs/python/tf/set_random_seed\n", "\n", "tf.set_random_seed(1234)\n", "\n", "sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)\n", "K.set_session(sess)\n", "\n", "seed(17)\n", "set_random_seed(17)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 1. Dataset and features description" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[ECG Heartbeat Categorization Dataset](https://www.kaggle.com/shayanfazeli/heartbeat)\n", "\n", "This dataset is composed of two collections of heartbeat signals derived from two famous datasets in heartbeat classification, the MIT-BIH Arrhythmia Dataset and The PTB Diagnostic ECG Database. \n", "\n", "The signals correspond to electrocardiogram (ECG) shapes of heartbeats for the normal case and the cases affected by different arrhythmias and myocardial infarction. These signals are preprocessed and segmented, with each segment corresponding to a heartbeat.\n", "Content\n", "Arrhythmia Dataset\n", "\n", "| Physionet's MIT-BIH Arrhythmia Dataset | |\n", "|-|-|\n", "| Number of Samples | 109446 |\n", "| Number of Categories | 5 |\n", "| Sampling Frequency | 125Hz |\n", "| Classes |['N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4] |\n", "\n", "| Class | Description |\n", "|-|-|\n", "| N | Normal |\n", "| S | Supraventricular premature beat |\n", "| V | Premature ventricular contraction |\n", "| F | Fusion of ventricular and normal beat |\n", "| Q | Unclassifiable beat |\n", "\n", " \n", "| Physionet's PTB Diagnostic Database | |\n", "|-|-|\n", "| Number of Samples | 14552 |\n", "| Number of Categories | 2 |\n", "| Sampling Frequency | 125Hz |\n", "\n", "Remark: All the samples are cropped, downsampled and padded with zeroes if necessary to the fixed dimension of **188**.\n", "\n", "Electrocardiogram (**ECG**) is a graphical representation of the electric activity of the heart and has been commonly used for cardiovascular disease diagnosis.\n", "\n", "Data Files\n", "\n", "This dataset consists of a series of CSV files. Each of these CSV files contain a matrix, with each row representing an example in that portion of the dataset. The final element of each row denotes the class to which that example belongs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = \"../../data\"\n", "signal_frequency = 125" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_csv(os.path.join(PATH, 'mitbih_train.csv'), header=None)\n", "test_df = pd.read_csv(os.path.join(PATH, 'mitbih_test.csv'), header=None)\n", "full_df = pd.concat([train_df, test_df], axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 2. Exploratory data analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df.shape, test_df.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "full_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are 109446 samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "full_df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### There are 87554 train samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df[187].value_counts(normalize=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### There are 21892 test samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_df.info()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_df[187].value_counts(normalize=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "M = full_df.values\n", "X = M[:, :-1]\n", "y = M[:, -1].astype(int)\n", "\n", "M_train = train_df.values\n", "XX_train = M_train[:, :-1]\n", "yy_train = M_train[:, -1].astype(int)\n", "\n", "M_test = test_df.values\n", "XX_test = M_test[:, :-1]\n", "yy_test = M_test[:, -1].astype(int)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Delete unused variable" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "del M, M_train, M_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 3. Visual analysis of the features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_count(data, x, ax, title='Data', xlabel='Class', tick_labesl=['N', 'S', 'V', 'F', 'Q']):\n", " ax1 = sns.countplot(x=x, data=data, ax=ax)\n", " ncount = len(data)\n", " for p in ax1.patches:\n", " x=p.get_bbox().get_points()[:,0]\n", " y=p.get_bbox().get_points()[1,1]\n", " ax1.annotate('{:.1f}%'.format(100.*y/ncount), (x.mean(), y), \n", " ha='center', va='bottom') # set the alignment of the text\n", " ax.set_xticklabels(['N', 'S', 'V', 'F', 'Q']);\n", " ax.set_title(title)\n", " ax.set_xlabel(xlabel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Distribution of all data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 5))\n", "plot_count(full_df, 187, ax, title='All Data')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Distribution of train/test data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_count(train_df, 187, ax[0], title='Train')\n", "plot_count(test_df, 187, ax[1], title='Test')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "C0 = np.argwhere(y == 0).flatten()\n", "C1 = np.argwhere(y == 1).flatten()\n", "C2 = np.argwhere(y == 2).flatten()\n", "C3 = np.argwhere(y == 3).flatten()\n", "C4 = np.argwhere(y == 4).flatten()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.arange(0, 187)*8/1000\n", "\n", "plt.figure(figsize=(20,8))\n", "plt.plot(x, X[C0, :][0], label=\"Cat. N\")\n", "plt.plot(x, X[C1, :][0], label=\"Cat. S\")\n", "plt.plot(x, X[C2, :][0], label=\"Cat. V\")\n", "plt.plot(x, X[C3, :][0], label=\"Cat. F\")\n", "plt.plot(x, X[C4, :][0], label=\"Cat. Q\")\n", "plt.legend()\n", "plt.title(\"1-beat ECG for every category\", fontsize=20)\n", "plt.ylabel(\"Amplitude\", fontsize=15)\n", "plt.grid(True)\n", "plt.xlabel(\"Time (ms)\", fontsize=15)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 4. Patterns, insights, pecularities of data " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Electrocardiogram (ECG) is a graphical representation of the electric activity of the heart and has been commonly used for cardiovascular disease diagnosis." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Some usefull functions (plot data, filter data, detect R-peaks, HR (heartrate))**:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy import signal\n", "from scipy.signal import butter, lfilter\n", "\n", "def butter_bandpass(lowcut, highcut, fs, order=5):\n", " \n", " nyq = 0.5 * fs\n", " low = lowcut / nyq\n", " high = highcut / nyq\n", " b, a = butter(order, [low, high], btype='band')\n", " return b, a\n", "\n", "\n", "def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): \n", " b, a = butter_bandpass(lowcut, highcut, fs, order=order)\n", " y = lfilter(b, a, data)\n", " return y\n", "\n", "\n", "def create_fir_filter(cutoff_hz, fs, n=0, window=\"hamming\"): \n", " N = int(fs / 2)\n", " if n == 0:\n", " n = N\n", " return signal.firwin(n, cutoff=cutoff_hz / N, window=window)\n", "\n", "\n", "def fir_filter(fir, data):\n", " return lfilter(fir, 1.0, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# code from https://github.com/c-labpl/qrs_detector\n", " \n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from time import gmtime, strftime\n", "from scipy.signal import butter, lfilter\n", "\n", "\n", "LOG_DIR = \"logs/\"\n", "PLOT_DIR = \"plots/\"\n", "\n", "\n", "class QRSDetectorOffline(object):\n", " \"\"\"\n", " Python Offline ECG QRS Detector based on the Pan-Tomkins algorithm.\n", " \n", " Michał Sznajder (Jagiellonian University) - technical contact (msznajder@gmail.com)\n", " Marta Łukowska (Jagiellonian University)\n", "\n", "\n", " The module is offline Python implementation of QRS complex detection in the ECG signal based\n", " on the Pan-Tomkins algorithm: Pan J, Tompkins W.J., A real-time QRS detection algorithm,\n", " IEEE Transactions on Biomedical Engineering, Vol. BME-32, No. 3, March 1985, pp. 230-236.\n", "\n", " The QRS complex corresponds to the depolarization of the right and left ventricles of the human heart. It is the most visually obvious part of the ECG signal. QRS complex detection is essential for time-domain ECG signal analyses, namely heart rate variability. It makes it possible to compute inter-beat interval (RR interval) values that correspond to the time between two consecutive R peaks. Thus, a QRS complex detector is an ECG-based heart contraction detector.\n", "\n", " Offline version detects QRS complexes in a pre-recorded ECG signal dataset (e.g. stored in .csv format).\n", "\n", " This implementation of a QRS Complex Detector is by no means a certified medical tool and should not be used in health monitoring. It was created and used for experimental purposes in psychophysiology and psychology.\n", "\n", " You can find more information in module documentation:\n", " https://github.com/c-labpl/qrs_detector\n", "\n", " If you use these modules in a research project, please consider citing it:\n", " https://zenodo.org/record/583770\n", "\n", " If you use these modules in any other project, please refer to MIT open-source license.\n", "\n", "\n", " MIT License\n", "\n", " Copyright (c) 2017 Michał Sznajder, Marta Łukowska\n", "\n", " Permission is hereby granted, free of charge, to any person obtaining a copy\n", " of this software and associated documentation files (the \"Software\"), to deal\n", " in the Software without restriction, including without limitation the rights\n", " to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", " copies of the Software, and to permit persons to whom the Software is\n", " furnished to do so, subject to the following conditions:\n", "\n", " The above copyright notice and this permission notice shall be included in all\n", " copies or substantial portions of the Software.\n", "\n", " THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", " IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", " FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", " AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", " LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", " OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", " SOFTWARE.\n", " \"\"\"\n", "\n", " def __init__(self, ecg_data_path, ecg_data_raw=None, signal_frequency=125,\n", " filter_lowcut=0.1, filter_highcut=15.0, filter_order=1,\n", " verbose=True, log_data=False, plot_data=False, show_plot=False,\n", " save_plot=False):\n", " \"\"\"\n", " QRSDetectorOffline class initialisation method.\n", " :param string ecg_data_path: path to the ECG dataset\n", " :param bool verbose: flag for printing the results\n", " :param bool log_data: flag for logging the results\n", " :param bool plot_data: flag for plotting the results to a file\n", " :param bool show_plot: flag for showing generated results plot - will not show anything if plot is not generated\n", " \"\"\"\n", " # Configuration parameters.\n", " self.ecg_data_path = ecg_data_path\n", "\n", " self.signal_frequency = signal_frequency # Set ECG device frequency in samples per second here.\n", "\n", " self.filter_lowcut = filter_lowcut\n", " self.filter_highcut = filter_highcut\n", " self.filter_order = filter_order\n", "\n", " self.integration_window = 8 #15 # Change proportionally when adjusting frequency (in samples).\n", "\n", " self.findpeaks_limit = 0.35\n", " self.findpeaks_spacing = 25 #50 # Change proportionally when adjusting frequency (in samples).\n", "\n", " self.refractory_period = 60 #120 # Change proportionally when adjusting frequency (in samples).\n", " self.qrs_peak_filtering_factor = 0.125\n", " self.noise_peak_filtering_factor = 0.125\n", " self.qrs_noise_diff_weight = 0.25\n", "\n", " # Loaded ECG data.\n", " self.ecg_data_raw = None\n", "\n", " # Measured and calculated values.\n", " self.filtered_ecg_measurements = None\n", " self.differentiated_ecg_measurements = None\n", " self.squared_ecg_measurements = None\n", " self.integrated_ecg_measurements = None\n", " self.detected_peaks_indices = None\n", " self.detected_peaks_values = None\n", "\n", " self.qrs_peak_value = 0.0\n", " self.noise_peak_value = 0.0\n", " self.threshold_value = 0.0\n", "\n", " # Detection results.\n", " self.qrs_peaks_indices = np.array([], dtype=int)\n", " self.noise_peaks_indices = np.array([], dtype=int)\n", "\n", " # Final ECG data and QRS detection results array - samples with detected QRS are marked with 1 value.\n", " self.ecg_data_detected = None\n", "\n", " # Run whole detector flow.\n", " if ecg_data_raw is not None:\n", " self.ecg_data_raw = ecg_data_raw \n", " else: \n", " self.load_ecg_data()\n", " self.detect_peaks()\n", " self.detect_qrs()\n", "\n", " if verbose:\n", " self.print_detection_data()\n", "\n", " if log_data:\n", " self.log_path = \"{:s}QRS_offline_detector_log_{:s}.csv\".format(LOG_DIR,\n", " strftime(\"%Y_%m_%d_%H_%M_%S\", gmtime()))\n", " self.log_detection_data()\n", "\n", " if plot_data:\n", " self.plot_path = \"{:s}QRS_offline_detector_plot_{:s}.png\".format(PLOT_DIR,\n", " strftime(\"%Y_%m_%d_%H_%M_%S\", gmtime()))\n", " self.plot_detection_data(show_plot=show_plot, save_plot=save_plot)\n", "\n", " \"\"\"Loading ECG measurements data methods.\"\"\"\n", "\n", " def load_ecg_data(self):\n", " \"\"\"\n", " Method loading ECG data set from a file.\n", " \"\"\"\n", " self.ecg_data_raw = np.loadtxt(self.ecg_data_path, skiprows=1, delimiter=',')\n", "\n", " \"\"\"ECG measurements data processing methods.\"\"\"\n", "\n", " def detect_peaks(self):\n", " \"\"\"\n", " Method responsible for extracting peaks from loaded ECG measurements data through measurements processing.\n", " \"\"\"\n", " # Extract measurements from loaded ECG data.\n", " ecg_measurements = self.ecg_data_raw#[:, 1]\n", "\n", " # Measurements filtering - 0-15 Hz band pass filter.\n", " self.filtered_ecg_measurements = self.bandpass_filter(ecg_measurements.flatten(), \n", " lowcut=self.filter_lowcut,\n", " highcut=self.filter_highcut, \n", " signal_freq=self.signal_frequency,\n", " filter_order=self.filter_order)\n", "# self.filtered_ecg_measurements[:5] = self.filtered_ecg_measurements[5]\n", "\n", " # Derivative - provides QRS slope information.\n", " self.differentiated_ecg_measurements = np.ediff1d(self.filtered_ecg_measurements)\n", "\n", " # Squaring - intensifies values received in derivative.\n", " self.squared_ecg_measurements = self.differentiated_ecg_measurements ** 2\n", "\n", " # Moving-window integration.\n", " self.integrated_ecg_measurements = np.convolve(\n", " self.squared_ecg_measurements, \n", " np.ones(self.integration_window))\n", "\n", " # Fiducial mark - peak detection on integrated measurements.\n", " self.detected_peaks_indices = self.findpeaks(data=self.integrated_ecg_measurements,\n", " limit=self.findpeaks_limit,\n", " spacing=self.findpeaks_spacing)\n", "\n", " self.detected_peaks_values = self.integrated_ecg_measurements[self.detected_peaks_indices]\n", "\n", " \"\"\"QRS detection methods.\"\"\"\n", "\n", " def detect_qrs(self):\n", " \"\"\"\n", " Method responsible for classifying detected ECG measurements peaks either as noise or as QRS complex (heart beat).\n", " \"\"\"\n", " if self.detected_peaks_indices is None:\n", " return;\n", " for detected_peak_index, detected_peaks_value in zip(self.detected_peaks_indices, self.detected_peaks_values):\n", "\n", " try:\n", " last_qrs_index = self.qrs_peaks_indices[-1]\n", " except IndexError:\n", " last_qrs_index = 0\n", "\n", " # After a valid QRS complex detection, there is a 200 ms refractory period before next one can be detected.\n", " if detected_peak_index - last_qrs_index > self.refractory_period or not self.qrs_peaks_indices.size:\n", " # Peak must be classified either as a noise peak or a QRS peak.\n", " # To be classified as a QRS peak it must exceed dynamically set threshold value.\n", " if detected_peaks_value > self.threshold_value:\n", " self.qrs_peaks_indices = np.append(self.qrs_peaks_indices, detected_peak_index)\n", "\n", " # Adjust QRS peak value used later for setting QRS-noise threshold.\n", " self.qrs_peak_value = self.qrs_peak_filtering_factor * detected_peaks_value + \\\n", " (1 - self.qrs_peak_filtering_factor) * self.qrs_peak_value\n", " else:\n", " self.noise_peaks_indices = np.append(self.noise_peaks_indices, detected_peak_index)\n", "\n", " # Adjust noise peak value used later for setting QRS-noise threshold.\n", " self.noise_peak_value = self.noise_peak_filtering_factor * detected_peaks_value + \\\n", " (1 - self.noise_peak_filtering_factor) * self.noise_peak_value\n", "\n", " # Adjust QRS-noise threshold value based on previously detected QRS or noise peaks value.\n", " self.threshold_value = self.noise_peak_value + \\\n", " self.qrs_noise_diff_weight * (self.qrs_peak_value - self.noise_peak_value)\n", "\n", " # Create array containing both input ECG measurements data and QRS detection indication column.\n", " # We mark QRS detection with '1' flag in 'qrs_detected' log column ('0' otherwise).\n", " measurement_qrs_detection_flag = np.zeros([len(self.ecg_data_raw), 1])\n", " measurement_qrs_detection_flag[self.qrs_peaks_indices] = 1\n", " self.ecg_data_detected = np.append(self.ecg_data_raw, measurement_qrs_detection_flag, 1)\n", "\n", " \"\"\"Results reporting methods.\"\"\"\n", "\n", " def print_detection_data(self):\n", " \"\"\"\n", " Method responsible for printing the results.\n", " \"\"\"\n", " print(\"qrs peaks indices\")\n", " print(self.qrs_peaks_indices)\n", " print(\"noise peaks indices\")\n", " print(self.noise_peaks_indices)\n", "\n", " def log_detection_data(self):\n", " \"\"\"\n", " Method responsible for logging measured ECG and detection results to a file.\n", " \"\"\"\n", " with open(self.log_path, \"wb\") as fin:\n", " fin.write(b\"timestamp,ecg_measurement,qrs_detected\\n\")\n", " np.savetxt(fin, self.ecg_data_detected, delimiter=\",\")\n", "\n", " def plot_detection_data(self, show_plot=False, save_plot=False):\n", " \"\"\"\n", " Method responsible for plotting detection results.\n", " :param bool show_plot: flag for plotting the results and showing plot\n", " \"\"\"\n", " def plot_data(axis, data, title='', fontsize=10):\n", " axis.set_title(title, fontsize=fontsize)\n", " axis.grid(which='both', axis='both', linestyle='--')\n", " axis.plot(data, color=\"salmon\", zorder=1)\n", "\n", " def plot_points(axis, values, indices):\n", " axis.scatter(x=indices, y=values[indices], c=\"black\", s=50, zorder=2)\n", "\n", " plt.close('all')\n", " fig, axarr = plt.subplots(6, sharex=True, figsize=(15, 18))\n", "\n", " plot_data(axis=axarr[0], data=self.ecg_data_raw, title='Raw ECG measurements')\n", " plot_data(axis=axarr[1], data=self.filtered_ecg_measurements, title='Filtered ECG measurements')\n", " plot_data(axis=axarr[2], data=self.differentiated_ecg_measurements, title='Differentiated ECG measurements')\n", " plot_data(axis=axarr[3], data=self.squared_ecg_measurements, title='Squared ECG measurements')\n", " plot_data(axis=axarr[4], data=self.integrated_ecg_measurements, title='Integrated ECG measurements with QRS peaks marked (black)')\n", " plot_points(axis=axarr[4], values=self.integrated_ecg_measurements, indices=self.qrs_peaks_indices)\n", " plot_data(axis=axarr[5], data=self.ecg_data_detected[:, 1], title='Raw ECG measurements with QRS peaks marked (black)')\n", " plot_points(axis=axarr[5], values=self.ecg_data_detected[:, 1], indices=self.qrs_peaks_indices)\n", "\n", " plt.tight_layout()\n", " if save_plot:\n", " fig.savefig(self.plot_path)\n", "\n", " if show_plot:\n", " plt.show()\n", "\n", " plt.close()\n", "\n", " \"\"\"Tools methods.\"\"\"\n", "\n", " def bandpass_filter(self, data, lowcut, highcut, signal_freq, filter_order):\n", " \"\"\"\n", " Method responsible for creating and applying Butterworth filter.\n", " :param deque data: raw data\n", " :param float lowcut: filter lowcut frequency value\n", " :param float highcut: filter highcut frequency value\n", " :param int signal_freq: signal frequency in samples per second (Hz)\n", " :param int filter_order: filter order\n", " :return array: filtered data\n", " \"\"\"\n", " nyquist_freq = 0.5 * signal_freq\n", " low = lowcut / nyquist_freq\n", " high = highcut / nyquist_freq\n", " b, a = butter(filter_order, [low, high], btype=\"band\")\n", " y = lfilter(b, a, data)\n", " return y\n", "\n", " def findpeaks(self, data, spacing=1, limit=None):\n", " \"\"\"\n", " Janko Slavic peak detection algorithm and implementation.\n", " https://github.com/jankoslavic/py-tools/tree/master/findpeaks\n", " Finds peaks in `data` which are of `spacing` width and >=`limit`.\n", " :param ndarray data: data\n", " :param float spacing: minimum spacing to the next peak (should be 1 or more)\n", " :param float limit: peaks should have value greater or equal\n", " :return array: detected peaks indexes array\n", " \"\"\"\n", " len = data.size\n", " x = np.zeros(len + 2 * spacing)\n", " x[:spacing] = data[0] - 1.e-6\n", " x[-spacing:] = data[-1] - 1.e-6\n", " x[spacing:spacing + len] = data\n", " peak_candidate = np.zeros(len)\n", " peak_candidate[:] = True\n", " for s in range(spacing):\n", " start = spacing - s - 1\n", " h_b = x[start: start + len] # before\n", " start = spacing\n", " h_c = x[start: start + len] # central\n", " start = spacing + s + 1\n", " h_a = x[start: start + len] # after\n", " peak_candidate = np.logical_and(peak_candidate, np.logical_and(h_c > h_b, h_c > h_a))\n", "\n", " ind = np.argwhere(peak_candidate)\n", " ind = ind.reshape(ind.size)\n", " if limit is not None:\n", " ind = ind[data[ind] > limit]\n", " return ind" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FilterHelper():\n", " def __init__(self, \n", " signal_frequency=125,\n", " filter_lowcut=0.01, \n", " filter_highcut=50.0, \n", " filter_order=1,\n", " cutoff_hz=15.0,\n", " verbose=False, \n", " log_data=False, \n", " plot_data=False, \n", " show_plot=False,\n", " save_plot=False):\n", " self.signal_frequency = signal_frequency\n", " self.filter_lowcut = filter_lowcut\n", " self.filter_highcut = filter_highcut\n", " self.filter_order = filter_order\n", " self.cutoff_hz=cutoff_hz\n", " self.verbose=verbose \n", " self.log_data=log_data\n", " self.plot_data=plot_data\n", " self.show_plot=show_plot\n", " self.save_plot=save_plot\n", " \n", " def rr_peaks_detect(self, x):\n", " qrs_detector = QRSDetectorOffline(ecg_data_raw=x.reshape(-1,1), \n", " filter_lowcut=self.filter_lowcut,\n", " filter_highcut=self.filter_highcut,\n", " filter_order=self.filter_order,\n", " signal_frequency=self.signal_frequency,\n", " ecg_data_path=None, \n", " verbose=self.verbose,\n", " log_data=self.log_data, \n", " plot_data=self.plot_data, \n", " show_plot=self.show_plot)\n", " return qrs_detector\n", "\n", " def qrs_filter(self, x):\n", " qrs_detector = QRSDetectorOffline(ecg_data_raw=x.reshape(-1,1), \n", " filter_lowcut=self.filter_lowcut,\n", " filter_highcut=self.filter_highcut,\n", " filter_order=self.filter_order,\n", " signal_frequency=self.signal_frequency,\n", " ecg_data_path=None, \n", " verbose=self.verbose,\n", " log_data=self.log_data, \n", " plot_data=self.plot_data, \n", " show_plot=self.show_plot)\n", " return qrs_detector.filtered_ecg_measurements.flatten()\n", "\n", " def fir_filter(self, x):\n", " # The cutoff frequency of the filter.\n", " flt = create_fir_filter(self.cutoff_hz, self.signal_frequency)\n", " return fir_filter(flt, x)\n", "\n", " def bandpass_filter(self, x):\n", " return butter_bandpass_filter(\n", " x, \n", " self.filter_lowcut, \n", " self.filter_highcut, \n", " self.signal_frequency, \n", " order=self.filter_order)\n", " \n", " def compose_filters(self, x):\n", " return self.fir_filter(self.bandpass_filter(x))\n", " \n", "def calculate_heart_rate(qrs_detector, signal_frequency=125):\n", " sec_per_minute = 60\n", " hr = 0\n", " i1 = qrs_detector.qrs_peaks_indices[0] if len(qrs_detector.qrs_peaks_indices) > 0 else 0\n", " i2 = qrs_detector.qrs_peaks_indices[1] if len(qrs_detector.qrs_peaks_indices) > 1 else 0\n", " if len(qrs_detector.qrs_peaks_indices) > 1:\n", " hr = signal_frequency / (i2 - i1) * sec_per_minute\n", " return hr, i1, i2\n", "\n", "def plot_ecg(data, filter_helper=None, show_filtered=True, show_hr=True, title='ECG', ax=None):\n", " if ax is not None:\n", " plt.axes(ax)\n", " else:\n", " plt.figure(figsize=(20,6))\n", " \n", " if show_filtered:\n", " flt = filter_function(data)\n", " \n", " plt.plot(data, color='b', label='Original ECG')\n", " \n", " if show_hr:\n", " qrsdetector = filter_helper.rr_peaks_detect(data)\n", " hr, i1, i2 = calculate_heart_rate(qrsdetector)\n", " print('Heart Rate: {}, R1: {}, R2: {}'.format(hr, i1, i2))\n", "\n", " plt.scatter(x=qrsdetector.qrs_peaks_indices, \n", " y=qrsdetector.ecg_data_detected[:, 1][qrsdetector.qrs_peaks_indices], \n", " c=\"black\", \n", " s=50, \n", " zorder=2)\n", "\n", " if show_filtered:\n", " plt.plot(flt, color='r', label='Filtered')\n", " plt.title(title, fontsize=16)\n", " plt.grid(True)\n", " plt.legend()\n", " if ax is None:\n", " plt.show();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Examples of ECG for each class" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X[np.argwhere(y==0)[0][0], :], ax=axes[0],\n", " title='Normal ECG', show_filtered=False, show_hr=False)\n", "plot_ecg(X[np.argwhere(y==1)[0][0], :], ax=axes[1],\n", " title='Class: Supraventricular premature beat', \n", " show_filtered=False, show_hr=False)\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X[np.argwhere(y==2)[0][0], :], ax=axes[0],\n", " title='Class: Premature ventricular contraction', \n", " show_filtered=False, show_hr=False)\n", "plot_ecg(X[np.argwhere(y==3)[0][0], :], ax=axes[1],\n", " title='Class: Fusion of ventricular and normal beat', \n", " show_filtered=False, show_hr=False)\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X[np.argwhere(y==4)[0][0], :], ax=axes[0],\n", " title='Class: Unclassifiable beat', \n", " show_filtered=False, show_hr=False)\n", "fig.delaxes(axes[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 5. Data preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The baseline of the signal is substracted. Additionally, some noise removal can be done.\n", "\n", "Two median filters are applied for this purpose.\n", "Some ideas from [this site](https://github.com/mondejar/ecg-classification)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def baseline_ecg(ecg):\n", " # Remove Baseline\n", " baseline = medfilt(ecg, 71) \n", " baseline = medfilt(baseline, 215) \n", " return ecg - baseline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Example of ECG with substracted baseline:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orig_ecg = X[C4, :][0]\n", "\n", "flt_bl = baseline_ecg(orig_ecg)\n", " \n", "plt.figure(figsize=(20,8))\n", "plt.plot(x, orig_ecg, label=\"Original\")\n", "plt.plot(x, flt_bl, label=\"Baselined\")\n", "plt.legend()\n", "plt.title(\"'Baselined' 1-beat ECG\", fontsize=20)\n", "plt.ylabel(\"Amplitude\", fontsize=15)\n", "plt.grid(True)\n", "plt.xlabel(\"Time (ms)\", fontsize=15)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "XX_train_bl = np.apply_along_axis(baseline_ecg, 1, XX_train)\n", "XX_test_bl = np.apply_along_axis(baseline_ecg, 1, XX_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "XX_train_bl.shape, XX_train.shape, XX_test.shape, XX_test_bl.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Data augmentation**\n", "\n", "To train properly the model, we sould have to augment all data to the same level. Nevertheless, for a first try, we will just augment the smallest class to the same level as class 1.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def stretch(x):\n", " l = int(187 * (1 + (random.random()-0.5)/3))\n", " y = resample(x, l)\n", " if l < 187:\n", " y_ = np.zeros(shape=(187, ))\n", " y_[:l] = y\n", " else:\n", " y_ = y[:187]\n", " return y_\n", "\n", "def amplify(x):\n", " alpha = (random.random()-0.5)\n", " factor = -alpha*x + (1+alpha)\n", " return x*factor\n", "\n", "def augment(x):\n", " result = np.zeros(shape= (4, 187))\n", " for i in range(3):\n", " if random.random() < 0.33:\n", " new_y = stretch(x)\n", " elif random.random() < 0.66:\n", " new_y = amplify(x)\n", " else:\n", " new_y = stretch(x)\n", " new_y = amplify(new_y)\n", " result[i, :] = new_y\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(20,10))\n", "plt.plot(X[0, :], label='original')\n", "plt.plot(amplify(X[0, :]), label='amplify')\n", "plt.plot(stretch(X[0, :]), label='stretch')\n", "plt.title(\"'Amplified/Stretched' ECG\", fontsize=20)\n", "plt.ylabel(\"Amplitude\", fontsize=15)\n", "plt.xlabel(\"Time (ms)\", fontsize=15)\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "result = np.apply_along_axis(augment, axis=1, arr=X[C3]).reshape(-1, 187)\n", "classes = np.ones(shape=(result.shape[0],), dtype=int)*3\n", "X = np.vstack([X, result])\n", "y = np.hstack([y, classes])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "subC0 = np.random.choice(C0, 800)\n", "subC1 = np.random.choice(C1, 800)\n", "subC2 = np.random.choice(C2, 800)\n", "subC3 = np.random.choice(C3, 800)\n", "subC4 = np.random.choice(C4, 800)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "use800_test = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if use800_test:\n", " X_test = np.vstack([X[subC0], X[subC1], X[subC2], X[subC3], X[subC4]])\n", " y_test = np.hstack([y[subC0], y[subC1], y[subC2], y[subC3], y[subC4]])\n", "\n", " X_train = np.delete(X, [subC0, subC1, subC2, subC3, subC4], axis=0)\n", " y_train = np.delete(y, [subC0, subC1, subC2, subC3, subC4], axis=0)\n", "else:\n", " X_test_orig = XX_test\n", " X_test = XX_test_bl #XX_test\n", " y_test = yy_test\n", "\n", " X_train_orig = XX_train\n", " X_train = XX_train_bl #XX_train\n", " y_train = yy_train\n", "\n", "X_train, y_train, X_train_orig = shuffle(X_train, y_train, X_train_orig, random_state=17)\n", "X_test, y_test, X_test_orig = shuffle(X_test, y_test, X_test_orig, random_state=17)\n", "\n", "\n", "del X\n", "del y\n", "del XX_train, XX_test, XX_train_bl, XX_test_bl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### X_train_orig, X_test_orig - original(raw) data\n", "\n", "#### X_train, X_test - 'baselined' data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ptrain_df = pd.DataFrame(y_train, columns=['class'])\n", "ptest_df = pd.DataFrame(y_test, columns=['class'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Class counts of train/test data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(ptrain_df['class'].value_counts())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(ptest_df['class'].value_counts())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_count(ptrain_df, 'class', ax[0], title='Train')\n", "plot_count(ptest_df, 'class', ax[1], title='Test')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# delete unneccessary variables\n", "del ptrain_df\n", "del ptest_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 6. Feature engineering and description " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Filter ECG signal using Butterworth filter**\n", "\n", "The high-pass and low-pass filters together are known as a bandpass filter, literally allowing only a certain frequency band to pass through. We will use bandpass filter with cutoff=50." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "filter_helper = FilterHelper(\n", " signal_frequency=125,\n", " filter_lowcut=0.01, \n", " filter_highcut=50.0, \n", " filter_order=1,\n", " cutoff_hz=30.,\n", " verbose=False, \n", " log_data=False, \n", " plot_data=False, \n", " show_plot=False,\n", " save_plot=False)\n", "\n", "filter_function = filter_helper.bandpass_filter\n", "\n", "\n", "def calc_hr(data):\n", " qrsdetector = filter_helper.rr_peaks_detect(data)\n", " hr, i1, i2 = calculate_heart_rate(qrsdetector)\n", " return hr, i1, i2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A typical ECG-based heartbeat mainly consists of three waves including P-wave, QRS complex ([wiki](https://en.wikipedia.org/wiki/QRS_complex)), and T-wave. The QRS complex is the most prominent feature and it can be used to obtain additional useful clinical information from ECG signals, such as RR interval, QT interval, and PR interval, etc. Thus, QRS detection is critical for ECG-based health evaluation.\n", "\n", "The QRS complex is the most noticeable feature in the electrocardiogram (ECG) signal, therefore, its detection is critical for ECG signal analysis.\n", "\n", "See for more detail info - [Electrocardiography](https://en.wikipedia.org/wiki/Electrocardiography)\n", "\n", "We use **QRSDetectorOffline** for detecting R-peaks (QRS)." ] }, { "attachments": { "%D0%B8%D0%B7%D0%BE%D0%B1%D1%80%D0%B0%D0%B6%D0%B5%D0%BD%D0%B8%D0%B5.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "![%D0%B8%D0%B7%D0%BE%D0%B1%D1%80%D0%B0%D0%B6%D0%B5%D0%BD%D0%B8%D0%B5.png](attachment:%D0%B8%D0%B7%D0%BE%D0%B1%D1%80%D0%B0%D0%B6%D0%B5%D0%BD%D0%B8%D0%B5.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Examples of **original/filtered** ECGs with **R-peaks** (QRS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X_train[np.argwhere(y_train==0)[0][0], :], filter_helper, ax=axes[0],\n", " title='Normal ECG', \n", " show_filtered=True, show_hr=True)\n", "plot_ecg(X_train[np.argwhere(y_train==1)[0][0], :], filter_helper, ax=axes[1],\n", " title='Class: Supraventricular premature beat', \n", " show_filtered=True, show_hr=True)\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X_train[np.argwhere(y_train==2)[0][0], :], filter_helper, ax=axes[0],\n", " title='Class: Premature ventricular contraction', \n", " show_filtered=True, show_hr=True)\n", "plot_ecg(X_train[np.argwhere(y_train==3)[0][0], :], filter_helper, ax=axes[1],\n", " title='Class: Fusion of ventricular and normal beat', \n", " show_filtered=True, show_hr=True)\n", "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "plot_ecg(X_train[np.argwhere(y_train==4)[0][0], :], filter_helper, ax=axes[0],\n", " title='Class: Unclassifiable beat', \n", " show_filtered=True, show_hr=True)\n", "fig.delaxes(axes[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_hr = np.apply_along_axis(calc_hr, 1, X_train)\n", "X_test_hr = np.apply_along_axis(calc_hr, 1, X_test)\n", "\n", "scaler_hr = StandardScaler()\n", "X_train_hr = scaler_hr.fit_transform(X_train_hr)\n", "X_test_hr = scaler_hr.transform(X_test_hr)\n", "\n", "X_train_hr.shape, X_test_hr.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_filtered = np.apply_along_axis(filter_function, 1, X_train)\n", "X_test_filtered = np.apply_along_axis(filter_function, 1, X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**X_train_filtered, X_test_filtered - filtered data**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"X_train\", X_train.shape)\n", "print(\"X_train_filtered\", X_train_filtered.shape)\n", "print(\"y_train\", y_train.shape)\n", "print(\"X_test\", X_test.shape)\n", "print(\"X_test_filtered\", X_test_filtered.shape)\n", "print(\"y_test\", y_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Wavelets**\n", "\n", "The wavelet transforms have the capability to allow information extraction from both frequency and time domains, which make them suitable for ECG description. The signal is decomposed using wave_decomposition function using family db1 and 3 levels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_wavelet_descriptor(beat, family='db1', level=3):\n", " \"\"\" Compute the wavelet for a ecg \"\"\"\n", " wave_family = pywt.Wavelet(family)\n", " coeffs = pywt.wavedec(beat, wave_family, level=level)\n", " return coeffs[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_wv_only = np.apply_along_axis(compute_wavelet_descriptor, 1, X_train_orig)\n", "X_test_wv_only = np.apply_along_axis(compute_wavelet_descriptor, 1, X_test_orig)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "X_train_wv_only_scaled = scaler.fit_transform(X_train_wv_only)\n", "X_test_wv_only_scaled = scaler.transform(X_test_wv_only)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Heart rate feature" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_orig_hr = np.hstack([X_train_orig, X_train_hr])\n", "X_test_orig_hr = np.hstack([X_test_orig, X_test_hr])\n", "X_train_orig_hr.shape, X_test_orig_hr.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_wv_scaled = np.hstack([X_train, X_train_wv_only_scaled, X_train_hr])\n", "X_test_wv_scaled = np.hstack([X_test, X_test_wv_only_scaled, X_test_hr])\n", "X_train_wv_scaled.shape, X_test_wv_scaled.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_wv_filtered_scaled = np.hstack([X_train_filtered, X_train_wv_only_scaled, X_train_hr])\n", "X_test_wv_filtered_scaled = np.hstack([X_test_filtered, X_test_wv_only_scaled, X_test_hr])\n", "X_train_wv_filtered_scaled.shape, X_test_wv_filtered_scaled.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**All our DataFrames:**\n", "\n", "| DatFrame | Type of data |\n", "|-|-|\n", "| X_train_orig, X_test_orig | Original(raw) data |\n", "| X_train, X_test | Baselined data |\n", "| X_train_orig_hr, X_test_orig_hr | Original data + HR |\n", "| X_train_filtered, X_test_filtered | Baselined & Filtered data |\n", "| X_train_wv_scaled, X_test_wv_scaled | Baselined data with wavelet coefficients + HR |\n", "| X_train_wv_filtered_scaled, X_test_wv_filtered_scaled | Baselined & Filtered data with wavelet coefficients + HR |" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_test_for_pred = X_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 7. Cross-validation, hyperparameter tuning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Some usefull model creation/training functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def grid_clf(estimator, param_grid, Xtrain, ytrain, cv, scoring='accuracy'):\n", " \"\"\"\n", " CV using GridSearchCV for given estimator.\n", " \"\"\" \n", " grid_nn = GridSearchCV(estimator=estimator, scoring=scoring, param_grid=param_grid, cv=cv)\n", " grid_result_nn = grid_nn.fit(Xtrain, ytrain)\n", " \n", " # summarize results\n", " print(\"Best: %f using %s\" % (grid_result_nn.best_score_, grid_result_nn.best_params_))\n", " means = grid_result_nn.cv_results_['mean_test_score']\n", " stds = grid_result_nn.cv_results_['std_test_score']\n", " params = grid_result_nn.cv_results_['params']\n", " for mean, stdev, param in zip(means, stds, params):\n", " print(\"%f (%f) with: %r\" % (mean, stdev, param))\n", " \n", " return grid_result_nn\n", " \n", "def grid_logit_model(\n", " Xtrain, \n", " ytrain, \n", " scoring='accuracy',\n", " cv=StratifiedKFold(n_splits=3),\n", " param_grid=None):\n", " \"\"\"\n", " CV using GridSearchCV for LogisticRegression model.\n", " \"\"\" \n", " clf = LogisticRegression(\n", " multi_class='ovr', \n", " solver='saga',\n", " random_state=17, \n", " n_jobs=-1)\n", "\n", " if param_grid is None:\n", " Cs = [1, 0.01]\n", " param_grid = dict(\n", " C=Cs,\n", " multi_class=['ovr']\n", " )\n", " \n", " return grid_clf(clf, param_grid, Xtrain, ytrain, cv, scoring=scoring)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_confusion_matrix(cm, classes,\n", " normalize=False,\n", " title='Confusion matrix',\n", " cmap=plt.cm.Blues):\n", " \"\"\"\n", " This function prints and plots the confusion matrix.\n", " Normalization can be applied by setting `normalize=True`.\n", " \"\"\"\n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", "\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " plt.colorbar()\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, rotation=45)\n", " plt.yticks(tick_marks, classes)\n", "\n", " fmt = '.2f' if normalize else 'd'\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt),\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.tight_layout()\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_total_report(ypred, ytrain, ytest, additional_title='', \n", " show_classification_report = True, show_confusion_plot=True):\n", " \"\"\"\n", " Show base classification metrics.\n", " \"\"\" \n", "\n", " ohe = OneHotEncoder()\n", " ytrain_ = ohe.fit_transform(ytrain.reshape(-1,1))\n", " ytest_ = ohe.transform(ytest.reshape(-1,1))\n", " ypred_ = ohe.transform(ypred.reshape(-1,1))\n", " \n", " print(\"ranking-based average precision : {:.3f}\".format(\n", " label_ranking_average_precision_score(ytest_.todense(), ypred_.todense())))\n", " print(\"Ranking loss : {:.3f}\".format(label_ranking_loss(ytest_, ypred_.todense())))\n", " print(\"Coverage_error : {:.3f}\".format(coverage_error(ytest_.todense(), ypred_.todense())))\n", "\n", " if show_classification_report:\n", " print(classification_report(ytest_.toarray().argmax(axis=1), ypred_.argmax(axis=1)))\n", "\n", " if show_confusion_plot:\n", " # Compute confusion matrix\n", " cnf_matrix = confusion_matrix(ytest_.argmax(axis=1), ypred_.argmax(axis=1))\n", " np.set_printoptions(precision=2)\n", "\n", " # Plot non-normalized confusion matrix\n", " plt.figure(figsize=(10, 10))\n", " plot_confusion_matrix(cnf_matrix, classes=['N', 'S', 'V', 'F', 'Q'],\n", " title=additional_title + ' Confusion matrix, without normalization')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
**Train LogisticRegression model**\n", "\n", "Our base model is LogisticRegression. For the multiclass we use LogisticRegression with a one-vs-one scheme." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_clf_results(\n", " clf, \n", " Xtest, \n", " ytest, \n", " ytrain,\n", " additional_title='',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True):\n", " \n", " ypred = clf.predict(Xtest)\n", " \n", " if show_metrics:\n", " print('accuracy: {}'.format(accuracy_score(y_pred=ypred, y_true=ytest)))\n", " print('precision: {}'.format(precision_score(y_pred=ypred, y_true=ytest, average='macro'))) \n", " \n", " if show_classification_report or show_confusion_plot:\n", " show_total_report(ypred, ytrain, ytest, additional_title=additional_title, \n", " show_classification_report = show_classification_report, \n", " show_confusion_plot=show_confusion_plot)\n", " ypred\n", " \n", "def train_logit(\n", " Xtrain, \n", " ytrain, \n", " Xtest, \n", " ytest, \n", " C=1,\n", " multi_class='ovr',\n", " additional_title='[Logit]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True):\n", " \n", " logit = LogisticRegression(\n", " C=C,\n", " multi_class=multi_class, \n", " random_state=17, \n", " n_jobs=-1)\n", "\n", " logit.fit(Xtrain, ytrain) \n", " \n", " return show_clf_results(\n", " logit, \n", " Xtest, \n", " ytest, \n", " ytrain, \n", " additional_title,\n", " show_metrics, \n", " show_classification_report, \n", " show_confusion_plot\n", " ), logit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
Let's try LogisticRegression with original data - it is our **baseline**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_orig,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit, logit = train_logit(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " C=1,\n", " additional_title='[Logit]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For base LogisticRegression with **original** data we have:\n", "\n", "Accuracy: **0.905**\n", "\n", "Precision: **0.821**\n", "\n", "Ranking-based average precision : **0.924**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "**HR features improve model quality:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_orig_hr,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit, logit = train_logit(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " C=1,\n", " additional_title='[Logit HR]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### LogisticRegression with **Baseline+Filtered** data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_filtered,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit, logit = train_logit(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " C=1,\n", " additional_title='[Logit Filtered]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LogisticRegression with **Baseline+Filtered** data we have:\n", "\n", "Accuracy: **0.911**!\n", "\n", "Precision: **0.830**!\n", "\n", "Ranking-based average precision : **0.929**!\n", "\n", "So using Baseline & Filter functions improves model quality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### LogisticRegression with **Baselined+Filtered+Wavelet+HR** data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_wv_filtered_scaled,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit, logit = train_logit(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " C=1,\n", " additional_title='[Logit BL+Filtered_WV+HR]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LogisticRegression with **Baselined+Filtered+Wavelet+HR** data we have:\n", "\n", "Accuracy: **0.933!**\n", "\n", "Precision: **0.882!**\n", "\n", "Ranking-based average precision : **0.947!**\n", "\n", "And Wavelets also add some improvement." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " #### LogisticRegression with **Baselined + Wavelet + HR** data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_wv_scaled,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit_flt, logit_flt = train_logit(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " C=1,\n", " additional_title='[Scaled BL+Wavelet+HR Logit]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For **BL + Wavelet + HR** data with LogisticRegression results are:\n", "\n", "Accuracy: **0.933**\n", "\n", "Precision: **0.882**\n", "\n", "Ranking-based average precision : **0.947**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "#### The number of samples in both collections is large enough for training a deep neural network!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class KerasModelCreator():\n", " def __init__(\n", " self, \n", " feature,\n", " depth,\n", " filters=32, #32\n", " pool_size=5, # 5\n", " random_state=17):\n", " \n", " self.feature=feature\n", " self.depth=depth\n", " self.filters=filters\n", " self.pool_size=pool_size\n", " self.random_state=random_state\n", " \n", " def __call__(\n", " self, \n", " optimizer='adam', \n", " init='glorot_uniform', \n", " loss='categorical_crossentropy', #'sparse_categorical_crossentropy'\n", " metrics=['accuracy'],\n", " **sk_params):\n", " \n", " seed(self.random_state)\n", " set_random_seed(self.random_state)\n", "\n", " filters=self.filters\n", " pool_size=self.pool_size\n", "\n", " inp = Input(shape=(self.feature, self.depth))\n", " C = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init)(inp)\n", "\n", " C11 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(C)\n", " A11 = Activation(\"relu\")(C11)\n", " C12 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(A11)\n", " S11 = Add()([C12, C])\n", " A12 = Activation(\"relu\")(S11)\n", " M11 = MaxPooling1D(pool_size=pool_size, strides=2)(A12)\n", "\n", "\n", " C21 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(M11)\n", " A21 = Activation(\"relu\")(C21)\n", " C22 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(A21)\n", " S21 = Add()([C22, M11])\n", " A22 = Activation(\"relu\")(S11)\n", " M21 = MaxPooling1D(pool_size=pool_size, strides=2)(A22)\n", "\n", "\n", " C31 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(M21)\n", " A31 = Activation(\"relu\")(C31)\n", " C32 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(A31)\n", " S31 = Add()([C32, M21])\n", " A32 = Activation(\"relu\")(S31)\n", " M31 = MaxPooling1D(pool_size=pool_size, strides=2)(A32)\n", "\n", "\n", " C41 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(M31)\n", " A41 = Activation(\"relu\")(C41)\n", " C42 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(A41)\n", " S41 = Add()([C42, M31])\n", " A42 = Activation(\"relu\")(S41)\n", " M41 = MaxPooling1D(pool_size=pool_size, strides=2)(A42)\n", "\n", "\n", " C51 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(M41)\n", " A51 = Activation(\"relu\")(C51)\n", " C52 = Conv1D(filters=filters, kernel_size=5, strides=1, kernel_initializer=init, padding='same')(A51)\n", " S51 = Add()([C52, M41])\n", " A52 = Activation(\"relu\")(S51)\n", " M51 = MaxPooling1D(pool_size=pool_size, strides=2)(A52)\n", "\n", " F1 = Flatten()(M51)\n", "\n", " D1 = Dense(filters, kernel_initializer=init)(F1)\n", " A6 = Activation(\"relu\")(D1)\n", " D2 = Dense(filters, kernel_initializer=init)(A6)\n", " D3 = Dense(5, kernel_initializer=init)(D2)\n", " A7 = Softmax()(D3)\n", "\n", " model = Model(inputs=inp, outputs=A7)\n", " model.compile(loss=loss, optimizer=optimizer, metrics=metrics, **sk_params)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def grid_nn_model( \n", " Xtrain, \n", " ytrain, \n", " scoring='accuracy',\n", " batch_size=100,\n", " epochs=15,\n", " filters=32,\n", " pool_size=5,\n", " cv=StratifiedKFold(n_splits=3),\n", " param_grid=None):\n", " \"\"\"\n", " CV using GridSearchCV for NN model.\n", " \"\"\" \n", " n_obs, feature, depth = Xtrain.shape\n", " model_creator = KerasModelCreator(\n", " feature=feature, \n", " depth=depth,\n", " filters=filters,\n", " pool_size=pool_size,\n", " random_state=17)\n", "\n", " clf = KerasClassifier(build_fn=model_creator, \n", " epochs=2, \n", " batch_size=batch_size, \n", " verbose=1)\n", " if param_grid is None:\n", " optimizers = ['rmsprop', 'adam']\n", " init = ['glorot_uniform', 'normal', 'uniform']\n", " metrics=[['accuracy']]#'['categorical_accuracy', 'accuracy']\n", " epochs = [1, 2]\n", " batches = [50, 100]\n", " param_grid = dict(\n", " optimizer=optimizers, \n", " epochs=epochs, \n", " batch_size=batches, \n", " init=init,\n", "# metrics=metrics\n", " )\n", " \n", " return grid_clf(clf, param_grid, Xtrain, ytrain, cv, scoring=scoring)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def exp_decay(epoch):\n", " initial_lrate = 0.001\n", " k = 0.75\n", " t = n_obs//(10000 * batch_size) # every epoch we do n_obs/batch_size iteration\n", " lrate = initial_lrate * math.exp(-k*t)\n", " return lrate\n", "\n", "lrate = LearningRateScheduler(exp_decay)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use sklearn wrapper from keras framework - **KerasClassifier**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_nn_model(\n", " Xtrain, \n", " ytrain, \n", " Xtest, \n", " ytest,\n", " batch_size=100,\n", " epochs=15,\n", " filters=32,\n", " pool_size=5,\n", " additional_title='[NN]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " ):\n", " \n", " seed(17)\n", " set_random_seed(17)\n", " \n", " n_obs, feature, depth = Xtrain.shape\n", " model_creator = KerasModelCreator(\n", " feature=feature, \n", " depth=depth,\n", " filters=filters,\n", " pool_size=pool_size,\n", " random_state=17)\n", " \n", " model_nn1 = KerasClassifier(\n", " build_fn=model_creator, \n", " epochs=epochs, \n", " batch_size=batch_size, \n", " verbose=0)\n", " \n", " history = model_nn1.fit(Xtrain,\n", " ytrain, \n", " epochs=epochs,\n", " batch_size=batch_size, \n", " verbose=2, \n", " # validation_data=(X_test, y_test_nn), \n", " # callbacks=[lrate]\n", " )\n", " \n", " return show_clf_results(\n", " model_nn1, \n", " Xtest, \n", " ytest, \n", " ytrain, \n", " additional_title,\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False\n", " ), model_nn1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NN:\n", "ohe = OneHotEncoder()\n", "y_train_nn = ohe.fit_transform(y_train.reshape(-1,1))\n", "y_test_nn = ohe.transform(y_test.reshape(-1,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "EPOCHS = 2\n", "BATCH_SIZE = 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Our first **Deep Neural Net** with original data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_nn1 = np.expand_dims(X_train_orig, 2)\n", "X_test_nn1 = np.expand_dims(X_test_orig, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn1,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_nn1, model_nn1 = train_nn_model(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " additional_title='[NN]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For **original** data with DNN model we have:\n", "\n", "Accuracy: **0.974**\n", "\n", "Precision: **0.957**\n", "\n", "Ranking-based average precision : **0.979**\n", "\n", "
Base DNN model is better than LogisticRegression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NN with **Baselined+Filtered** data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_nn2 = np.expand_dims(X_train_filtered, 2)\n", "X_test_nn2 = np.expand_dims(X_test_filtered, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn2,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_nn2, model_nn2 = train_nn_model(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid, \n", " additional_title='[NN Filtered data]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For DNN model with **Baselined+Filtered** data we have:\n", "\n", "Accuracy: **0.972**\n", "\n", "Precision: **0.957**\n", "\n", "Ranking-based average precision : **0.977**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Neural Network with **Baselined+Filtered+Wavelet** data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_nn3 = np.expand_dims(X_train_wv_filtered_scaled, 2)\n", "X_test_nn3 = np.expand_dims(X_test_wv_filtered_scaled, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn3,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_nn3, model_nn3 = train_nn_model(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid,\n", " additional_title='[NN BL+Filtered+Wavelet]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For DNN model with **Baselined+Filtered+Wavelet** data we have:\n", "\n", "Accuracy: **0.972**\n", "\n", "Precision: **0.966**\n", "\n", "Ranking-based average precision : **0.978**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Neural Network with **Filtered+Wavelet** data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_nn4 = np.expand_dims(X_train_wv_scaled, 2)\n", "X_test_nn4 = np.expand_dims(X_test_wv_scaled, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn4,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_nn4, model_nn4 = train_nn_model(\n", " X_train_part, \n", " y_train_part, \n", " X_valid, \n", " y_valid,\n", " additional_title='[NN Filtered+Wavelet]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For DNN model with **Filtered+Wavelet** data we have:\n", "\n", "Accuracy: **0.975**\n", "\n", "Precision: **0.960**\n", "\n", "Ranking-based average precision : **0.980**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**LogisticRegression cross-validation**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "skf = StratifiedKFold(n_splits=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_with_err(x, mu, std, legend=False, grid=False, title=None, xlabel=None, **kwargs):\n", " lines = plt.plot(x, mu, '-', **kwargs)\n", " plt.fill_between(x, mu - std, mu + std, edgecolor='none',\n", " facecolor=lines[0].get_color(), alpha=0.2)\n", " if legend:\n", " plt.legend()\n", " if grid:\n", " plt.grid(True);\n", " if title is not None:\n", " plt.title(title)\n", " if xlabel is not None:\n", " plt.xlabel(xlabel)\n", " \n", "def plot_learning_curve(extimator, X, y, cv):\n", " train_sizes = np.linspace(0.05, 1, 20)\n", " N_train, val_train, val_test = learning_curve(estimator,\n", " X, y, \n", " train_sizes=train_sizes, \n", " cv=cv,\n", " scoring='accuracy')\n", " plot_with_err(N_train, val_train.mean(1), val_train.std(1), label='training scores')\n", " plot_with_err(N_train, val_test.mean(1), val_test.std(1), label='validation scores')\n", " plt.xlabel('Training Set Size'); \n", " plt.ylabel('Accuracy')\n", " plt.legend()\n", " plt.grid(True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### LogisticRegression for our best data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "param_grid = dict(\n", " C=np.arange(1, 3, 0.25),\n", " multi_class=['ovr'],\n", " )\n", "\n", "\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_wv_filtered_scaled,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "grid_logit = grid_logit_model(\n", " X_train_part, \n", " y_train_part, \n", " cv=skf,\n", " param_grid=param_grid)\n", "\n", "print(grid_logit.best_params_)\n", "\n", "y_pred_logit_gr = show_clf_results(\n", " grid_logit, \n", " X_valid, \n", " y_valid, \n", " y_train, \n", " additional_title='[Logit CV]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So optimal C value is **2.5**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "C = 2.5\n", "grid_logit.best_params_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# with open('grid_logit.pkl', 'wb') as f:\n", "# pickle.dump(grid_logit, file=f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Neural Net cross-validation**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn3,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "param_nn_grid = dict(\n", " epochs=[1, 2, 10, 30, 50], \n", " batch_size=[50], \n", ")\n", "\n", "grid_nn = grid_nn_model(\n", " X_train_part, \n", " y_train_part, \n", " cv=StratifiedKFold(n_splits=3),\n", " param_grid=param_nn_grid)\n", "\n", "grid_nn.best_params_\n", "\n", "y_pred_nn_gr = show_clf_results(\n", " grid_nn, \n", " X_valid, \n", " y_valid, \n", " y_train, \n", " additional_title='[NN CV]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**So optimal epochs value is 30**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid_nn.best_params_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# with open('grid_nn.pkl', 'wb') as f:\n", "# pickle.dump(grid_nn, file=f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 8. Validation and learning curves" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### LogisticRegression validation/learning curves" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_with_err(grid_logit.param_grid['C'], \n", " grid_logit.cv_results_['mean_train_score'], \n", " grid_logit.cv_results_['std_train_score'], \n", " legend=True,\n", " grid=True,\n", " xlabel='C',\n", " label='Training scores')\n", "plot_with_err(grid_logit.param_grid['C'], \n", " grid_logit.cv_results_['mean_test_score'], \n", " grid_logit.cv_results_['std_test_score'], \n", " legend=True, \n", " grid=True,\n", " xlabel='C',\n", " title='Logit',\n", " label='Test scores')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NN validation/learning curves" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_with_err(grid_nn.param_grid['epochs'], \n", " grid_nn.cv_results_['mean_train_score'], \n", " grid_nn.cv_results_['std_train_score'], \n", " legend=True,\n", " grid=True,\n", " xlabel='epochs',\n", " label='Training scores')\n", "plot_with_err(grid_nn.param_grid['epochs'], \n", " grid_nn.cv_results_['mean_test_score'], \n", " grid_nn.cv_results_['std_test_score'], \n", " legend=True, \n", " grid=True,\n", " xlabel='epochs',\n", " title='NN',\n", " label='Test scores')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 9. Prediction for hold-out and test samples " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Build model using full train data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# grid_logit_full = grid_logit.fit(\n", "# X_train_wv_filtered_scaled,\n", "# y_train)\n", "# our train/valid for logit\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_wv_filtered_scaled,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_logit_full, grid_logit_full = train_logit(\n", " X_train_wv_filtered_scaled, \n", " y_train, \n", " X_test_wv_filtered_scaled, \n", " y_test, \n", " C=grid_logit.best_params_['C'],\n", " additional_title='[Logit BL+Filtered_WV+HR]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**LogisticRegresion hold-out results:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred_logit_valid_full = show_clf_results(\n", " grid_logit_full, \n", " X_valid, \n", " y_valid, \n", " y_train, \n", " additional_title='[Logit hold-out]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**LogisticRegresion test results:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred_logit_test_full = show_clf_results(\n", " grid_logit_full, \n", " X_test_wv_filtered_scaled,\n", " y_test,\n", " y_train,\n", " additional_title='[Logit Test]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "#### Build NN model using full train data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "X_train_part, X_valid, y_train_part, y_valid = \\\n", " train_test_split(\n", " X_train_nn3,\n", " y_train,\n", " test_size=0.3,\n", " random_state=17)\n", "\n", "y_pred_nn3_full, grid_nn_full = train_nn_model(\n", " X_train_nn3, \n", " y_train, \n", " X_test_nn3,\n", " y_test,\n", " additional_title='[NN BL+Filtered+Wavelet]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=False,\n", " batch_size=grid_nn.best_params_['batch_size'], \n", " epochs=grid_nn.best_params_['epochs'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**NN hold-out results:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred_nn_valid_full = show_clf_results(\n", " grid_nn_full, \n", " X_valid, \n", " y_valid, \n", " y_train, \n", " additional_title='[NN hold-out]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**NN test results:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid_nn_test_full = show_clf_results(\n", " grid_nn_full, \n", " X_test_nn3,\n", " y_test,\n", " y_train,\n", " additional_title='[NN Test]',\n", " show_metrics=True, \n", " show_classification_report=True, \n", " show_confusion_plot=True\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 10. Model evaluation with metrics description" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use accuracy metrics in our model. In medical diagnosis, test sensitivity is the ability of a test to correctly identify those with the disease - [wiki](https://en.wikipedia.org/wiki/Sensitivity_and_specificity#Medical_examples), so we will also use **precision** metric (we calculate other classification metrics, such as f1-score, ...).\n", "\n", "For LogisticRegression we get accuracy=0.934, precision=0.908 for test data. But NN is much better - accuracy=0.985, precision=0.943." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part 11. Conclusions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Neural Network model's accuracy is 0.968! It is perfect result. \n", "\n", "But we can improve our model, some ideas for improvement:\n", "* Use additional ECG features, such as P, Q, R, S, T waves\n", "* Our data is inbalanced, so try data oversampling (see data augmentation above)\n", "* Using original ECG will be better solution (in our current DB we have only 'one R-peaks')\n", "* Use FFT analysis\n", "* Tune hyperparameters (Neural Network tuning need powewfull hardware)\n", "* Try another models (SVM, ...)\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }