{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
" "
]
},
{
"cell_type": "markdown",
"source": [
"## BreezeGen - Imbalanced Classification\n",
"\n",
"Analyze an imbalanced dataset. Train models using the imbalanced-learn library for over- and under-sampling, and compare model performance.\n",
"\n",
"BreezeGen is a maintenance company specializing in wind turbines. They have supplied ciphered sensor data. Generator maintenance is costly, so preventing a full replacement is paramount.\n",
"\n",
"* An inspection costs \\$5,000 regardless of whether any repairs end up being necessary.\n",
"\n",
"* A repair costs \\$15,000.\n",
"\n",
"* A complete device replacement costs \\$40,000. We obviously want to minimize the frequency of replacement.\n",
"\n",
"The target variable encodes the failure state, with 0 indicating no failure and 1 signalling failure."
],
"metadata": {
"id": "M5GT5IPIB9FL"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "jjFpJBnb4jak"
},
"source": [
"## Importing libraries"
]
},
{
"cell_type": "markdown",
"source": [
"First update scikit-learn and imbalanced-learn libraries."
],
"metadata": {
"id": "cHuqjfjK39GX"
}
},
{
"cell_type": "code",
"source": [
"! pip install scikit-learn -U"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "V2uT9Z0_vm0C",
"outputId": "a9d0c943-0b8f-41e0-8bb6-a4e5ac07a5e4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2)\n",
"Collecting scikit-learn\n",
" Downloading scikit_learn-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.7/9.7 MB\u001b[0m \u001b[31m52.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0)\n",
"Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0)\n",
"Installing collected packages: scikit-learn\n",
" Attempting uninstall: scikit-learn\n",
" Found existing installation: scikit-learn 1.0.2\n",
" Uninstalling scikit-learn-1.0.2:\n",
" Successfully uninstalled scikit-learn-1.0.2\n",
"Successfully installed scikit-learn-1.2.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"! pip install imbalanced-learn -U"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "REbQIklqyoZd",
"outputId": "c180c302-1ec8-4478-9efe-cf8837514686"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: imbalanced-learn in /usr/local/lib/python3.8/dist-packages (0.8.1)\n",
"Collecting imbalanced-learn\n",
" Downloading imbalanced_learn-0.10.1-py3-none-any.whl (226 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m226.0/226.0 KB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.8/dist-packages (from imbalanced-learn) (1.21.6)\n",
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.8/dist-packages (from imbalanced-learn) (1.7.3)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from imbalanced-learn) (3.1.0)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from imbalanced-learn) (1.2.0)\n",
"Requirement already satisfied: scikit-learn>=1.0.2 in /usr/local/lib/python3.8/dist-packages (from imbalanced-learn) (1.2.0)\n",
"Installing collected packages: imbalanced-learn\n",
" Attempting uninstall: imbalanced-learn\n",
" Found existing installation: imbalanced-learn 0.8.1\n",
" Uninstalling imbalanced-learn-0.8.1:\n",
" Successfully uninstalled imbalanced-learn-0.8.1\n",
"Successfully installed imbalanced-learn-0.10.1\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Then we can import the necessary libraries."
],
"metadata": {
"id": "jraQT35h4D0v"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "83D17_Wl4jal"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"# data processing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import KNNImputer\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"\n",
"# models building\n",
"from sklearn.model_selection import StratifiedKFold, cross_val_score\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.ensemble import BaggingClassifier, RandomForestClassifier\n",
"from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier\n",
"from sklearn.ensemble import VotingClassifier\n",
"from xgboost import XGBClassifier\n",
"\n",
"# model assessment and production\n",
"from sklearn import metrics\n",
"from sklearn.model_selection import RandomizedSearchCV, GridSearchCV\n",
"from sklearn.pipeline import Pipeline, make_pipeline\n",
"from sklearn.ensemble import StackingClassifier"
]
},
{
"cell_type": "code",
"source": [
"# display setups\n",
"\n",
"# plotting theme\n",
"sns.set_theme()\n",
"\n",
"# dataframe display all columns\n",
"pd.set_option('display.max_columns',None)"
],
"metadata": {
"id": "nfI1WIP10ca5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vqF4q7G94jam"
},
"source": [
"## Loading Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oJnKoHy14jam"
},
"outputs": [],
"source": [
"df=pd.read_csv('dataset_train.csv')"
]
},
{
"cell_type": "code",
"source": [
"data=df.copy()"
],
"metadata": {
"id": "8-2KXw_jz4FU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DVCj6_DD4jan"
},
"source": [
"## EDA"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4e-xjd2YhKyj",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 427
},
"outputId": "b402b3e8-ceec-4bfa-edff-e13a254b44df"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" V1 V2 V3 V4 V5 V6 V7 \\\n",
"3841 -1.761261 1.725575 3.115833 -0.722076 1.636716 -0.665690 -1.947430 \n",
"12898 -0.319513 -3.223125 6.945362 -6.073991 0.560963 -1.151006 -1.368454 \n",
"15032 3.660333 -0.500171 1.673259 -0.759881 -3.446873 -0.177149 -0.150797 \n",
"36781 -2.031378 -7.075963 1.609581 -2.808285 -2.626699 -3.161704 -2.401735 \n",
"9201 -1.167479 4.636223 -1.011644 0.316850 3.275565 -0.310062 1.388811 \n",
"21288 -3.187048 3.810940 0.850021 1.188452 3.331867 -1.399935 -1.550283 \n",
"37321 5.669580 1.844251 7.218002 1.555360 -3.158366 -2.388045 -0.937380 \n",
"8600 -2.459670 -2.393870 3.225757 0.063384 -0.580143 -2.797114 -0.634555 \n",
"33089 1.001489 0.983718 -2.561203 3.374567 2.495414 -0.720727 0.067609 \n",
"39511 -3.648807 -1.589335 -0.575212 0.649347 1.329551 -2.730112 -2.246243 \n",
"\n",
" V8 V9 V10 V11 V12 V13 V14 \\\n",
"3841 0.340637 -0.434492 -1.861200 -0.748612 1.483686 3.233458 -1.749122 \n",
"12898 -0.203942 -3.529692 5.112241 -3.728806 2.433346 3.220423 1.316607 \n",
"15032 -0.962328 0.798321 -0.378106 2.456049 1.879260 2.194074 -1.375423 \n",
"36781 2.357427 -1.254624 2.679572 -4.866034 5.473158 5.339906 2.596990 \n",
"9201 4.651706 -4.848882 -2.074384 3.231886 5.032511 -2.634830 -2.825284 \n",
"21288 0.635168 -0.389022 -2.848663 -3.003486 3.366907 2.550711 -2.365782 \n",
"37321 -2.299472 3.368328 -0.475084 -1.593526 -0.763408 4.474542 -1.536328 \n",
"8600 2.118902 -2.976534 2.924087 -3.354634 4.535033 0.790923 0.388404 \n",
"33089 3.662200 0.327424 -1.763364 -3.399646 -4.088555 -5.765949 1.293401 \n",
"39511 2.806568 -0.600226 -1.071473 -5.598787 4.039250 2.624120 0.526896 \n",
"\n",
" V15 V16 V17 V18 V19 V20 V21 \\\n",
"3841 -5.027887 -1.380684 -1.399922 2.132650 -1.505987 -2.451394 -4.512434 \n",
"12898 -1.652448 -0.902062 -6.406294 2.609491 -0.583582 -3.180174 -3.682714 \n",
"15032 -1.303997 -3.645691 -1.916947 -0.564766 3.382290 -0.034927 -5.565015 \n",
"36781 -5.179196 -2.976010 -0.568500 3.978816 1.724852 6.053674 -5.327246 \n",
"9201 1.290487 6.432091 0.251476 -1.082105 0.626093 -7.337963 0.912257 \n",
"21288 -6.294215 -2.694268 1.701079 3.318560 -0.941472 -2.522348 -4.691335 \n",
"37321 -2.689027 -7.051011 -1.965468 -2.414158 3.769097 0.893762 -9.227253 \n",
"8600 -0.780539 0.137287 -0.147010 0.193798 4.075666 1.939926 -1.893998 \n",
"33089 4.038515 3.405233 6.860836 -0.479386 -3.237459 2.401008 4.419077 \n",
"39511 -5.738545 -2.256579 3.684117 4.407944 -0.596813 3.524609 -3.745387 \n",
"\n",
" V22 V23 V24 V25 V26 V27 V28 \\\n",
"3841 0.036412 0.488204 2.010276 -2.194374 5.600887 -4.419503 -3.459057 \n",
"12898 -0.107914 -3.111936 1.132707 1.378427 2.399556 5.679397 -3.810228 \n",
"15032 2.032844 -0.867979 -1.966172 0.320279 -0.502116 2.028311 -0.768731 \n",
"36781 3.560616 3.365471 -0.391724 1.284162 -5.200010 2.720958 -0.473470 \n",
"9201 -0.578727 5.034060 8.326799 -4.863529 5.333237 -8.192814 -1.593566 \n",
"21288 1.915986 4.572873 6.014727 -3.252703 6.471669 -7.581418 -2.034078 \n",
"37321 1.080952 -4.955377 -3.369827 2.379708 4.442623 3.109184 -2.405830 \n",
"8600 0.262768 -0.026678 3.035784 0.127010 -0.972351 0.181684 0.249843 \n",
"33089 -0.393963 2.743099 4.115304 1.274670 -1.566217 -4.673000 2.077497 \n",
"39511 3.215177 6.282294 4.594613 -1.079856 -0.412728 -5.065619 -0.157468 \n",
"\n",
" V29 V30 V31 V32 V33 V34 V35 \\\n",
"3841 1.848695 3.183373 3.131519 4.559034 5.339557 -4.860047 3.971955 \n",
"12898 -4.106580 -3.610371 9.925717 2.086696 0.300373 -0.581224 2.849599 \n",
"15032 1.730565 3.023456 -2.705545 -3.397233 0.711632 0.350068 5.073397 \n",
"36781 -4.738653 0.069819 -1.207865 0.449110 -5.094617 2.979341 2.545381 \n",
"9201 5.291905 6.553607 5.233752 10.877396 8.250367 -5.159849 4.674169 \n",
"21288 -0.284077 2.482489 0.878916 7.499800 3.778780 -5.057147 2.754992 \n",
"37321 -1.814943 -1.632428 0.736474 -4.127790 -1.698119 -0.144663 6.158802 \n",
"8600 -2.562915 -0.478893 2.429896 2.392210 -2.040141 3.390597 2.891005 \n",
"33089 -0.692994 -1.068473 2.815314 5.783949 -1.469405 -2.904733 -2.098224 \n",
"39511 -3.058230 1.398945 -1.026855 6.162144 -1.610211 -1.636248 1.450378 \n",
"\n",
" V36 V37 V38 V39 V40 Target \n",
"3841 1.084243 -0.286629 -2.123846 0.088027 0.467075 0 \n",
"12898 11.783096 -1.036491 -2.656733 4.048503 -0.043561 0 \n",
"15032 -2.271627 -0.376924 -0.115856 -0.630295 -1.382311 0 \n",
"36781 7.661457 2.721677 -5.888977 3.233609 -2.895987 0 \n",
"9201 -3.516928 -1.730339 1.969370 -3.425079 2.554127 0 \n",
"21288 -0.312605 0.304002 0.109490 -0.080787 -0.263837 0 \n",
"37321 -0.501844 -2.075539 -3.096800 0.029782 -2.895171 0 \n",
"8600 5.108812 0.085175 -1.349147 2.056523 -2.479685 0 \n",
"33089 -2.400717 -0.437158 -3.385150 -3.896752 3.312009 1 \n",
"39511 2.772455 2.196238 -3.787583 0.694094 -1.008109 0 "
],
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" V1 \n",
" V2 \n",
" V3 \n",
" V4 \n",
" V5 \n",
" V6 \n",
" V7 \n",
" V8 \n",
" V9 \n",
" V10 \n",
" V11 \n",
" V12 \n",
" V13 \n",
" V14 \n",
" V15 \n",
" V16 \n",
" V17 \n",
" V18 \n",
" V19 \n",
" V20 \n",
" V21 \n",
" V22 \n",
" V23 \n",
" V24 \n",
" V25 \n",
" V26 \n",
" V27 \n",
" V28 \n",
" V29 \n",
" V30 \n",
" V31 \n",
" V32 \n",
" V33 \n",
" V34 \n",
" V35 \n",
" V36 \n",
" V37 \n",
" V38 \n",
" V39 \n",
" V40 \n",
" Target \n",
" \n",
" \n",
" \n",
" \n",
" 3841 \n",
" -1.761261 \n",
" 1.725575 \n",
" 3.115833 \n",
" -0.722076 \n",
" 1.636716 \n",
" -0.665690 \n",
" -1.947430 \n",
" 0.340637 \n",
" -0.434492 \n",
" -1.861200 \n",
" -0.748612 \n",
" 1.483686 \n",
" 3.233458 \n",
" -1.749122 \n",
" -5.027887 \n",
" -1.380684 \n",
" -1.399922 \n",
" 2.132650 \n",
" -1.505987 \n",
" -2.451394 \n",
" -4.512434 \n",
" 0.036412 \n",
" 0.488204 \n",
" 2.010276 \n",
" -2.194374 \n",
" 5.600887 \n",
" -4.419503 \n",
" -3.459057 \n",
" 1.848695 \n",
" 3.183373 \n",
" 3.131519 \n",
" 4.559034 \n",
" 5.339557 \n",
" -4.860047 \n",
" 3.971955 \n",
" 1.084243 \n",
" -0.286629 \n",
" -2.123846 \n",
" 0.088027 \n",
" 0.467075 \n",
" 0 \n",
" \n",
" \n",
" 12898 \n",
" -0.319513 \n",
" -3.223125 \n",
" 6.945362 \n",
" -6.073991 \n",
" 0.560963 \n",
" -1.151006 \n",
" -1.368454 \n",
" -0.203942 \n",
" -3.529692 \n",
" 5.112241 \n",
" -3.728806 \n",
" 2.433346 \n",
" 3.220423 \n",
" 1.316607 \n",
" -1.652448 \n",
" -0.902062 \n",
" -6.406294 \n",
" 2.609491 \n",
" -0.583582 \n",
" -3.180174 \n",
" -3.682714 \n",
" -0.107914 \n",
" -3.111936 \n",
" 1.132707 \n",
" 1.378427 \n",
" 2.399556 \n",
" 5.679397 \n",
" -3.810228 \n",
" -4.106580 \n",
" -3.610371 \n",
" 9.925717 \n",
" 2.086696 \n",
" 0.300373 \n",
" -0.581224 \n",
" 2.849599 \n",
" 11.783096 \n",
" -1.036491 \n",
" -2.656733 \n",
" 4.048503 \n",
" -0.043561 \n",
" 0 \n",
" \n",
" \n",
" 15032 \n",
" 3.660333 \n",
" -0.500171 \n",
" 1.673259 \n",
" -0.759881 \n",
" -3.446873 \n",
" -0.177149 \n",
" -0.150797 \n",
" -0.962328 \n",
" 0.798321 \n",
" -0.378106 \n",
" 2.456049 \n",
" 1.879260 \n",
" 2.194074 \n",
" -1.375423 \n",
" -1.303997 \n",
" -3.645691 \n",
" -1.916947 \n",
" -0.564766 \n",
" 3.382290 \n",
" -0.034927 \n",
" -5.565015 \n",
" 2.032844 \n",
" -0.867979 \n",
" -1.966172 \n",
" 0.320279 \n",
" -0.502116 \n",
" 2.028311 \n",
" -0.768731 \n",
" 1.730565 \n",
" 3.023456 \n",
" -2.705545 \n",
" -3.397233 \n",
" 0.711632 \n",
" 0.350068 \n",
" 5.073397 \n",
" -2.271627 \n",
" -0.376924 \n",
" -0.115856 \n",
" -0.630295 \n",
" -1.382311 \n",
" 0 \n",
" \n",
" \n",
" 36781 \n",
" -2.031378 \n",
" -7.075963 \n",
" 1.609581 \n",
" -2.808285 \n",
" -2.626699 \n",
" -3.161704 \n",
" -2.401735 \n",
" 2.357427 \n",
" -1.254624 \n",
" 2.679572 \n",
" -4.866034 \n",
" 5.473158 \n",
" 5.339906 \n",
" 2.596990 \n",
" -5.179196 \n",
" -2.976010 \n",
" -0.568500 \n",
" 3.978816 \n",
" 1.724852 \n",
" 6.053674 \n",
" -5.327246 \n",
" 3.560616 \n",
" 3.365471 \n",
" -0.391724 \n",
" 1.284162 \n",
" -5.200010 \n",
" 2.720958 \n",
" -0.473470 \n",
" -4.738653 \n",
" 0.069819 \n",
" -1.207865 \n",
" 0.449110 \n",
" -5.094617 \n",
" 2.979341 \n",
" 2.545381 \n",
" 7.661457 \n",
" 2.721677 \n",
" -5.888977 \n",
" 3.233609 \n",
" -2.895987 \n",
" 0 \n",
" \n",
" \n",
" 9201 \n",
" -1.167479 \n",
" 4.636223 \n",
" -1.011644 \n",
" 0.316850 \n",
" 3.275565 \n",
" -0.310062 \n",
" 1.388811 \n",
" 4.651706 \n",
" -4.848882 \n",
" -2.074384 \n",
" 3.231886 \n",
" 5.032511 \n",
" -2.634830 \n",
" -2.825284 \n",
" 1.290487 \n",
" 6.432091 \n",
" 0.251476 \n",
" -1.082105 \n",
" 0.626093 \n",
" -7.337963 \n",
" 0.912257 \n",
" -0.578727 \n",
" 5.034060 \n",
" 8.326799 \n",
" -4.863529 \n",
" 5.333237 \n",
" -8.192814 \n",
" -1.593566 \n",
" 5.291905 \n",
" 6.553607 \n",
" 5.233752 \n",
" 10.877396 \n",
" 8.250367 \n",
" -5.159849 \n",
" 4.674169 \n",
" -3.516928 \n",
" -1.730339 \n",
" 1.969370 \n",
" -3.425079 \n",
" 2.554127 \n",
" 0 \n",
" \n",
" \n",
" 21288 \n",
" -3.187048 \n",
" 3.810940 \n",
" 0.850021 \n",
" 1.188452 \n",
" 3.331867 \n",
" -1.399935 \n",
" -1.550283 \n",
" 0.635168 \n",
" -0.389022 \n",
" -2.848663 \n",
" -3.003486 \n",
" 3.366907 \n",
" 2.550711 \n",
" -2.365782 \n",
" -6.294215 \n",
" -2.694268 \n",
" 1.701079 \n",
" 3.318560 \n",
" -0.941472 \n",
" -2.522348 \n",
" -4.691335 \n",
" 1.915986 \n",
" 4.572873 \n",
" 6.014727 \n",
" -3.252703 \n",
" 6.471669 \n",
" -7.581418 \n",
" -2.034078 \n",
" -0.284077 \n",
" 2.482489 \n",
" 0.878916 \n",
" 7.499800 \n",
" 3.778780 \n",
" -5.057147 \n",
" 2.754992 \n",
" -0.312605 \n",
" 0.304002 \n",
" 0.109490 \n",
" -0.080787 \n",
" -0.263837 \n",
" 0 \n",
" \n",
" \n",
" 37321 \n",
" 5.669580 \n",
" 1.844251 \n",
" 7.218002 \n",
" 1.555360 \n",
" -3.158366 \n",
" -2.388045 \n",
" -0.937380 \n",
" -2.299472 \n",
" 3.368328 \n",
" -0.475084 \n",
" -1.593526 \n",
" -0.763408 \n",
" 4.474542 \n",
" -1.536328 \n",
" -2.689027 \n",
" -7.051011 \n",
" -1.965468 \n",
" -2.414158 \n",
" 3.769097 \n",
" 0.893762 \n",
" -9.227253 \n",
" 1.080952 \n",
" -4.955377 \n",
" -3.369827 \n",
" 2.379708 \n",
" 4.442623 \n",
" 3.109184 \n",
" -2.405830 \n",
" -1.814943 \n",
" -1.632428 \n",
" 0.736474 \n",
" -4.127790 \n",
" -1.698119 \n",
" -0.144663 \n",
" 6.158802 \n",
" -0.501844 \n",
" -2.075539 \n",
" -3.096800 \n",
" 0.029782 \n",
" -2.895171 \n",
" 0 \n",
" \n",
" \n",
" 8600 \n",
" -2.459670 \n",
" -2.393870 \n",
" 3.225757 \n",
" 0.063384 \n",
" -0.580143 \n",
" -2.797114 \n",
" -0.634555 \n",
" 2.118902 \n",
" -2.976534 \n",
" 2.924087 \n",
" -3.354634 \n",
" 4.535033 \n",
" 0.790923 \n",
" 0.388404 \n",
" -0.780539 \n",
" 0.137287 \n",
" -0.147010 \n",
" 0.193798 \n",
" 4.075666 \n",
" 1.939926 \n",
" -1.893998 \n",
" 0.262768 \n",
" -0.026678 \n",
" 3.035784 \n",
" 0.127010 \n",
" -0.972351 \n",
" 0.181684 \n",
" 0.249843 \n",
" -2.562915 \n",
" -0.478893 \n",
" 2.429896 \n",
" 2.392210 \n",
" -2.040141 \n",
" 3.390597 \n",
" 2.891005 \n",
" 5.108812 \n",
" 0.085175 \n",
" -1.349147 \n",
" 2.056523 \n",
" -2.479685 \n",
" 0 \n",
" \n",
" \n",
" 33089 \n",
" 1.001489 \n",
" 0.983718 \n",
" -2.561203 \n",
" 3.374567 \n",
" 2.495414 \n",
" -0.720727 \n",
" 0.067609 \n",
" 3.662200 \n",
" 0.327424 \n",
" -1.763364 \n",
" -3.399646 \n",
" -4.088555 \n",
" -5.765949 \n",
" 1.293401 \n",
" 4.038515 \n",
" 3.405233 \n",
" 6.860836 \n",
" -0.479386 \n",
" -3.237459 \n",
" 2.401008 \n",
" 4.419077 \n",
" -0.393963 \n",
" 2.743099 \n",
" 4.115304 \n",
" 1.274670 \n",
" -1.566217 \n",
" -4.673000 \n",
" 2.077497 \n",
" -0.692994 \n",
" -1.068473 \n",
" 2.815314 \n",
" 5.783949 \n",
" -1.469405 \n",
" -2.904733 \n",
" -2.098224 \n",
" -2.400717 \n",
" -0.437158 \n",
" -3.385150 \n",
" -3.896752 \n",
" 3.312009 \n",
" 1 \n",
" \n",
" \n",
" 39511 \n",
" -3.648807 \n",
" -1.589335 \n",
" -0.575212 \n",
" 0.649347 \n",
" 1.329551 \n",
" -2.730112 \n",
" -2.246243 \n",
" 2.806568 \n",
" -0.600226 \n",
" -1.071473 \n",
" -5.598787 \n",
" 4.039250 \n",
" 2.624120 \n",
" 0.526896 \n",
" -5.738545 \n",
" -2.256579 \n",
" 3.684117 \n",
" 4.407944 \n",
" -0.596813 \n",
" 3.524609 \n",
" -3.745387 \n",
" 3.215177 \n",
" 6.282294 \n",
" 4.594613 \n",
" -1.079856 \n",
" -0.412728 \n",
" -5.065619 \n",
" -0.157468 \n",
" -3.058230 \n",
" 1.398945 \n",
" -1.026855 \n",
" 6.162144 \n",
" -1.610211 \n",
" -1.636248 \n",
" 1.450378 \n",
" 2.772455 \n",
" 2.196238 \n",
" -3.787583 \n",
" 0.694094 \n",
" -1.008109 \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 7
}
],
"source": [
"data.sample(10,random_state=1)"
]
},
{
"cell_type": "markdown",
"source": [
"* The ciphered data consists of floating point numbers with values near 0. There are both positive and negative values.\n",
"\n",
"* The target variable is categorical, Boolean in fact. There seems, at first glance, to be imbalance in the target classes."
],
"metadata": {
"id": "Q1JOFXdT8wzH"
}
},
{
"cell_type": "code",
"source": [
"data['Target'].value_counts(normalize=True)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "klLXk5rK65vn",
"outputId": "5c48a543-c603-499d-9127-82b074652d27"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0 0.945325\n",
"1 0.054675\n",
"Name: Target, dtype: float64"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"source": [
"Indeed, around 95% of the target observations are 0. From the data dictionary, we learn this indicates 'No Failure', which is ideal. Thus, only around 5% of observations are 'Failure' cases."
],
"metadata": {
"id": "x2Oq5OEb9jjH"
}
},
{
"cell_type": "code",
"source": [
"data.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "o5AVc3Eb7Nrp",
"outputId": "4dfbd329-e6f4-4a4d-a78e-37967bdf34d0"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"RangeIndex: 40000 entries, 0 to 39999\n",
"Data columns (total 41 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 V1 39954 non-null float64\n",
" 1 V2 39961 non-null float64\n",
" 2 V3 40000 non-null float64\n",
" 3 V4 40000 non-null float64\n",
" 4 V5 40000 non-null float64\n",
" 5 V6 40000 non-null float64\n",
" 6 V7 40000 non-null float64\n",
" 7 V8 40000 non-null float64\n",
" 8 V9 40000 non-null float64\n",
" 9 V10 40000 non-null float64\n",
" 10 V11 40000 non-null float64\n",
" 11 V12 40000 non-null float64\n",
" 12 V13 40000 non-null float64\n",
" 13 V14 40000 non-null float64\n",
" 14 V15 40000 non-null float64\n",
" 15 V16 40000 non-null float64\n",
" 16 V17 40000 non-null float64\n",
" 17 V18 40000 non-null float64\n",
" 18 V19 40000 non-null float64\n",
" 19 V20 40000 non-null float64\n",
" 20 V21 40000 non-null float64\n",
" 21 V22 40000 non-null float64\n",
" 22 V23 40000 non-null float64\n",
" 23 V24 40000 non-null float64\n",
" 24 V25 40000 non-null float64\n",
" 25 V26 40000 non-null float64\n",
" 26 V27 40000 non-null float64\n",
" 27 V28 40000 non-null float64\n",
" 28 V29 40000 non-null float64\n",
" 29 V30 40000 non-null float64\n",
" 30 V31 40000 non-null float64\n",
" 31 V32 40000 non-null float64\n",
" 32 V33 40000 non-null float64\n",
" 33 V34 40000 non-null float64\n",
" 34 V35 40000 non-null float64\n",
" 35 V36 40000 non-null float64\n",
" 36 V37 40000 non-null float64\n",
" 37 V38 40000 non-null float64\n",
" 38 V39 40000 non-null float64\n",
" 39 V40 40000 non-null float64\n",
" 40 Target 40000 non-null int64 \n",
"dtypes: float64(40), int64(1)\n",
"memory usage: 12.5 MB\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Just to double-check the integrity of our data, we confirm that the forty ciphered features are all stored as floats. This indicates to me that there aren't any errant entries, such as '?', which would force the column to be string data type.\n",
"\n",
"We see some columns have missing entries."
],
"metadata": {
"id": "9KjbX9C798pD"
}
},
{
"cell_type": "code",
"source": [
"data.isna().sum().loc[data.isna().sum()>0].index.tolist()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mwJ5LtR07X5m",
"outputId": "d93b6cfe-b5d1-45e3-a81e-418cb8439653"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['V1', 'V2']"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"source": [
"Columns 'V1' and 'V2' have missing values."
],
"metadata": {
"id": "713vuP3W-b-R"
}
},
{
"cell_type": "code",
"source": [
"data.describe().T"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "tOQZs1Ah8F_S",
"outputId": "cf90cdd9-5c28-4b5d-bfb3-f0e8160cccf2"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" count mean std min 25% 50% 75% \\\n",
"V1 39954.0 -0.288120 3.449072 -13.501880 -2.751460 -0.773518 1.836708 \n",
"V2 39961.0 0.442672 3.139431 -13.212051 -1.638355 0.463939 2.537508 \n",
"V3 40000.0 2.505514 3.406263 -11.469369 0.202682 2.265319 4.584920 \n",
"V4 40000.0 -0.066078 3.437330 -16.015417 -2.349574 -0.123691 2.148596 \n",
"V5 40000.0 -0.044574 2.107183 -8.612973 -1.507206 -0.096824 1.346224 \n",
"V6 40000.0 -1.000849 2.036756 -10.227147 -2.363446 -1.006635 0.373909 \n",
"V7 40000.0 -0.892793 1.756510 -8.205806 -2.036913 -0.934738 0.206820 \n",
"V8 40000.0 -0.563123 3.298916 -15.657561 -2.660415 -0.384188 1.714383 \n",
"V9 40000.0 -0.007739 2.161833 -8.596313 -1.493676 -0.052085 1.425713 \n",
"V10 40000.0 -0.001848 2.183034 -11.000790 -1.390549 0.105779 1.486105 \n",
"V11 40000.0 -1.917794 3.116426 -14.832058 -3.940969 -1.941726 0.089444 \n",
"V12 40000.0 1.578095 2.914613 -13.619304 -0.431373 1.485367 3.540787 \n",
"V13 40000.0 1.591309 2.865222 -13.830128 -0.208522 1.653836 3.476336 \n",
"V14 40000.0 -0.946620 1.787759 -8.309443 -2.164513 -0.957444 0.265874 \n",
"V15 40000.0 -2.435720 3.341244 -17.201998 -4.451365 -2.398608 -0.381757 \n",
"V16 40000.0 -2.943168 4.211646 -21.918711 -5.631812 -2.718600 -0.112947 \n",
"V17 40000.0 -0.142794 3.344332 -17.633947 -2.227048 -0.027895 2.071801 \n",
"V18 40000.0 1.188949 2.586164 -11.643994 -0.402848 0.867433 2.564239 \n",
"V19 40000.0 1.181333 3.394979 -13.491784 -1.050903 1.278402 3.497277 \n",
"V20 40000.0 0.027201 3.674985 -13.922659 -2.433811 0.030136 2.513245 \n",
"V21 40000.0 -3.621359 3.556979 -19.436404 -5.920847 -3.559327 -1.284178 \n",
"V22 40000.0 0.943242 1.645538 -10.122095 -0.112147 0.962802 2.018031 \n",
"V23 40000.0 -0.387617 4.052147 -16.187510 -3.118868 -0.275339 2.438047 \n",
"V24 40000.0 1.142220 3.912820 -18.487811 -1.483210 0.963586 3.563055 \n",
"V25 40000.0 -0.003019 2.024691 -8.228266 -1.373400 0.021100 1.399816 \n",
"V26 40000.0 1.895717 3.421454 -12.587902 -0.319231 1.963826 4.163146 \n",
"V27 40000.0 -0.616838 4.392161 -14.904939 -3.692075 -0.909640 2.200608 \n",
"V28 40000.0 -0.888121 1.924947 -9.685082 -2.192763 -0.904757 0.376856 \n",
"V29 40000.0 -1.005327 2.676299 -12.579469 -2.799008 -1.206027 0.604473 \n",
"V30 40000.0 -0.032664 3.031009 -14.796047 -1.908202 0.184613 2.040131 \n",
"V31 40000.0 0.505885 3.482735 -19.376732 -1.798975 0.491352 2.777519 \n",
"V32 40000.0 0.326831 5.499369 -23.200866 -3.392115 0.056243 3.789241 \n",
"V33 40000.0 0.056542 3.574219 -17.454014 -2.237550 -0.049729 2.255985 \n",
"V34 40000.0 -0.464127 3.185712 -17.985094 -2.127757 -0.250842 1.432885 \n",
"V35 40000.0 2.234861 2.924185 -15.349803 0.332081 2.110125 4.044659 \n",
"V36 40000.0 1.530020 3.819754 -17.478949 -0.937119 1.571511 3.996721 \n",
"V37 40000.0 -0.000498 1.778273 -7.639952 -1.265717 -0.132620 1.160828 \n",
"V38 40000.0 -0.351199 3.964186 -17.375002 -3.016805 -0.318724 2.291342 \n",
"V39 40000.0 0.900035 1.751022 -7.135788 -0.261578 0.921321 2.069016 \n",
"V40 40000.0 -0.897166 2.997750 -11.930259 -2.949590 -0.949269 1.092178 \n",
"Target 40000.0 0.054675 0.227348 0.000000 0.000000 0.000000 0.000000 \n",
"\n",
" max \n",
"V1 17.436981 \n",
"V2 13.089269 \n",
"V3 18.366477 \n",
"V4 13.279712 \n",
"V5 9.403469 \n",
"V6 7.065470 \n",
"V7 8.006091 \n",
"V8 11.679495 \n",
"V9 8.507138 \n",
"V10 8.108472 \n",
"V11 13.851834 \n",
"V12 15.753586 \n",
"V13 15.419616 \n",
"V14 6.213289 \n",
"V15 12.874679 \n",
"V16 13.583212 \n",
"V17 17.404510 \n",
"V18 13.179863 \n",
"V19 16.059004 \n",
"V20 16.052339 \n",
"V21 13.840473 \n",
"V22 7.409856 \n",
"V23 15.080172 \n",
"V24 19.769376 \n",
"V25 8.223389 \n",
"V26 16.836410 \n",
"V27 21.594552 \n",
"V28 6.906865 \n",
"V29 11.852476 \n",
"V30 13.190889 \n",
"V31 17.255090 \n",
"V32 24.847833 \n",
"V33 16.692486 \n",
"V34 14.358213 \n",
"V35 16.804859 \n",
"V36 19.329576 \n",
"V37 7.803278 \n",
"V38 15.964053 \n",
"V39 7.997832 \n",
"V40 10.654265 \n",
"Target 1.000000 "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" count \n",
" mean \n",
" std \n",
" min \n",
" 25% \n",
" 50% \n",
" 75% \n",
" max \n",
" \n",
" \n",
" \n",
" \n",
" V1 \n",
" 39954.0 \n",
" -0.288120 \n",
" 3.449072 \n",
" -13.501880 \n",
" -2.751460 \n",
" -0.773518 \n",
" 1.836708 \n",
" 17.436981 \n",
" \n",
" \n",
" V2 \n",
" 39961.0 \n",
" 0.442672 \n",
" 3.139431 \n",
" -13.212051 \n",
" -1.638355 \n",
" 0.463939 \n",
" 2.537508 \n",
" 13.089269 \n",
" \n",
" \n",
" V3 \n",
" 40000.0 \n",
" 2.505514 \n",
" 3.406263 \n",
" -11.469369 \n",
" 0.202682 \n",
" 2.265319 \n",
" 4.584920 \n",
" 18.366477 \n",
" \n",
" \n",
" V4 \n",
" 40000.0 \n",
" -0.066078 \n",
" 3.437330 \n",
" -16.015417 \n",
" -2.349574 \n",
" -0.123691 \n",
" 2.148596 \n",
" 13.279712 \n",
" \n",
" \n",
" V5 \n",
" 40000.0 \n",
" -0.044574 \n",
" 2.107183 \n",
" -8.612973 \n",
" -1.507206 \n",
" -0.096824 \n",
" 1.346224 \n",
" 9.403469 \n",
" \n",
" \n",
" V6 \n",
" 40000.0 \n",
" -1.000849 \n",
" 2.036756 \n",
" -10.227147 \n",
" -2.363446 \n",
" -1.006635 \n",
" 0.373909 \n",
" 7.065470 \n",
" \n",
" \n",
" V7 \n",
" 40000.0 \n",
" -0.892793 \n",
" 1.756510 \n",
" -8.205806 \n",
" -2.036913 \n",
" -0.934738 \n",
" 0.206820 \n",
" 8.006091 \n",
" \n",
" \n",
" V8 \n",
" 40000.0 \n",
" -0.563123 \n",
" 3.298916 \n",
" -15.657561 \n",
" -2.660415 \n",
" -0.384188 \n",
" 1.714383 \n",
" 11.679495 \n",
" \n",
" \n",
" V9 \n",
" 40000.0 \n",
" -0.007739 \n",
" 2.161833 \n",
" -8.596313 \n",
" -1.493676 \n",
" -0.052085 \n",
" 1.425713 \n",
" 8.507138 \n",
" \n",
" \n",
" V10 \n",
" 40000.0 \n",
" -0.001848 \n",
" 2.183034 \n",
" -11.000790 \n",
" -1.390549 \n",
" 0.105779 \n",
" 1.486105 \n",
" 8.108472 \n",
" \n",
" \n",
" V11 \n",
" 40000.0 \n",
" -1.917794 \n",
" 3.116426 \n",
" -14.832058 \n",
" -3.940969 \n",
" -1.941726 \n",
" 0.089444 \n",
" 13.851834 \n",
" \n",
" \n",
" V12 \n",
" 40000.0 \n",
" 1.578095 \n",
" 2.914613 \n",
" -13.619304 \n",
" -0.431373 \n",
" 1.485367 \n",
" 3.540787 \n",
" 15.753586 \n",
" \n",
" \n",
" V13 \n",
" 40000.0 \n",
" 1.591309 \n",
" 2.865222 \n",
" -13.830128 \n",
" -0.208522 \n",
" 1.653836 \n",
" 3.476336 \n",
" 15.419616 \n",
" \n",
" \n",
" V14 \n",
" 40000.0 \n",
" -0.946620 \n",
" 1.787759 \n",
" -8.309443 \n",
" -2.164513 \n",
" -0.957444 \n",
" 0.265874 \n",
" 6.213289 \n",
" \n",
" \n",
" V15 \n",
" 40000.0 \n",
" -2.435720 \n",
" 3.341244 \n",
" -17.201998 \n",
" -4.451365 \n",
" -2.398608 \n",
" -0.381757 \n",
" 12.874679 \n",
" \n",
" \n",
" V16 \n",
" 40000.0 \n",
" -2.943168 \n",
" 4.211646 \n",
" -21.918711 \n",
" -5.631812 \n",
" -2.718600 \n",
" -0.112947 \n",
" 13.583212 \n",
" \n",
" \n",
" V17 \n",
" 40000.0 \n",
" -0.142794 \n",
" 3.344332 \n",
" -17.633947 \n",
" -2.227048 \n",
" -0.027895 \n",
" 2.071801 \n",
" 17.404510 \n",
" \n",
" \n",
" V18 \n",
" 40000.0 \n",
" 1.188949 \n",
" 2.586164 \n",
" -11.643994 \n",
" -0.402848 \n",
" 0.867433 \n",
" 2.564239 \n",
" 13.179863 \n",
" \n",
" \n",
" V19 \n",
" 40000.0 \n",
" 1.181333 \n",
" 3.394979 \n",
" -13.491784 \n",
" -1.050903 \n",
" 1.278402 \n",
" 3.497277 \n",
" 16.059004 \n",
" \n",
" \n",
" V20 \n",
" 40000.0 \n",
" 0.027201 \n",
" 3.674985 \n",
" -13.922659 \n",
" -2.433811 \n",
" 0.030136 \n",
" 2.513245 \n",
" 16.052339 \n",
" \n",
" \n",
" V21 \n",
" 40000.0 \n",
" -3.621359 \n",
" 3.556979 \n",
" -19.436404 \n",
" -5.920847 \n",
" -3.559327 \n",
" -1.284178 \n",
" 13.840473 \n",
" \n",
" \n",
" V22 \n",
" 40000.0 \n",
" 0.943242 \n",
" 1.645538 \n",
" -10.122095 \n",
" -0.112147 \n",
" 0.962802 \n",
" 2.018031 \n",
" 7.409856 \n",
" \n",
" \n",
" V23 \n",
" 40000.0 \n",
" -0.387617 \n",
" 4.052147 \n",
" -16.187510 \n",
" -3.118868 \n",
" -0.275339 \n",
" 2.438047 \n",
" 15.080172 \n",
" \n",
" \n",
" V24 \n",
" 40000.0 \n",
" 1.142220 \n",
" 3.912820 \n",
" -18.487811 \n",
" -1.483210 \n",
" 0.963586 \n",
" 3.563055 \n",
" 19.769376 \n",
" \n",
" \n",
" V25 \n",
" 40000.0 \n",
" -0.003019 \n",
" 2.024691 \n",
" -8.228266 \n",
" -1.373400 \n",
" 0.021100 \n",
" 1.399816 \n",
" 8.223389 \n",
" \n",
" \n",
" V26 \n",
" 40000.0 \n",
" 1.895717 \n",
" 3.421454 \n",
" -12.587902 \n",
" -0.319231 \n",
" 1.963826 \n",
" 4.163146 \n",
" 16.836410 \n",
" \n",
" \n",
" V27 \n",
" 40000.0 \n",
" -0.616838 \n",
" 4.392161 \n",
" -14.904939 \n",
" -3.692075 \n",
" -0.909640 \n",
" 2.200608 \n",
" 21.594552 \n",
" \n",
" \n",
" V28 \n",
" 40000.0 \n",
" -0.888121 \n",
" 1.924947 \n",
" -9.685082 \n",
" -2.192763 \n",
" -0.904757 \n",
" 0.376856 \n",
" 6.906865 \n",
" \n",
" \n",
" V29 \n",
" 40000.0 \n",
" -1.005327 \n",
" 2.676299 \n",
" -12.579469 \n",
" -2.799008 \n",
" -1.206027 \n",
" 0.604473 \n",
" 11.852476 \n",
" \n",
" \n",
" V30 \n",
" 40000.0 \n",
" -0.032664 \n",
" 3.031009 \n",
" -14.796047 \n",
" -1.908202 \n",
" 0.184613 \n",
" 2.040131 \n",
" 13.190889 \n",
" \n",
" \n",
" V31 \n",
" 40000.0 \n",
" 0.505885 \n",
" 3.482735 \n",
" -19.376732 \n",
" -1.798975 \n",
" 0.491352 \n",
" 2.777519 \n",
" 17.255090 \n",
" \n",
" \n",
" V32 \n",
" 40000.0 \n",
" 0.326831 \n",
" 5.499369 \n",
" -23.200866 \n",
" -3.392115 \n",
" 0.056243 \n",
" 3.789241 \n",
" 24.847833 \n",
" \n",
" \n",
" V33 \n",
" 40000.0 \n",
" 0.056542 \n",
" 3.574219 \n",
" -17.454014 \n",
" -2.237550 \n",
" -0.049729 \n",
" 2.255985 \n",
" 16.692486 \n",
" \n",
" \n",
" V34 \n",
" 40000.0 \n",
" -0.464127 \n",
" 3.185712 \n",
" -17.985094 \n",
" -2.127757 \n",
" -0.250842 \n",
" 1.432885 \n",
" 14.358213 \n",
" \n",
" \n",
" V35 \n",
" 40000.0 \n",
" 2.234861 \n",
" 2.924185 \n",
" -15.349803 \n",
" 0.332081 \n",
" 2.110125 \n",
" 4.044659 \n",
" 16.804859 \n",
" \n",
" \n",
" V36 \n",
" 40000.0 \n",
" 1.530020 \n",
" 3.819754 \n",
" -17.478949 \n",
" -0.937119 \n",
" 1.571511 \n",
" 3.996721 \n",
" 19.329576 \n",
" \n",
" \n",
" V37 \n",
" 40000.0 \n",
" -0.000498 \n",
" 1.778273 \n",
" -7.639952 \n",
" -1.265717 \n",
" -0.132620 \n",
" 1.160828 \n",
" 7.803278 \n",
" \n",
" \n",
" V38 \n",
" 40000.0 \n",
" -0.351199 \n",
" 3.964186 \n",
" -17.375002 \n",
" -3.016805 \n",
" -0.318724 \n",
" 2.291342 \n",
" 15.964053 \n",
" \n",
" \n",
" V39 \n",
" 40000.0 \n",
" 0.900035 \n",
" 1.751022 \n",
" -7.135788 \n",
" -0.261578 \n",
" 0.921321 \n",
" 2.069016 \n",
" 7.997832 \n",
" \n",
" \n",
" V40 \n",
" 40000.0 \n",
" -0.897166 \n",
" 2.997750 \n",
" -11.930259 \n",
" -2.949590 \n",
" -0.949269 \n",
" 1.092178 \n",
" 10.654265 \n",
" \n",
" \n",
" Target \n",
" 40000.0 \n",
" 0.054675 \n",
" 0.227348 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"source": [
"Looking at the table of statistics, it is challenging to glean much insight. From the statistics, it appears that most of the distributions are fairly symmetric."
],
"metadata": {
"id": "9u-ZGJc3_FpB"
}
},
{
"cell_type": "code",
"source": [
"plt.figure(figsize=(8,5))\n",
"plt.title('Target Distribution',fontsize=20)\n",
"sns.countplot(data=data,x='Target');"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 360
},
"id": "XzU2ZlRK2HJz",
"outputId": "0dd4abac-c31f-4876-c277-c2f01b267553"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"We find that cases requiring the replacement of equipment are fairly rare.\n",
"\n",
"In the next section, we will process the data for modeling. This includes scaling the features so they have approximately mean 0 and standard deviation 1."
],
"metadata": {
"id": "CByL6xIW56bk"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "knk0w9XH4jao"
},
"source": [
"## Data Pre-processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2JbJc1bX4jao"
},
"outputs": [],
"source": [
"data['Target']=pd.Categorical(data['Target'])"
]
},
{
"cell_type": "markdown",
"source": [
"To start, we change the target variable type to categorical."
],
"metadata": {
"id": "ObRFwZex9hLi"
}
},
{
"cell_type": "code",
"source": [
"X=data.drop('Target',axis=1)\n",
"y=data['Target']"
],
"metadata": {
"id": "P8jMwMQfHrmt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X_train,X_val,y_train,y_val=train_test_split(X,y,test_size=0.25,stratify=y,random_state=57)"
],
"metadata": {
"id": "VFIlKv9qHw3H"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We separate our data set. First, we break off the target variable. Then we split X and y into training and validation sets. Recall that we have a separate CSV file with test data for final model production, so we need not apportion a test set here."
],
"metadata": {
"id": "GNUkjsXU-wPJ"
}
},
{
"cell_type": "code",
"source": [
"pre=make_pipeline(\n",
" StandardScaler(),\n",
" KNNImputer()\n",
").set_output(transform='pandas')"
],
"metadata": {
"id": "szJ3YASp6yxm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We next make a simple pre-processing pipeline that outputs pandas DataFrames. We will use this to scale our data and impute missing values:\n",
"* We will scale columns so they have mean 0 and standard deviation 1.\n",
"* We impute missing data using the K Nearest Neighbors method. By default, we use 5 neighbors."
],
"metadata": {
"id": "Bpd-Cq4Y_P86"
}
},
{
"cell_type": "code",
"source": [
"X_train=pre.fit_transform(X_train)"
],
"metadata": {
"id": "oF_bLl6x7U4Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We fit the pipeline on the training data and transform it."
],
"metadata": {
"id": "tb3eAMkW_wDu"
}
},
{
"cell_type": "code",
"source": [
"X_val=pre.transform(X_val)"
],
"metadata": {
"id": "HiA9Se0s9Cdk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"By fitting the pipeline on the _training_ data and then transforming the validation data using this pipeline, we avoid data leakage: The data in the validation set does not influence the means, standard deviations, or nearest neighbors calculation."
],
"metadata": {
"id": "NR-6ZpnO_zJI"
}
},
{
"cell_type": "markdown",
"source": [
"## Model Building"
],
"metadata": {
"id": "e1816sbz7tB4"
}
},
{
"cell_type": "markdown",
"source": [
"### Functions\n",
"\n",
"The following functions will assist with model building and assessment. Some are adapted from my previous projects."
],
"metadata": {
"id": "DB3gT0OD1Nyx"
}
},
{
"cell_type": "code",
"source": [
"def confusion_heatmap(model,show_scores=True):\n",
" '''Heatmap of confusion matrix for\n",
" model performance on validation data.'''\n",
"\n",
" actual=y_val\n",
" predicted=model.predict(X_val)\n",
"\n",
" # generate confusion matrix\n",
" cm=metrics.confusion_matrix(actual,predicted)\n",
" cm=np.flip(cm).T\n",
"\n",
" # heatmap labels\n",
" labels=['TP','FP','FN','TN']\n",
" cm_labels=np.array(cm).flatten()\n",
" cm_percents=np.round((cm_labels/np.sum(cm))*100,3)\n",
" annot_labels=[]\n",
" for i in range(4):\n",
" annot_labels.append(str(labels[i])+'\\nCount:'+str(cm_labels[i])+'\\n'+str(cm_percents[i])+'%')\n",
" annot_labels=np.array(annot_labels).reshape(2,2)\n",
"\n",
" # print figure\n",
" plt.figure(figsize=(8,5))\n",
" plt.title('Confusion Matrix',fontsize=20)\n",
" sns.heatmap(data=cm,\n",
" annot=annot_labels,\n",
" annot_kws={'fontsize':'x-large'},\n",
" xticklabels=[1,0],\n",
" yticklabels=[1,0],\n",
" cmap='Greens',\n",
" fmt='s')\n",
" plt.xlabel('Actual',fontsize=14)\n",
" plt.ylabel('Predicted',fontsize=14)\n",
" plt.tight_layout();\n",
"\n",
" # scores\n",
" if show_scores==True:\n",
" scores=['Accuracy','Precision','Recall','F1']\n",
" score_list=[metrics.accuracy_score(actual,predicted),\n",
" metrics.precision_score(actual,predicted),\n",
" metrics.recall_score(actual,predicted),\n",
" metrics.f1_score(actual,predicted)]\n",
" df=pd.DataFrame(index=scores)\n",
" df['Scores']=score_list\n",
" return df\n",
" return\n",
"\n",
"# alias function name to something shorter\n",
"ch=confusion_heatmap"
],
"metadata": {
"id": "vkdBqzHC1NIg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The function above prints a confusion matrix of model performance on validation data. It also prints a table with validation accuracy, precision, recall, and F1 scores."
],
"metadata": {
"id": "YWhS_cUZ34Mm"
}
},
{
"cell_type": "code",
"source": [
"def cv_recall(estimator,sample_strategy=None):\n",
" '''Compute a recall score using\n",
" stratified k-fold cross-validation.'''\n",
"\n",
" # define data based on sampling strategy\n",
" if sample_strategy=='over':\n",
" X_data=X_train_over\n",
" y_data=y_train_over\n",
" elif sample_strategy=='under':\n",
" X_data=X_train_under\n",
" y_data=y_train_under\n",
" else:\n",
" X_data=X_train\n",
" y_data=y_train\n",
" \n",
" # cv strategy\n",
" e=estimator\n",
" kfold=StratifiedKFold(n_splits=5,\n",
" shuffle=True,\n",
" random_state=2)\n",
" \n",
" # run cv\n",
" cvs=cross_val_score(estimator=e,\n",
" X=X_data,\n",
" y=y_data,\n",
" scoring='recall',\n",
" cv=kfold,\n",
" n_jobs=-1)\n",
" return cvs.mean()"
],
"metadata": {
"id": "HseIpOYK2Haq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The function above returns the mean cross-validated recall for a given model."
],
"metadata": {
"id": "CrFwj2Yj4G-8"
}
},
{
"cell_type": "code",
"source": [
"model_table=pd.DataFrame(columns=['Train Acc',\n",
" 'Val Acc',\n",
" 'Train Recall',\n",
" 'CV Recall',\n",
" 'Val Recall'])\n",
"\n",
"def tabulate(model,name,sample=None,cvs=None):\n",
" '''Compute train/val accuracy and\n",
" recall for a given model. Add to table.'''\n",
"\n",
" # run predictions with model\n",
" X_val_pred=model.predict(X_val)\n",
" if sample==None:\n",
" y_tr=y_train\n",
" y_pred=model.predict(X_train)\n",
" elif sample=='over':\n",
" y_tr=y_train_over\n",
" y_pred=model.predict(X_train_over)\n",
" elif sample=='under':\n",
" y_tr=y_train_under\n",
" y_pred=model.predict(X_train_under)\n",
" else:\n",
" raise ValueError(\"Sample parameter takes values in {None,'over','under'}.\")\n",
"\n",
" # cross validation recall\n",
" if cvs==None:\n",
" m=cv_recall(model,sample_strategy=sample)\n",
" else:\n",
" m=cvs\n",
"\n",
" # collect data for new table row\n",
" model_table.loc[name]=[metrics.accuracy_score(y_tr,y_pred),\n",
" metrics.accuracy_score(y_val,X_val_pred),\n",
" metrics.recall_score(y_tr,y_pred),\n",
" m,\n",
" metrics.recall_score(y_val,X_val_pred)]\n",
"\n",
" return model_table"
],
"metadata": {
"id": "YndYupxhtb0n"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The funciton above collects various metrics for evaluating model performance into a comparison table."
],
"metadata": {
"id": "yWvWLY3x4OzT"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "eqCDCbcw4jas"
},
"source": [
"### Model Building with original data"
]
},
{
"cell_type": "markdown",
"source": [
"#### Decision Tree"
],
"metadata": {
"id": "FtZ_9iAkvLyv"
}
},
{
"cell_type": "code",
"source": [
"dtree=DecisionTreeClassifier(random_state=1)"
],
"metadata": {
"id": "YgAMFzSg47tl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"m=cv_recall(dtree)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0_pageE63kpG",
"outputId": "21bca094-717e-4d2d-eec0-3b918fe84dd2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.7207317073170731.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A plain decision tree classifier yields a cross-validated recall of 0.72, decent performance for the first attempt."
],
"metadata": {
"id": "hrDuy3iuydNp"
}
},
{
"cell_type": "code",
"source": [
"dtree.fit(X_train,y_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "LlX8rM-xHQ3h",
"outputId": "635f3582-48c9-48c9-db1e-f7489b8ccbbb"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeClassifier(random_state=1)"
],
"text/html": [
"DecisionTreeClassifier(random_state=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"ch(dtree)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "52ZyEH706qGo",
"outputId": "548f8b85-90c0-41a3-f38b-6f350cbdf6fe"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.972400\n",
"Precision 0.734024\n",
"Recall 0.776965\n",
"F1 0.754885"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.972400 \n",
" \n",
" \n",
" Precision \n",
" 0.734024 \n",
" \n",
" \n",
" Recall \n",
" 0.776965 \n",
" \n",
" \n",
" F1 \n",
" 0.754885 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 27
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"After fitting the model to the whole training set, we find high accuracy (around 97%), while precision and recall are closer to 75%.\n",
"\n",
"**Note:** All confusion matrices in this project are compiled using the validation data."
],
"metadata": {
"id": "-GGoguhHyr-Z"
}
},
{
"cell_type": "code",
"source": [
"tabulate(dtree,'dtree',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "6koRY1NVHQ1F",
"outputId": "00281147-d5d2-4605-db6c-46e18f4de425"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.0 0.9724 1.0 0.720732 0.776965"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.0 \n",
" 0.9724 \n",
" 1.0 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 28
}
]
},
{
"cell_type": "markdown",
"source": [
"Moreover, we see clear evidence of overfitting when we compare training and validation set results."
],
"metadata": {
"id": "AJPD0slZy7eB"
}
},
{
"cell_type": "markdown",
"source": [
"#### Logistic Regression"
],
"metadata": {
"id": "vKvj0hJc7kJg"
}
},
{
"cell_type": "code",
"source": [
"lr=LogisticRegression()"
],
"metadata": {
"id": "yW7rmlP9HQyq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"m=cv_recall(lr)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OPgufQDjHQv_",
"outputId": "f0e26e1d-e704-464c-af7f-79431360428a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.47804878048780486.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"The plain logistic regression model performs far worse than the decision tree, with a cross-validated mean recall of under 50%. In other words, this model performs worse than randomly guessing for the positive class (1)."
],
"metadata": {
"id": "GGx97RH-4buT"
}
},
{
"cell_type": "code",
"source": [
"lr.fit(X_train,y_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "q3KJobQ1HQs-",
"outputId": "dd96485b-9703-4d06-b9fe-49e1bac404ad"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LogisticRegression()"
],
"text/html": [
"LogisticRegression() In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 31
}
]
},
{
"cell_type": "code",
"source": [
"ch(lr)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "K31Iae7f8JYZ",
"outputId": "7e873e5f-7248-44e2-edda-33d321e8df5e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.967700\n",
"Precision 0.839394\n",
"Recall 0.506399\n",
"F1 0.631699"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.967700 \n",
" \n",
" \n",
" Precision \n",
" 0.839394 \n",
" \n",
" \n",
" Recall \n",
" 0.506399 \n",
" \n",
" \n",
" F1 \n",
" 0.631699 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 32
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"As the negative class (0) is by far the majority class, this logistic regression scores better on precision because it predicts few false positives."
],
"metadata": {
"id": "KNCF-NcZ41jZ"
}
},
{
"cell_type": "code",
"source": [
"tabulate(lr,'Logistic Regr',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 112
},
"id": "aIb5k-zJ8JV6",
"outputId": "cf8b90e2-3a46-4728-a088-e7699fac1fc4"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.00000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.47622 0.478049 0.506399"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.00000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.47622 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 33
}
]
},
{
"cell_type": "markdown",
"source": [
"While this model's performance is poor, at least it isn't overfit."
],
"metadata": {
"id": "CxTaTCau5XrS"
}
},
{
"cell_type": "markdown",
"source": [
"#### Bagging Classifier"
],
"metadata": {
"id": "OU9dKO8C8gCP"
}
},
{
"cell_type": "code",
"source": [
"bag=BaggingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(bag)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RqTjbE7w8kkv",
"outputId": "b9eed8ae-2bbb-4de4-b700-7d83db76815b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.7128048780487805.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A CV recall of 71% is much better than the logistlic regression, similar to the decision tree score."
],
"metadata": {
"id": "Tvbu6cd_5e0n"
}
},
{
"cell_type": "code",
"source": [
"bag.fit(X_train,y_train)\n",
"\n",
"ch(bag)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "MU1gl4l08JS7",
"outputId": "78bba327-979e-4894-bead-f2d55ad93e45"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.985500\n",
"Precision 0.944690\n",
"Recall 0.780622\n",
"F1 0.854855"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.985500 \n",
" \n",
" \n",
" Precision \n",
" 0.944690 \n",
" \n",
" \n",
" Recall \n",
" 0.780622 \n",
" \n",
" \n",
" F1 \n",
" 0.854855 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 35
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"In addition to good recall, this model produces high accuracy and precision!"
],
"metadata": {
"id": "Iw_ORunz5qRE"
}
},
{
"cell_type": "code",
"source": [
"tabulate(bag,'Bagging Clfr',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"id": "9HDT_8j_8JQN",
"outputId": "d156c4e4-610d-4092-fb9f-ebd2e8f71150"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.00000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.47622 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.95061 0.712805 0.780622"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.00000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.47622 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.95061 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 36
}
]
},
{
"cell_type": "markdown",
"source": [
"Unfortunately, when comparing training and validation scores, we find that this model is also overfitting (note recall scores)."
],
"metadata": {
"id": "7XF628ij5wGU"
}
},
{
"cell_type": "markdown",
"source": [
"#### Random Forest"
],
"metadata": {
"id": "KBwJwZ8B-HX0"
}
},
{
"cell_type": "code",
"source": [
"rf=RandomForestClassifier(random_state=1)\n",
"\n",
"m=cv_recall(rf)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FxpQCkxk8JM6",
"outputId": "07c6ed6f-5340-4866-857f-166706e1ce8a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.7567073170731707.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Our second bagging model, random forest, scores a bit better, with a CV recall of around 76%."
],
"metadata": {
"id": "WHxZsPWD55HO"
}
},
{
"cell_type": "code",
"source": [
"rf.fit(X_train,y_train)\n",
"\n",
"ch(rf)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "Q3e_HTdN-Lyp",
"outputId": "b550d285-5ebe-407d-e7a7-09654e9328c3"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.988300\n",
"Precision 0.986425\n",
"Recall 0.797075\n",
"F1 0.881699"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.988300 \n",
" \n",
" \n",
" Precision \n",
" 0.986425 \n",
" \n",
" \n",
" Recall \n",
" 0.797075 \n",
" \n",
" \n",
" F1 \n",
" 0.881699 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 38
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Again, precision and accuracy are stellar on the validation data. Note in the confusion matrix above that there are **SIX** false positives in a data set of 10,000! This translates to a specificity (true negative rate) of over 99.9%."
],
"metadata": {
"id": "YhY8JwId6A7z"
}
},
{
"cell_type": "code",
"source": [
"tabulate(rf,'Random Forest',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
},
"id": "2hGucm2w-Lvi",
"outputId": "0e8bad5a-c499-46cb-d3e3-2aed0a3b83ec"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.00000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.47622 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.95061 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.00000 0.756707 0.797075"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.00000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.47622 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.95061 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.00000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 39
}
]
},
{
"cell_type": "markdown",
"source": [
"Our model unfortunately suffers from overfitting."
],
"metadata": {
"id": "UaR9Pwyj6y_a"
}
},
{
"cell_type": "markdown",
"source": [
"#### AdaBoost"
],
"metadata": {
"id": "_8f8_MuaGHfD"
}
},
{
"cell_type": "code",
"source": [
"abc=AdaBoostClassifier(random_state=1)\n",
"\n",
"m=cv_recall(abc)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TyQNqccl-Llf",
"outputId": "a34fdf4e-053a-4e42-fa75-c49f9034f71d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.6079268292682926.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"While AdaBoost yields a lower CV recall score, at only around 61%, we are hopeful that boosting methods will be less susceptible to overfitting."
],
"metadata": {
"id": "tGre733m67Vh"
}
},
{
"cell_type": "code",
"source": [
"abc.fit(X_train,y_train)\n",
"\n",
"ch(abc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "U967patu-Lim",
"outputId": "83821eb1-7f50-4ebf-db2b-ce64834e53b8"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.975000\n",
"Precision 0.856115\n",
"Recall 0.652651\n",
"F1 0.740664"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.975000 \n",
" \n",
" \n",
" Precision \n",
" 0.856115 \n",
" \n",
" \n",
" Recall \n",
" 0.652651 \n",
" \n",
" \n",
" F1 \n",
" 0.740664 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 41
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"An accuracy of 97% and good precision is promising, but the surprising result in the table above is the recall, which is higher than the CV recall calculated above. We find 65% recall on validation data with this model."
],
"metadata": {
"id": "rkf-Axfm7IqC"
}
},
{
"cell_type": "code",
"source": [
"tabulate(abc,'AdaBoost',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "FhriF7wQGzc7",
"outputId": "797b8876-0ae6-410e-9320-63af982c5d4f"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 42
}
]
},
{
"cell_type": "markdown",
"source": [
"While there is not much evidence of overfitting here, the model performance is appreciably lower than other models."
],
"metadata": {
"id": "pZObw3sJ7cSO"
}
},
{
"cell_type": "markdown",
"source": [
"#### Gradient Boosting"
],
"metadata": {
"id": "hb2wrOYqG5HD"
}
},
{
"cell_type": "code",
"source": [
"gbc=GradientBoostingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(gbc)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "E6bZifl_Gzau",
"outputId": "d4c1325a-4220-400f-e3ad-4e2fcfd22930"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.7201219512195122.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Gradient boosting has done much better as compared to AdaBoost, with a CV recall of 72%."
],
"metadata": {
"id": "McWnrdVL8UWx"
}
},
{
"cell_type": "code",
"source": [
"gbc.fit(X_train,y_train)\n",
"\n",
"ch(gbc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "E1aV6pkvGzYe",
"outputId": "3ea05536-4a0b-449a-ab6d-d170cdd45668"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.984500\n",
"Precision 0.957944\n",
"Recall 0.749543\n",
"F1 0.841026"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.984500 \n",
" \n",
" \n",
" Precision \n",
" 0.957944 \n",
" \n",
" \n",
" Recall \n",
" 0.749543 \n",
" \n",
" \n",
" F1 \n",
" 0.841026 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 44
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"As with many other models, recall is actually the lowest score of the four metrics above, with accuracy and precision being much higher."
],
"metadata": {
"id": "gkDwUv4G8c8t"
}
},
{
"cell_type": "code",
"source": [
"tabulate(gbc,'Grad Boost',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "c8kPqLO_GzUS",
"outputId": "39045f81-081e-4f53-96d7-1ab003d083e9"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 45
}
]
},
{
"cell_type": "markdown",
"source": [
"So far, the gradient boosting model is one of the better models. The only issue is evidence of overfitting (note the recall)."
],
"metadata": {
"id": "mPwZ5VhO9Vz_"
}
},
{
"cell_type": "markdown",
"source": [
"#### XGBoost"
],
"metadata": {
"id": "QxdSgzK3HmEw"
}
},
{
"cell_type": "code",
"source": [
"xgb=XGBClassifier(random_state=1)\n",
"\n",
"m=cv_recall(xgb)\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cTIuu9t6HlzZ",
"outputId": "1a5cd0b6-0dae-408d-8840-8fc94af5b963"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.7365853658536585.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A score of 74% is comparable to the better performing models above."
],
"metadata": {
"id": "-or99U2G9sFk"
}
},
{
"cell_type": "code",
"source": [
"xgb.fit(X_train,y_train)\n",
"\n",
"ch(xgb)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "rkl6XglWHlwq",
"outputId": "e95c7c8a-de8b-4f28-f108-3f13978959db"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.985900\n",
"Precision 0.959276\n",
"Recall 0.775137\n",
"F1 0.857432"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.985900 \n",
" \n",
" \n",
" Precision \n",
" 0.959276 \n",
" \n",
" \n",
" Recall \n",
" 0.775137 \n",
" \n",
" \n",
" F1 \n",
" 0.857432 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 47
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Recall at 77% is one of the best we've seen so far, and both precision and accuracy scores are quite high."
],
"metadata": {
"id": "AJVMnfKF9xvg"
}
},
{
"cell_type": "code",
"source": [
"tabulate(xgb,'XGBoost',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 269
},
"id": "fVZHfC23HltT",
"outputId": "667154ce-61ac-437f-d04e-403d4de4338d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 48
}
]
},
{
"cell_type": "markdown",
"source": [
"Note how much lower the CV recall score is from the training and validation recall scores. The latter two are comparable, so I do not fear overfitting here."
],
"metadata": {
"id": "LG-X6w5Y985s"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "oBKJaFU24jas"
},
"source": [
"### Model Building with Oversampled data\n"
]
},
{
"cell_type": "code",
"source": [
"sm=SMOTE(\n",
" k_neighbors=5,\n",
" sampling_strategy=1.0,\n",
" random_state=1\n",
")\n",
"\n",
"X_train_over,y_train_over=sm.fit_resample(X_train,y_train)"
],
"metadata": {
"id": "3HYaCLWLYiy4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now we will oversample the minority class of the target variable."
],
"metadata": {
"id": "m86cYMl8-Mjp"
}
},
{
"cell_type": "markdown",
"source": [
"#### Decision Tree [Oversampled]"
],
"metadata": {
"id": "8g8_kos4LTuS"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uYDlbnUO4jat",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "519df4a9-845e-4aec-d998-c82538160e1f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.9704866008462624.\n"
]
}
],
"source": [
"dtree_over=DecisionTreeClassifier(random_state=1)\n",
"\n",
"m=cv_recall(dtree_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
]
},
{
"cell_type": "markdown",
"source": [
"With such a high CV recall score (97%!), I am skeptical already of this model."
],
"metadata": {
"id": "EcU_G5r5_NsF"
}
},
{
"cell_type": "code",
"source": [
"dtree_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(dtree_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "Plzs-TUpLkT5",
"outputId": "cea00b1d-8bbb-45f1-fe5d-648d600c4e7b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.951000\n",
"Precision 0.532423\n",
"Recall 0.855576\n",
"F1 0.656381"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.951000 \n",
" \n",
" \n",
" Precision \n",
" 0.532423 \n",
" \n",
" \n",
" Recall \n",
" 0.855576 \n",
" \n",
" \n",
" F1 \n",
" 0.656381 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 51
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Unlike most of the previous models, recall is higher here than precision. This is certainly due to the oversampling."
],
"metadata": {
"id": "NQ9GXT0H_VU7"
}
},
{
"cell_type": "code",
"source": [
"tabulate(dtree_over,'dtree (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "8mpBzmayLkQe",
"outputId": "b0831d90-ca04-4c98-b9d0-387d2cf4b3aa"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 52
}
]
},
{
"cell_type": "markdown",
"source": [
"Unsurprisingly though, this model is exceptionally overfit."
],
"metadata": {
"id": "nWyZZKyd_fkX"
}
},
{
"cell_type": "markdown",
"source": [
"#### Logistic Regression [Oversampled]"
],
"metadata": {
"id": "RYTJbehCL--f"
}
},
{
"cell_type": "code",
"source": [
"lr_over=LogisticRegression()\n",
"\n",
"m=cv_recall(lr_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FoIZVa0_LkHK",
"outputId": "6418791d-2d5a-4183-f3f9-52ce1a350643"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.866784203102962.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A CV recall score of 87% is one of the highest we've seen so far, but isn't so high as to immediately suggest overfitting."
],
"metadata": {
"id": "BdBBBXpN_o98"
}
},
{
"cell_type": "code",
"source": [
"lr_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(lr_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "4siCEcqELkDc",
"outputId": "8f56c546-fbb6-4ac3-bfc2-e9830149a7bf"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.871800\n",
"Precision 0.280072\n",
"Recall 0.855576\n",
"F1 0.422002"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.871800 \n",
" \n",
" \n",
" Precision \n",
" 0.280072 \n",
" \n",
" \n",
" Recall \n",
" 0.855576 \n",
" \n",
" \n",
" F1 \n",
" 0.422002 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 54
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Both accuracy and recall are quite good with this logistic regression model. The cost, however, is disasterous precision: 28%! The regression's prediction of the positive class is far worse than randomly guessing."
],
"metadata": {
"id": "FiFnOsrYABRX"
}
},
{
"cell_type": "code",
"source": [
"tabulate(lr_over,'Logistic Regr (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 332
},
"id": "A7b0cgXJLkAZ",
"outputId": "09abb894-de36-445f-e6cb-1b44a4beee6c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 55
}
]
},
{
"cell_type": "markdown",
"source": [
"There is no real evidence of overfitting, so this model looks solid. Great accuracy and recall in both training and validation. It's only downside is terrible precision."
],
"metadata": {
"id": "HAxqzul3ALvm"
}
},
{
"cell_type": "markdown",
"source": [
"#### Bagging Classifier [Oversampled]"
],
"metadata": {
"id": "hvBnxtfghFc1"
}
},
{
"cell_type": "code",
"source": [
"bag_over=BaggingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(bag_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bn_LZ5IELjzr",
"outputId": "55b040de-cfe1-4a69-edc6-9db80e9b4b74"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.9737658674188999.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"As with the last decision tree, this CV recall is high enough (97%) to raise suspicion of overfitting."
],
"metadata": {
"id": "5SbQmkuwAwYF"
}
},
{
"cell_type": "code",
"source": [
"bag_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(bag_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "gHA9A7jJhMCX",
"outputId": "7d274f2c-ff1c-40b2-cd61-5fb05041efb3"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.985600\n",
"Precision 0.854130\n",
"Recall 0.888483\n",
"F1 0.870968"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.985600 \n",
" \n",
" \n",
" Precision \n",
" 0.854130 \n",
" \n",
" \n",
" Recall \n",
" 0.888483 \n",
" \n",
" \n",
" F1 \n",
" 0.870968 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 57
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Great scores across the board, while promising, could be symptomatic of overfitting."
],
"metadata": {
"id": "2KlZ_ucsBAx3"
}
},
{
"cell_type": "code",
"source": [
"tabulate(bag_over,'Bagging (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "DkwNoGLBhMAO",
"outputId": "5f7c8da4-0905-414b-8b7d-7108451d9176"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 58
}
]
},
{
"cell_type": "markdown",
"source": [
"Sure enough, this model is overfit (see recall)."
],
"metadata": {
"id": "xIqA4AjlBH-K"
}
},
{
"cell_type": "markdown",
"source": [
"#### Random Forest [Oversampled]"
],
"metadata": {
"id": "VsHvSIUXhlew"
}
},
{
"cell_type": "code",
"source": [
"rf_over=RandomForestClassifier(random_state=1)\n",
"\n",
"m=cv_recall(rf_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qFAYuMJthL9Q",
"outputId": "3bcc69ff-fc98-47aa-a947-897b61b0fe55"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.9824047954866009.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Another questionably high score here: 98%."
],
"metadata": {
"id": "fNWaB7U5FZgR"
}
},
{
"cell_type": "code",
"source": [
"rf_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(rf_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "9O5cNvT6hL6s",
"outputId": "164087e2-487f-420c-96ea-6e8a40cccc67"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.991300\n",
"Precision 0.945736\n",
"Recall 0.892139\n",
"F1 0.918156"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.991300 \n",
" \n",
" \n",
" Precision \n",
" 0.945736 \n",
" \n",
" \n",
" Recall \n",
" 0.892139 \n",
" \n",
" \n",
" F1 \n",
" 0.918156 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 60
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Again, it would be lovely if these scores were beliveable, but we should first look at the table below to assess whether this model is overfit."
],
"metadata": {
"id": "84O1hUjzFgKV"
}
},
{
"cell_type": "code",
"source": [
"tabulate(rf_over,'Rand Forest (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 394
},
"id": "BIgg9Q-ohL4M",
"outputId": "c6e41fc2-1fc0-4f27-9904-6d64b0d58207"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 61
}
]
},
{
"cell_type": "markdown",
"source": [
"Indeed, the difference in training and validation recall is enough to confirm overfitting."
],
"metadata": {
"id": "y3-u4mssFrAe"
}
},
{
"cell_type": "markdown",
"source": [
"#### AdaBoost [Oversampled]"
],
"metadata": {
"id": "HC8mY_eyh5Po"
}
},
{
"cell_type": "code",
"source": [
"abc_over=AdaBoostClassifier(random_state=1)\n",
"\n",
"m=cv_recall(abc_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CJeRL5H3hL1e",
"outputId": "97918c54-3814-450c-8ace-31dbc361857e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8840267983074753.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"We have seen good results with boosting before, so a score of 88% is promising."
],
"metadata": {
"id": "WduXcceDG_9h"
}
},
{
"cell_type": "code",
"source": [
"abc_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(abc_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "QOuQbKRPhLy_",
"outputId": "9534cefc-0852-4de5-9de3-746159128a7a"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.907600\n",
"Precision 0.359223\n",
"Recall 0.879342\n",
"F1 0.510074"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.907600 \n",
" \n",
" \n",
" Precision \n",
" 0.359223 \n",
" \n",
" \n",
" Recall \n",
" 0.879342 \n",
" \n",
" \n",
" F1 \n",
" 0.510074 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 63
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Here we see that the cost of good recall is poor precision. Indeed, the percentage of false positives is nearly double that of true positives predicted by this AdaBoost model (see confusion matrix)."
],
"metadata": {
"id": "Gc4q0iZfHG6h"
}
},
{
"cell_type": "code",
"source": [
"tabulate(abc_over,'AdaBoost (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 426
},
"id": "rKuEGk0PhLwk",
"outputId": "b411bd21-1d9f-4e35-f5c0-e76cc6302ce8"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 64
}
]
},
{
"cell_type": "markdown",
"source": [
"This boosting model offers one of the higher recall scores we have seen without overfitting."
],
"metadata": {
"id": "7SycUJSXHYRf"
}
},
{
"cell_type": "markdown",
"source": [
"#### Gradient Boosting [Oversampled]"
],
"metadata": {
"id": "nOnFdTTQilK4"
}
},
{
"cell_type": "code",
"source": [
"gbc_over=GradientBoostingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(gbc_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XFz6QLlchLto",
"outputId": "1c2cb410-fa12-48b7-c6d1-0c40cfbe297f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.9094146685472497.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"As with the boosting models trained on the original data, the gradient boosting model here has a better CV recall score than the previous AdaBoost model, at around 91%."
],
"metadata": {
"id": "eiC-vW2iH3b5"
}
},
{
"cell_type": "code",
"source": [
"gbc_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(gbc_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "-Gcpgqz7hLqt",
"outputId": "713ef537-ad06-4837-924d-d662cb3c325c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.967100\n",
"Precision 0.641927\n",
"Recall 0.901280\n",
"F1 0.749810"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.967100 \n",
" \n",
" \n",
" Precision \n",
" 0.641927 \n",
" \n",
" \n",
" Recall \n",
" 0.901280 \n",
" \n",
" \n",
" F1 \n",
" 0.749810 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 66
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Here too we have good performance on all four metrics. In contrast to the AdaBoost model above, this model offers much better precision, in particular, a score greater than 50%, or random chance. "
],
"metadata": {
"id": "HTeQNZbeIEdq"
}
},
{
"cell_type": "code",
"source": [
"tabulate(gbc_over,'Grad Boost (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 457
},
"id": "QTWCyQ6DjBLk",
"outputId": "5aeeb7b4-0735-475c-c640-b8d7ae9035e2"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 67
}
]
},
{
"cell_type": "markdown",
"source": [
"As we have seen with boosting models, this model does not have much issue with overfitting."
],
"metadata": {
"id": "GUx5gZh2Ia3x"
}
},
{
"cell_type": "markdown",
"source": [
"#### XGBoost [Oversampled]"
],
"metadata": {
"id": "cMFgSuvSjFqJ"
}
},
{
"cell_type": "code",
"source": [
"xgb_over=XGBClassifier(random_state=1)\n",
"\n",
"m=cv_recall(xgb_over,sample_strategy='over')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CwWtHJyZjBI7",
"outputId": "14869cd2-b7e0-4262-e4e6-6798f2bf750f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.9058180535966149.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"With comparable performance to gradient boosting, a score of about 91% for XGBoost is unsurprising."
],
"metadata": {
"id": "xwzprLXKJA_x"
}
},
{
"cell_type": "code",
"source": [
"xgb_over.fit(X_train_over,y_train_over)\n",
"\n",
"ch(xgb_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "4WXAf21WjBGS",
"outputId": "63e67174-4078-40a3-83c1-ef4c8208fecc"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.967100\n",
"Precision 0.640464\n",
"Recall 0.908592\n",
"F1 0.751323"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.967100 \n",
" \n",
" \n",
" Precision \n",
" 0.640464 \n",
" \n",
" \n",
" Recall \n",
" 0.908592 \n",
" \n",
" \n",
" F1 \n",
" 0.751323 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 69
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Good or great performance on all metrics, and more true positives than both false positive and false negatives."
],
"metadata": {
"id": "hNaBZlfOJZhe"
}
},
{
"cell_type": "code",
"source": [
"tabulate(xgb_over,'XGBoost (over)',sample='over',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 488
},
"id": "9APwSDTBjBDS",
"outputId": "f9977957-0f44-4357-8cc0-bf21534d0440"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 70
}
]
},
{
"cell_type": "markdown",
"source": [
"Looking back over this table so far, the gradient boosting and XGBoost models trained on oversampled data have performed best. Random forest boasts promising precision if only we could curtail overfitting."
],
"metadata": {
"id": "9nxuqeGJJ0HC"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "1aimb6bn4jat"
},
"source": [
"### Model Building with Undersampled data"
]
},
{
"cell_type": "code",
"source": [
"rus=RandomUnderSampler(\n",
" sampling_strategy=1.0,\n",
" random_state=1\n",
")\n",
"\n",
"X_train_under,y_train=rus.fit_resample(X_train,y_train)"
],
"metadata": {
"id": "0Ws_HcJmZTdM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now we undersample the majority class of the target variable. This is another method to balanced the class weights in the target."
],
"metadata": {
"id": "RPwnp_r8KLK4"
}
},
{
"cell_type": "markdown",
"source": [
"#### Decision Tree [Undersampled]"
],
"metadata": {
"id": "cl6_o_XtjzBd"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jROP_DVF4jau",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "64a9c306-18c0-4fc1-c4f8-080f3c3276cf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8530487804878047.\n"
]
}
],
"source": [
"dtree_under=DecisionTreeClassifier(random_state=1)\n",
"\n",
"m=cv_recall(dtree_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
]
},
{
"cell_type": "markdown",
"source": [
"A score of 85% is good for this estimator, but since previous decision trees had a tendency to overfit, I will wait to assess performance until I see the other metrics."
],
"metadata": {
"id": "Bt_xyMAaIVn1"
}
},
{
"cell_type": "code",
"source": [
"dtree_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(dtree_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "XONzrLVroShO",
"outputId": "b71169cb-1312-4aaa-e992-40bbeccc191e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.843900\n",
"Precision 0.241853\n",
"Recall 0.868373\n",
"F1 0.378335"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.843900 \n",
" \n",
" \n",
" Precision \n",
" 0.241853 \n",
" \n",
" \n",
" Recall \n",
" 0.868373 \n",
" \n",
" \n",
" F1 \n",
" 0.378335 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 73
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"This decision tree is the worst on precision (see the concerning rate of false positives). Recall and accuracy, however, are good."
],
"metadata": {
"id": "kRLLQC7LIvjM"
}
},
{
"cell_type": "code",
"source": [
"tabulate(dtree_under,'dtree (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 520
},
"id": "UdO3Ql_xoSSi",
"outputId": "ee346d22-e085-436c-ce2f-b8e2d67e9cee"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 74
}
]
},
{
"cell_type": "markdown",
"source": [
"As with the other decision trees, this model overfits disasterously."
],
"metadata": {
"id": "rVc_hSeXJE1S"
}
},
{
"cell_type": "markdown",
"source": [
"#### Logistic Regression [Undersampled]"
],
"metadata": {
"id": "Zg3UAbHorCvw"
}
},
{
"cell_type": "code",
"source": [
"lr_under=LogisticRegression()\n",
"\n",
"m=cv_recall(lr_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FDfoFwvErGEo",
"outputId": "45209c77-63dd-46f6-b7ac-5635912d8aff"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8445121951219512.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A comparable score to the last model, 84% is promising!"
],
"metadata": {
"id": "foBps5PxJ-pO"
}
},
{
"cell_type": "code",
"source": [
"lr_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(lr_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "wWkfkLpjrGCB",
"outputId": "8d2ac6b7-7c9e-4297-ba26-484be1877914"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.869900\n",
"Precision 0.277450\n",
"Recall 0.859232\n",
"F1 0.419456"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.869900 \n",
" \n",
" \n",
" Precision \n",
" 0.277450 \n",
" \n",
" \n",
" Recall \n",
" 0.859232 \n",
" \n",
" \n",
" F1 \n",
" 0.419456 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 76
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Again, the false positive rate really cuts into precision (and thus F1) score. Accuracy and recall are good though."
],
"metadata": {
"id": "zJTGuPKbKH-z"
}
},
{
"cell_type": "code",
"source": [
"tabulate(lr_under,'Logistic Regr (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 551
},
"id": "0Ws2455RrF_V",
"outputId": "d5f433c5-846d-41ee-f11b-09fddfd9363e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 77
}
]
},
{
"cell_type": "markdown",
"source": [
"Thankfully, this model doesn't overfit. Decent recall would make this a top contender, were it not for the worse-than-guessing precision, which would end up costing BreezeGen greatly in inspection costs."
],
"metadata": {
"id": "ZvC0hsL8KW5A"
}
},
{
"cell_type": "markdown",
"source": [
"#### Bagging Classifier [Undersampled]"
],
"metadata": {
"id": "1s6As99rK68d"
}
},
{
"cell_type": "code",
"source": [
"bag_under=BaggingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(bag_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aCJoA3tkMjdd",
"outputId": "202ad1f7-ca36-45db-8f18-c67e741ed09f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8713414634146343.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Bagging has been overfitting, so I'm curious to see if the model trained on undersampled data avoids this issue. The cross-validated recall score is 87%."
],
"metadata": {
"id": "bAIUymLSMYvV"
}
},
{
"cell_type": "code",
"source": [
"bag_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(bag_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "n4k88218MjX7",
"outputId": "3a7ddd46-81ae-4376-9623-8f961049948f"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.947000\n",
"Precision 0.508827\n",
"Recall 0.895795\n",
"F1 0.649007"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.947000 \n",
" \n",
" \n",
" Precision \n",
" 0.508827 \n",
" \n",
" \n",
" Recall \n",
" 0.895795 \n",
" \n",
" \n",
" F1 \n",
" 0.649007 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 79
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"With about as many false positives as true positives (see confusion matrix), precision lands at 50%, or as good as random guessing. That being said, accuracy and recall and both great."
],
"metadata": {
"id": "GnluxrX6TSbN"
}
},
{
"cell_type": "code",
"source": [
"tabulate(bag_under,'Bagging (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 582
},
"id": "Uk3tm6mnMjUx",
"outputId": "24447e5d-3162-49c9-ee5b-e7ffb6dac655"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232\n",
"Bagging (under) 0.991463 0.9470 0.985366 0.871341 0.895795"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
" Bagging (under) \n",
" 0.991463 \n",
" 0.9470 \n",
" 0.985366 \n",
" 0.871341 \n",
" 0.895795 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 80
}
]
},
{
"cell_type": "markdown",
"source": [
"Unfortunately overfitting plagues this model too. Note especially the disparity in recall."
],
"metadata": {
"id": "XhK5c1z8TkAO"
}
},
{
"cell_type": "markdown",
"source": [
"#### Random Forest [Undersampled]"
],
"metadata": {
"id": "ZfIu9Kn2LAlk"
}
},
{
"cell_type": "code",
"source": [
"rf_under=RandomForestClassifier(random_state=1)\n",
"\n",
"m=cv_recall(rf_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fJG8FuZENFKA",
"outputId": "c6b5596f-a87b-4bca-bdb2-43f3e9bf154e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8914634146341462.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"This model scores 98%, just like the random forest trained on oversampled data."
],
"metadata": {
"id": "wcU0tf4tTr_n"
}
},
{
"cell_type": "code",
"source": [
"rf_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(rf_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "Akk0oU0iNKEc",
"outputId": "2fce765b-5a6b-448d-b4db-960f40a44b49"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.963300\n",
"Precision 0.609756\n",
"Recall 0.914077\n",
"F1 0.731529"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.963300 \n",
" \n",
" \n",
" Precision \n",
" 0.609756 \n",
" \n",
" \n",
" Recall \n",
" 0.914077 \n",
" \n",
" \n",
" F1 \n",
" 0.731529 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 82
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"As with the random forest model trained on oversampled data, the precision is good. Accuracy and recall are high."
],
"metadata": {
"id": "WHx9XbGhZdbj"
}
},
{
"cell_type": "code",
"source": [
"tabulate(rf_under,'Rand Forest (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 614
},
"id": "oBHd6uu-NKA_",
"outputId": "ecf4c5fa-426a-4705-a2a6-67e38c4d4cea"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232\n",
"Bagging (under) 0.991463 0.9470 0.985366 0.871341 0.895795\n",
"Rand Forest (under) 1.000000 0.9633 1.000000 0.891463 0.914077"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
" Bagging (under) \n",
" 0.991463 \n",
" 0.9470 \n",
" 0.985366 \n",
" 0.871341 \n",
" 0.895795 \n",
" \n",
" \n",
" Rand Forest (under) \n",
" 1.000000 \n",
" 0.9633 \n",
" 1.000000 \n",
" 0.891463 \n",
" 0.914077 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 83
}
]
},
{
"cell_type": "markdown",
"source": [
"While this model is surely overfit, tuning might just curtail the issue. And with a decent precision score, this model might be worth tuning to hang onto the cost savings good precision affords."
],
"metadata": {
"id": "0iJLytcwajc_"
}
},
{
"cell_type": "markdown",
"source": [
"#### AdaBoost [Undersampled]"
],
"metadata": {
"id": "4v8cS-zHLAi7"
}
},
{
"cell_type": "code",
"source": [
"abc_under=AdaBoostClassifier(random_state=1)\n",
"\n",
"m=cv_recall(abc_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BiWWvscEOnGT",
"outputId": "4d660745-1e98-4d92-e724-e67136efe243"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8591463414634146.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"The cross-validated recall score for this AdaBoost classifier is 86%."
],
"metadata": {
"id": "JjGZdCWca6Kl"
}
},
{
"cell_type": "code",
"source": [
"abc_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(abc_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "iLBTVCOkOtCL",
"outputId": "8ad0738e-d024-4c60-a050-ff98e8ab375d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.894100\n",
"Precision 0.325850\n",
"Recall 0.875686\n",
"F1 0.474963"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.894100 \n",
" \n",
" \n",
" Precision \n",
" 0.325850 \n",
" \n",
" \n",
" Recall \n",
" 0.875686 \n",
" \n",
" \n",
" F1 \n",
" 0.474963 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 85
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"While both accuracy and recall are good on the validation set, precision is really quite poor."
],
"metadata": {
"id": "WRVuUXY0bA-P"
}
},
{
"cell_type": "code",
"source": [
"tabulate(abc_under,'AdaBoost (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 645
},
"id": "BYXiffWIOs7O",
"outputId": "5803207e-1127-48fc-cecb-4f8be06e328b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232\n",
"Bagging (under) 0.991463 0.9470 0.985366 0.871341 0.895795\n",
"Rand Forest (under) 1.000000 0.9633 1.000000 0.891463 0.914077\n",
"AdaBoost (under) 0.902439 0.8941 0.881707 0.859146 0.875686"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
" Bagging (under) \n",
" 0.991463 \n",
" 0.9470 \n",
" 0.985366 \n",
" 0.871341 \n",
" 0.895795 \n",
" \n",
" \n",
" Rand Forest (under) \n",
" 1.000000 \n",
" 0.9633 \n",
" 1.000000 \n",
" 0.891463 \n",
" 0.914077 \n",
" \n",
" \n",
" AdaBoost (under) \n",
" 0.902439 \n",
" 0.8941 \n",
" 0.881707 \n",
" 0.859146 \n",
" 0.875686 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 86
}
]
},
{
"cell_type": "markdown",
"source": [
"This model performs well for recall, but insufficient precision would cost BreezeGen in the long run."
],
"metadata": {
"id": "j2ZY5C7FbJRy"
}
},
{
"cell_type": "markdown",
"source": [
"#### Gradient Boosting [Undersampled]"
],
"metadata": {
"id": "YHW8uveoLAgH"
}
},
{
"cell_type": "code",
"source": [
"gbc_under=GradientBoostingClassifier(random_state=1)\n",
"\n",
"m=cv_recall(gbc_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wJIiGNyhPEbC",
"outputId": "c1bc18d8-1626-43ae-bcd1-c4cde4156522"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8847560975609756.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A CV recall score of 88% is slightly less than that of the Gradient Boosting classifier trained on oversampled data."
],
"metadata": {
"id": "fcccjvVZbSYx"
}
},
{
"cell_type": "code",
"source": [
"gbc_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(gbc_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "ky-xHlE0PEYI",
"outputId": "37a79953-aa09-4e44-8ce1-cac592ba5a43"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.951700\n",
"Precision 0.534409\n",
"Recall 0.908592\n",
"F1 0.672986"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.951700 \n",
" \n",
" \n",
" Precision \n",
" 0.534409 \n",
" \n",
" \n",
" Recall \n",
" 0.908592 \n",
" \n",
" \n",
" F1 \n",
" 0.672986 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 88
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Accuracy and recall are both great, and as this is a boosting model, I do not fear overfitting. Precision is about as good as guessing: compare the true positives and false positives in the confusion matrix above."
],
"metadata": {
"id": "Y0NF_ieTbl-h"
}
},
{
"cell_type": "code",
"source": [
"tabulate(gbc_under,'Grad Boost (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 677
},
"id": "7qm9NI7bPEVL",
"outputId": "6d3a256f-af73-4faf-dff5-0a1769b24cd0"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232\n",
"Bagging (under) 0.991463 0.9470 0.985366 0.871341 0.895795\n",
"Rand Forest (under) 1.000000 0.9633 1.000000 0.891463 0.914077\n",
"AdaBoost (under) 0.902439 0.8941 0.881707 0.859146 0.875686\n",
"Grad Boost (under) 0.952439 0.9517 0.917073 0.884756 0.908592"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
" Bagging (under) \n",
" 0.991463 \n",
" 0.9470 \n",
" 0.985366 \n",
" 0.871341 \n",
" 0.895795 \n",
" \n",
" \n",
" Rand Forest (under) \n",
" 1.000000 \n",
" 0.9633 \n",
" 1.000000 \n",
" 0.891463 \n",
" 0.914077 \n",
" \n",
" \n",
" AdaBoost (under) \n",
" 0.902439 \n",
" 0.8941 \n",
" 0.881707 \n",
" 0.859146 \n",
" 0.875686 \n",
" \n",
" \n",
" Grad Boost (under) \n",
" 0.952439 \n",
" 0.9517 \n",
" 0.917073 \n",
" 0.884756 \n",
" 0.908592 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 89
}
]
},
{
"cell_type": "markdown",
"source": [
"This is one of the better models we've seen. Great recall and accuracy, at around 91% and 95% respectively."
],
"metadata": {
"id": "vfwR4iISb7Hf"
}
},
{
"cell_type": "markdown",
"source": [
"#### XGBoost [Undersampled]"
],
"metadata": {
"id": "_AxGifdHLARe"
}
},
{
"cell_type": "code",
"source": [
"xgb_under=XGBClassifier(random_state=1)\n",
"\n",
"m=cv_recall(xgb_under,sample_strategy='under')\n",
"print(f'Cross-validated recall is {m}.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HBfajULcrF6N",
"outputId": "6ef62bbf-ddef-438d-b4e8-fa3e1ddd5486"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cross-validated recall is 0.8774390243902438.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"A CV recall score of 88% is good, comparable with the previous gradient boosting model."
],
"metadata": {
"id": "NTrQChyPcLBk"
}
},
{
"cell_type": "code",
"source": [
"xgb_under.fit(X_train_under,y_train_under)\n",
"\n",
"ch(xgb_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "IH0uLnBNrF3s",
"outputId": "d0face3a-e48d-489f-a4d5-b60d8356fba2"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.959800\n",
"Precision 0.586207\n",
"Recall 0.901280\n",
"F1 0.710375"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.959800 \n",
" \n",
" \n",
" Precision \n",
" 0.586207 \n",
" \n",
" \n",
" Recall \n",
" 0.901280 \n",
" \n",
" \n",
" F1 \n",
" 0.710375 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 91
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"This model scores comparably on accuracy and recall to the previous model, but improves on precision! (Compare 53% to 59%.)"
],
"metadata": {
"id": "TKbI0h0VcVBq"
}
},
{
"cell_type": "code",
"source": [
"tabulate(xgb_under,'XGBoost (under)',sample='under',cvs=m)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 708
},
"id": "6UfPeJ4zrF0w",
"outputId": "d551583b-056e-48ca-f39e-c0b59acbb157"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Train Acc Val Acc Train Recall CV Recall Val Recall\n",
"dtree 1.000000 0.9724 1.000000 0.720732 0.776965\n",
"Logistic Regr 0.966967 0.9677 0.476220 0.478049 0.506399\n",
"Bagging Clfr 0.997233 0.9855 0.950610 0.712805 0.780622\n",
"Random Forest 1.000000 0.9883 1.000000 0.756707 0.797075\n",
"AdaBoost 0.973900 0.9750 0.626829 0.607927 0.652651\n",
"Grad Boost 0.987500 0.9845 0.783537 0.720122 0.749543\n",
"XGBoost 0.987667 0.9859 0.786585 0.736585 0.775137\n",
"dtree (over) 1.000000 0.9510 1.000000 0.970487 0.855576\n",
"Logistic Regr (over) 0.867331 0.8718 0.866890 0.866784 0.855576\n",
"Bagging (over) 0.999418 0.9856 0.998907 0.973766 0.888483\n",
"Rand Forest (over) 1.000000 0.9913 1.000000 0.982405 0.892139\n",
"AdaBoost (over) 0.898554 0.9076 0.886142 0.884027 0.879342\n",
"Grad Boost (over) 0.942719 0.9671 0.914245 0.909415 0.901280\n",
"XGBoost (over) 0.939774 0.9671 0.906946 0.905818 0.908592\n",
"dtree (under) 1.000000 0.8439 1.000000 0.853049 0.868373\n",
"Logistic Regr (under) 0.857622 0.8699 0.848171 0.844512 0.859232\n",
"Bagging (under) 0.991463 0.9470 0.985366 0.871341 0.895795\n",
"Rand Forest (under) 1.000000 0.9633 1.000000 0.891463 0.914077\n",
"AdaBoost (under) 0.902439 0.8941 0.881707 0.859146 0.875686\n",
"Grad Boost (under) 0.952439 0.9517 0.917073 0.884756 0.908592\n",
"XGBoost (under) 0.949390 0.9598 0.908537 0.877439 0.901280"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Train Acc \n",
" Val Acc \n",
" Train Recall \n",
" CV Recall \n",
" Val Recall \n",
" \n",
" \n",
" \n",
" \n",
" dtree \n",
" 1.000000 \n",
" 0.9724 \n",
" 1.000000 \n",
" 0.720732 \n",
" 0.776965 \n",
" \n",
" \n",
" Logistic Regr \n",
" 0.966967 \n",
" 0.9677 \n",
" 0.476220 \n",
" 0.478049 \n",
" 0.506399 \n",
" \n",
" \n",
" Bagging Clfr \n",
" 0.997233 \n",
" 0.9855 \n",
" 0.950610 \n",
" 0.712805 \n",
" 0.780622 \n",
" \n",
" \n",
" Random Forest \n",
" 1.000000 \n",
" 0.9883 \n",
" 1.000000 \n",
" 0.756707 \n",
" 0.797075 \n",
" \n",
" \n",
" AdaBoost \n",
" 0.973900 \n",
" 0.9750 \n",
" 0.626829 \n",
" 0.607927 \n",
" 0.652651 \n",
" \n",
" \n",
" Grad Boost \n",
" 0.987500 \n",
" 0.9845 \n",
" 0.783537 \n",
" 0.720122 \n",
" 0.749543 \n",
" \n",
" \n",
" XGBoost \n",
" 0.987667 \n",
" 0.9859 \n",
" 0.786585 \n",
" 0.736585 \n",
" 0.775137 \n",
" \n",
" \n",
" dtree (over) \n",
" 1.000000 \n",
" 0.9510 \n",
" 1.000000 \n",
" 0.970487 \n",
" 0.855576 \n",
" \n",
" \n",
" Logistic Regr (over) \n",
" 0.867331 \n",
" 0.8718 \n",
" 0.866890 \n",
" 0.866784 \n",
" 0.855576 \n",
" \n",
" \n",
" Bagging (over) \n",
" 0.999418 \n",
" 0.9856 \n",
" 0.998907 \n",
" 0.973766 \n",
" 0.888483 \n",
" \n",
" \n",
" Rand Forest (over) \n",
" 1.000000 \n",
" 0.9913 \n",
" 1.000000 \n",
" 0.982405 \n",
" 0.892139 \n",
" \n",
" \n",
" AdaBoost (over) \n",
" 0.898554 \n",
" 0.9076 \n",
" 0.886142 \n",
" 0.884027 \n",
" 0.879342 \n",
" \n",
" \n",
" Grad Boost (over) \n",
" 0.942719 \n",
" 0.9671 \n",
" 0.914245 \n",
" 0.909415 \n",
" 0.901280 \n",
" \n",
" \n",
" XGBoost (over) \n",
" 0.939774 \n",
" 0.9671 \n",
" 0.906946 \n",
" 0.905818 \n",
" 0.908592 \n",
" \n",
" \n",
" dtree (under) \n",
" 1.000000 \n",
" 0.8439 \n",
" 1.000000 \n",
" 0.853049 \n",
" 0.868373 \n",
" \n",
" \n",
" Logistic Regr (under) \n",
" 0.857622 \n",
" 0.8699 \n",
" 0.848171 \n",
" 0.844512 \n",
" 0.859232 \n",
" \n",
" \n",
" Bagging (under) \n",
" 0.991463 \n",
" 0.9470 \n",
" 0.985366 \n",
" 0.871341 \n",
" 0.895795 \n",
" \n",
" \n",
" Rand Forest (under) \n",
" 1.000000 \n",
" 0.9633 \n",
" 1.000000 \n",
" 0.891463 \n",
" 0.914077 \n",
" \n",
" \n",
" AdaBoost (under) \n",
" 0.902439 \n",
" 0.8941 \n",
" 0.881707 \n",
" 0.859146 \n",
" 0.875686 \n",
" \n",
" \n",
" Grad Boost (under) \n",
" 0.952439 \n",
" 0.9517 \n",
" 0.917073 \n",
" 0.884756 \n",
" 0.908592 \n",
" \n",
" \n",
" XGBoost (under) \n",
" 0.949390 \n",
" 0.9598 \n",
" 0.908537 \n",
" 0.877439 \n",
" 0.901280 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 92
}
]
},
{
"cell_type": "markdown",
"source": [
"Another promising model. Great recall and accuracy, without excessive detriment to precision."
],
"metadata": {
"id": "Vyh50L-OcgqE"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "yZGY1eL84jau"
},
"source": [
"## HyperparameterTuning "
]
},
{
"cell_type": "markdown",
"source": [
"### Model Finalists"
],
"metadata": {
"id": "ZP0UOpHOfqM5"
}
},
{
"cell_type": "markdown",
"source": [
"We'll assemble some candidate finalists and compare them in a table. The following function computes many performance metrics for a given model."
],
"metadata": {
"id": "hpvWLk4kvkCy"
}
},
{
"cell_type": "code",
"source": [
"def model_scores(model,*,sample):\n",
"\n",
" X_val_pred=model.predict(X_val)\n",
" if sample==None:\n",
" y_tr=y_train\n",
" y_pred=model.predict(X_train)\n",
" elif sample=='over':\n",
" y_tr=y_train_over\n",
" y_pred=model.predict(X_train_over)\n",
" elif sample=='under':\n",
" y_tr=y_train_under\n",
" y_pred=model.predict(X_train_under)\n",
" else:\n",
" raise ValueError(\"Sample parameter takes values in {None,'over','under'}.\")\n",
" \n",
" ser=pd.Series(dtype=float)\n",
"\n",
" # accuracy\n",
" ser.loc['Train Accuracy']=metrics.accuracy_score(y_tr,y_pred)\n",
" ser.loc['Validation Accuracy']=metrics.accuracy_score(y_val,X_val_pred)\n",
"\n",
" # recall\n",
" ser.loc['Train Recall']=metrics.recall_score(y_tr,y_pred)\n",
" ser.loc['Validation Recall']=metrics.recall_score(y_val,X_val_pred)\n",
"\n",
" # validation precision and f1\n",
" ser.loc['Validation Precision']=metrics.precision_score(y_val,X_val_pred)\n",
" ser.loc['Validation F1']=metrics.f1_score(y_val,X_val_pred)\n",
" return ser"
],
"metadata": {
"id": "NVn7vlR4fp6M"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"finalists=pd.DataFrame()\n",
"finalists['Bag_over']=model_scores(bag_over,sample='over')\n",
"finalists['Bag_under']=model_scores(bag_under,sample='under')\n",
"finalists['RF']=model_scores(rf,sample=None)\n",
"finalists['RF_over']=model_scores(rf_over,sample='over')\n",
"finalists['RF_under']=model_scores(rf_under,sample='under')\n",
"finalists['GB_over']=model_scores(gbc_over,sample='over')\n",
"finalists['GB_under']=model_scores(gbc_under,sample='under')\n",
"finalists['XGB_over']=model_scores(xgb_over,sample='over')\n",
"finalists['XGB_under']=model_scores(xgb_under,sample='under')"
],
"metadata": {
"id": "TyfJY3HHfu3i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"From the model building process, we compile a list of candidates for tuning."
],
"metadata": {
"id": "yYPTguL_xh5I"
}
},
{
"cell_type": "code",
"source": [
"finalists"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "wz0d_0_fgMmU",
"outputId": "55ea1924-a5c6-40f2-f34b-b53eb4a80759"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Bag_over Bag_under RF RF_over RF_under \\\n",
"Train Accuracy 0.999418 0.991463 1.000000 1.000000 1.000000 \n",
"Validation Accuracy 0.985600 0.947000 0.988300 0.991300 0.963300 \n",
"Train Recall 0.998907 0.985366 1.000000 1.000000 1.000000 \n",
"Validation Recall 0.888483 0.895795 0.797075 0.892139 0.914077 \n",
"Validation Precision 0.854130 0.508827 0.986425 0.945736 0.609756 \n",
"Validation F1 0.870968 0.649007 0.881699 0.918156 0.731529 \n",
"\n",
" GB_over GB_under XGB_over XGB_under \n",
"Train Accuracy 0.942719 0.952439 0.939774 0.949390 \n",
"Validation Accuracy 0.967100 0.951700 0.967100 0.959800 \n",
"Train Recall 0.914245 0.917073 0.906946 0.908537 \n",
"Validation Recall 0.901280 0.908592 0.908592 0.901280 \n",
"Validation Precision 0.641927 0.534409 0.640464 0.586207 \n",
"Validation F1 0.749810 0.672986 0.751323 0.710375 "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Bag_over \n",
" Bag_under \n",
" RF \n",
" RF_over \n",
" RF_under \n",
" GB_over \n",
" GB_under \n",
" XGB_over \n",
" XGB_under \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.999418 \n",
" 0.991463 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 0.942719 \n",
" 0.952439 \n",
" 0.939774 \n",
" 0.949390 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.985600 \n",
" 0.947000 \n",
" 0.988300 \n",
" 0.991300 \n",
" 0.963300 \n",
" 0.967100 \n",
" 0.951700 \n",
" 0.967100 \n",
" 0.959800 \n",
" \n",
" \n",
" Train Recall \n",
" 0.998907 \n",
" 0.985366 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 0.914245 \n",
" 0.917073 \n",
" 0.906946 \n",
" 0.908537 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.888483 \n",
" 0.895795 \n",
" 0.797075 \n",
" 0.892139 \n",
" 0.914077 \n",
" 0.901280 \n",
" 0.908592 \n",
" 0.908592 \n",
" 0.901280 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.854130 \n",
" 0.508827 \n",
" 0.986425 \n",
" 0.945736 \n",
" 0.609756 \n",
" 0.641927 \n",
" 0.534409 \n",
" 0.640464 \n",
" 0.586207 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.870968 \n",
" 0.649007 \n",
" 0.881699 \n",
" 0.918156 \n",
" 0.731529 \n",
" 0.749810 \n",
" 0.672986 \n",
" 0.751323 \n",
" 0.710375 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 95
}
]
},
{
"cell_type": "markdown",
"source": [
"* The bagging classifiers are outperformed by the random forest classifiers in every metric. Accordingly, we would rather just study the random forest models.\n",
"\n",
"* I am curious to tune at least one model with the original data. Random forest performed best in this category.\n",
"\n",
"* The random forest models trained on both oversampled and undersampled data are strong. While the undersamped model shows higher validation recall, the oversampled model boasts impressive precision and F1, without much sacrifice in recall. A model with both stellar recall _and_ good precision will further cut down on operating costs for BreezeGen. The only issue with the random forest models is their overfitting. We will tune both to eliminate overfitting and find out which model comes out on top.\n",
"\n",
"* Both gradient boosting and XGBoost performed exceptionally. Additionally, these models did not treaten overfitting as much as the random forest models. The models trained on oversampled data and undersampled data boast great accuracy and recall. We will favor the oversampled ones, since their precision beats undersampled, and we'll tune both `GB_over` and `XGB_over`."
],
"metadata": {
"id": "KIIOUrpYhu6Y"
}
},
{
"cell_type": "markdown",
"source": [
"### Random Forest"
],
"metadata": {
"id": "V3JSLzLIzgPJ"
}
},
{
"cell_type": "code",
"source": [
"params={'n_estimators':np.arange(100,250,50),\n",
" 'max_depth':np.arange(3,10),\n",
" 'class_weight':[None,'balanced']}"
],
"metadata": {
"id": "oFOp3XBMzf6c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We will start by tuning the random forest trained on the original data. We vary `n_estimators` in an attempt to improve performance. We use `max_depth` to control overfitting, and balancing `class_weight` should help with recall on our highly unbalanced data set."
],
"metadata": {
"id": "yt8igVmzxAV1"
}
},
{
"cell_type": "code",
"source": [
"rf_tuned=RandomForestClassifier(random_state=2)\n",
"\n",
"search=RandomizedSearchCV(estimator=rf_tuned,\n",
" param_distributions=params,\n",
" n_iter=20,\n",
" scoring='recall',\n",
" n_jobs=-1,\n",
" cv=5,\n",
" verbose=1,\n",
" random_state=1)\n",
"\n",
"search.fit(X_train,y_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "7JHmm8yO0mw0",
"outputId": "214ba8b3-613b-49d3-9224-75440ab6ca1e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 20 candidates, totalling 100 fits\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(random_state=2),\n",
" n_iter=20, n_jobs=-1,\n",
" param_distributions={'class_weight': [None, 'balanced'],\n",
" 'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'n_estimators': array([100, 150, 200])},\n",
" random_state=1, scoring='recall', verbose=1)"
],
"text/html": [
"RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(random_state=2),\n",
" n_iter=20, n_jobs=-1,\n",
" param_distributions={'class_weight': [None, 'balanced'],\n",
" 'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'n_estimators': array([100, 150, 200])},\n",
" random_state=1, scoring='recall', verbose=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. RandomizedSearchCV RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(random_state=2),\n",
" n_iter=20, n_jobs=-1,\n",
" param_distributions={'class_weight': [None, 'balanced'],\n",
" 'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'n_estimators': array([100, 150, 200])},\n",
" random_state=1, scoring='recall', verbose=1) "
]
},
"metadata": {},
"execution_count": 97
}
]
},
{
"cell_type": "code",
"source": [
"search.best_params_"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kFFYz_Ji-2V8",
"outputId": "692dc78c-f57c-4192-a2cd-71337fad060b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'n_estimators': 100, 'max_depth': 5, 'class_weight': 'balanced'}"
]
},
"metadata": {},
"execution_count": 98
}
]
},
{
"cell_type": "markdown",
"source": [
"It turns out fewer estimators yielded better recall. A depth of 5, roughly in the middle of our proposed range, is better, and balanced class weights won out over no weighted classes."
],
"metadata": {
"id": "fQECuQEcx0i2"
}
},
{
"cell_type": "code",
"source": [
"best_rf=search.best_params_\n",
"\n",
"# fit model with best params\n",
"rf_tuned=rf_tuned=RandomForestClassifier(\n",
" random_state=2,\n",
" n_jobs=-1,\n",
" **best_rf\n",
")\n",
"\n",
"rf_tuned.fit(X_train,y_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"id": "kCmd_FKy3h0q",
"outputId": "368e59cd-1e3d-4411-c154-6f1dfb17188c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomForestClassifier(class_weight='balanced', max_depth=5, n_jobs=-1,\n",
" random_state=2)"
],
"text/html": [
"RandomForestClassifier(class_weight='balanced', max_depth=5, n_jobs=-1,\n",
" random_state=2) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 99
}
]
},
{
"cell_type": "markdown",
"source": [
"The trained estimator has the following performance on validation data."
],
"metadata": {
"id": "3RdlQIUkyJy-"
}
},
{
"cell_type": "code",
"source": [
"ch(rf_tuned)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "zEr0Zzs07hFq",
"outputId": "be8d78d1-c76d-4c7d-a0a9-986cbe85e3a7"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.950400\n",
"Precision 0.528365\n",
"Recall 0.868373\n",
"F1 0.656985"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.950400 \n",
" \n",
" \n",
" Precision \n",
" 0.528365 \n",
" \n",
" \n",
" Recall \n",
" 0.868373 \n",
" \n",
" \n",
" F1 \n",
" 0.656985 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 100
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"tuned_models=pd.DataFrame()\n",
"\n",
"tuned_models['RF']=model_scores(rf_tuned,sample=None)\n",
"tuned_models"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "dfhptaYh-12n",
"outputId": "987af58b-e68b-4051-b3d7-3ccd815ed05d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF\n",
"Train Accuracy 0.950633\n",
"Validation Accuracy 0.950400\n",
"Train Recall 0.878049\n",
"Validation Recall 0.868373\n",
"Validation Precision 0.528365\n",
"Validation F1 0.656985"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 101
}
]
},
{
"cell_type": "markdown",
"source": [
"Accuracy is stellar and certainly **not** overfit. Recall is good, at around 87%. Precision is a bit lower: not much more than 50%, i.e., random guessing. While precision is not our number one priority, other finalist models demonstrate good precision too, which would further reduce costs for BreezeGen."
],
"metadata": {
"id": "nJeMrcbNyUUQ"
}
},
{
"cell_type": "markdown",
"source": [
"### Random Forest [Oversampled]"
],
"metadata": {
"id": "4O8Lw-Kz23D-"
}
},
{
"cell_type": "code",
"source": [
"params={'n_estimators':np.arange(250,350,25),\n",
" 'max_depth':np.arange(4,9),\n",
" 'max_features':['sqrt',0.5]}"
],
"metadata": {
"id": "5RYylH7ACrKq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Again, we vary the number of estimators with the goal of improving performance. We will curtail overfitting with the `max_depth` and `max_features` parameters. The latter should also aid performance. (After many trials, I discovered that low values of `n_estimators` caused some overfitting, so I set 250 as the minimum value in the parameter distribution.)\n",
"\n",
"Additionally, we instantiate the estimator with `min_samples_leaf=2` to further prevent overfitting: a leaf cannot consist of a single datum. This reduces occurances of the model memorizing noise in the training data."
],
"metadata": {
"id": "PVAxYuU_y00T"
}
},
{
"cell_type": "code",
"source": [
"rf_over_tuned=RandomForestClassifier(random_state=2,min_samples_leaf=2)\n",
"\n",
"search=RandomizedSearchCV(estimator=rf_over_tuned,\n",
" param_distributions=params,\n",
" n_iter=15,\n",
" scoring='recall',\n",
" n_jobs=-1,\n",
" cv=5,\n",
" verbose=1,\n",
" random_state=1)\n",
"\n",
"search.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "aUiheQta28hg",
"outputId": "dee59858-2355-4c24-c9db-6a0577d36c3a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 15 candidates, totalling 75 fits\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=15, n_jobs=-1,\n",
" param_distributions={'max_depth': array([4, 5, 6, 7, 8]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([250, 275, 300, 325])},\n",
" random_state=1, scoring='recall', verbose=1)"
],
"text/html": [
"RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=15, n_jobs=-1,\n",
" param_distributions={'max_depth': array([4, 5, 6, 7, 8]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([250, 275, 300, 325])},\n",
" random_state=1, scoring='recall', verbose=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. RandomizedSearchCV RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=15, n_jobs=-1,\n",
" param_distributions={'max_depth': array([4, 5, 6, 7, 8]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([250, 275, 300, 325])},\n",
" random_state=1, scoring='recall', verbose=1) "
]
},
"metadata": {},
"execution_count": 103
}
]
},
{
"cell_type": "code",
"source": [
"search.best_params_"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kDtolaIa9F5B",
"outputId": "6686d1cc-e2c1-4919-b675-1634adc28ff6"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'n_estimators': 325, 'max_features': 0.5, 'max_depth': 8}"
]
},
"metadata": {},
"execution_count": 104
}
]
},
{
"cell_type": "markdown",
"source": [
"We find that greater depth and more estimators increase recall. Additionally, taking 50% of features (greater than $\\sqrt{\\text{num_features}}$) yielded a higher score."
],
"metadata": {
"id": "eSKcPu_O5Ko-"
}
},
{
"cell_type": "code",
"source": [
"best_rf_over=search.best_params_\n",
"\n",
"# fit model with best params\n",
"rf_over_tuned=RandomForestClassifier(\n",
" random_state=2,\n",
" min_samples_leaf=2,\n",
" n_jobs=-1,\n",
" **best_rf_over\n",
")\n",
"\n",
"rf_over_tuned.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"id": "SEpnTVh2YGr8",
"outputId": "8ed224c3-580d-42ee-d117-ebc33d246ee1"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomForestClassifier(max_depth=8, max_features=0.5, min_samples_leaf=2,\n",
" n_estimators=325, n_jobs=-1, random_state=2)"
],
"text/html": [
"RandomForestClassifier(max_depth=8, max_features=0.5, min_samples_leaf=2,\n",
" n_estimators=325, n_jobs=-1, random_state=2) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 105
}
]
},
{
"cell_type": "code",
"source": [
"ch(rf_over_tuned)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "m_4HYron0B1H",
"outputId": "66bc4fe0-153a-4428-a086-e02ecef66d42"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.978800\n",
"Precision 0.754173\n",
"Recall 0.908592\n",
"F1 0.824212"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.978800 \n",
" \n",
" \n",
" Precision \n",
" 0.754173 \n",
" \n",
" \n",
" Recall \n",
" 0.908592 \n",
" \n",
" \n",
" F1 \n",
" 0.824212 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 106
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"The metrics for validation data are great: 91% recall and around 98% accuracy. Precision is good too, at around 75%, yielding an F1 score of 82%. Note especially the rarity of false negatives in the confusion matrix above."
],
"metadata": {
"id": "Pz-8eBYG6Kuk"
}
},
{
"cell_type": "code",
"source": [
"tuned_models['RF_over']=model_scores(rf_over_tuned,sample='over')\n",
"tuned_models"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "Oi0sW9r63Edm",
"outputId": "6c4f1a40-1776-40b2-a089-2f0843e2c0e0"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF RF_over\n",
"Train Accuracy 0.950633 0.950864\n",
"Validation Accuracy 0.950400 0.978800\n",
"Train Recall 0.878049 0.916890\n",
"Validation Recall 0.868373 0.908592\n",
"Validation Precision 0.528365 0.754173\n",
"Validation F1 0.656985 0.824212"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" RF_over \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" 0.950864 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" 0.978800 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" 0.916890 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" 0.908592 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" 0.754173 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" 0.824212 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 107
}
]
},
{
"cell_type": "markdown",
"source": [
"Comparing with training data, there's not much concern for overfitting here. Recall is locked in around 91%, and precision is a good improvement on the previous random forest model."
],
"metadata": {
"id": "DHgd4HgzzR0S"
}
},
{
"cell_type": "markdown",
"source": [
"### Random Forest [Undersampled]"
],
"metadata": {
"id": "GMuedqh63HdB"
}
},
{
"cell_type": "code",
"source": [
"params={'n_estimators':np.arange(150,300,50),\n",
" 'max_depth':np.arange(3,10),\n",
" 'max_features':['sqrt',0.5]}"
],
"metadata": {
"id": "6XvYCWCykT_r"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"As with the last model, we will test values for `n_estimators`, `max_depth`, and `max_features`. Mostly, we are looking to prevent overfitting."
],
"metadata": {
"id": "57vDOLNl6z3B"
}
},
{
"cell_type": "code",
"source": [
"rf_under_tuned=RandomForestClassifier(random_state=2,min_samples_leaf=2)\n",
"\n",
"search=RandomizedSearchCV(estimator=rf_under_tuned,\n",
" param_distributions=params,\n",
" n_iter=30,\n",
" scoring='recall',\n",
" n_jobs=-1,\n",
" cv=5,\n",
" verbose=1,\n",
" random_state=1)\n",
"\n",
"search.fit(X_train_under,y_train_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "QjUoE7tY28du",
"outputId": "cb810f71-6ec0-455d-82b7-6d415de6e7fa"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 30 candidates, totalling 150 fits\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=30, n_jobs=-1,\n",
" param_distributions={'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([150, 200, 250])},\n",
" random_state=1, scoring='recall', verbose=1)"
],
"text/html": [
"RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=30, n_jobs=-1,\n",
" param_distributions={'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([150, 200, 250])},\n",
" random_state=1, scoring='recall', verbose=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. RandomizedSearchCV RandomizedSearchCV(cv=5,\n",
" estimator=RandomForestClassifier(min_samples_leaf=2,\n",
" random_state=2),\n",
" n_iter=30, n_jobs=-1,\n",
" param_distributions={'max_depth': array([3, 4, 5, 6, 7, 8, 9]),\n",
" 'max_features': ['sqrt', 0.5],\n",
" 'n_estimators': array([150, 200, 250])},\n",
" random_state=1, scoring='recall', verbose=1) "
]
},
"metadata": {},
"execution_count": 110
}
]
},
{
"cell_type": "code",
"source": [
"search.best_params_"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EtF_31UF3Jr9",
"outputId": "9d2a0d4c-e54e-47ba-de1a-da7cd20e1128"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'n_estimators': 150, 'max_features': 'sqrt', 'max_depth': 9}"
]
},
"metadata": {},
"execution_count": 111
}
]
},
{
"cell_type": "markdown",
"source": [
"In this case, fewer estimators and fewer features yielded better results. Like the previous model, a greater depth was preferable."
],
"metadata": {
"id": "yyqrrL6jiD8K"
}
},
{
"cell_type": "code",
"source": [
"best_rf_under=search.best_params_\n",
"\n",
"# fit model with best params\n",
"rf_under_tuned=RandomForestClassifier(\n",
" random_state=2,\n",
" min_samples_leaf=2,\n",
" n_jobs=-1,\n",
" **best_rf_under\n",
")\n",
"\n",
"rf_under_tuned.fit(X_train_under,y_train_under)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"id": "1xCp55aL3QMT",
"outputId": "e53502e9-b2b3-40f8-843f-339a7a25274b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomForestClassifier(max_depth=9, min_samples_leaf=2, n_estimators=150,\n",
" n_jobs=-1, random_state=2)"
],
"text/html": [
"RandomForestClassifier(max_depth=9, min_samples_leaf=2, n_estimators=150,\n",
" n_jobs=-1, random_state=2) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 112
}
]
},
{
"cell_type": "code",
"source": [
"ch(rf_under_tuned)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "BzQN2KFakdXG",
"outputId": "2af3bee2-5846-4c68-c9f1-e64daae3dae6"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.966000\n",
"Precision 0.630847\n",
"Recall 0.912249\n",
"F1 0.745889"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.966000 \n",
" \n",
" \n",
" Precision \n",
" 0.630847 \n",
" \n",
" \n",
" Recall \n",
" 0.912249 \n",
" \n",
" \n",
" F1 \n",
" 0.745889 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 122
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"tuned_models['RF_under']=model_scores(rf_under_tuned,sample='under')\n",
"tuned_models"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "y2EQjfBB3JoT",
"outputId": "4b446624-bc76-475b-e750-f3598b905dc9"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF RF_over RF_under\n",
"Train Accuracy 0.950633 0.950864 0.958841\n",
"Validation Accuracy 0.950400 0.978800 0.966000\n",
"Train Recall 0.878049 0.916890 0.918902\n",
"Validation Recall 0.868373 0.908592 0.912249\n",
"Validation Precision 0.528365 0.754173 0.630847\n",
"Validation F1 0.656985 0.824212 0.745889"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" RF_over \n",
" RF_under \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" 0.950864 \n",
" 0.958841 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" 0.978800 \n",
" 0.966000 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" 0.916890 \n",
" 0.918902 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" 0.908592 \n",
" 0.912249 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" 0.754173 \n",
" 0.630847 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" 0.824212 \n",
" 0.745889 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 123
}
]
},
{
"cell_type": "markdown",
"source": [
"Performance here is quite similar to the last model. What's different, however, is the consistency across data sets: the difference between training and validation scores is less than that of the previous model. Recall is reliably 91%. It only falls short of the previous model in precision."
],
"metadata": {
"id": "d4xiuYXeicY6"
}
},
{
"cell_type": "markdown",
"source": [
"### Gradient Boosting [Oversampled]"
],
"metadata": {
"id": "TfoNTvTY3UwL"
}
},
{
"cell_type": "code",
"source": [
"params={'n_estimators':np.arange(50,125,25),\n",
" 'subsample':[0.5,0.75],\n",
" 'max_depth':[3,4]}"
],
"metadata": {
"id": "QTNFx-M028aC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"After many trials, the main issue with tuned gradient boosting was overfitting. Keeping `n_estimators` low helped curtail this issue, as did taking values for `subsample` less than 1. Adjusting `max_depth` assisted with increasing precision without much impact on recall."
],
"metadata": {
"id": "-t7LH5XPjeT0"
}
},
{
"cell_type": "code",
"source": [
"gbc_over_tuned=GradientBoostingClassifier(random_state=2,min_samples_leaf=4)\n",
"\n",
"search=RandomizedSearchCV(estimator=gbc_over_tuned,\n",
" param_distributions=params,\n",
" n_iter=8,\n",
" scoring='recall',\n",
" n_jobs=-1,\n",
" cv=5,\n",
" verbose=1,\n",
" random_state=1)\n",
"\n",
"search.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "ik-ASmWB28Vz",
"outputId": "b55e885a-1923-47f0-c543-ba4de4d4fa94"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 8 candidates, totalling 40 fits\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5,\n",
" estimator=GradientBoostingClassifier(min_samples_leaf=4,\n",
" random_state=2),\n",
" n_iter=8, n_jobs=-1,\n",
" param_distributions={'max_depth': [3, 4],\n",
" 'n_estimators': array([ 50, 75, 100]),\n",
" 'subsample': [0.5, 0.75]},\n",
" random_state=1, scoring='recall', verbose=1)"
],
"text/html": [
"RandomizedSearchCV(cv=5,\n",
" estimator=GradientBoostingClassifier(min_samples_leaf=4,\n",
" random_state=2),\n",
" n_iter=8, n_jobs=-1,\n",
" param_distributions={'max_depth': [3, 4],\n",
" 'n_estimators': array([ 50, 75, 100]),\n",
" 'subsample': [0.5, 0.75]},\n",
" random_state=1, scoring='recall', verbose=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. RandomizedSearchCV RandomizedSearchCV(cv=5,\n",
" estimator=GradientBoostingClassifier(min_samples_leaf=4,\n",
" random_state=2),\n",
" n_iter=8, n_jobs=-1,\n",
" param_distributions={'max_depth': [3, 4],\n",
" 'n_estimators': array([ 50, 75, 100]),\n",
" 'subsample': [0.5, 0.75]},\n",
" random_state=1, scoring='recall', verbose=1) "
]
},
"metadata": {},
"execution_count": 126
}
]
},
{
"cell_type": "code",
"source": [
"search.best_params_"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0QblZzLkigs_",
"outputId": "133bb493-0f23-4a27-e9ea-b7203c1b75ad"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'subsample': 0.5, 'n_estimators': 100, 'max_depth': 4}"
]
},
"metadata": {},
"execution_count": 127
}
]
},
{
"cell_type": "markdown",
"source": [
"The model certainly leaned toward the high end of `n_estimators`, but previous trials revealed that a value any higher than 100 introduced overfitting issues. Better results were observed with lower `subsample` and higher `max_depth`."
],
"metadata": {
"id": "QwXGya3_kL5T"
}
},
{
"cell_type": "code",
"source": [
"best_gbc_over=search.best_params_\n",
"\n",
"# fit model with best params\n",
"gbc_over_tuned=GradientBoostingClassifier(\n",
" random_state=2,\n",
" min_samples_leaf=4,\n",
" **best_gbc_over\n",
")\n",
"\n",
"gbc_over_tuned.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"id": "CzQgIDLiIwGg",
"outputId": "29ff09d5-4f11-4c2e-91eb-a0e6db8949f2"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GradientBoostingClassifier(max_depth=4, min_samples_leaf=4, random_state=2,\n",
" subsample=0.5)"
],
"text/html": [
"GradientBoostingClassifier(max_depth=4, min_samples_leaf=4, random_state=2,\n",
" subsample=0.5) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 129
}
]
},
{
"cell_type": "code",
"source": [
"ch(gbc_over_tuned)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "RfF_UZP6ktX4",
"outputId": "c3087537-41ca-4171-e65f-5f17a3e3ce3e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.977600\n",
"Precision 0.739259\n",
"Recall 0.912249\n",
"F1 0.816694"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.977600 \n",
" \n",
" \n",
" Precision \n",
" 0.739259 \n",
" \n",
" \n",
" Recall \n",
" 0.912249 \n",
" \n",
" \n",
" F1 \n",
" 0.816694 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 131
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"tuned_models['GBC_over']=model_scores(gbc_over_tuned,sample='over')\n",
"tuned_models"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "J6lApZPLI6Ak",
"outputId": "d61018bb-5a09-4d26-ed7f-822e2dc3fc85"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF RF_over RF_under GBC_over\n",
"Train Accuracy 0.950633 0.950864 0.958841 0.953138\n",
"Validation Accuracy 0.950400 0.978800 0.966000 0.977600\n",
"Train Recall 0.878049 0.916890 0.918902 0.924260\n",
"Validation Recall 0.868373 0.908592 0.912249 0.912249\n",
"Validation Precision 0.528365 0.754173 0.630847 0.739259\n",
"Validation F1 0.656985 0.824212 0.745889 0.816694"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" RF_over \n",
" RF_under \n",
" GBC_over \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" 0.950864 \n",
" 0.958841 \n",
" 0.953138 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" 0.978800 \n",
" 0.966000 \n",
" 0.977600 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" 0.916890 \n",
" 0.918902 \n",
" 0.924260 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" 0.908592 \n",
" 0.912249 \n",
" 0.912249 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" 0.754173 \n",
" 0.630847 \n",
" 0.739259 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" 0.824212 \n",
" 0.745889 \n",
" 0.816694 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 132
}
]
},
{
"cell_type": "markdown",
"source": [
"This model performs comparably to `RF_over`, which is to say, fantastically! Accuracy is solidly in the 95-98% range, with recall at 91%. Precision clocks in at 74%, one of our top scores for these finalist models."
],
"metadata": {
"id": "8IRS0zPflADO"
}
},
{
"cell_type": "markdown",
"source": [
"### XGBoost [Oversampled]"
],
"metadata": {
"id": "g-ChEjXO62Cz"
}
},
{
"cell_type": "code",
"source": [
"params={'eta':[0.05,0.1,0.15],\n",
" 'colsample_bytree':[0.5,0.75,1.0],\n",
" 'max_depth':np.arange(2,5)}"
],
"metadata": {
"id": "Dz38cRN3UNqr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"After much experimentation, I narrowed down the parameter space to `eta` (=`learning_rate`), `colsample_bytree`, and `max_depth`. All three parameters will be used to control overfitting and improve performance on secondary metrics, like precision, since recall is solidly around 90% across a large swath of the parameter space.\n",
"\n",
"What's more, by instantiating the classifier with `tree_method='gpu_hist'`, training time is lightnight fast! So fast, in fact, that I am able to run exhaustive parameter searches with `GridSearchCV` nearly instantly."
],
"metadata": {
"id": "jwWFBTwt23k0"
}
},
{
"cell_type": "code",
"source": [
"xgb_over_tuned=XGBClassifier(random_state=1,\n",
" tree_method='gpu_hist')\n",
"\n",
"go=GridSearchCV(estimator=xgb_over_tuned,\n",
" param_grid=params,\n",
" scoring='recall',\n",
" cv=5,\n",
" n_jobs=-1,\n",
" verbose=1)\n",
"\n",
"go.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "lqERdqGpRRWa",
"outputId": "66cb4b3b-0386-414f-8585-63c0d313345a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 27 candidates, totalling 135 fits\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=5,\n",
" estimator=XGBClassifier(random_state=1, tree_method='gpu_hist'),\n",
" n_jobs=-1,\n",
" param_grid={'colsample_bytree': [0.5, 0.75, 1.0],\n",
" 'eta': [0.05, 0.1, 0.15],\n",
" 'max_depth': array([2, 3, 4])},\n",
" scoring='recall', verbose=1)"
],
"text/html": [
"GridSearchCV(cv=5,\n",
" estimator=XGBClassifier(random_state=1, tree_method='gpu_hist'),\n",
" n_jobs=-1,\n",
" param_grid={'colsample_bytree': [0.5, 0.75, 1.0],\n",
" 'eta': [0.05, 0.1, 0.15],\n",
" 'max_depth': array([2, 3, 4])},\n",
" scoring='recall', verbose=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. GridSearchCV GridSearchCV(cv=5,\n",
" estimator=XGBClassifier(random_state=1, tree_method='gpu_hist'),\n",
" n_jobs=-1,\n",
" param_grid={'colsample_bytree': [0.5, 0.75, 1.0],\n",
" 'eta': [0.05, 0.1, 0.15],\n",
" 'max_depth': array([2, 3, 4])},\n",
" scoring='recall', verbose=1) "
]
},
"metadata": {},
"execution_count": 134
}
]
},
{
"cell_type": "code",
"source": [
"go.best_params_"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ptiw1nyQbJUJ",
"outputId": "67e4c36b-ffab-4657-b700-f7cbd23e1771"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'colsample_bytree': 1.0, 'eta': 0.05, 'max_depth': 4}"
]
},
"metadata": {},
"execution_count": 135
}
]
},
{
"cell_type": "markdown",
"source": [
"The best setting for `colsample_bytree` turns out to be the default. I found that `eta` wanted to be as low as possible, but any lower than 0.05 threatened overfitting. Reducing `max_depth` to 4 (from the default of 6) certainly did prevent overfitting."
],
"metadata": {
"id": "bbQ4YndRlxTn"
}
},
{
"cell_type": "code",
"source": [
"best_xgb_over=go.best_params_\n",
"\n",
"# fit model with best params\n",
"xgb_over_tuned=XGBClassifier(\n",
" random_state=1,\n",
" tree_method='gpu_hist',\n",
" **best_xgb_over\n",
")\n",
"\n",
"xgb_over_tuned.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"id": "RGHDzYCLSbKi",
"outputId": "d896448a-3dc6-44cc-e9e7-c66d56676782"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"XGBClassifier(colsample_bytree=1.0, eta=0.05, max_depth=4, random_state=1,\n",
" tree_method='gpu_hist')"
],
"text/html": [
"XGBClassifier(colsample_bytree=1.0, eta=0.05, max_depth=4, random_state=1,\n",
" tree_method='gpu_hist') In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
]
},
"metadata": {},
"execution_count": 136
}
]
},
{
"cell_type": "code",
"source": [
"ch(xgb_over_tuned)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"id": "QYVK9de1KwAu",
"outputId": "4b751a69-672a-4945-d7fa-89c00ac42101"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Scores\n",
"Accuracy 0.978800\n",
"Precision 0.754947\n",
"Recall 0.906764\n",
"F1 0.823920"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Scores \n",
" \n",
" \n",
" \n",
" \n",
" Accuracy \n",
" 0.978800 \n",
" \n",
" \n",
" Precision \n",
" 0.754947 \n",
" \n",
" \n",
" Recall \n",
" 0.906764 \n",
" \n",
" \n",
" F1 \n",
" 0.823920 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 137
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"tuned_models['XGB_over']=model_scores(xgb_over_tuned,sample='over')\n",
"tuned_models"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "scjrQDlJl1Th",
"outputId": "8e480797-74e3-49f9-a6f7-d0a0f7d71dc2"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF RF_over RF_under GBC_over XGB_over\n",
"Train Accuracy 0.950633 0.950864 0.958841 0.953138 0.950934\n",
"Validation Accuracy 0.950400 0.978800 0.966000 0.977600 0.978800\n",
"Train Recall 0.878049 0.916890 0.918902 0.924260 0.918900\n",
"Validation Recall 0.868373 0.908592 0.912249 0.912249 0.906764\n",
"Validation Precision 0.528365 0.754173 0.630847 0.739259 0.754947\n",
"Validation F1 0.656985 0.824212 0.745889 0.816694 0.823920"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" RF_over \n",
" RF_under \n",
" GBC_over \n",
" XGB_over \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" 0.950864 \n",
" 0.958841 \n",
" 0.953138 \n",
" 0.950934 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" 0.978800 \n",
" 0.966000 \n",
" 0.977600 \n",
" 0.978800 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" 0.916890 \n",
" 0.918902 \n",
" 0.924260 \n",
" 0.918900 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" 0.908592 \n",
" 0.912249 \n",
" 0.912249 \n",
" 0.906764 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" 0.754173 \n",
" 0.630847 \n",
" 0.739259 \n",
" 0.754947 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" 0.824212 \n",
" 0.745889 \n",
" 0.816694 \n",
" 0.823920 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 138
}
]
},
{
"cell_type": "markdown",
"source": [
"Performance here is nearly identical to `RF_over` and `GBC_over`. Great recall will surely be a cost savings for BreezeGen, and good precision will help too!"
],
"metadata": {
"id": "DcH2AjFymtwx"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "D9JNnpxa4jau"
},
"source": [
"## Model performance comparison and choosing the final model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0JG85rkY4jav",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "3df11d55-0943-48f4-d64e-450b3d7b1ab4"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF RF_over RF_under GBC_over XGB_over\n",
"Train Accuracy 0.950633 0.950864 0.958841 0.953138 0.950934\n",
"Validation Accuracy 0.950400 0.978800 0.966000 0.977600 0.978800\n",
"Train Recall 0.878049 0.916890 0.918902 0.924260 0.918900\n",
"Validation Recall 0.868373 0.908592 0.912249 0.912249 0.906764\n",
"Validation Precision 0.528365 0.754173 0.630847 0.739259 0.754947\n",
"Validation F1 0.656985 0.824212 0.745889 0.816694 0.823920"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" RF_over \n",
" RF_under \n",
" GBC_over \n",
" XGB_over \n",
" \n",
" \n",
" \n",
" \n",
" Train Accuracy \n",
" 0.950633 \n",
" 0.950864 \n",
" 0.958841 \n",
" 0.953138 \n",
" 0.950934 \n",
" \n",
" \n",
" Validation Accuracy \n",
" 0.950400 \n",
" 0.978800 \n",
" 0.966000 \n",
" 0.977600 \n",
" 0.978800 \n",
" \n",
" \n",
" Train Recall \n",
" 0.878049 \n",
" 0.916890 \n",
" 0.918902 \n",
" 0.924260 \n",
" 0.918900 \n",
" \n",
" \n",
" Validation Recall \n",
" 0.868373 \n",
" 0.908592 \n",
" 0.912249 \n",
" 0.912249 \n",
" 0.906764 \n",
" \n",
" \n",
" Validation Precision \n",
" 0.528365 \n",
" 0.754173 \n",
" 0.630847 \n",
" 0.739259 \n",
" 0.754947 \n",
" \n",
" \n",
" Validation F1 \n",
" 0.656985 \n",
" 0.824212 \n",
" 0.745889 \n",
" 0.816694 \n",
" 0.823920 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 139
}
],
"source": [
"tuned_models"
]
},
{
"cell_type": "markdown",
"source": [
"* All models boast accuracy in the 95-97% range without overfitting.\n",
"\n",
"* The random forest trained on the original data scores worst on recall, so it will not be the final model. Every other model has a recall score squarely in the 91-92% range.\n",
"\n",
"* Precision is an important secondary metric, as has been discussed throughout. While recall will cut down on BreezeGen's greatest expense, namely repair and replacement costs, precision reduces the instances where an inspection is unnecessary. This reduces money wasted on needless inspections. Of the remaining models, the random forest trained on undersampled data scores worst on precision, so it will not be the final model."
],
"metadata": {
"id": "Jb4pvYi1nTg7"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "d_pDMFAz4jav"
},
"source": [
"### Test set final performance"
]
},
{
"cell_type": "markdown",
"source": [
"To choose the best model, we'll look at performance on completely unseen data."
],
"metadata": {
"id": "N3-wH7eTsJCp"
}
},
{
"cell_type": "code",
"source": [
"test_data=pd.read_csv('dataset_test.csv')"
],
"metadata": {
"id": "9bE7W638olop"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X_test=test_data.drop('Target',axis=1)\n",
"y_test=test_data['Target']\n",
"\n",
"# X_test=pre.transform(X_test)"
],
"metadata": {
"id": "wv-X56z0o6Xf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"After spliting the data into predictor features and target, we run `X_test` through the preprocessing pipeline defined earlier. This pipeline includes a scaler and an imputer."
],
"metadata": {
"id": "_ChhWvjjsOk0"
}
},
{
"cell_type": "code",
"source": [
"test_perf=pd.DataFrame()\n",
"\n",
"def test_scores(model):\n",
"\n",
" y_pred=model.predict(X_test)\n",
" ser=pd.Series(dtype=float)\n",
"\n",
" ser.loc['Test Accuracy']=metrics.accuracy_score(y_test,y_pred)\n",
" ser.loc['Test Recall']=metrics.recall_score(y_test,y_pred)\n",
" ser.loc['Test Precision']=metrics.precision_score(y_test,y_pred)\n",
" return ser"
],
"metadata": {
"id": "st6zk1jwWO25"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"This funciton will compile metrics."
],
"metadata": {
"id": "P7NUG2fzshsT"
}
},
{
"cell_type": "code",
"source": [
"test_perf['RF']=test_scores(rf_over_tuned)\n",
"test_perf['GBC']=test_scores(gbc_over_tuned)\n",
"test_perf['XGB']=test_scores(xgb_over_tuned)\n",
"test_perf"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"id": "V0zg_7oBpqS5",
"outputId": "b0b96847-8092-41c6-d61c-93a954b78726"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" RF GBC XGB\n",
"Test Accuracy 0.977600 0.974200 0.974600\n",
"Test Recall 0.861060 0.872029 0.870201\n",
"Test Precision 0.760905 0.717293 0.722307"
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" RF \n",
" GBC \n",
" XGB \n",
" \n",
" \n",
" \n",
" \n",
" Test Accuracy \n",
" 0.977600 \n",
" 0.974200 \n",
" 0.974600 \n",
" \n",
" \n",
" Test Recall \n",
" 0.861060 \n",
" 0.872029 \n",
" 0.870201 \n",
" \n",
" \n",
" Test Precision \n",
" 0.760905 \n",
" 0.717293 \n",
" 0.722307 \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 149
}
]
},
{
"cell_type": "markdown",
"source": [
"Interesting! Recall is lower on the testing data than on training and validation data. All three models score around 97% on accuracy, and around 87% on recall. What sets apart our winner is its precision score. The random forest model has noticably better precision (76%), which will cut down on inspection costs for BreezeGen.\n",
"\n",
"The best model is the tuned random forest trained on the oversampled data."
],
"metadata": {
"id": "GDoPfchasnBc"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "TM6VZTRn4jav"
},
"source": [
"## Pipelines to build the final model\n"
]
},
{
"cell_type": "markdown",
"source": [
"Our preprocessing pipeline was already most of what we needed, so we will add to what was already built above."
],
"metadata": {
"id": "472zsc8atU4i"
}
},
{
"cell_type": "code",
"source": [
"transformers=[\n",
" ('Scaler',StandardScaler()),\n",
" ('Imputer',KNNImputer()),\n",
" ('Predictor',rf_over_tuned)\n",
"]\n",
"pipe=Pipeline(transformers)"
],
"metadata": {
"id": "P1SEqh3itTIF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The three steps in our pipeline are scaling, imputing, and predicting."
],
"metadata": {
"id": "thQP5weeuG0r"
}
},
{
"cell_type": "code",
"source": [
"pipe.fit(X_train_over,y_train_over)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 161
},
"id": "UtphVJxZuN1c",
"outputId": "ce2a0e36-29b0-42e7-dabc-699e6ed0ada7"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Pipeline(steps=[('Scaler', StandardScaler()), ('Imputer', KNNImputer()),\n",
" ('Predictor',\n",
" RandomForestClassifier(max_depth=8, max_features=0.5,\n",
" min_samples_leaf=2, n_estimators=325,\n",
" n_jobs=-1, random_state=2))])"
],
"text/html": [
"Pipeline(steps=[('Scaler', StandardScaler()), ('Imputer', KNNImputer()),\n",
" ('Predictor',\n",
" RandomForestClassifier(max_depth=8, max_features=0.5,\n",
" min_samples_leaf=2, n_estimators=325,\n",
" n_jobs=-1, random_state=2))]) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. Pipeline Pipeline(steps=[('Scaler', StandardScaler()), ('Imputer', KNNImputer()),\n",
" ('Predictor',\n",
" RandomForestClassifier(max_depth=8, max_features=0.5,\n",
" min_samples_leaf=2, n_estimators=325,\n",
" n_jobs=-1, random_state=2))]) "
]
},
"metadata": {},
"execution_count": 167
}
]
},
{
"cell_type": "markdown",
"source": [
"This pipeline can be used to quickly process new data and subsequently make predictions for possible inspections."
],
"metadata": {
"id": "IHiQhMX6wcJL"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "c5hPmHyR4jaw"
},
"source": [
"# Business Insights and Conclusions"
]
},
{
"cell_type": "code",
"source": [
"ch(pipe,show_scores=False)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 365
},
"id": "HkjhkSa2wwwR",
"outputId": "b1ecbc7f-9f70-4874-c359-7e445fe43a9d"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"BreezeGen incurs the following maintenance costs: \n",
"* \\$40,000 - generator replacement,\n",
"* \\$15,000 - generator repair,\n",
"* \\$5,00 - generator inspection.\n",
"\n",
"To save money on maintenance, BreezeGen must reduce occurances of replacement first. The company spends money unnecessarily when sensors do not alert technicians that a component has failed. This will lead to degredation or outright failure of the generator, necessitating replacement. Such situations are coded as false negatives, and we built a model that avoids false negatives. Our model correctly predicts component failures 86% of the time. Note how few instances of false negatives (FN) are present in the confusion matrix above.\n",
"\n",
"While repair costs are inevitable, unnecessary inspections add expense without any operational benefit. Cutting down on inspection costs means reducing false negatives; a false negative is predicting failure in a generator where the compontents are all functioning. For every three superfluous inspections, BreezeGen can afford another generator repair, so these savings translate into serious business gains. Our model correctly rejects false negatives 76% of the time, allowing technicians to spend more work hours on generators in genuine need of repair. The additional savings our model offers by reducing unnecessary inspections can be put toward repair expenses, meaning BreezeGen's expenses are going toward maintaining their generator infrastructure without much overhead."
],
"metadata": {
"id": "vTvPdtzewoNx"
}
},
{
"cell_type": "markdown",
"source": [
"***"
],
"metadata": {
"id": "VB3eO21n_sgt"
}
}
],
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"TWlGr1u1hKyk",
"tgkSxZkO0tRa",
"Rxw_gopM4jar",
"d_pDMFAz4jav",
"TM6VZTRn4jav",
"c5hPmHyR4jaw"
],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"gpuClass": "premium"
},
"nbformat": 4,
"nbformat_minor": 0
}