{ "cells": [ { "cell_type": "markdown", "id": "f7efc033", "metadata": {}, "source": [ "## License \n", "\n", "Copyright 2021-2025 Patrick Hall (jphall@gwu.edu)\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License.\n", "\n", "*DISCLAIMER*: This notebook is not legal or compliance advice." ] }, { "cell_type": "markdown", "id": "aab60b41", "metadata": {}, "source": [ "# Model Evaluation Notebook" ] }, { "cell_type": "markdown", "id": "281af306", "metadata": {}, "source": [ "#### Imports and inits" ] }, { "cell_type": "code", "execution_count": 1, "id": "fd180587", "metadata": {}, "outputs": [], "source": [ "import os # for directory and file manipulation\n", "import numpy as np # for basic array manipulation\n", "import pandas as pd # for dataframe manipulation\n", "import datetime # for timestamp\n", "\n", "# for model eval\n", "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n", "\n", "# global constants \n", "ROUND = 3 # generally, insane precision is not needed \n", "SEED = 12345 # seed for better reproducibility\n", "\n", "# set global random seed for better reproducibility\n", "np.random.seed(SEED)" ] }, { "cell_type": "markdown", "id": "eb2a39d4", "metadata": {}, "source": [ "#### Set basic metadata" ] }, { "cell_type": "code", "execution_count": 2, "id": "98f640ed", "metadata": {}, "outputs": [], "source": [ "y_name = 'high_priced'\n", "scores_dir = 'data/scores'" ] }, { "cell_type": "markdown", "id": "cc8d83d0", "metadata": {}, "source": [ "#### Read in score files " ] }, { "cell_type": "code", "execution_count": 3, "id": "355c2b81", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
high_pricedfoldGroup11_nmodels50_rem_ebmGroup11_s23456_rem_ebmGroup11_s309_rem_ebmGroup11_s78600_rem_ebmgroup10_rem_ebmgroup2_rem_ebm_01_20_15group2_rem_ebm_22_17_12group3_rem_ebm_baselinegroup3_rem_ebm_higherAUCgroup4_rem_ebmgroup5_rem_ebmgroup6_rem_ebmgroup7_rem_ebmgroup8_rem_ebmgroup9_rem_ebmph_advval_rem_ebm_300
00.020.1375400.1397480.2052710.1793290.1712820.1857770.1722270.1431100.2232920.1674590.1901140.1359340.2114640.2318740.2179850.046616
10.010.0878200.0829080.0616750.0508350.0524470.0545260.0480650.0861340.0560140.0503920.0507830.0897820.0569150.0507600.0516020.107210
21.040.2058590.1755630.1382440.1342980.1427250.1154590.1439080.1784890.1449380.1349640.1269080.1466950.1453230.1446450.1508860.205463
30.010.0042150.0057340.0045960.0056230.0120510.0173980.0036930.0053060.0035170.0107640.0047220.0099070.0098200.0034670.0031050.026128
41.020.1579310.1282200.1715610.1318820.1280830.1270620.1342110.1305220.1578400.1287870.1455170.1136840.1519890.1597070.1578560.205463
.........................................................
198260.030.1604170.1755630.1802650.2225890.2321890.1786840.2359760.1784890.2059580.2328470.1852100.1520400.1872600.2022800.1907820.161056
198270.010.0911540.1134000.1588040.1308240.1306600.1197420.1314740.1177660.1710420.1310070.1390190.1096610.1563640.1679310.1510610.166683
198281.030.1184010.0973930.2395170.2020360.2125460.2103690.2159910.1003390.2566410.2116100.2164550.0897820.2384780.2349410.2364550.205463
198290.010.0009510.0279180.0013400.0237910.0116420.0067760.0005450.0289720.0005170.0124900.0007690.0031520.0064740.0004770.0000840.006582
198300.000.1185910.1672270.1354000.2084030.2150050.1594520.2194840.1717490.1573230.2129930.1459020.1466950.1267890.1434960.1325110.161920
\n", "

19831 rows × 18 columns

\n", "
" ], "text/plain": [ " high_priced fold Group11_nmodels50_rem_ebm Group11_s23456_rem_ebm \\\n", "0 0.0 2 0.137540 0.139748 \n", "1 0.0 1 0.087820 0.082908 \n", "2 1.0 4 0.205859 0.175563 \n", "3 0.0 1 0.004215 0.005734 \n", "4 1.0 2 0.157931 0.128220 \n", "... ... ... ... ... \n", "19826 0.0 3 0.160417 0.175563 \n", "19827 0.0 1 0.091154 0.113400 \n", "19828 1.0 3 0.118401 0.097393 \n", "19829 0.0 1 0.000951 0.027918 \n", "19830 0.0 0 0.118591 0.167227 \n", "\n", " Group11_s309_rem_ebm Group11_s78600_rem_ebm group10_rem_ebm \\\n", "0 0.205271 0.179329 0.171282 \n", "1 0.061675 0.050835 0.052447 \n", "2 0.138244 0.134298 0.142725 \n", "3 0.004596 0.005623 0.012051 \n", "4 0.171561 0.131882 0.128083 \n", "... ... ... ... \n", "19826 0.180265 0.222589 0.232189 \n", "19827 0.158804 0.130824 0.130660 \n", "19828 0.239517 0.202036 0.212546 \n", "19829 0.001340 0.023791 0.011642 \n", "19830 0.135400 0.208403 0.215005 \n", "\n", " group2_rem_ebm_01_20_15 group2_rem_ebm_22_17_12 \\\n", "0 0.185777 0.172227 \n", "1 0.054526 0.048065 \n", "2 0.115459 0.143908 \n", "3 0.017398 0.003693 \n", "4 0.127062 0.134211 \n", "... ... ... \n", "19826 0.178684 0.235976 \n", "19827 0.119742 0.131474 \n", "19828 0.210369 0.215991 \n", "19829 0.006776 0.000545 \n", "19830 0.159452 0.219484 \n", "\n", " group3_rem_ebm_baseline group3_rem_ebm_higherAUC group4_rem_ebm \\\n", "0 0.143110 0.223292 0.167459 \n", "1 0.086134 0.056014 0.050392 \n", "2 0.178489 0.144938 0.134964 \n", "3 0.005306 0.003517 0.010764 \n", "4 0.130522 0.157840 0.128787 \n", "... ... ... ... \n", "19826 0.178489 0.205958 0.232847 \n", "19827 0.117766 0.171042 0.131007 \n", "19828 0.100339 0.256641 0.211610 \n", "19829 0.028972 0.000517 0.012490 \n", "19830 0.171749 0.157323 0.212993 \n", "\n", " group5_rem_ebm group6_rem_ebm group7_rem_ebm group8_rem_ebm \\\n", "0 0.190114 0.135934 0.211464 0.231874 \n", "1 0.050783 0.089782 0.056915 0.050760 \n", "2 0.126908 0.146695 0.145323 0.144645 \n", "3 0.004722 0.009907 0.009820 0.003467 \n", "4 0.145517 0.113684 0.151989 0.159707 \n", "... ... ... ... ... \n", "19826 0.185210 0.152040 0.187260 0.202280 \n", "19827 0.139019 0.109661 0.156364 0.167931 \n", "19828 0.216455 0.089782 0.238478 0.234941 \n", "19829 0.000769 0.003152 0.006474 0.000477 \n", "19830 0.145902 0.146695 0.126789 0.143496 \n", "\n", " group9_rem_ebm ph_advval_rem_ebm_300 \n", "0 0.217985 0.046616 \n", "1 0.051602 0.107210 \n", "2 0.150886 0.205463 \n", "3 0.003105 0.026128 \n", "4 0.157856 0.205463 \n", "... ... ... \n", "19826 0.190782 0.161056 \n", "19827 0.151061 0.166683 \n", "19828 0.236455 0.205463 \n", "19829 0.000084 0.006582 \n", "19830 0.132511 0.161920 \n", "\n", "[19831 rows x 18 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# init score frame with known test y values\n", "scores_frame = pd.read_csv(scores_dir + os.sep +'key.csv', index_col='Unnamed: 0')\n", "\n", "# create random folds in reproducible way\n", "np.random.seed(SEED)\n", "scores_frame['fold'] = np.random.choice(5, scores_frame.shape[0])\n", "\n", "# read in each score file in the directory as a new column \n", "for file in sorted(os.listdir(scores_dir)):\n", " if file != 'key.csv' and file.endswith('.csv'):\n", " scores_frame[file[:-4]] = pd.read_csv(scores_dir + os.sep + file)['phat']\n", "\n", "# sanity check \n", "scores_frame" ] }, { "cell_type": "markdown", "id": "3e3cccda", "metadata": {}, "source": [ "#### Utility function for max. accuracy" ] }, { "cell_type": "code", "execution_count": 4, "id": "2eb43506", "metadata": {}, "outputs": [], "source": [ "def max_acc(y, phat, res=0.01): \n", "\n", " \"\"\" Utility function for finding max. accuracy at some cutoff. \n", " \n", " :param y: Known y values.\n", " :param phat: Model scores.\n", " :param res: Resolution over which to search for max. accuracy, default 0.01.\n", " :return: Max. accuracy for model scores.\n", " \n", " \"\"\"\n", " \n", " # init frame to store acc at different cutoffs\n", " acc_frame = pd.DataFrame(columns=['cut', 'acc'])\n", " \n", " # copy known y and score values into a temporary frame\n", " temp_df = pd.concat([y, phat], axis=1)\n", " \n", " # find accuracy at different cutoffs and store in acc_frame\n", " for cut in np.arange(0, 1 + res, res):\n", " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", " acc = accuracy_score(temp_df.iloc[:, 0], temp_df['decision'])\n", " acc_frame = acc_frame.append({'cut': cut,\n", " 'acc': acc},\n", " ignore_index=True)\n", "\n", " # find max accurcay across all cutoffs\n", " max_acc = acc_frame['acc'].max()\n", " \n", " # house keeping\n", " del acc_frame, temp_df\n", " \n", " return max_acc" ] }, { "cell_type": "markdown", "id": "b02c9651", "metadata": {}, "source": [ "#### Utility function for max. F1" ] }, { "cell_type": "code", "execution_count": 5, "id": "fae3756b", "metadata": {}, "outputs": [], "source": [ "def max_f1(y, phat, res=0.01): \n", " \n", " \"\"\" Utility function for finding max. F1 at some cutoff. \n", " \n", " :param y: Known y values.\n", " :param phat: Model scores.\n", " :param res: Resolution over which to search for max. F1, default 0.01.\n", " :return: Max. F1 for model scores.\n", " \n", " \"\"\"\n", " \n", " # init frame to store f1 at different cutoffs\n", " f1_frame = pd.DataFrame(columns=['cut', 'f1'])\n", " \n", " # copy known y and score values into a temporary frame\n", " temp_df = pd.concat([y, phat], axis=1)\n", " \n", " # find f1 at different cutoffs and store in acc_frame\n", " for cut in np.arange(0, 1 + res, res):\n", " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", " f1 = f1_score(temp_df.iloc[:, 0], temp_df['decision'])\n", " f1_frame = f1_frame.append({'cut': cut,\n", " 'f1': f1},\n", " ignore_index=True)\n", " \n", " # find max f1 across all cutoffs\n", " max_f1 = f1_frame['f1'].max()\n", " \n", " # house keeping\n", " del f1_frame, temp_df\n", " \n", " return max_f1" ] }, { "cell_type": "markdown", "id": "b447b732", "metadata": {}, "source": [ "#### Rank all submitted scores " ] }, { "cell_type": "code", "execution_count": 6, "id": "40fbe608", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
foldmetricGroup11_nmodels50_rem_ebmGroup11_s23456_rem_ebmGroup11_s309_rem_ebmGroup11_s78600_rem_ebmgroup10_rem_ebmgroup2_rem_ebm_01_20_15group2_rem_ebm_22_17_12group3_rem_ebm_baseline...group2_rem_ebm_22_17_12_rankgroup3_rem_ebm_baseline_rankgroup3_rem_ebm_higherAUC_rankgroup4_rem_ebm_rankgroup5_rem_ebm_rankgroup6_rem_ebm_rankgroup7_rem_ebm_rankgroup8_rem_ebm_rankgroup9_rem_ebm_rankph_advval_rem_ebm_300_rank
00.0acc0.9000.9000.9010.9010.9010.9010.9010.900...4.512.512.54.512.512.54.512.512.54.5
10.0auc0.7580.7310.7940.7760.7750.7540.7810.741...6.014.01.010.011.016.05.04.03.09.0
20.0f10.3260.3010.3530.3320.3310.3120.3350.309...7.014.03.011.012.016.06.04.52.04.5
30.0logloss0.2870.2950.2730.2790.2800.2870.2770.292...6.014.02.010.011.016.07.04.54.52.0
40.0mse0.0830.0850.0810.0810.0820.0830.0810.084...4.014.04.09.012.016.04.04.04.09.0
51.0acc0.9060.9060.9060.9060.9060.9060.9060.906...8.58.58.58.58.58.58.58.58.58.5
61.0auc0.7390.7050.7770.7460.7490.7340.7560.710...7.014.01.09.011.016.06.02.54.05.0
71.0f10.2810.2620.3260.2920.2960.2790.2950.265...8.514.03.010.58.516.04.01.55.06.0
81.0logloss0.2790.2880.2700.2780.2780.2810.2750.287...7.014.03.09.59.516.06.03.03.01.0
91.0mse0.0800.0820.0780.0800.0800.0800.0790.081...5.514.05.510.510.516.05.52.05.52.0
102.0acc0.9080.9080.9080.9080.9080.9090.9080.908...10.010.010.010.01.010.010.010.010.02.5
112.0auc0.7250.6930.7840.7630.7600.7460.7700.705...7.014.01.511.010.016.05.53.04.05.5
122.0f10.2750.2610.3280.3070.3040.2940.3160.268...6.014.03.010.011.016.04.52.04.57.0
132.0logloss0.2800.2880.2620.2680.2700.2730.2660.285...6.014.02.011.09.516.07.04.54.52.0
142.0mse0.0800.0810.0760.0770.0780.0770.0770.081...7.514.52.511.57.516.07.52.52.57.5
153.0acc0.9030.9030.9030.9030.9030.9030.9030.903...8.58.58.58.58.58.58.58.58.58.5
163.0auc0.7300.7070.7740.7520.7470.7300.7610.711...6.514.02.010.011.016.06.54.05.01.0
173.0f10.2880.2710.3220.3010.3000.2920.3090.274...7.014.02.010.011.016.06.03.04.01.0
183.0logloss0.2870.2930.2750.2810.2830.2880.2780.292...6.014.02.010.011.016.07.04.04.01.0
193.0mse0.0830.0840.0800.0810.0810.0830.0810.084...7.014.53.010.510.516.07.07.03.01.0
204.0acc0.8950.8950.8960.8960.8960.8950.8960.895...4.012.012.04.012.012.012.04.012.04.0
214.0auc0.7180.6970.7870.7710.7680.7440.7730.707...6.014.01.010.011.016.05.04.03.09.0
224.0f10.2910.2790.3650.3470.3410.3320.3440.287...7.014.02.09.012.016.05.03.04.011.0
234.0logloss0.3060.3110.2850.2900.2910.3020.2900.308...7.514.01.010.011.016.05.54.02.55.5
244.0mse0.0890.0890.0840.0850.0850.0870.0850.089...7.514.03.07.510.516.03.03.03.010.5
\n", "

25 rows × 34 columns

\n", "
" ], "text/plain": [ " fold metric Group11_nmodels50_rem_ebm Group11_s23456_rem_ebm \\\n", "0 0.0 acc 0.900 0.900 \n", "1 0.0 auc 0.758 0.731 \n", "2 0.0 f1 0.326 0.301 \n", "3 0.0 logloss 0.287 0.295 \n", "4 0.0 mse 0.083 0.085 \n", "5 1.0 acc 0.906 0.906 \n", "6 1.0 auc 0.739 0.705 \n", "7 1.0 f1 0.281 0.262 \n", "8 1.0 logloss 0.279 0.288 \n", "9 1.0 mse 0.080 0.082 \n", "10 2.0 acc 0.908 0.908 \n", "11 2.0 auc 0.725 0.693 \n", "12 2.0 f1 0.275 0.261 \n", "13 2.0 logloss 0.280 0.288 \n", "14 2.0 mse 0.080 0.081 \n", "15 3.0 acc 0.903 0.903 \n", "16 3.0 auc 0.730 0.707 \n", "17 3.0 f1 0.288 0.271 \n", "18 3.0 logloss 0.287 0.293 \n", "19 3.0 mse 0.083 0.084 \n", "20 4.0 acc 0.895 0.895 \n", "21 4.0 auc 0.718 0.697 \n", "22 4.0 f1 0.291 0.279 \n", "23 4.0 logloss 0.306 0.311 \n", "24 4.0 mse 0.089 0.089 \n", "\n", " Group11_s309_rem_ebm Group11_s78600_rem_ebm group10_rem_ebm \\\n", "0 0.901 0.901 0.901 \n", "1 0.794 0.776 0.775 \n", "2 0.353 0.332 0.331 \n", "3 0.273 0.279 0.280 \n", "4 0.081 0.081 0.082 \n", "5 0.906 0.906 0.906 \n", "6 0.777 0.746 0.749 \n", "7 0.326 0.292 0.296 \n", "8 0.270 0.278 0.278 \n", "9 0.078 0.080 0.080 \n", "10 0.908 0.908 0.908 \n", "11 0.784 0.763 0.760 \n", "12 0.328 0.307 0.304 \n", "13 0.262 0.268 0.270 \n", "14 0.076 0.077 0.078 \n", "15 0.903 0.903 0.903 \n", "16 0.774 0.752 0.747 \n", "17 0.322 0.301 0.300 \n", "18 0.275 0.281 0.283 \n", "19 0.080 0.081 0.081 \n", "20 0.896 0.896 0.896 \n", "21 0.787 0.771 0.768 \n", "22 0.365 0.347 0.341 \n", "23 0.285 0.290 0.291 \n", "24 0.084 0.085 0.085 \n", "\n", " group2_rem_ebm_01_20_15 group2_rem_ebm_22_17_12 group3_rem_ebm_baseline \\\n", "0 0.901 0.901 0.900 \n", "1 0.754 0.781 0.741 \n", "2 0.312 0.335 0.309 \n", "3 0.287 0.277 0.292 \n", "4 0.083 0.081 0.084 \n", "5 0.906 0.906 0.906 \n", "6 0.734 0.756 0.710 \n", "7 0.279 0.295 0.265 \n", "8 0.281 0.275 0.287 \n", "9 0.080 0.079 0.081 \n", "10 0.909 0.908 0.908 \n", "11 0.746 0.770 0.705 \n", "12 0.294 0.316 0.268 \n", "13 0.273 0.266 0.285 \n", "14 0.077 0.077 0.081 \n", "15 0.903 0.903 0.903 \n", "16 0.730 0.761 0.711 \n", "17 0.292 0.309 0.274 \n", "18 0.288 0.278 0.292 \n", "19 0.083 0.081 0.084 \n", "20 0.895 0.896 0.895 \n", "21 0.744 0.773 0.707 \n", "22 0.332 0.344 0.287 \n", "23 0.302 0.290 0.308 \n", "24 0.087 0.085 0.089 \n", "\n", " ... group2_rem_ebm_22_17_12_rank group3_rem_ebm_baseline_rank \\\n", "0 ... 4.5 12.5 \n", "1 ... 6.0 14.0 \n", "2 ... 7.0 14.0 \n", "3 ... 6.0 14.0 \n", "4 ... 4.0 14.0 \n", "5 ... 8.5 8.5 \n", "6 ... 7.0 14.0 \n", "7 ... 8.5 14.0 \n", "8 ... 7.0 14.0 \n", "9 ... 5.5 14.0 \n", "10 ... 10.0 10.0 \n", "11 ... 7.0 14.0 \n", "12 ... 6.0 14.0 \n", "13 ... 6.0 14.0 \n", "14 ... 7.5 14.5 \n", "15 ... 8.5 8.5 \n", "16 ... 6.5 14.0 \n", "17 ... 7.0 14.0 \n", "18 ... 6.0 14.0 \n", "19 ... 7.0 14.5 \n", "20 ... 4.0 12.0 \n", "21 ... 6.0 14.0 \n", "22 ... 7.0 14.0 \n", "23 ... 7.5 14.0 \n", "24 ... 7.5 14.0 \n", "\n", " group3_rem_ebm_higherAUC_rank group4_rem_ebm_rank group5_rem_ebm_rank \\\n", "0 12.5 4.5 12.5 \n", "1 1.0 10.0 11.0 \n", "2 3.0 11.0 12.0 \n", "3 2.0 10.0 11.0 \n", "4 4.0 9.0 12.0 \n", "5 8.5 8.5 8.5 \n", "6 1.0 9.0 11.0 \n", "7 3.0 10.5 8.5 \n", "8 3.0 9.5 9.5 \n", "9 5.5 10.5 10.5 \n", "10 10.0 10.0 1.0 \n", "11 1.5 11.0 10.0 \n", "12 3.0 10.0 11.0 \n", "13 2.0 11.0 9.5 \n", "14 2.5 11.5 7.5 \n", "15 8.5 8.5 8.5 \n", "16 2.0 10.0 11.0 \n", "17 2.0 10.0 11.0 \n", "18 2.0 10.0 11.0 \n", "19 3.0 10.5 10.5 \n", "20 12.0 4.0 12.0 \n", "21 1.0 10.0 11.0 \n", "22 2.0 9.0 12.0 \n", "23 1.0 10.0 11.0 \n", "24 3.0 7.5 10.5 \n", "\n", " group6_rem_ebm_rank group7_rem_ebm_rank group8_rem_ebm_rank \\\n", "0 12.5 4.5 12.5 \n", "1 16.0 5.0 4.0 \n", "2 16.0 6.0 4.5 \n", "3 16.0 7.0 4.5 \n", "4 16.0 4.0 4.0 \n", "5 8.5 8.5 8.5 \n", "6 16.0 6.0 2.5 \n", "7 16.0 4.0 1.5 \n", "8 16.0 6.0 3.0 \n", "9 16.0 5.5 2.0 \n", "10 10.0 10.0 10.0 \n", "11 16.0 5.5 3.0 \n", "12 16.0 4.5 2.0 \n", "13 16.0 7.0 4.5 \n", "14 16.0 7.5 2.5 \n", "15 8.5 8.5 8.5 \n", "16 16.0 6.5 4.0 \n", "17 16.0 6.0 3.0 \n", "18 16.0 7.0 4.0 \n", "19 16.0 7.0 7.0 \n", "20 12.0 12.0 4.0 \n", "21 16.0 5.0 4.0 \n", "22 16.0 5.0 3.0 \n", "23 16.0 5.5 4.0 \n", "24 16.0 3.0 3.0 \n", "\n", " group9_rem_ebm_rank ph_advval_rem_ebm_300_rank \n", "0 12.5 4.5 \n", "1 3.0 9.0 \n", "2 2.0 4.5 \n", "3 4.5 2.0 \n", "4 4.0 9.0 \n", "5 8.5 8.5 \n", "6 4.0 5.0 \n", "7 5.0 6.0 \n", "8 3.0 1.0 \n", "9 5.5 2.0 \n", "10 10.0 2.5 \n", "11 4.0 5.5 \n", "12 4.5 7.0 \n", "13 4.5 2.0 \n", "14 2.5 7.5 \n", "15 8.5 8.5 \n", "16 5.0 1.0 \n", "17 4.0 1.0 \n", "18 4.0 1.0 \n", "19 3.0 1.0 \n", "20 12.0 4.0 \n", "21 3.0 9.0 \n", "22 4.0 11.0 \n", "23 2.5 5.5 \n", "24 3.0 10.5 \n", "\n", "[25 rows x 34 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_frame = pd.DataFrame() # init frame to hold score ranking\n", "metric_list = ['acc', 'auc', 'f1', 'logloss', 'mse'] # metric to use for evaluation\n", "\n", "# create eval frame row-by-row\n", "for fold in sorted(scores_frame['fold'].unique()): # loop through folds \n", " for metric_name in metric_list: # loop through metrics\n", " \n", " # init row dict to hold each rows values\n", " row_dict = {'fold': fold,\n", " 'metric': metric_name}\n", " \n", " # cache known y values for fold\n", " fold_y = scores_frame.loc[scores_frame['fold'] == fold, y_name]\n", " \n", " for col_name in scores_frame.columns[2:]:\n", " \n", " # cache fold scores\n", " fold_scores = scores_frame.loc[scores_frame['fold'] == fold, col_name]\n", " \n", " # calculate evaluation metric for fold\n", " # with reasonable precision \n", " \n", " if metric_name == 'acc':\n", " row_dict[col_name] = np.round(max_acc(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'auc':\n", " row_dict[col_name] = np.round(roc_auc_score(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'f1':\n", " row_dict[col_name] = np.round(max_f1(fold_y, fold_scores), ROUND) \n", " \n", " if metric_name == 'logloss':\n", " row_dict[col_name] = np.round(log_loss(fold_y, fold_scores), ROUND)\n", " \n", " if metric_name == 'mse':\n", " row_dict[col_name] = np.round(mean_squared_error(fold_y, fold_scores), ROUND)\n", " \n", " # append row values to eval_frame\n", " eval_frame = eval_frame.append(row_dict, ignore_index=True)\n", "\n", "# init a temporary frame to hold rank information\n", "rank_names = [name + '_rank' for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]\n", "rank_frame = pd.DataFrame(columns=rank_names) \n", "\n", "# set columns to necessary order\n", "eval_frame = eval_frame[['fold', 'metric'] + [name for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]]\n", "\n", "# determine score ranks row-by-row\n", "for i in range(0, eval_frame.shape[0]):\n", " \n", " # get ranks for row based on metric\n", " metric_name = eval_frame.loc[i, 'metric']\n", " if metric_name in ['logloss', 'mse']:\n", " ranks = eval_frame.iloc[i, 2:].rank().values\n", " else:\n", " ranks = eval_frame.iloc[i, 2:].rank(ascending=False).values\n", " \n", " # create single-row frame and append to rank_frame\n", " row_frame = pd.DataFrame(ranks.reshape(1, ranks.shape[0]), columns=rank_names)\n", " rank_frame = rank_frame.append(row_frame, ignore_index=True)\n", " \n", " # house keeping\n", " del row_frame\n", "\n", "# merge ranks onto eval_frame\n", "eval_frame = pd.concat([eval_frame, rank_frame], axis=1)\n", "\n", "# house keeping\n", "del rank_frame\n", " \n", "eval_frame" ] }, { "cell_type": "markdown", "id": "37ed3b5f", "metadata": {}, "source": [ "#### Save `eval_frame` as CSV" ] }, { "cell_type": "code", "execution_count": 7, "id": "aa89d862", "metadata": {}, "outputs": [], "source": [ "eval_frame.to_csv('model_eval_' + str(datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\") + '.csv'), \n", " index=False)" ] }, { "cell_type": "markdown", "id": "4525d3ea", "metadata": {}, "source": [ "#### Display simple ranked score list " ] }, { "cell_type": "code", "execution_count": 8, "id": "f8ff5fa5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Group11_s309_rem_ebm_rank 3.44\n", "group3_rem_ebm_higherAUC_rank 3.96\n", "group8_rem_ebm_rank 4.54\n", "group9_rem_ebm_rank 5.06\n", "ph_advval_rem_ebm_300_rank 5.14\n", "group7_rem_ebm_rank 6.26\n", "group2_rem_ebm_22_17_12_rank 6.70\n", "Group11_s78600_rem_ebm_rank 7.74\n", "group10_rem_ebm_rank 8.44\n", "group4_rem_ebm_rank 9.42\n", "group5_rem_ebm_rank 10.16\n", "group2_rem_ebm_01_20_15_rank 11.02\n", "Group11_nmodels50_rem_ebm_rank 11.98\n", "group3_rem_ebm_baseline_rank 13.30\n", "Group11_s23456_rem_ebm_rank 13.98\n", "group6_rem_ebm_rank 14.86\n", "dtype: float64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_frame[[name for name in eval_frame.columns if name.endswith('rank')]].mean().sort_values()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.17" } }, "nbformat": 4, "nbformat_minor": 5 }