{ "cells": [ { "cell_type": "markdown", "id": "f7efc033", "metadata": {}, "source": [ "## License \n", "\n", "Copyright 2021-2023 Patrick Hall (jphall@gwu.edu)\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License.\n", "\n", "*DISCLAIMER*: This notebook is not legal or compliance advice." ] }, { "cell_type": "markdown", "id": "aab60b41", "metadata": {}, "source": [ "# Model Evaluation Notebook" ] }, { "cell_type": "markdown", "id": "281af306", "metadata": {}, "source": [ "#### Imports and inits" ] }, { "cell_type": "code", "execution_count": 1, "id": "fd180587", "metadata": {}, "outputs": [], "source": [ "import os # for directory and file manipulation\n", "import numpy as np # for basic array manipulation\n", "import pandas as pd # for dataframe manipulation\n", "import datetime # for timestamp\n", "\n", "# for model eval\n", "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n", "\n", "# global constants \n", "ROUND = 3 # generally, insane precision is not needed \n", "SEED = 12345 # seed for better reproducibility\n", "\n", "# set global random seed for better reproducibility\n", "np.random.seed(SEED)" ] }, { "cell_type": "markdown", "id": "eb2a39d4", "metadata": {}, "source": [ "#### Set basic metadata" ] }, { "cell_type": "code", "execution_count": 2, "id": "98f640ed", "metadata": {}, "outputs": [], "source": [ "y_name = 'high_priced'\n", "scores_dir = 'data/scores'" ] }, { "cell_type": "markdown", "id": "cc8d83d0", "metadata": {}, "source": [ "#### Read in score files " ] }, { "cell_type": "code", "execution_count": 3, "id": "355c2b81", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | high_priced | \n", "fold | \n", "group1_rem_ebm | \n", "group2_rem_ebm | \n", "group2_rem_ebm2 | \n", "group3_rem_piml_EBM | \n", "group3_rem_piml_EBM2 | \n", "group5_rem_xgb2 | \n", "group8_rem_ebm | \n", "group9_rem_xgb | \n", "ph_rem_ebm | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "2 | \n", "0.118787 | \n", "0.080557 | \n", "0.080557 | \n", "0.920389 | \n", "0.136749 | \n", "0.078326 | \n", "0.223846 | \n", "0.081792 | \n", "0.219429 | \n", "
1 | \n", "0.0 | \n", "1 | \n", "0.084506 | \n", "0.026001 | \n", "0.026001 | \n", "0.969301 | \n", "0.053751 | \n", "0.035825 | \n", "0.053926 | \n", "0.110702 | \n", "0.053929 | \n", "
2 | \n", "1.0 | \n", "4 | \n", "0.210389 | \n", "0.194961 | \n", "0.194961 | \n", "0.814272 | \n", "0.182311 | \n", "0.195332 | \n", "0.143522 | \n", "0.204048 | \n", "0.133863 | \n", "
3 | \n", "0.0 | \n", "1 | \n", "0.008529 | \n", "0.028556 | \n", "0.028556 | \n", "0.974559 | \n", "0.004065 | \n", "0.022765 | \n", "0.009371 | \n", "0.024038 | \n", "0.014419 | \n", "
4 | \n", "1.0 | \n", "2 | \n", "0.189933 | \n", "0.208263 | \n", "0.208263 | \n", "0.802908 | \n", "0.211120 | \n", "0.193035 | \n", "0.151100 | \n", "0.170243 | \n", "0.156047 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
19826 | \n", "0.0 | \n", "3 | \n", "0.163697 | \n", "0.228342 | \n", "0.228342 | \n", "0.792251 | \n", "0.209322 | \n", "0.235192 | \n", "0.216720 | \n", "0.181403 | \n", "0.184214 | \n", "
19827 | \n", "0.0 | \n", "1 | \n", "0.114999 | \n", "0.253998 | \n", "0.253998 | \n", "0.762946 | \n", "0.206744 | \n", "0.235832 | \n", "0.161401 | \n", "0.159468 | \n", "0.141663 | \n", "
19828 | \n", "1.0 | \n", "3 | \n", "0.141307 | \n", "0.213364 | \n", "0.213364 | \n", "0.747401 | \n", "0.246610 | \n", "0.208723 | \n", "0.242814 | \n", "0.138141 | \n", "0.233266 | \n", "
19829 | \n", "0.0 | \n", "1 | \n", "0.007766 | \n", "0.002176 | \n", "0.002176 | \n", "0.996455 | \n", "0.000268 | \n", "0.018702 | \n", "0.005657 | \n", "0.034570 | \n", "0.009914 | \n", "
19830 | \n", "0.0 | \n", "0 | \n", "0.163946 | \n", "0.185484 | \n", "0.185484 | \n", "0.811429 | \n", "0.177857 | \n", "0.215085 | \n", "0.167812 | \n", "0.177785 | \n", "0.155447 | \n", "
19831 rows × 11 columns
\n", "\n", " | fold | \n", "metric | \n", "group1_rem_ebm | \n", "group2_rem_ebm | \n", "group2_rem_ebm2 | \n", "group3_rem_piml_EBM | \n", "group3_rem_piml_EBM2 | \n", "group5_rem_xgb2 | \n", "group8_rem_ebm | \n", "group9_rem_xgb | \n", "ph_rem_ebm | \n", "group1_rem_ebm_rank | \n", "group2_rem_ebm_rank | \n", "group2_rem_ebm2_rank | \n", "group3_rem_piml_EBM_rank | \n", "group3_rem_piml_EBM2_rank | \n", "group5_rem_xgb2_rank | \n", "group8_rem_ebm_rank | \n", "group9_rem_xgb_rank | \n", "ph_rem_ebm_rank | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "acc | \n", "0.900 | \n", "0.901 | \n", "0.901 | \n", "0.900 | \n", "0.901 | \n", "0.901 | \n", "0.901 | \n", "0.900 | \n", "0.901 | \n", "8.0 | \n", "3.5 | \n", "3.5 | \n", "8.0 | \n", "3.5 | \n", "3.5 | \n", "3.5 | \n", "8.0 | \n", "3.5 | \n", "
1 | \n", "0.0 | \n", "auc | \n", "0.781 | \n", "0.840 | \n", "0.840 | \n", "0.163 | \n", "0.821 | \n", "0.836 | \n", "0.793 | \n", "0.797 | \n", "0.791 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "6.0 | \n", "5.0 | \n", "7.0 | \n", "
2 | \n", "0.0 | \n", "f1 | \n", "0.347 | \n", "0.405 | \n", "0.405 | \n", "0.182 | \n", "0.381 | \n", "0.392 | \n", "0.342 | \n", "0.357 | \n", "0.347 | \n", "6.5 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "8.0 | \n", "5.0 | \n", "6.5 | \n", "
3 | \n", "0.0 | \n", "logloss | \n", "0.280 | \n", "0.251 | \n", "0.251 | \n", "3.257 | \n", "0.262 | \n", "0.254 | \n", "0.274 | \n", "0.277 | \n", "0.275 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "7.0 | \n", "6.0 | \n", "
4 | \n", "0.0 | \n", "mse | \n", "0.082 | \n", "0.077 | \n", "0.077 | \n", "0.773 | \n", "0.078 | \n", "0.077 | \n", "0.081 | \n", "0.081 | \n", "0.081 | \n", "8.0 | \n", "2.0 | \n", "2.0 | \n", "9.0 | \n", "4.0 | \n", "2.0 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "
5 | \n", "1.0 | \n", "acc | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "0.906 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "
6 | \n", "1.0 | \n", "auc | \n", "0.767 | \n", "0.828 | \n", "0.828 | \n", "0.172 | \n", "0.810 | \n", "0.822 | \n", "0.774 | \n", "0.779 | \n", "0.772 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "6.0 | \n", "5.0 | \n", "7.0 | \n", "
7 | \n", "1.0 | \n", "f1 | \n", "0.312 | \n", "0.368 | \n", "0.368 | \n", "0.172 | \n", "0.348 | \n", "0.360 | \n", "0.319 | \n", "0.329 | \n", "0.321 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "7.0 | \n", "5.0 | \n", "6.0 | \n", "
8 | \n", "1.0 | \n", "logloss | \n", "0.272 | \n", "0.246 | \n", "0.246 | \n", "3.253 | \n", "0.258 | \n", "0.250 | \n", "0.270 | \n", "0.271 | \n", "0.272 | \n", "7.5 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "6.0 | \n", "7.5 | \n", "
9 | \n", "1.0 | \n", "mse | \n", "0.079 | \n", "0.074 | \n", "0.074 | \n", "0.778 | \n", "0.077 | \n", "0.075 | \n", "0.079 | \n", "0.078 | \n", "0.079 | \n", "7.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "7.0 | \n", "5.0 | \n", "7.0 | \n", "
10 | \n", "2.0 | \n", "acc | \n", "0.908 | \n", "0.908 | \n", "0.908 | \n", "0.908 | \n", "0.908 | \n", "0.910 | \n", "0.908 | \n", "0.908 | \n", "0.909 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "1.0 | \n", "6.0 | \n", "6.0 | \n", "2.0 | \n", "
11 | \n", "2.0 | \n", "auc | \n", "0.759 | \n", "0.825 | \n", "0.825 | \n", "0.175 | \n", "0.815 | \n", "0.826 | \n", "0.781 | \n", "0.772 | \n", "0.780 | \n", "8.0 | \n", "2.5 | \n", "2.5 | \n", "9.0 | \n", "4.0 | \n", "1.0 | \n", "5.0 | \n", "7.0 | \n", "6.0 | \n", "
12 | \n", "2.0 | \n", "f1 | \n", "0.304 | \n", "0.372 | \n", "0.372 | \n", "0.169 | \n", "0.354 | \n", "0.371 | \n", "0.315 | \n", "0.320 | \n", "0.323 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "7.0 | \n", "6.0 | \n", "5.0 | \n", "
13 | \n", "2.0 | \n", "logloss | \n", "0.271 | \n", "0.246 | \n", "0.246 | \n", "3.284 | \n", "0.251 | \n", "0.245 | \n", "0.264 | \n", "0.271 | \n", "0.264 | \n", "7.5 | \n", "2.5 | \n", "2.5 | \n", "9.0 | \n", "4.0 | \n", "1.0 | \n", "5.5 | \n", "7.5 | \n", "5.5 | \n", "
14 | \n", "2.0 | \n", "mse | \n", "0.078 | \n", "0.073 | \n", "0.073 | \n", "0.781 | \n", "0.074 | \n", "0.073 | \n", "0.076 | \n", "0.077 | \n", "0.076 | \n", "8.0 | \n", "2.0 | \n", "2.0 | \n", "9.0 | \n", "4.0 | \n", "2.0 | \n", "5.5 | \n", "7.0 | \n", "5.5 | \n", "
15 | \n", "3.0 | \n", "acc | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "0.903 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "
16 | \n", "3.0 | \n", "auc | \n", "0.772 | \n", "0.826 | \n", "0.826 | \n", "0.174 | \n", "0.809 | \n", "0.823 | \n", "0.775 | \n", "0.786 | \n", "0.772 | \n", "7.5 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "6.0 | \n", "5.0 | \n", "7.5 | \n", "
17 | \n", "3.0 | \n", "f1 | \n", "0.317 | \n", "0.371 | \n", "0.371 | \n", "0.177 | \n", "0.361 | \n", "0.365 | \n", "0.328 | \n", "0.343 | \n", "0.323 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "6.0 | \n", "5.0 | \n", "7.0 | \n", "
18 | \n", "3.0 | \n", "logloss | \n", "0.276 | \n", "0.252 | \n", "0.252 | \n", "3.254 | \n", "0.262 | \n", "0.253 | \n", "0.275 | \n", "0.275 | \n", "0.276 | \n", "7.5 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "5.5 | \n", "5.5 | \n", "7.5 | \n", "
19 | \n", "3.0 | \n", "mse | \n", "0.081 | \n", "0.077 | \n", "0.077 | \n", "0.775 | \n", "0.079 | \n", "0.077 | \n", "0.080 | \n", "0.080 | \n", "0.080 | \n", "8.0 | \n", "2.0 | \n", "2.0 | \n", "9.0 | \n", "4.0 | \n", "2.0 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "
20 | \n", "4.0 | \n", "acc | \n", "0.895 | \n", "0.897 | \n", "0.897 | \n", "0.895 | \n", "0.895 | \n", "0.898 | \n", "0.895 | \n", "0.896 | \n", "0.895 | \n", "7.0 | \n", "2.5 | \n", "2.5 | \n", "7.0 | \n", "7.0 | \n", "1.0 | \n", "7.0 | \n", "4.0 | \n", "7.0 | \n", "
21 | \n", "4.0 | \n", "auc | \n", "0.754 | \n", "0.831 | \n", "0.831 | \n", "0.170 | \n", "0.818 | \n", "0.828 | \n", "0.785 | \n", "0.779 | \n", "0.782 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "7.0 | \n", "6.0 | \n", "
22 | \n", "4.0 | \n", "f1 | \n", "0.323 | \n", "0.401 | \n", "0.401 | \n", "0.190 | \n", "0.404 | \n", "0.397 | \n", "0.364 | \n", "0.354 | \n", "0.362 | \n", "8.0 | \n", "2.5 | \n", "2.5 | \n", "9.0 | \n", "1.0 | \n", "4.0 | \n", "5.0 | \n", "7.0 | \n", "6.0 | \n", "
23 | \n", "4.0 | \n", "logloss | \n", "0.296 | \n", "0.263 | \n", "0.263 | \n", "3.200 | \n", "0.273 | \n", "0.266 | \n", "0.286 | \n", "0.291 | \n", "0.287 | \n", "8.0 | \n", "1.5 | \n", "1.5 | \n", "9.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "7.0 | \n", "6.0 | \n", "
24 | \n", "4.0 | \n", "mse | \n", "0.087 | \n", "0.080 | \n", "0.080 | \n", "0.771 | \n", "0.082 | \n", "0.080 | \n", "0.084 | \n", "0.086 | \n", "0.084 | \n", "8.0 | \n", "2.0 | \n", "2.0 | \n", "9.0 | \n", "4.0 | \n", "2.0 | \n", "5.5 | \n", "7.0 | \n", "5.5 | \n", "