{ "cells": [ { "cell_type": "markdown", "id": "f7efc033", "metadata": {}, "source": [ "## License \n", "\n", "Copyright 2021-2023 Patrick Hall (jphall@gwu.edu)\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License.\n", "\n", "*DISCLAIMER*: This notebook is not legal or compliance advice." ] }, { "cell_type": "markdown", "id": "aab60b41", "metadata": {}, "source": [ "# Model Evaluation Notebook" ] }, { "cell_type": "markdown", "id": "281af306", "metadata": {}, "source": [ "#### Imports and inits" ] }, { "cell_type": "code", "execution_count": 1, "id": "fd180587", "metadata": {}, "outputs": [], "source": [ "import os # for directory and file manipulation\n", "import numpy as np # for basic array manipulation\n", "import pandas as pd # for dataframe manipulation\n", "import datetime # for timestamp\n", "\n", "# for model eval\n", "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n", "\n", "# global constants \n", "ROUND = 3 # generally, insane precision is not needed \n", "SEED = 12345 # seed for better reproducibility\n", "\n", "# set global random seed for better reproducibility\n", "np.random.seed(SEED)" ] }, { "cell_type": "markdown", "id": "eb2a39d4", "metadata": {}, "source": [ "#### Set basic metadata" ] }, { "cell_type": "code", "execution_count": 2, "id": "98f640ed", "metadata": {}, "outputs": [], "source": [ "y_name = 'high_priced'\n", "scores_dir = 'data/scores'" ] }, { "cell_type": "markdown", "id": "cc8d83d0", "metadata": {}, "source": [ "#### Read in score files " ] }, { "cell_type": "code", "execution_count": 3, "id": "355c2b81", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
high_pricedfoldgroup1_rem_ebmgroup2_rem_ebmgroup2_rem_ebm2group3_rem_piml_EBMgroup3_rem_piml_EBM2group5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebm
00.020.1187870.0805570.0805570.9203890.1367490.0783260.2238460.0817920.219429
10.010.0845060.0260010.0260010.9693010.0537510.0358250.0539260.1107020.053929
21.040.2103890.1949610.1949610.8142720.1823110.1953320.1435220.2040480.133863
30.010.0085290.0285560.0285560.9745590.0040650.0227650.0093710.0240380.014419
41.020.1899330.2082630.2082630.8029080.2111200.1930350.1511000.1702430.156047
....................................
198260.030.1636970.2283420.2283420.7922510.2093220.2351920.2167200.1814030.184214
198270.010.1149990.2539980.2539980.7629460.2067440.2358320.1614010.1594680.141663
198281.030.1413070.2133640.2133640.7474010.2466100.2087230.2428140.1381410.233266
198290.010.0077660.0021760.0021760.9964550.0002680.0187020.0056570.0345700.009914
198300.000.1639460.1854840.1854840.8114290.1778570.2150850.1678120.1777850.155447
\n", "

19831 rows × 11 columns

\n", "
" ], "text/plain": [ " high_priced fold group1_rem_ebm group2_rem_ebm group2_rem_ebm2 \\\n", "0 0.0 2 0.118787 0.080557 0.080557 \n", "1 0.0 1 0.084506 0.026001 0.026001 \n", "2 1.0 4 0.210389 0.194961 0.194961 \n", "3 0.0 1 0.008529 0.028556 0.028556 \n", "4 1.0 2 0.189933 0.208263 0.208263 \n", "... ... ... ... ... ... \n", "19826 0.0 3 0.163697 0.228342 0.228342 \n", "19827 0.0 1 0.114999 0.253998 0.253998 \n", "19828 1.0 3 0.141307 0.213364 0.213364 \n", "19829 0.0 1 0.007766 0.002176 0.002176 \n", "19830 0.0 0 0.163946 0.185484 0.185484 \n", "\n", " group3_rem_piml_EBM group3_rem_piml_EBM2 group5_rem_xgb2 \\\n", "0 0.920389 0.136749 0.078326 \n", "1 0.969301 0.053751 0.035825 \n", "2 0.814272 0.182311 0.195332 \n", "3 0.974559 0.004065 0.022765 \n", "4 0.802908 0.211120 0.193035 \n", "... ... ... ... \n", "19826 0.792251 0.209322 0.235192 \n", "19827 0.762946 0.206744 0.235832 \n", "19828 0.747401 0.246610 0.208723 \n", "19829 0.996455 0.000268 0.018702 \n", "19830 0.811429 0.177857 0.215085 \n", "\n", " group8_rem_ebm group9_rem_xgb ph_rem_ebm \n", "0 0.223846 0.081792 0.219429 \n", "1 0.053926 0.110702 0.053929 \n", "2 0.143522 0.204048 0.133863 \n", "3 0.009371 0.024038 0.014419 \n", "4 0.151100 0.170243 0.156047 \n", "... ... ... ... \n", "19826 0.216720 0.181403 0.184214 \n", "19827 0.161401 0.159468 0.141663 \n", "19828 0.242814 0.138141 0.233266 \n", "19829 0.005657 0.034570 0.009914 \n", "19830 0.167812 0.177785 0.155447 \n", "\n", "[19831 rows x 11 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# init score frame with known test y values\n", "scores_frame = pd.read_csv(scores_dir + os.sep +'key.csv', index_col='Unnamed: 0')\n", "\n", "# create random folds in reproducible way\n", "np.random.seed(SEED)\n", "scores_frame['fold'] = np.random.choice(5, scores_frame.shape[0])\n", "\n", "# read in each score file in the directory as a new column \n", "for file in sorted(os.listdir(scores_dir)):\n", " if file != 'key.csv' and file.endswith('.csv'):\n", " scores_frame[file[:-4]] = pd.read_csv(scores_dir + os.sep + file)['phat']\n", "\n", "# sanity check \n", "scores_frame" ] }, { "cell_type": "markdown", "id": "3e3cccda", "metadata": {}, "source": [ "#### Utility function for max. accuracy" ] }, { "cell_type": "code", "execution_count": 4, "id": "2eb43506", "metadata": {}, "outputs": [], "source": [ "def max_acc(y, phat, res=0.01): \n", "\n", " \"\"\" Utility function for finding max. accuracy at some cutoff. \n", " \n", " :param y: Known y values.\n", " :param phat: Model scores.\n", " :param res: Resolution over which to search for max. accuracy, default 0.01.\n", " :return: Max. accuracy for model scores.\n", " \n", " \"\"\"\n", " \n", " # init frame to store acc at different cutoffs\n", " acc_frame = pd.DataFrame(columns=['cut', 'acc'])\n", " \n", " # copy known y and score values into a temporary frame\n", " temp_df = pd.concat([y, phat], axis=1)\n", " \n", " # find accuracy at different cutoffs and store in acc_frame\n", " for cut in np.arange(0, 1 + res, res):\n", " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", " acc = accuracy_score(temp_df.iloc[:, 0], temp_df['decision'])\n", " acc_frame = acc_frame.append({'cut': cut,\n", " 'acc': acc},\n", " ignore_index=True)\n", "\n", " # find max accurcay across all cutoffs\n", " max_acc = acc_frame['acc'].max()\n", " \n", " # house keeping\n", " del acc_frame, temp_df\n", " \n", " return max_acc" ] }, { "cell_type": "markdown", "id": "b02c9651", "metadata": {}, "source": [ "#### Utility function for max. F1" ] }, { "cell_type": "code", "execution_count": 5, "id": "fae3756b", "metadata": {}, "outputs": [], "source": [ "def max_f1(y, phat, res=0.01): \n", " \n", " \"\"\" Utility function for finding max. F1 at some cutoff. \n", " \n", " :param y: Known y values.\n", " :param phat: Model scores.\n", " :param res: Resolution over which to search for max. F1, default 0.01.\n", " :return: Max. F1 for model scores.\n", " \n", " \"\"\"\n", " \n", " # init frame to store f1 at different cutoffs\n", " f1_frame = pd.DataFrame(columns=['cut', 'f1'])\n", " \n", " # copy known y and score values into a temporary frame\n", " temp_df = pd.concat([y, phat], axis=1)\n", " \n", " # find f1 at different cutoffs and store in acc_frame\n", " for cut in np.arange(0, 1 + res, res):\n", " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", " f1 = f1_score(temp_df.iloc[:, 0], temp_df['decision'])\n", " f1_frame = f1_frame.append({'cut': cut,\n", " 'f1': f1},\n", " ignore_index=True)\n", " \n", " # find max f1 across all cutoffs\n", " max_f1 = f1_frame['f1'].max()\n", " \n", " # house keeping\n", " del f1_frame, temp_df\n", " \n", " return max_f1" ] }, { "cell_type": "markdown", "id": "b447b732", "metadata": {}, "source": [ "#### Rank all submitted scores " ] }, { "cell_type": "code", "execution_count": 6, "id": "40fbe608", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
foldmetricgroup1_rem_ebmgroup2_rem_ebmgroup2_rem_ebm2group3_rem_piml_EBMgroup3_rem_piml_EBM2group5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebmgroup1_rem_ebm_rankgroup2_rem_ebm_rankgroup2_rem_ebm2_rankgroup3_rem_piml_EBM_rankgroup3_rem_piml_EBM2_rankgroup5_rem_xgb2_rankgroup8_rem_ebm_rankgroup9_rem_xgb_rankph_rem_ebm_rank
00.0acc0.9000.9010.9010.9000.9010.9010.9010.9000.9018.03.53.58.03.53.53.58.03.5
10.0auc0.7810.8400.8400.1630.8210.8360.7930.7970.7918.01.51.59.04.03.06.05.07.0
20.0f10.3470.4050.4050.1820.3810.3920.3420.3570.3476.51.51.59.04.03.08.05.06.5
30.0logloss0.2800.2510.2513.2570.2620.2540.2740.2770.2758.01.51.59.04.03.05.07.06.0
40.0mse0.0820.0770.0770.7730.0780.0770.0810.0810.0818.02.02.09.04.02.06.06.06.0
51.0acc0.9060.9060.9060.9060.9060.9060.9060.9060.9065.05.05.05.05.05.05.05.05.0
61.0auc0.7670.8280.8280.1720.8100.8220.7740.7790.7728.01.51.59.04.03.06.05.07.0
71.0f10.3120.3680.3680.1720.3480.3600.3190.3290.3218.01.51.59.04.03.07.05.06.0
81.0logloss0.2720.2460.2463.2530.2580.2500.2700.2710.2727.51.51.59.04.03.05.06.07.5
91.0mse0.0790.0740.0740.7780.0770.0750.0790.0780.0797.01.51.59.04.03.07.05.07.0
102.0acc0.9080.9080.9080.9080.9080.9100.9080.9080.9096.06.06.06.06.01.06.06.02.0
112.0auc0.7590.8250.8250.1750.8150.8260.7810.7720.7808.02.52.59.04.01.05.07.06.0
122.0f10.3040.3720.3720.1690.3540.3710.3150.3200.3238.01.51.59.04.03.07.06.05.0
132.0logloss0.2710.2460.2463.2840.2510.2450.2640.2710.2647.52.52.59.04.01.05.57.55.5
142.0mse0.0780.0730.0730.7810.0740.0730.0760.0770.0768.02.02.09.04.02.05.57.05.5
153.0acc0.9030.9030.9030.9030.9030.9030.9030.9030.9035.05.05.05.05.05.05.05.05.0
163.0auc0.7720.8260.8260.1740.8090.8230.7750.7860.7727.51.51.59.04.03.06.05.07.5
173.0f10.3170.3710.3710.1770.3610.3650.3280.3430.3238.01.51.59.04.03.06.05.07.0
183.0logloss0.2760.2520.2523.2540.2620.2530.2750.2750.2767.51.51.59.04.03.05.55.57.5
193.0mse0.0810.0770.0770.7750.0790.0770.0800.0800.0808.02.02.09.04.02.06.06.06.0
204.0acc0.8950.8970.8970.8950.8950.8980.8950.8960.8957.02.52.57.07.01.07.04.07.0
214.0auc0.7540.8310.8310.1700.8180.8280.7850.7790.7828.01.51.59.04.03.05.07.06.0
224.0f10.3230.4010.4010.1900.4040.3970.3640.3540.3628.02.52.59.01.04.05.07.06.0
234.0logloss0.2960.2630.2633.2000.2730.2660.2860.2910.2878.01.51.59.04.03.05.07.06.0
244.0mse0.0870.0800.0800.7710.0820.0800.0840.0860.0848.02.02.09.04.02.05.57.05.5
\n", "
" ], "text/plain": [ " fold metric group1_rem_ebm group2_rem_ebm group2_rem_ebm2 \\\n", "0 0.0 acc 0.900 0.901 0.901 \n", "1 0.0 auc 0.781 0.840 0.840 \n", "2 0.0 f1 0.347 0.405 0.405 \n", "3 0.0 logloss 0.280 0.251 0.251 \n", "4 0.0 mse 0.082 0.077 0.077 \n", "5 1.0 acc 0.906 0.906 0.906 \n", "6 1.0 auc 0.767 0.828 0.828 \n", "7 1.0 f1 0.312 0.368 0.368 \n", "8 1.0 logloss 0.272 0.246 0.246 \n", "9 1.0 mse 0.079 0.074 0.074 \n", "10 2.0 acc 0.908 0.908 0.908 \n", "11 2.0 auc 0.759 0.825 0.825 \n", "12 2.0 f1 0.304 0.372 0.372 \n", "13 2.0 logloss 0.271 0.246 0.246 \n", "14 2.0 mse 0.078 0.073 0.073 \n", "15 3.0 acc 0.903 0.903 0.903 \n", "16 3.0 auc 0.772 0.826 0.826 \n", "17 3.0 f1 0.317 0.371 0.371 \n", "18 3.0 logloss 0.276 0.252 0.252 \n", "19 3.0 mse 0.081 0.077 0.077 \n", "20 4.0 acc 0.895 0.897 0.897 \n", "21 4.0 auc 0.754 0.831 0.831 \n", "22 4.0 f1 0.323 0.401 0.401 \n", "23 4.0 logloss 0.296 0.263 0.263 \n", "24 4.0 mse 0.087 0.080 0.080 \n", "\n", " group3_rem_piml_EBM group3_rem_piml_EBM2 group5_rem_xgb2 \\\n", "0 0.900 0.901 0.901 \n", "1 0.163 0.821 0.836 \n", "2 0.182 0.381 0.392 \n", "3 3.257 0.262 0.254 \n", "4 0.773 0.078 0.077 \n", "5 0.906 0.906 0.906 \n", "6 0.172 0.810 0.822 \n", "7 0.172 0.348 0.360 \n", "8 3.253 0.258 0.250 \n", "9 0.778 0.077 0.075 \n", "10 0.908 0.908 0.910 \n", "11 0.175 0.815 0.826 \n", "12 0.169 0.354 0.371 \n", "13 3.284 0.251 0.245 \n", "14 0.781 0.074 0.073 \n", "15 0.903 0.903 0.903 \n", "16 0.174 0.809 0.823 \n", "17 0.177 0.361 0.365 \n", "18 3.254 0.262 0.253 \n", "19 0.775 0.079 0.077 \n", "20 0.895 0.895 0.898 \n", "21 0.170 0.818 0.828 \n", "22 0.190 0.404 0.397 \n", "23 3.200 0.273 0.266 \n", "24 0.771 0.082 0.080 \n", "\n", " group8_rem_ebm group9_rem_xgb ph_rem_ebm group1_rem_ebm_rank \\\n", "0 0.901 0.900 0.901 8.0 \n", "1 0.793 0.797 0.791 8.0 \n", "2 0.342 0.357 0.347 6.5 \n", "3 0.274 0.277 0.275 8.0 \n", "4 0.081 0.081 0.081 8.0 \n", "5 0.906 0.906 0.906 5.0 \n", "6 0.774 0.779 0.772 8.0 \n", "7 0.319 0.329 0.321 8.0 \n", "8 0.270 0.271 0.272 7.5 \n", "9 0.079 0.078 0.079 7.0 \n", "10 0.908 0.908 0.909 6.0 \n", "11 0.781 0.772 0.780 8.0 \n", "12 0.315 0.320 0.323 8.0 \n", "13 0.264 0.271 0.264 7.5 \n", "14 0.076 0.077 0.076 8.0 \n", "15 0.903 0.903 0.903 5.0 \n", "16 0.775 0.786 0.772 7.5 \n", "17 0.328 0.343 0.323 8.0 \n", "18 0.275 0.275 0.276 7.5 \n", "19 0.080 0.080 0.080 8.0 \n", "20 0.895 0.896 0.895 7.0 \n", "21 0.785 0.779 0.782 8.0 \n", "22 0.364 0.354 0.362 8.0 \n", "23 0.286 0.291 0.287 8.0 \n", "24 0.084 0.086 0.084 8.0 \n", "\n", " group2_rem_ebm_rank group2_rem_ebm2_rank group3_rem_piml_EBM_rank \\\n", "0 3.5 3.5 8.0 \n", "1 1.5 1.5 9.0 \n", "2 1.5 1.5 9.0 \n", "3 1.5 1.5 9.0 \n", "4 2.0 2.0 9.0 \n", "5 5.0 5.0 5.0 \n", "6 1.5 1.5 9.0 \n", "7 1.5 1.5 9.0 \n", "8 1.5 1.5 9.0 \n", "9 1.5 1.5 9.0 \n", "10 6.0 6.0 6.0 \n", "11 2.5 2.5 9.0 \n", "12 1.5 1.5 9.0 \n", "13 2.5 2.5 9.0 \n", "14 2.0 2.0 9.0 \n", "15 5.0 5.0 5.0 \n", "16 1.5 1.5 9.0 \n", "17 1.5 1.5 9.0 \n", "18 1.5 1.5 9.0 \n", "19 2.0 2.0 9.0 \n", "20 2.5 2.5 7.0 \n", "21 1.5 1.5 9.0 \n", "22 2.5 2.5 9.0 \n", "23 1.5 1.5 9.0 \n", "24 2.0 2.0 9.0 \n", "\n", " group3_rem_piml_EBM2_rank group5_rem_xgb2_rank group8_rem_ebm_rank \\\n", "0 3.5 3.5 3.5 \n", "1 4.0 3.0 6.0 \n", "2 4.0 3.0 8.0 \n", "3 4.0 3.0 5.0 \n", "4 4.0 2.0 6.0 \n", "5 5.0 5.0 5.0 \n", "6 4.0 3.0 6.0 \n", "7 4.0 3.0 7.0 \n", "8 4.0 3.0 5.0 \n", "9 4.0 3.0 7.0 \n", "10 6.0 1.0 6.0 \n", "11 4.0 1.0 5.0 \n", "12 4.0 3.0 7.0 \n", "13 4.0 1.0 5.5 \n", "14 4.0 2.0 5.5 \n", "15 5.0 5.0 5.0 \n", "16 4.0 3.0 6.0 \n", "17 4.0 3.0 6.0 \n", "18 4.0 3.0 5.5 \n", "19 4.0 2.0 6.0 \n", "20 7.0 1.0 7.0 \n", "21 4.0 3.0 5.0 \n", "22 1.0 4.0 5.0 \n", "23 4.0 3.0 5.0 \n", "24 4.0 2.0 5.5 \n", "\n", " group9_rem_xgb_rank ph_rem_ebm_rank \n", "0 8.0 3.5 \n", "1 5.0 7.0 \n", "2 5.0 6.5 \n", "3 7.0 6.0 \n", "4 6.0 6.0 \n", "5 5.0 5.0 \n", "6 5.0 7.0 \n", "7 5.0 6.0 \n", "8 6.0 7.5 \n", "9 5.0 7.0 \n", "10 6.0 2.0 \n", "11 7.0 6.0 \n", "12 6.0 5.0 \n", "13 7.5 5.5 \n", "14 7.0 5.5 \n", "15 5.0 5.0 \n", "16 5.0 7.5 \n", "17 5.0 7.0 \n", "18 5.5 7.5 \n", "19 6.0 6.0 \n", "20 4.0 7.0 \n", "21 7.0 6.0 \n", "22 7.0 6.0 \n", "23 7.0 6.0 \n", "24 7.0 5.5 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_frame = pd.DataFrame() # init frame to hold score ranking\n", "metric_list = ['acc', 'auc', 'f1', 'logloss', 'mse'] # metric to use for evaluation\n", "\n", "# create eval frame row-by-row\n", "for fold in sorted(scores_frame['fold'].unique()): # loop through folds \n", " for metric_name in metric_list: # loop through metrics\n", " \n", " # init row dict to hold each rows values\n", " row_dict = {'fold': fold,\n", " 'metric': metric_name}\n", " \n", " # cache known y values for fold\n", " fold_y = scores_frame.loc[scores_frame['fold'] == fold, y_name]\n", " \n", " for col_name in scores_frame.columns[2:]:\n", " \n", " # cache fold scores\n", " fold_scores = scores_frame.loc[scores_frame['fold'] == fold, col_name]\n", " \n", " # calculate evaluation metric for fold\n", " # with reasonable precision \n", " \n", " if metric_name == 'acc':\n", " row_dict[col_name] = np.round(max_acc(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'auc':\n", " row_dict[col_name] = np.round(roc_auc_score(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'f1':\n", " row_dict[col_name] = np.round(max_f1(fold_y, fold_scores), ROUND) \n", " \n", " if metric_name == 'logloss':\n", " row_dict[col_name] = np.round(log_loss(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'mse':\n", " row_dict[col_name] = np.round(mean_squared_error(fold_y, fold_scores), ROUND)\n", " \n", " # append row values to eval_frame\n", " eval_frame = eval_frame.append(row_dict, ignore_index=True)\n", "\n", "# init a temporary frame to hold rank information\n", "rank_names = [name + '_rank' for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]\n", "rank_frame = pd.DataFrame(columns=rank_names) \n", "\n", "# set columns to necessary order\n", "eval_frame = eval_frame[['fold', 'metric'] + [name for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]]\n", "\n", "# determine score ranks row-by-row\n", "for i in range(0, eval_frame.shape[0]):\n", " \n", " # get ranks for row based on metric\n", " metric_name = eval_frame.loc[i, 'metric']\n", " if metric_name in ['logloss', 'mse']:\n", " ranks = eval_frame.iloc[i, 2:].rank().values\n", " else:\n", " ranks = eval_frame.iloc[i, 2:].rank(ascending=False).values\n", " \n", " # create single-row frame and append to rank_frame\n", " row_frame = pd.DataFrame(ranks.reshape(1, ranks.shape[0]), columns=rank_names)\n", " rank_frame = rank_frame.append(row_frame, ignore_index=True)\n", " \n", " # house keeping\n", " del row_frame\n", "\n", "# merge ranks onto eval_frame\n", "eval_frame = pd.concat([eval_frame, rank_frame], axis=1)\n", "\n", "# house keeping\n", "del rank_frame\n", " \n", "eval_frame" ] }, { "cell_type": "markdown", "id": "37ed3b5f", "metadata": {}, "source": [ "#### Save `eval_frame` as CSV" ] }, { "cell_type": "code", "execution_count": 7, "id": "aa89d862", "metadata": {}, "outputs": [], "source": [ "eval_frame.to_csv('model_eval_' + str(datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\") + '.csv'), \n", " index=False)" ] }, { "cell_type": "markdown", "id": "4525d3ea", "metadata": {}, "source": [ "#### Display simple ranked score list " ] }, { "cell_type": "code", "execution_count": 8, "id": "f8ff5fa5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "group2_rem_ebm_rank 2.28\n", "group2_rem_ebm2_rank 2.28\n", "group5_rem_xgb2_rank 2.74\n", "group3_rem_piml_EBM2_rank 4.14\n", "group8_rem_ebm_rank 5.74\n", "group9_rem_xgb_rank 5.96\n", "ph_rem_ebm_rank 5.96\n", "group1_rem_ebm_rank 7.46\n", "group3_rem_piml_EBM_rank 8.44\n", "dtype: float64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_frame[[name for name in eval_frame.columns if name.endswith('rank')]].mean().sort_values()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }