{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This code performs four main tasks:\n",
    "\n",
    "**[First part]**\n",
    "- annotate the MitImpact's variants with additional information provided by supplementary files available in the `data/APOGEE2_2022` folder. \n",
    "- select the features of interest that will later be used in the learning process and store them in a file named `mitimpact_features.csv` stored in a newly created `extracted_data` folder.\n",
    "\n",
    "**[Seconda part]**\n",
    "- perform a nested cross-validation. You will be given the chance to choose the machine learning classifier among a few alternatives, the number of splits and partitions for the test cross-validation, the number of splits for the grid-search cross-validation.\n",
    "- calculate a few performance measures.\n",
    "\n",
    "Note that several hours are required to execute the entire notebook and that some of the hyperparameters have been dropped from the grid-search procedure in order to speed up the computation. The notebook will create a `test/` folder to store models and predictions (if `test/` is not empty its content will be overwritten). After completing the cross-validation procedure, the notebook will calculate the mean auROC and its confidence intervals."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### File description\n",
    "`data/`\\\n",
    "   `MitImpact_db_3.0.6.1.txt`: MitImpact file; it contains all possible nonsynonymous mitochondrial SNVs\\\n",
    "   `mtmam.csv`: MtMam substitution matrix for mitochondrial amino acid changes\\\n",
    "   `ddg_dict.pk`: pickle file encoding a Python dictionary for the precomputed ΔΔG values (ΔG_mutant - ΔG_wildtype) resulting from the amino acid changes\\\n",
    "   `AA3Dposition_aligned.csv`: 3D coordinates of the amino acids in the mitochondrial proteins\\\n",
    "   `dataset_21-04-21.csv`: labels in the reference dataset (`P` and `N` indicate respectively that a variant is pathogenic or neutral); only SNVs are reported in the file\n",
    "\n",
    "`extracted_data/`\n",
    "   `mitimpact_features.csv`: features extracted from MitImpact and used to annotateall SNVs\n",
    "\n",
    "*n.b., `extracted_data/` is created in the section '**Create of a list containing the names of features to extract from the `mitimpact` dataframe**' below*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle as pk\n",
    "import os, warnings\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy import stats\n",
    "\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.base import clone\n",
    "\n",
    "from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold, GridSearchCV, ParameterGrid\n",
    "from sklearn.metrics import roc_auc_score, roc_curve\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import average_precision_score\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.impute import IterativeImputer\n",
    "from sklearn.feature_selection import SelectFromModel\n",
    "\n",
    "from imblearn.pipeline import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if  os.path.exists(\"./data/APOGEE2_2022/\"):\n",
    "    PATH= \".\"\n",
    "else: # not os.path.exists(\"./playgrounds/data/APOGEE2_2022/\"):\n",
    "    !git clone https://github.com/mazzalab/playgrounds.git\n",
    "    PATH=\"./playgrounds\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# First part"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Retrieve and prepare the MitImpact database flat file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mitimpact = pd.read_csv(PATH+\"/data/APOGEE2_2022/MitImpact_db_3.0.6.1.txt\", sep=\"\\t\", low_memory=False, index_col=\"MitImpact_id\")\n",
    "\n",
    "# Rename the column Start to Pos\n",
    "mitimpact.rename(columns={\"Start\": \"Pos\"}, inplace=True)\n",
    "\n",
    "# Replace \".\" with NaN values\n",
    "mitimpact.replace(\".\", np.nan, inplace=True)\n",
    "\n",
    "mitimpact.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Read the mitochondrial substitution matrix (mtMam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mtmam = pd.read_csv(PATH+\"/data/APOGEE2_2022/mtmam.csv\", sep=\"\\t\",  index_col=0)\n",
    "mtmam.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Annotate the mitimpact variants with mtMam scores\n",
    "mitimpact[\"mtmam\"] = mtmam.stack().loc[mitimpact.set_index([\"AA_ref\", \"AA_alt\"]).index].to_numpy()\n",
    "mitimpact[[\"mtmam\"]].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load, format, and integrate the precomputed energetic variation values, calculated as $\\Delta \\Delta G = \\Delta G_{mutant} - \\Delta G_{wt}$,  into MitImpact "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddg = pd.DataFrame.from_dict(pk.load(open(PATH+\"/data/APOGEE2_2022/ddg_dict.pk\", \"rb\")), orient=\"index\").stack().to_frame()\n",
    "\n",
    "# Parse the dataframe\n",
    "ddg = pd.DataFrame(ddg[0].values.tolist(), index=ddg.index).stack().to_frame()\n",
    "\n",
    "# Rename the column of the scores as \"ddg\"\n",
    "ddg.columns = [\"ddg\"]\n",
    "\n",
    "# Reset index\n",
    "ddg.reset_index(inplace=True)\n",
    "\n",
    "# Parse the information about the position of AA in the protein\n",
    "ddg[\"AA_pos\"] = ddg.level_1.str.split(r\"[A-Z]\", expand=True)[1].astype(float)\n",
    "\n",
    "# Parse the information about the reference AA\n",
    "ddg[\"level_1\"] = ddg.level_1.str.split(\"\", expand=True)[1]\n",
    "\n",
    "# Rename the columns\n",
    "ddg.rename(columns={\"level_0\":\"Gene_symbol\",\"level_2\":\"AA_alt\",\"level_1\":\"AA_ref\"}, inplace=True)\n",
    "\n",
    "# Annotate variants in MitImpact with the ddg values (from the `ddg` dataframe)\n",
    "mitimpact[\"ddg\"] = pd.merge(\n",
    "    mitimpact[[\"Gene_symbol\",\"AA_pos\",\"AA_ref\",\"AA_alt\"]],\n",
    "    ddg,\n",
    "    on=[\"Gene_symbol\",\"AA_pos\",\"AA_ref\",\"AA_alt\"],\n",
    "    how=\"left\",\n",
    ").ddg.to_numpy()\n",
    "\n",
    "mitimpact[[\"ddg\"]].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Read the 3D coordinates of the amino acids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "coords3D = pd.read_csv(PATH+\"/data/APOGEE2_2022/AA3Dposition_aligned.csv\", sep=\"\\t\", header=None)\n",
    "\n",
    "# Rename the columns appropriately\n",
    "coords3D.columns=[\"Gene_symbol\", \"AA_pos\", \"X\", \"Y\", \"Z\"]\n",
    "\n",
    "# Annotate the variants in MitImpact with X, Y, and Z values corresponding to the 3D coordinates of the amino acids.\n",
    "mitimpact[[\"X\", \"Y\", \"Z\"]] = pd.merge(\n",
    "    mitimpact[[\"Gene_symbol\",\"AA_pos\"]],\n",
    "    coords3D,\n",
    "    on=[\"Gene_symbol\",\"AA_pos\"],\n",
    "    how=\"left\",\n",
    ")[[\"X\", \"Y\", \"Z\"]].to_numpy()\n",
    "\n",
    "mitimpact[[\"X\", \"Y\", \"Z\"]].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create of a list containing the names of features to extract from the `mitimpact` dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = [\n",
    "    \"PhyloP_100V\", \"PhastCons_100V\", \"PolyPhen2_score\", \"SIFT_score\", \"FatHmmW_score\",\n",
    "    \"FatHmm_score\", \"PROVEAN_score\", \"MutationAssessor_score\", \"EFIN_SP_score\", \"EFIN_HD_score\",\n",
    "    \"CADD_phred_score\", \"VEST_pvalue\", \"PANTHER_score\", \"PhD-SNP_score\", \"SNAP_score\", \"MutationTaster_score\",\n",
    "    \"mtmam\", \"ddg\", \"Pos\",  \"X\", \"Y\", \"Z\"\n",
    "]\n",
    "\n",
    "# Cast the variables into float values.\n",
    "mitimpact[features] = mitimpact[features].astype(float)\n",
    "\n",
    "# Save the dataframe obtained.\n",
    "try:\n",
    "  os.mkdir(PATH+\"/data/APOGEE2_2022/extracted_data/\")\n",
    "except FileExistsError:\n",
    "  print(\"The folder exists. Data will be overwritten\")\n",
    "  \n",
    "mitimpact[features].to_csv(PATH+\"/data/APOGEE2_2022/extracted_data/mitimpact_features.csv\", sep=\"\\t\")\n",
    "mitimpact[features].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Second part"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Handy functions and files check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "def _empty_or_create(folder_path):\n",
    "  if(os.path.exists(folder_path)):\n",
    "    warnings.warn(folder_path+\" is not empty. It will be emptied\")\n",
    "    shutil.rmtree(folder_path)\n",
    "  \n",
    "  os.mkdir(folder_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Existence file check\n",
    "if not 'dataset_21-04-21.csv' in os.listdir(PATH+\"/data/APOGEE2_2022/\"):\n",
    "  raise FileNotFoundError(f\"Cannot find 'dataset_21-04-21.csv' in folder {PATH}/data/APOGEE2_2022/\")\n",
    "elif not 'mitimpact_features.csv' in os.listdir(PATH+\"/data/APOGEE2_2022/extracted_data\"):\n",
    "  raise FileNotFoundError(f\"Cannot find 'mitimpact_features.csv' in folder {PATH}/data/APOGEE2_2022/. Please ensure you have executed the first part of the code above\")\n",
    "else:\n",
    "    print(\"Required files present\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Read the classes of training-set variants:\n",
    "- `0` means Benign variants;\n",
    "- `1` means Pathogenic variants.\n",
    "\n",
    "The index values are the `MitImpact_ID` of the variants."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = pd.read_csv(PATH+\"/data/APOGEE2_2022/dataset_21-04-21.csv\", sep=\"\\t\", index_col=\"MitImpact_ID\", comment=\"#\").Class\n",
    "Y.replace([\"N\", \"P\"], [0, 1], inplace=True)\n",
    "\n",
    "# Target Vector\n",
    "Y.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Annotate the training-set variants (`Y` vector) with the features extracted from the `mitimpact_features.csv` file (previously generated in the **First part**)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = pd.read_csv(PATH+\"/data/APOGEE2_2022/extracted_data/mitimpact_features.csv\", sep=\"\\t\", index_col=\"MitImpact_id\").loc[Y.index]\n",
    "\n",
    "# Feature Matrix\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ML Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title ##Settings\n",
    "\n",
    "#@markdown <h2>**SETTINGS**</h2>\n",
    "\n",
    "#@markdown\n",
    "\n",
    "\n",
    "#@markdown choose the machine learning classifier\n",
    "method = 'KNN_RusSmote' #@param [\"BalancedRF\", \"GNB_BalancedBagging\", \"KNN_BalancedBagging\", \"KNN_RusSmote\", \"rbfSVC\", \"RusSmoteForest\"]\n",
    "\n",
    "#@markdown\n",
    "\n",
    "#@markdown select the number of splits for the test cross-validation (outer cv splits)\n",
    "n_splits = 10 #@param {type:\"slider\", min:5, max:20, step:1}\n",
    "\n",
    "#@markdown\n",
    "\n",
    "#@markdown select the number of different partitions for the test cross-validation (outer cv repetitions)\n",
    "n_repeats = 2 #@param {type:\"slider\", min:1, max:10, step:1}\n",
    "\n",
    "#@markdown <h7>**Attention:**</h7>\n",
    "#@markdown <h7><i>For computational limit, please ensure that `n_splits * n_repeats` is not greater than 60</i></h7>\n",
    "\n",
    "#@markdown\n",
    "\n",
    "#@markdown select the number of splits for the grid-search cross-validation (inner cv splits)\n",
    "gridsearc_cv_splits = 5 #@param {type:\"slider\", min:5, max:19, step:1}\n",
    "\n",
    "\n",
    "#@markdown\n",
    "\n",
    "#@markdown type an integer number as random state (used for the partitions in both inner and outer cv)\n",
    "random_state = 118 #@param {type:\"integer\"}\n",
    "\n",
    "kf = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=random_state)\n",
    "\n",
    "\n",
    "#@markdown\n",
    "\n",
    "#@markdown check if you want to have real-time feedbacks during the training, validation and test (recommended for debugging)\n",
    "verbose = True #@param {type:\"boolean\"}\n",
    "\n",
    "# Define the three sequential preprocessing steps: standardize the features, impute the missing values, remove the low importance features.\n",
    "preprocessing_pipeline = Pipeline([\n",
    "    (\"scaler\", StandardScaler()),\n",
    "    (\"imputer\", IterativeImputer(RandomForestRegressor(n_jobs=-1), tol=0.05, verbose=verbose)),\n",
    "    (\"feature_selection\", SelectFromModel(DecisionTreeClassifier(criterion=\"entropy\", class_weight=\"balanced\"), threshold=0.01)),\n",
    "])\n",
    "\n",
    "\n",
    "# Initialize the classifier according to the settings\n",
    "# In addition hyper-parameters grid is initialized according to the `method` chosen in the settings\n",
    "def init_model_and_params():\n",
    "  global clf, grid\n",
    "\n",
    "  if method==\"rbfSVC\":\n",
    "    from sklearn.svm import SVC\n",
    "    clf = SVC(class_weight=\"balanced\", kernel='rbf', probability=True)\n",
    "    grid = ParameterGrid({\n",
    "        \"C\": [[100], [10], [1], [0.1]],\n",
    "        \"gamma\": [[0.1], [0.01], [0.001], [0.0001]],\n",
    "    })\n",
    "    return\n",
    "  \n",
    "  if method==\"KNN_RusSmote\":\n",
    "    from sklearn.ensemble import BaggingClassifier\n",
    "    from sklearn.neighbors import KNeighborsClassifier\n",
    "    from imblearn.under_sampling import RandomUnderSampler\n",
    "    from imblearn.over_sampling import SMOTE\n",
    "    clf = BaggingClassifier(\n",
    "        base_estimator=Pipeline([\n",
    "            (\"rus\", RandomUnderSampler()),\n",
    "            (\"smote\", SMOTE(n_jobs=-1)),\n",
    "            (\"clf\", KNeighborsClassifier(n_jobs=-1))\n",
    "        ]),\n",
    "        n_estimators=200,\n",
    "        n_jobs=-1,\n",
    "        verbose=0\n",
    "    )\n",
    "    grid = ParameterGrid({\n",
    "        \"base_estimator__clf__n_neighbors\": [[3], [5], [7], [9]],\n",
    "        \"base_estimator__rus__sampling_strategy\": [[0.25], [0.5]],\n",
    "        \"max_features\": [[0.25], [0.5]],\n",
    "        \"base_estimator__clf__weights\": [[\"uniform\"], [\"distance\"]],\n",
    "    })\n",
    "    return\n",
    "  \n",
    "  if method==\"GNB_BalancedBagging\":\n",
    "    \n",
    "    from imblearn.ensemble import BalancedBaggingClassifier\n",
    "    from sklearn.naive_bayes import GaussianNB\n",
    "    clf = BalancedBaggingClassifier(GaussianNB(), n_estimators=200, n_jobs=-1)\n",
    "    grid = ParameterGrid({\n",
    "        \"base_estimator__var_smoothing\": [[1e-07], [1e-08], [1e-09], [1e-10], [1e-11]],\n",
    "        \"max_features\": [[0.25], [0.5], [0.75]],\n",
    "        \"max_samples\": [[0.25], [0.5], [0.75]],\n",
    "    })\n",
    "    return\n",
    "\n",
    "  if method==\"BalancedRF\":\n",
    "    from imblearn.ensemble import BalancedRandomForestClassifier\n",
    "    clf = BalancedRandomForestClassifier(n_estimators=200, criterion=\"entropy\", n_jobs=-1)\n",
    "    grid = ParameterGrid({\n",
    "        \"max_depth\": [[7], [9], [11], [13]],\n",
    "        \"max_features\": [[0.25], [0.5], [0.75]],\n",
    "        \"min_samples_leaf\": [[1], [2], [3]],\n",
    "        \"min_samples_split\": [[2], [4]],\n",
    "    })\n",
    "    return\n",
    "  \n",
    "  if method==\"RusSmoteForest\":\n",
    "    from sklearn.ensemble import BaggingClassifier\n",
    "    from imblearn.under_sampling import RandomUnderSampler\n",
    "    from imblearn.over_sampling import SMOTE\n",
    "    clf = BaggingClassifier(\n",
    "        base_estimator=Pipeline([\n",
    "            (\"rus\", RandomUnderSampler()),\n",
    "            (\"smote\", SMOTE(n_jobs=-1)),\n",
    "            (\"clf\", DecisionTreeClassifier(criterion=\"entropy\"))\n",
    "        ]),\n",
    "        n_estimators=200,\n",
    "        n_jobs=-1,\n",
    "        verbose=0\n",
    "    )\n",
    "    grid = ParameterGrid({\n",
    "        \"base_estimator__clf__max_depth\": [[11], [15], [19]],\n",
    "        \"base_estimator__clf__min_samples_leaf\": [[1], [2], [3]],\n",
    "        \"base_estimator__clf__min_samples_split\": [[2], [4]],\n",
    "        \"base_estimator__rus__sampling_strategy\": [[0.25], [0.5], [0.75]],\n",
    "        \"base_estimator__clf__max_features\": [[0.25], [0.5], [0.75]],\n",
    "        \n",
    "    })\n",
    "    return\n",
    "\n",
    "\n",
    "  if method==\"KNN_BalancedBagging\":\n",
    "    from imblearn.ensemble import BalancedBaggingClassifier\n",
    "    from sklearn.neighbors import KNeighborsClassifier\n",
    "    clf = BalancedBaggingClassifier(KNeighborsClassifier(), n_estimators=200, n_jobs=-1)\n",
    "    grid = ParameterGrid({\n",
    "        \"base_estimator__n_neighbors\": [[3], [5], [7], [9]],\n",
    "        \"max_features\": [[0.25], [0.5], [0.75]],\n",
    "        \"base_estimator__weights\": [[\"distance\"]],\n",
    "        \"max_samples\": [[0.25], [0.5], [0.75]],\n",
    "    })\n",
    "    return\n",
    "    \n",
    "  raise ValueError(\"%s is not a valid method\"%method)\n",
    "\n",
    "\n",
    "init_model_and_params()\n",
    "\n",
    "# Define the grid-search object according to the classifier and hyper-paramethers grid\n",
    "gridsearch_clf = GridSearchCV(\n",
    "    estimator=clf,\n",
    "    param_grid=grid,\n",
    "    scoring=\"roc_auc\",\n",
    "    cv=StratifiedKFold(gridsearc_cv_splits, shuffle=True, random_state=random_state),\n",
    "    n_jobs=-1,\n",
    "    error_score=\"raise\",\n",
    "    verbose=verbose,\n",
    ")\n",
    "\n",
    "if n_repeats*n_splits > 60:\n",
    "  raise ValueError(\"`n_repeats*n_splits` expected to be <=60, but %s * %s == %s.\\nI now rise this `ValueError` in order to avoid a future `IOError`.\"%(n_repeats, n_splits, n_repeats*n_splits))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@markdown Folder where data from test will be stored.\n",
    "TEST_PATH = PATH + \"/data/APOGEE2_2022/test\" #@param {type:\"string\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each cross-validation iteration:\n",
    "1. fit the three sequential preprocessing steps on the training-set;\n",
    "2. save the fitted preprocessing models.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_empty_or_create(TEST_PATH)\n",
    "_empty_or_create(TEST_PATH+\"/preprocessing/\")\n",
    "\n",
    "for i, (train, test) in enumerate(kf.split(X, Y)):\n",
    "    preprocessing_pipeline.fit(X.iloc[train], Y[train])\n",
    "    pk.dump(preprocessing_pipeline, open(TEST_PATH+\"/preprocessing/fold_%i.pk\"%i, \"wb\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code:\n",
    "1. loads the fitted preprocessing steps\n",
    "2. preprocesses both training and test sets\n",
    "3. tunes and trains the classifier on the pre-processed training set\n",
    "4. saves the grid-search (as `GridSearchCV` object) results in the `test/classifier` folder\n",
    "5. predicts the scores for the pre-processed test set elements\n",
    "6. saves the predictions in the `test/predictions` folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_on_fold(i, train, test):\n",
    "  fitted_preprocessing = pk.load(open(TEST_PATH+\"/preprocessing/fold_%i.pk\"%i, \"rb\"))\n",
    "  X_train = fitted_preprocessing.transform(X.iloc[train])\n",
    "  X_test = fitted_preprocessing.transform(X.iloc[test])\n",
    "  tuned_classifier = clone(gridsearch_clf).fit(X_train, Y[train])\n",
    "  pk.dump(tuned_classifier, open(TEST_PATH+\"/classifier/fold_%i.pk\"%i, \"wb\"))\n",
    "  pk.dump(tuned_classifier.predict_proba(X_test)[:,1], open(TEST_PATH+\"/predictions/fold_%i.pk\"%(i), \"wb\"))\n",
    "\n",
    "_empty_or_create(TEST_PATH+\"/classifier/\")\n",
    "_empty_or_create(TEST_PATH+\"/predictions/\")\n",
    "\n",
    "# For each cross-validation iteration train the classifier as explained in the `test_on_fold` function\n",
    "for i, (train, test) in enumerate(kf.split(X, Y)):\n",
    "  print(f\"############################## TEST {i} ################################\")\n",
    "  test_on_fold(i, train, test)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compute Area Under the Receiver Operating Characteristic Curve (auROC) from prediction scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_pred = [pk.load(open(TEST_PATH+\"/predictions/fold_%i.pk\"%i, \"rb\")) for i in range(kf.get_n_splits())]\n",
    "Y_true = [Y[test] for _, test in kf.split(X, Y)]\n",
    "len(Y_pred), len(Y_true)\n",
    "\n",
    "AUCs = np.array([roc_auc_score(t, y) for t, y in zip(Y_true, Y_pred)])\n",
    "plt.boxplot(AUCs)\n",
    "plt.ylim(.5,1)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"test AUC\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Once the significance level (alpha) has been set, compute and show the mean AUC and its confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@markdown Significance level:\n",
    "\n",
    "alpha=0.95 #@param {type:\"slider\", min:0.5, max:0.995, step:0.005}\n",
    "\n",
    "AUC_mean_distribution = stats.norm(loc=AUCs.mean(), scale=AUCs.std(ddof=1)/np.sqrt(len(Y_pred)))\n",
    "CI = AUC_mean_distribution.interval(alpha)\n",
    "_bins = np.linspace(.5,1, 1001)\n",
    "plt.plot(_bins, AUC_mean_distribution.pdf(_bins))\n",
    "plt.vlines(\n",
    "  AUCs.mean(), 0, AUC_mean_distribution.pdf(AUCs.mean()),\n",
    "  linestyle=\"--\", label=\"mean AUC = %.4f\"%AUCs.mean()\n",
    ")\n",
    "plt.fill_between(\n",
    "  _bins[(_bins>CI[0])&(_bins<CI[1])],\n",
    "  AUC_mean_distribution.pdf(_bins[(_bins>CI[0])&(_bins<CI[1])]),\n",
    "  alpha=.25, label=\"CI = (%.4f, %.4f)\"%(CI[0], CI[1])\n",
    ")\n",
    "plt.xlim(AUCs.min(), 1)\n",
    "plt.xlabel(\"mean AUC\")\n",
    "plt.ylabel(\"PDF\")\n",
    "plt.ylim(0,None)\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance comparison with other meta-predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "external_predictors = [\"Meta-SNP_score\",\"CAROL_score\",\"Condel_score\",\"COVEC_WMV_score\",\"MtoolBox_DS\",\"APOGEE_score\"]\n",
    "thresholds = np.array([.5, .98, .5, 0, .43, .5])\n",
    "len(external_predictors), len(thresholds)\n",
    "\n",
    "extarnal_tools_predictions = mitimpact.loc[Y.index, external_predictors].replace(\".\", np.nan).astype(float)\n",
    "extarnal_tools_predictions[\"Condel_score\"] = -extarnal_tools_predictions[\"Condel_score\"]\n",
    "\n",
    "extarnal_tools_predictions.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compare the mean ROCs of APOGEE 2 vs. other meta-predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with plt.style.context('bmh'):\n",
    "    plt.figure(figsize=(10,10))\n",
    "\n",
    "    fpr_bins = np.linspace(0,1,51)\n",
    "\n",
    "    for predictor in external_predictors:\n",
    "      fpr, tpr, _ = roc_curve(Y[~extarnal_tools_predictions[predictor].isna()], extarnal_tools_predictions[predictor].dropna())\n",
    "      plt.plot(\n",
    "        [0]+list(fpr), [0]+list(tpr),\n",
    "        linewidth=2, alpha=.75, linestyle=\"--\",\n",
    "        label=\"%s (mean AUC=%.4f)\"%(predictor, roc_auc_score(Y[~extarnal_tools_predictions[predictor].isna()], extarnal_tools_predictions[predictor].dropna()))\n",
    "      )\n",
    "    \n",
    "\n",
    "    pr_bins = np.linspace(0,1,51)\n",
    "\n",
    "    tprs = []\n",
    "    for y_true, y_pred in zip(Y_true, Y_pred): \n",
    "        fpr, tpr, _ = roc_curve(y_true, y_pred)\n",
    "        tprs.append(np.interp(fpr_bins, fpr, tpr))\n",
    "    tprs = np.array(tprs)\n",
    "    plt.plot(\n",
    "        [0]+list(fpr_bins), [0]+list(tprs.mean(0)),\n",
    "        linewidth=3, alpha=.75, color=\"k\",\n",
    "        label=\"APOGEE2_score (mean AUC=%.4f)\"%(AUCs.mean()))\n",
    "    \n",
    "    plt.plot([0,1], [0,1], \"k:\", alpha=.5, label=\"chance\")\n",
    "\n",
    "    plt.xticks(np.linspace(0,1,11))\n",
    "    plt.yticks(np.linspace(0,1,11))\n",
    "    plt.legend(loc=\"lower right\")\n",
    "    plt.ylim(-.025,1.025)\n",
    "    plt.xlim(-.025,1.025)\n",
    "    plt.xlabel(\"false positive rate\")\n",
    "    plt.ylabel(\"true positive rate\")\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compare the mean PRCs of APOGEE 2 vs. other meta-predictors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with plt.style.context('bmh'):\n",
    "    plt.figure(figsize=(10,10))\n",
    "\n",
    "    recall_bins = np.linspace(0,1,51)\n",
    "\n",
    "    for predictor in external_predictors:\n",
    "\n",
    "      precision,recall, thresholds = precision_recall_curve(Y[~extarnal_tools_predictions[predictor].isna()], extarnal_tools_predictions[predictor].dropna())\n",
    "      avarage_PS=average_precision_score(Y[~extarnal_tools_predictions[predictor].isna()], extarnal_tools_predictions[predictor].dropna())\n",
    "      plt.plot(\n",
    "        [0]+list(recall), [1]+list(precision),\n",
    "        linewidth=2, alpha=.75, linestyle=\"--\",\n",
    "        label=\"%s (average precision=%.4f)\"%(predictor,avarage_PS) \n",
    "      )\n",
    "\n",
    "##########\n",
    "    precision_apo2, recall_apo2, _ = precision_recall_curve(np.concatenate(Y_true), np.concatenate(Y_pred))\n",
    "    avarage_PS_apo2 = average_precision_score(np.concatenate(Y_true), np.concatenate(Y_pred))\n",
    "\n",
    "    plt.plot(\n",
    "        list(recall_apo2), list(precision_apo2),\n",
    "        linewidth=3, alpha=.75, color=\"k\",\n",
    "        label=\"APOGEE2_score (average precision=%.4f)\"%(avarage_PS_apo2))\n",
    "    \n",
    "    plt.xticks(np.linspace(0,1,11))\n",
    "    plt.yticks(np.linspace(0,1,11))\n",
    "    plt.legend(loc=\"lower right\")\n",
    "    plt.ylim(-.025,1.025)\n",
    "    plt.xlim(-.025,1.025)\n",
    "    plt.xlabel(\"Recall\")\n",
    "    plt.ylabel(\"Precision\")\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Print system and required packages information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext watermark\n",
    "%watermark -v -m -p numpy,pandas,matplotlib,sklearn,scipy\n",
    "\n",
    "# date\n",
    "print(\" \")\n",
    "%watermark -u -n -t -z"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('playgrounds')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1bfe35eb01be7a105c8cf53de6ee0282833865f0d560b2cfa320b64389c32503"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}