{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Machine Learning pipeline: Shu *et al*, 2020\n",
    "\n",
    "This notebook attempts to replicate the Machine Learning pipeline of the following paper:\n",
    "\n",
    "<div class=\"alert alert-block alert-success\">\n",
    "Shu, Zhen‐Yu, et al. <a href=\"https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.28522?casa_token=Ab53WvMlODcAAAAA%3AXcgDLmq8egqW7uwd2g3jY9jIljhLu3VhIbvMWgbcfoWOxjO_9H7Arf91t2FBZDZ8E94Je4Wmrn0ZmkeZ\">Predicting the progression of Parkinson's disease using conventional MRI and machine learning: An application of radiomic biomarkers in whole‐brain white matter.</a> Magnetic Resonance in Medicine 85.3 (2021): 1611-1624.</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('code/scripts')\n",
    "import os\n",
    "import glob\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import GridSearchCV, StratifiedKFold, cross_val_score\n",
    "from sklearn.base import BaseEstimator, TransformerMixin\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from mrmr.pandas import mrmr_classif\n",
    "from verify import check_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage\n",
    "\n",
    "Assign the name of the CSV file containing your PPMI data to the `data` variable. Additionaly, enter cohort information in the `cohort` list:\n",
    "\n",
    "Specify the directory with cohort details. Ensure the following for each cohort:\n",
    "\n",
    "- **cohort.csv**: Must have columns - PATNO, EVENT_ID, Description.\n",
    "\n",
    "- **demographics**.csv: Should include columns - PATNO, Age, Stage, SEX, output, UPDRS_TOT (optional)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = \"radiomics_results.csv\"\n",
    "\n",
    "cohorts = [\n",
    "    \"cohorts/max_cohorts/cohort_max_off_2_35_matched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_off_off_65_75_unmatched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_both_2_59_matched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_both_102_92_unmatched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_on_2_43_matched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_on_on_66_81_unmatched\",\n",
    "    \"cohorts/max_cohorts/cohort_max_2_72_matched\"\n",
    "]\n",
    "\n",
    "# Verify data file\n",
    "assert os.path.exists(data), f\"{data} does not exist.\"\n",
    "\n",
    "# Verify data\n",
    "assert all(check_data(cohort) for cohort in cohorts), \"Not all cohorts have valid data files.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_df(data, cohort, demographics, useUPDRS=False):\n",
    "    \"\"\"\n",
    "    Merge and preprocess data from multiple CSV files to create a consolidated dataframe.\n",
    "\n",
    "    Parameters:\n",
    "    - data (str): The path to the primary data CSV file.\n",
    "    - cohort (str): The path to the 'cohort.csv' file containing cohort information.\n",
    "    - demographics (str): The path to the 'demographics*.csv' file containing demographic information.\n",
    "    - useUPDRS (bool, optional): Whether to include the 'UPDRS_TOT' column in the output dataframe. Defaults to False.\n",
    "\n",
    "    Returns:\n",
    "    - pandas.DataFrame: A consolidated dataframe containing merged data from the specified files.\n",
    "    \"\"\"\n",
    "    # Collect data files\n",
    "    data_df = pd.read_csv(data)\n",
    "    cohort_df = pd.read_csv(cohort)\n",
    "    demographics_df = pd.read_csv(demographics)\n",
    "\n",
    "    # Create output dataframe\n",
    "    columnsToDrop = [\"EVENT_ID\", \"Description\", \"Age\", \"Stage\", \"SEX\", \"UPDRS_TOT\"] if useUPDRS else [\"EVENT_ID\", \"Description\", \"Age\", \"Stage\", \"SEX\"]\n",
    "    output_df = cohort_df \\\n",
    "        .merge(demographics_df, on=\"PATNO\") \\\n",
    "        .drop(columnsToDrop, axis=1)\n",
    "\n",
    "    # Merge together\n",
    "    df = data_df \\\n",
    "        .merge(output_df, on=\"PATNO\") \\\n",
    "        .drop([\"Unnamed: 0\",], axis=1)\n",
    "\n",
    "    return df\n",
    "\n",
    "\n",
    "def getModel(model):\n",
    "    \"\"\"\n",
    "    Helper function that returns a machine learning model and its parameters for tuning based on the top parameters from Shu et al. models.\n",
    "\n",
    "    Parameters:\n",
    "    - model (str): The type of machine learning model to be returned. Options: \"SVM\", \"DecisionTree\", \"kNN\", \"GNB\".\n",
    "\n",
    "    Returns:\n",
    "    - tuple: A tuple containing the machine learning model (classifier) and a dictionary of hyperparameter grid for tuning.\n",
    "    \"\"\"\n",
    "    prefix = \"train__\"\n",
    "    \n",
    "    if model == \"SVM\":\n",
    "        param_grid = {f'{prefix}C': [10, 100, 1000],\n",
    "                      f'{prefix}gamma': [0.1, 0.01, 0.001],\n",
    "                      f'{prefix}kernel': ['rbf'],\n",
    "                      f'{prefix}class_weight': [None, 'balanced']}\n",
    "        clf = SVC(probability=True)\n",
    "    elif model == \"DecisionTree\":\n",
    "        param_grid = {f'{prefix}max_depth': [4, 5, 6, 8],\n",
    "                      f'{prefix}max_leaf_nodes': list(range(5, 20, 1)),\n",
    "                      f'{prefix}min_samples_split': [5, 8, 16],\n",
    "                      f'{prefix}class_weight': [None, 'balanced']}\n",
    "        clf = DecisionTreeClassifier()\n",
    "    elif model == \"kNN\":\n",
    "        param_grid = {f'{prefix}n_neighbors': list(range(5, 20)),\n",
    "                      f'{prefix}p': [1, 2],\n",
    "                      f'{prefix}weights': [\"uniform\", \"distance\"]}\n",
    "        clf = KNeighborsClassifier()\n",
    "    elif model == \"GNB\":\n",
    "        param_grid = {f'{prefix}var_smoothing': np.logspace(0, -9, num=100)}\n",
    "        clf = GaussianNB()\n",
    "\n",
    "    param_grid[f'pca__n_components'] = [2, 3, 5, 7, 10]\n",
    "\n",
    "    return clf, param_grid\n",
    "\n",
    "\n",
    "class MRMRFeatureSelector(BaseEstimator, TransformerMixin):\n",
    "    \"\"\"\n",
    "    Custom scikit-learn transformer for feature selection using MRMR (Maximum Relevance, Minimum Redundancy) algorithm.\n",
    "\n",
    "    Parameters:\n",
    "    - K (int, optional): The number of top features to select. Defaults to 7.\n",
    "    - n_jobs (int, optional): The number of CPU cores to use for parallel computation. Defaults to -1 (using all available cores).\n",
    "    \"\"\"\n",
    "    def __init__(self, K=7, n_jobs=-1):\n",
    "        self.K = K\n",
    "        self.n_jobs = n_jobs\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \"\"\"\n",
    "        Fit the MRMR feature selector to the input data.\n",
    "\n",
    "        Parameters:\n",
    "        - X (pandas.DataFrame): The feature matrix.\n",
    "        - y (pandas.Series): The target variable.\n",
    "\n",
    "        Returns:\n",
    "        - self (MRMRFeatureSelector): The fitted transformer object.\n",
    "        \"\"\"\n",
    "        self.selected_features_ = mrmr_classif(X, y, K=self.K, n_jobs=self.n_jobs)\n",
    "        return self\n",
    "\n",
    "    def transform(self, X):\n",
    "        \"\"\"\n",
    "        Transform the input data by selecting the relevant features.\n",
    "\n",
    "        Parameters:\n",
    "        - X (pandas.DataFrame): The feature matrix.\n",
    "\n",
    "        Returns:\n",
    "        - pandas.DataFrame: The transformed dataframe containing only the selected features.\n",
    "        \"\"\"\n",
    "        return X[self.selected_features_]\n",
    "\n",
    "\n",
    "def nested_cross_validation(df, clf, param_grid, usePCA=True):\n",
    "    \"\"\"\n",
    "    Perform nested cross-validation for model evaluation.\n",
    "\n",
    "    Parameters:\n",
    "    - df (pandas.DataFrame): The input dataframe containing features and target variable.\n",
    "    - clf (BaseEstimator): The machine learning classifier to be evaluated.\n",
    "    - param_grid (dict): The hyperparameter grid for the grid search.\n",
    "    - usePCA (bool, optional): Flag indicating whether to include PCA in the pipeline. Defaults to True.\n",
    "\n",
    "    Returns:\n",
    "    - tuple: A tuple containing the fitted grid search object and an array of ROC AUC scores from outer cross-validation.\n",
    "    \"\"\"\n",
    "    # Get data\n",
    "    features = df.loc[:, ~df.columns.isin(['PATNO', 'output'])].columns\n",
    "    X = df[features]\n",
    "    y = df[\"output\"]\n",
    "\n",
    "    # Define outer and inner cross-validation\n",
    "    cv_outer = StratifiedKFold(n_splits=7, shuffle=True, random_state=42)\n",
    "    cv_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)\n",
    "\n",
    "    # Define pipeline with or without PCA\n",
    "    if usePCA:\n",
    "        pipeline = Pipeline([\n",
    "            ('scale', StandardScaler()),\n",
    "            ('pca', PCA()),\n",
    "            ('train', clf)\n",
    "        ])\n",
    "    else:\n",
    "        del param_grid[\"pca__n_components\"]\n",
    "        pipeline = Pipeline([\n",
    "            ('scale', StandardScaler().set_output(transform=\"pandas\")),\n",
    "            ('mrmr', MRMRFeatureSelector(K=7)),\n",
    "            ('train', clf)\n",
    "        ])\n",
    "\n",
    "    # Train model using grid search\n",
    "    grid = GridSearchCV(pipeline, param_grid, scoring='roc_auc', cv=cv_inner, n_jobs=1, verbose=0, refit=True)\n",
    "    scores = cross_val_score(grid, X, y, scoring='roc_auc', cv=cv_outer, n_jobs=-1, verbose=0)\n",
    "\n",
    "    return grid, scores\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run():\n",
    "    \"\"\"\n",
    "    Perform nested cross-validation for multiple cohorts and machine learning models, recording results in a dataframe.\n",
    "\n",
    "    Returns:\n",
    "    - pandas.DataFrame: A dataframe containing the mean and standard deviation of ROC AUC scores for each cohort,\n",
    "                        model type, and feature selection method.\n",
    "    \"\"\"\n",
    "    results_df = pd.DataFrame(columns=[\"SVM\", \"Decision Tree\", \"kNN\", \"GNB\"])\n",
    "\n",
    "    for cohort in cohorts:\n",
    "        # Skip missing files\n",
    "        if not check_data(cohort):\n",
    "            print(f\"Skipping {cohort}\")\n",
    "            continue\n",
    "\n",
    "        # Get data\n",
    "        df = get_df(data, f\"{cohort}/cohort.csv\", glob.glob(f\"{cohort}/demographics*\")[0])\n",
    "        cohortName = cohort.split(\"/\")[-1]\n",
    "\n",
    "        # Define features\n",
    "        features = df.loc[:, ~df.columns.isin(['PATNO'])].columns\n",
    "        df_features = df[features]\n",
    "                \n",
    "        print(f\"COHORT: {cohortName}\\n========================================\")\n",
    "        for MODEL_TYPE in [\"SVM\", \"DecisionTree\", \"kNN\", \"GNB\"]:\n",
    "            print(f\"\\tTraining the following model: {MODEL_TYPE}\")\n",
    "\n",
    "            # Define classifier\n",
    "            clf, param_grid = getModel(MODEL_TYPE)\n",
    "\n",
    "            # Train\n",
    "            _, scores_pca = nested_cross_validation(df_features, clf, param_grid, usePCA=True)\n",
    "            _, scores_mrmr = nested_cross_validation(df_features, clf, param_grid, usePCA=False)\n",
    "            \n",
    "            if MODEL_TYPE == \"SVM\":\n",
    "                results_df.loc[f\"PCA_{cohortName}\", \"SVM\"] = f\"{round(scores_pca.mean(), 3)} +/- {round(scores_pca.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"DecisionTree\":\n",
    "                results_df.loc[f\"PCA_{cohortName}\", \"Decision Tree\"] = f\"{round(scores_pca.mean(), 3)} +/- {round(scores_pca.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"kNN\":\n",
    "                results_df.loc[f\"PCA_{cohortName}\", \"kNN\"] = f\"{round(scores_pca.mean(), 3)} +/- {round(scores_pca.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"GNB\":\n",
    "                results_df.loc[f\"PCA_{cohortName}\", \"GNB\"] = f\"{round(scores_pca.mean(), 3)} +/- {round(scores_pca.std(), 3)}\"\n",
    "\n",
    "            if MODEL_TYPE == \"SVM\":\n",
    "                results_df.loc[f\"MRMR_{cohortName}\", \"SVM\"] = f\"{round(scores_mrmr.mean(), 3)} +/- {round(scores_mrmr.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"DecisionTree\":\n",
    "                results_df.loc[f\"MRMR_{cohortName}\", \"Decision Tree\"] = f\"{round(scores_mrmr.mean(), 3)} +/- {round(scores_mrmr.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"kNN\":\n",
    "                results_df.loc[f\"MRMR_{cohortName}\", \"kNN\"] = f\"{round(scores_mrmr.mean(), 3)} +/- {round(scores_mrmr.std(), 3)}\"\n",
    "            if MODEL_TYPE == \"GNB\":\n",
    "                results_df.loc[f\"MRMR_{cohortName}\", \"GNB\"] = f\"{round(scores_mrmr.mean(), 3)} +/- {round(scores_mrmr.std(), 3)}\"\n",
    "                \n",
    "    return results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "COHORT: cohort_max_off_2_35_matched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_off_off_65_75_unmatched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_both_2_59_matched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_both_102_92_unmatched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_on_2_43_matched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_on_on_66_81_unmatched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n",
      "COHORT: cohort_max_2_72_matched\n",
      "========================================\n",
      "\tTraining the following model: SVM\n",
      "\tTraining the following model: DecisionTree\n",
      "\tTraining the following model: kNN\n",
      "\tTraining the following model: GNB\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SVM</th>\n",
       "      <th>Decision Tree</th>\n",
       "      <th>kNN</th>\n",
       "      <th>GNB</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_off_2_35_matched</th>\n",
       "      <td>0.617 +/- 0.244</td>\n",
       "      <td>0.597 +/- 0.198</td>\n",
       "      <td>0.529 +/- 0.184</td>\n",
       "      <td>0.566 +/- 0.181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_off_off_65_75_unmatched</th>\n",
       "      <td>0.541 +/- 0.137</td>\n",
       "      <td>0.6 +/- 0.142</td>\n",
       "      <td>0.526 +/- 0.186</td>\n",
       "      <td>0.527 +/- 0.098</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_both_2_59_matched</th>\n",
       "      <td>0.586 +/- 0.189</td>\n",
       "      <td>0.515 +/- 0.089</td>\n",
       "      <td>0.502 +/- 0.152</td>\n",
       "      <td>0.415 +/- 0.109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_both_102_92_unmatched</th>\n",
       "      <td>0.458 +/- 0.112</td>\n",
       "      <td>0.553 +/- 0.091</td>\n",
       "      <td>0.504 +/- 0.06</td>\n",
       "      <td>0.553 +/- 0.134</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_on_2_43_matched</th>\n",
       "      <td>0.506 +/- 0.173</td>\n",
       "      <td>0.495 +/- 0.146</td>\n",
       "      <td>0.545 +/- 0.14</td>\n",
       "      <td>0.486 +/- 0.192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_on_on_66_81_unmatched</th>\n",
       "      <td>0.531 +/- 0.089</td>\n",
       "      <td>0.519 +/- 0.07</td>\n",
       "      <td>0.518 +/- 0.066</td>\n",
       "      <td>0.48 +/- 0.134</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PCA_cohort_max_2_72_matched</th>\n",
       "      <td>0.405 +/- 0.099</td>\n",
       "      <td>0.461 +/- 0.105</td>\n",
       "      <td>0.467 +/- 0.105</td>\n",
       "      <td>0.509 +/- 0.082</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    SVM    Decision Tree  \\\n",
       "PCA_cohort_max_off_2_35_matched         0.617 +/- 0.244  0.597 +/- 0.198   \n",
       "PCA_cohort_max_off_off_65_75_unmatched  0.541 +/- 0.137    0.6 +/- 0.142   \n",
       "PCA_cohort_max_both_2_59_matched        0.586 +/- 0.189  0.515 +/- 0.089   \n",
       "PCA_cohort_max_both_102_92_unmatched    0.458 +/- 0.112  0.553 +/- 0.091   \n",
       "PCA_cohort_max_on_2_43_matched          0.506 +/- 0.173  0.495 +/- 0.146   \n",
       "PCA_cohort_max_on_on_66_81_unmatched    0.531 +/- 0.089   0.519 +/- 0.07   \n",
       "PCA_cohort_max_2_72_matched             0.405 +/- 0.099  0.461 +/- 0.105   \n",
       "\n",
       "                                                    kNN              GNB  \n",
       "PCA_cohort_max_off_2_35_matched         0.529 +/- 0.184  0.566 +/- 0.181  \n",
       "PCA_cohort_max_off_off_65_75_unmatched  0.526 +/- 0.186  0.527 +/- 0.098  \n",
       "PCA_cohort_max_both_2_59_matched        0.502 +/- 0.152  0.415 +/- 0.109  \n",
       "PCA_cohort_max_both_102_92_unmatched     0.504 +/- 0.06  0.553 +/- 0.134  \n",
       "PCA_cohort_max_on_2_43_matched           0.545 +/- 0.14  0.486 +/- 0.192  \n",
       "PCA_cohort_max_on_on_66_81_unmatched    0.518 +/- 0.066   0.48 +/- 0.134  \n",
       "PCA_cohort_max_2_72_matched             0.467 +/- 0.105  0.509 +/- 0.082  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = run()\n",
    "results_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}