{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## License \n", "\n", "Copyright 2017 - 2020 Patrick Hall and the H2O.ai team\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**DISCLAIMER:** This notebook is not legal compliance advice." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Engineering Transparency into Your Machine Learning Model with Python and XGBoost\n", "#### Monotonic XGBoost models, partial dependence, ICE, and Shapley explanations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A key to building interpretable models is to limit their complexity. The more complex a model is, the harder it is to explain and understand. Overly complex models can also make unstable predictions on new data, which is both difficult to explain and makes models harder to trust. Monotonicity constraints not only simplify models, but do so in a way that is somewhat natural for human reasoning, increasing the transparency of predictive models. Under monotonicity constraints, model predictions can only increase or only decrease as an input variable value increases, and the direction of the constraint is typically specified by the user for logical reasons. For instance, a model might be constrained to produce only increasing probabilities of a certain medical condition as a patient's age increases, or to make only increasing predictions for home prices as a home's square footage increases. \n", "\n", "In this notebook a gradient boosting machine (GBM) is trained with monotonicity constraints to predict credit card payment defaults, using the UCI credit card default data, Python, NumPy, Pandas, and XGBoost. First, the credit card default data is loaded and prepared. Then Pearson correlation with the prediction target is used to determine the direction of the monotonicity constraints for each input variable and the model is trained. After the model is trained, partial dependence and individual conditional expectation (ICE) plots are used to analyze and verify the model's monotonic behavior. Finally an example of creating regulator mandated reason codes from high fidelity Shapley explanations for any model prediction is presented. This combination of monotonic XGBoost, partial dependence, ICE, and Shapley explanations is probably the most direct way to create an interpretable machine learning model today.\n", "\n", "**Note**: As of the h2o 3.24 \"Yates\" release, Shapley values are supported in h2o, in addition to GBM monotonicity constraints and partial dependence. To see Shapley values and monotonicity constraints for an h2o GBM in action please see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Python imports " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with Python package imports. NumPy is used for basic arrray, vector, and matrix calculations. Pandas is used for data frame manipulation and plotting, and XGBoost is used to train a GBM with monotonicity constraints." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/patrickh/anaconda3/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n", " from numpy.core.umath_tests import inner1d\n" ] } ], "source": [ "import numpy as np # array, vector, matrix calculations\n", "import pandas as pd # DataFrame handling\n", "import shap # for consistent, signed variable importance measurements\n", "import xgboost as xgb # gradient boosting machines (GBMs)\n", "\n", "import matplotlib.pyplot as plt # plotting\n", "pd.options.display.max_columns = 999 # enable display of all columns in notebook\n", "\n", "# enables display of plots in notebook\n", "%matplotlib inline\n", "\n", "np.random.seed(12345) # set random seed for reproducibility" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Download, explore, and prepare UCI credit card default data\n", "\n", "UCI credit card default data: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients\n", "\n", "The UCI credit card default data contains demographic and payment information about credit card customers in Taiwan in the year 2005. The data set contains 23 input variables: \n", "\n", "* **`LIMIT_BAL`**: Amount of given credit (NT dollar)\n", "* **`SEX`**: 1 = male; 2 = female\n", "* **`EDUCATION`**: 1 = graduate school; 2 = university; 3 = high school; 4 = others \n", "* **`MARRIAGE`**: 1 = married; 2 = single; 3 = others\n", "* **`AGE`**: Age in years \n", "* **`PAY_0`, `PAY_2` - `PAY_6`**: History of past payment; `PAY_0` = the repayment status in September, 2005; `PAY_2` = the repayment status in August, 2005; ...; `PAY_6` = the repayment status in April, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; ...; 8 = payment delay for eight months; 9 = payment delay for nine months and above. \n", "* **`BILL_AMT1` - `BILL_AMT6`**: Amount of bill statement (NT dollar). `BILL_AMNT1` = amount of bill statement in September, 2005; `BILL_AMT2` = amount of bill statement in August, 2005; ...; `BILL_AMT6` = amount of bill statement in April, 2005. \n", "* **`PAY_AMT1` - `PAY_AMT6`**: Amount of previous payment (NT dollar). `PAY_AMT1` = amount paid in September, 2005; `PAY_AMT2` = amount paid in August, 2005; ...; `PAY_AMT6` = amount paid in April, 2005. \n", "\n", "These 23 input variables are used to predict the target variable, whether or not a customer defaulted on their credit card bill in late 2005. Because XGBoost accepts only numeric inputs, all variables will be treated as numeric." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import data and clean\n", "The credit card default data is available as an `.xls` file. Pandas reads `.xls` files automatically, so it's used to load the credit card default data and give the prediction target a shorter name: `DEFAULT_NEXT_MONTH`. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# import XLS file\n", "path = 'default_of_credit_card_clients.xls'\n", "data = pd.read_excel(path,\n", " skiprows=1) # skip the first row of the spreadsheet\n", "\n", "# remove spaces from target column name \n", "data = data.rename(columns={'default payment next month': 'DEFAULT_NEXT_MONTH'}) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Assign modeling roles" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The shorthand name `y` is assigned to the prediction target. `X` is assigned to all other input variables in the credit card default data except the row indentifier, `ID`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = DEFAULT_NEXT_MONTH\n", "X = ['LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']\n" ] } ], "source": [ "# assign target and inputs for GBM\n", "y = 'DEFAULT_NEXT_MONTH'\n", "X = [name for name in data.columns if name not in [y, 'ID']]\n", "print('y =', y)\n", "print('X =', X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display descriptive statistics\n", "The Pandas `describe()` function displays a brief description of the credit card default data. The input variables `SEX`, `EDUCATION`, `MARRIAGE`, `PAY_0`-`PAY_6`, and the prediction target `DEFAULT_NEXT_MONTH`, are really categorical variables, but they have already been encoded into meaningful numeric, integer values, which is great for XGBoost. Also, there are no missing values in this dataset." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_0PAY_2PAY_3PAY_4PAY_5PAY_6BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6DEFAULT_NEXT_MONTH
count30000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.00000030000.0000003.000000e+0430000.00000030000.00000030000.00000030000.0000003.000000e+0430000.0000030000.00000030000.00000030000.00000030000.000000
mean167484.3226671.6037331.8531331.55186735.485500-0.016700-0.133767-0.166200-0.220667-0.266200-0.29110051223.33090049179.0751674.701315e+0443262.94896740311.40096738871.7604005663.5805005.921163e+035225.681504826.0768674799.3876335215.5025670.221200
std129747.6615670.4891290.7903490.5219709.2179041.1238021.1971861.1968681.1691391.1331871.14998873635.86057671173.7687836.934939e+0464332.85613460797.15577059554.10753716563.2803542.304087e+0417606.9614715666.15974415278.30567917777.4657750.415062
min10000.0000001.0000000.0000000.00000021.000000-2.000000-2.000000-2.000000-2.000000-2.000000-2.000000-165580.000000-69777.000000-1.572640e+05-170000.000000-81334.000000-339603.0000000.0000000.000000e+000.000000.0000000.0000000.0000000.000000
25%50000.0000001.0000001.0000001.00000028.000000-1.000000-1.000000-1.000000-1.000000-1.000000-1.0000003558.7500002984.7500002.666250e+032326.7500001763.0000001256.0000001000.0000008.330000e+02390.00000296.000000252.500000117.7500000.000000
50%140000.0000002.0000002.0000002.00000034.0000000.0000000.0000000.0000000.0000000.0000000.00000022381.50000021200.0000002.008850e+0419052.00000018104.50000017071.0000002100.0000002.009000e+031800.000001500.0000001500.0000001500.0000000.000000
75%240000.0000002.0000002.0000002.00000041.0000000.0000000.0000000.0000000.0000000.0000000.00000067091.00000064006.2500006.016475e+0454506.00000050190.50000049198.2500005006.0000005.000000e+034505.000004013.2500004031.5000004000.0000000.000000
max1000000.0000002.0000006.0000003.00000079.0000008.0000008.0000008.0000008.0000008.0000008.000000964511.000000983931.0000001.664089e+06891586.000000927171.000000961664.000000873552.0000001.684259e+06896040.00000621000.000000426529.000000528666.0000001.000000
\n", "
" ], "text/plain": [ " LIMIT_BAL SEX EDUCATION MARRIAGE AGE \\\n", "count 30000.000000 30000.000000 30000.000000 30000.000000 30000.000000 \n", "mean 167484.322667 1.603733 1.853133 1.551867 35.485500 \n", "std 129747.661567 0.489129 0.790349 0.521970 9.217904 \n", "min 10000.000000 1.000000 0.000000 0.000000 21.000000 \n", "25% 50000.000000 1.000000 1.000000 1.000000 28.000000 \n", "50% 140000.000000 2.000000 2.000000 2.000000 34.000000 \n", "75% 240000.000000 2.000000 2.000000 2.000000 41.000000 \n", "max 1000000.000000 2.000000 6.000000 3.000000 79.000000 \n", "\n", " PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 \\\n", "count 30000.000000 30000.000000 30000.000000 30000.000000 30000.000000 \n", "mean -0.016700 -0.133767 -0.166200 -0.220667 -0.266200 \n", "std 1.123802 1.197186 1.196868 1.169139 1.133187 \n", "min -2.000000 -2.000000 -2.000000 -2.000000 -2.000000 \n", "25% -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 \n", "50% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "75% 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "max 8.000000 8.000000 8.000000 8.000000 8.000000 \n", "\n", " PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 \\\n", "count 30000.000000 30000.000000 30000.000000 3.000000e+04 \n", "mean -0.291100 51223.330900 49179.075167 4.701315e+04 \n", "std 1.149988 73635.860576 71173.768783 6.934939e+04 \n", "min -2.000000 -165580.000000 -69777.000000 -1.572640e+05 \n", "25% -1.000000 3558.750000 2984.750000 2.666250e+03 \n", "50% 0.000000 22381.500000 21200.000000 2.008850e+04 \n", "75% 0.000000 67091.000000 64006.250000 6.016475e+04 \n", "max 8.000000 964511.000000 983931.000000 1.664089e+06 \n", "\n", " BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 \\\n", "count 30000.000000 30000.000000 30000.000000 30000.000000 \n", "mean 43262.948967 40311.400967 38871.760400 5663.580500 \n", "std 64332.856134 60797.155770 59554.107537 16563.280354 \n", "min -170000.000000 -81334.000000 -339603.000000 0.000000 \n", "25% 2326.750000 1763.000000 1256.000000 1000.000000 \n", "50% 19052.000000 18104.500000 17071.000000 2100.000000 \n", "75% 54506.000000 50190.500000 49198.250000 5006.000000 \n", "max 891586.000000 927171.000000 961664.000000 873552.000000 \n", "\n", " PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 \\\n", "count 3.000000e+04 30000.00000 30000.000000 30000.000000 \n", "mean 5.921163e+03 5225.68150 4826.076867 4799.387633 \n", "std 2.304087e+04 17606.96147 15666.159744 15278.305679 \n", "min 0.000000e+00 0.00000 0.000000 0.000000 \n", "25% 8.330000e+02 390.00000 296.000000 252.500000 \n", "50% 2.009000e+03 1800.00000 1500.000000 1500.000000 \n", "75% 5.000000e+03 4505.00000 4013.250000 4031.500000 \n", "max 1.684259e+06 896040.00000 621000.000000 426529.000000 \n", "\n", " PAY_AMT6 DEFAULT_NEXT_MONTH \n", "count 30000.000000 30000.000000 \n", "mean 5215.502567 0.221200 \n", "std 17777.465775 0.415062 \n", "min 0.000000 0.000000 \n", "25% 117.750000 0.000000 \n", "50% 1500.000000 0.000000 \n", "75% 4000.000000 0.000000 \n", "max 528666.000000 1.000000 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[X + [y]].describe() # display descriptive statistics for all columns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Investigate pair-wise Pearson correlations for DEFAULT_NEXT_MONTH\n", "\n", "Monotonic relationships are much easier to explain to colleagues, bosses, customers, and regulators than more complex, non-monotonic relationships and monotonic relationships may also prevent overfitting and excess error due to variance for new data.\n", "\n", "To train a transparent monotonic classifier, contraints must be supplied to XGBoost that determine whether the learned relationship between an input variable and the prediction target `DEFAULT_NEXT_MONTH` will be increasing for increases in an input variable or decreasing for increases in an input variable. Pearson correlation provides a linear measure of the direction of the relationship between each input variable and the target. If the pair-wise Pearson correlation between an input and `DEFAULT_NEXT_MONTH` is positive, it will be constrained to have an increasing relationship with the predictions for `DEFAULT_NEXT_MONTH`. If the pair-wise Pearson correlation is negative, the input will be constrained to have a decreasing relationship with the predictions for `DEFAULT_NEXT_MONTH`. \n", "\n", "Constrainsts are supplied to XGBoost in the form of a Python tuple with length equal to the number of inputs. Each item in the tuple is associated with an input variable based on its index in the tuple. The first constraint in the tuple is associated with the first variable in the training data, the second constraint in the tuple is associated with the second variable in the training data, and so on. The constraints themselves take the form of a 1 for a positive relationship and a -1 for a negative relationship." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculate Pearson correlation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Pandas `.corr()` function returns the pair-wise Pearson correlation between variables in a Pandas DataFrame. Because `DEFAULT_NEXT_MONTH` is the last column in the `data` DataFrame, the last column of the Pearson correlation matrix indicates the direction of the linear relationship between each input variable and the prediction target, `DEFAULT_NEXT_MONTH`. According to the calculated values, as a customer's balance limit (`LIMIT_BAL`), bill amounts (`BILL_AMT1`-`BILL_AMT6`), and payment amounts (`PAY_AMT1`-`PAY_AMT6`) increase, their probability of default tends to decrease. However as a customer's number of late payments increase (`PAY_0`, `PAY_2`-`PAY6`), their probability of default usually increases. In general, the Pearson correlation values make sense, and they will be used to ensure that the modeled relationships will make sense as well. (Pearson correlation values between the target variable, DEFAULT_NEXT_MONTH, and each input variable are displayed directly below.)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DEFAULT_NEXT_MONTH
LIMIT_BAL-0.153520
SEX-0.039961
EDUCATION0.028006
MARRIAGE-0.024339
AGE0.013890
PAY_00.324794
PAY_20.263551
PAY_30.235253
PAY_40.216614
PAY_50.204149
PAY_60.186866
BILL_AMT1-0.019644
BILL_AMT2-0.014193
BILL_AMT3-0.014076
BILL_AMT4-0.010156
BILL_AMT5-0.006760
BILL_AMT6-0.005372
PAY_AMT1-0.072929
PAY_AMT2-0.058579
PAY_AMT3-0.056250
PAY_AMT4-0.056827
PAY_AMT5-0.055124
PAY_AMT6-0.053183
\n", "
" ], "text/plain": [ " DEFAULT_NEXT_MONTH\n", "LIMIT_BAL -0.153520\n", "SEX -0.039961\n", "EDUCATION 0.028006\n", "MARRIAGE -0.024339\n", "AGE 0.013890\n", "PAY_0 0.324794\n", "PAY_2 0.263551\n", "PAY_3 0.235253\n", "PAY_4 0.216614\n", "PAY_5 0.204149\n", "PAY_6 0.186866\n", "BILL_AMT1 -0.019644\n", "BILL_AMT2 -0.014193\n", "BILL_AMT3 -0.014076\n", "BILL_AMT4 -0.010156\n", "BILL_AMT5 -0.006760\n", "BILL_AMT6 -0.005372\n", "PAY_AMT1 -0.072929\n", "PAY_AMT2 -0.058579\n", "PAY_AMT3 -0.056250\n", "PAY_AMT4 -0.056827\n", "PAY_AMT5 -0.055124\n", "PAY_AMT6 -0.053183" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# displays last column of Pearson correlation matrix as Pandas DataFrame\n", "pd.DataFrame(data[X + [y]].corr()[y]).iloc[:-1] " ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "#### Create tuple of monotonicity constraints from Pearson correlation values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last column of the Pearson correlation matrix is transformed from a numeric column in a Pandas DataFrame into a Python tuple of `1`s and `-1`s that will be used to specifiy monotonicity constraints for each input variable in XGBoost. If the Pearson correlation between an input variable and `DEFAULT_NEXT_MONTH` is positive, a positive montonic relationship constraint is specified for that variable using `1`. If the correlation is negative, a negative monotonic constraint is specified using `-1`. (Specifying `0` indicates that no constraints should be used.) The resulting tuple will be passed to XGBoost when the GBM model is trained." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# creates a tuple in which positive correlation values are assigned a 1\n", "# and negative correlation values are assigned a -1\n", "mono_constraints = tuple([int(i) for i in np.sign(data[X + [y]].corr()[y].values[:-1])])\n", "\n", "# (-1, -1, 1, -1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Train XGBoost with monotonicity constraints\n", "\n", "XGBoost is a very accurate, open source GBM library for regression and classification tasks. XGBoost can learn complex relationships between input variables and a target variable, but here the `monotone_constraints` tuning parameter is used to enforce monotonicity between inputs and the prediction for `DEFAULT_NEXT_MONTH`. XGBoost's early stopping functionality is also used to limit overfitting to the training data\n", "\n", "XGBoost is available from: https://github.com/dmlc/xgboost and the implementation of XGBoost is described in detail here: http://www.kdd.org/kdd2016/papers/files/rfp0697-chenAemb.pdf.\n", "\n", "After training, GBM variable importance is calculated and displayed. GBM variable importance is a global measure of the overall impact of an input variable on the GBM model predictions. Global variable importance values give an indication of the magnitude of a variable's contribution to model predictions for all observations. To enhance trust in the GBM model, variable importance values should typically conform to human domain knowledge and reasonable expectations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split data into training and test sets for early stopping" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The credit card default data is split into training and test sets to monitor and prevent overtraining. Reproducibility is another important factor in creating trustworthy models, and randomly splitting datasets can introduce randomness in model predictions and other results. A random seed is used here to ensure the data split is reproducible." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train data rows = 20946, columns = 25\n", "Test data rows = 9054, columns = 25\n" ] } ], "source": [ "np.random.seed(12345) # set random seed for reproducibility\n", "split_ratio = 0.7 # 70%/30% train/test split\n", "\n", "# execute split\n", "split = np.random.rand(len(data)) < split_ratio\n", "train = data[split]\n", "test = data[~split]\n", "\n", "# summarize split\n", "print('Train data rows = %d, columns = %d' % (train.shape[0], train.shape[1]))\n", "print('Test data rows = %d, columns = %d' % (test.shape[0], test.shape[1]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train XGBoost GBM classifier\n", "To train an XGBoost classifier, the training and test data must be converted from Pandas DataFrames into SVMLight format. The `DMatrix()` function in the XGBoost package is used to convert the data. Many XGBoost tuning parameters must be specified as well. Typically a grid search would be performed to identify the best parameters for a given modeling task. For brevity's sake, a previously-discovered set of good tuning parameters are specified here. Notice that the monotonicity constraints are passed to XGBoost using the `monotone_constraints` parameter. Because gradient boosting methods typically resample training data, an additional random seed is also specified for XGBoost using the `seed` paramter to create reproducible predictions, error rates, and variable importance values. To avoid overfitting, the `early_stopping_rounds` parameter is used to stop the training process after the test area under the curve (AUC) statistic fails to increase for 50 iterations." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0]\ttrain-auc:0.738066\teval-auc:0.733449\n", "Multiple eval metrics have been passed: 'eval-auc' will be used for early stopping.\n", "\n", "Will train until eval-auc hasn't improved in 50 rounds.\n", "[1]\ttrain-auc:0.772991\teval-auc:0.769311\n", "[2]\ttrain-auc:0.775536\teval-auc:0.772624\n", "[3]\ttrain-auc:0.776248\teval-auc:0.771985\n", "[4]\ttrain-auc:0.777216\teval-auc:0.772796\n", "[5]\ttrain-auc:0.777782\teval-auc:0.773066\n", "[6]\ttrain-auc:0.777783\teval-auc:0.773471\n", "[7]\ttrain-auc:0.777856\teval-auc:0.773591\n", "[8]\ttrain-auc:0.777633\teval-auc:0.773209\n", "[9]\ttrain-auc:0.777417\teval-auc:0.772892\n", "[10]\ttrain-auc:0.777519\teval-auc:0.772666\n", "[11]\ttrain-auc:0.778476\teval-auc:0.773346\n", "[12]\ttrain-auc:0.778541\teval-auc:0.773312\n", "[13]\ttrain-auc:0.778351\teval-auc:0.773266\n", "[14]\ttrain-auc:0.778519\teval-auc:0.773695\n", "[15]\ttrain-auc:0.779383\teval-auc:0.774314\n", "[16]\ttrain-auc:0.779555\teval-auc:0.77473\n", "[17]\ttrain-auc:0.780178\teval-auc:0.775043\n", "[18]\ttrain-auc:0.780743\teval-auc:0.775246\n", "[19]\ttrain-auc:0.78146\teval-auc:0.776021\n", "[20]\ttrain-auc:0.781944\teval-auc:0.776477\n", "[21]\ttrain-auc:0.782181\teval-auc:0.776696\n", "[22]\ttrain-auc:0.782319\teval-auc:0.776915\n", "[23]\ttrain-auc:0.78263\teval-auc:0.776998\n", "[24]\ttrain-auc:0.782965\teval-auc:0.777144\n", "[25]\ttrain-auc:0.783396\teval-auc:0.777498\n", "[26]\ttrain-auc:0.783791\teval-auc:0.777791\n", "[27]\ttrain-auc:0.784151\teval-auc:0.778091\n", "[28]\ttrain-auc:0.784382\teval-auc:0.778375\n", "[29]\ttrain-auc:0.784678\teval-auc:0.778567\n", "[30]\ttrain-auc:0.784912\teval-auc:0.778755\n", "[31]\ttrain-auc:0.785231\teval-auc:0.778938\n", "[32]\ttrain-auc:0.785418\teval-auc:0.779125\n", "[33]\ttrain-auc:0.785598\teval-auc:0.779226\n", "[34]\ttrain-auc:0.785838\teval-auc:0.779421\n", "[35]\ttrain-auc:0.7861\teval-auc:0.779432\n", "[36]\ttrain-auc:0.786347\teval-auc:0.779549\n", "[37]\ttrain-auc:0.786633\teval-auc:0.779575\n", "[38]\ttrain-auc:0.786833\teval-auc:0.779668\n", "[39]\ttrain-auc:0.787093\teval-auc:0.779743\n", "[40]\ttrain-auc:0.787307\teval-auc:0.779926\n", "[41]\ttrain-auc:0.787602\teval-auc:0.780015\n", "[42]\ttrain-auc:0.787819\teval-auc:0.780089\n", "[43]\ttrain-auc:0.787943\teval-auc:0.780196\n", "[44]\ttrain-auc:0.788092\teval-auc:0.780239\n", "[45]\ttrain-auc:0.788194\teval-auc:0.780198\n", "[46]\ttrain-auc:0.788298\teval-auc:0.780235\n", "[47]\ttrain-auc:0.788422\teval-auc:0.780305\n", "[48]\ttrain-auc:0.78862\teval-auc:0.780372\n", "[49]\ttrain-auc:0.788803\teval-auc:0.780392\n", "[50]\ttrain-auc:0.789049\teval-auc:0.780558\n", "[51]\ttrain-auc:0.789221\teval-auc:0.780659\n", "[52]\ttrain-auc:0.789405\teval-auc:0.780756\n", "[53]\ttrain-auc:0.789553\teval-auc:0.780721\n", "[54]\ttrain-auc:0.78968\teval-auc:0.780791\n", "[55]\ttrain-auc:0.789779\teval-auc:0.780849\n", "[56]\ttrain-auc:0.789875\teval-auc:0.780932\n", "[57]\ttrain-auc:0.79002\teval-auc:0.780963\n", "[58]\ttrain-auc:0.790156\teval-auc:0.781038\n", "[59]\ttrain-auc:0.790292\teval-auc:0.781085\n", "[60]\ttrain-auc:0.790403\teval-auc:0.781079\n", "[61]\ttrain-auc:0.790509\teval-auc:0.781091\n", "[62]\ttrain-auc:0.790554\teval-auc:0.781061\n", "[63]\ttrain-auc:0.790635\teval-auc:0.781148\n", "[64]\ttrain-auc:0.790679\teval-auc:0.781166\n", "[65]\ttrain-auc:0.790822\teval-auc:0.781155\n", "[66]\ttrain-auc:0.790896\teval-auc:0.781178\n", "[67]\ttrain-auc:0.790977\teval-auc:0.781188\n", "[68]\ttrain-auc:0.791155\teval-auc:0.781211\n", "[69]\ttrain-auc:0.791239\teval-auc:0.781179\n", "[70]\ttrain-auc:0.7914\teval-auc:0.781347\n", "[71]\ttrain-auc:0.791525\teval-auc:0.78134\n", "[72]\ttrain-auc:0.791578\teval-auc:0.781312\n", "[73]\ttrain-auc:0.791691\teval-auc:0.781325\n", "[74]\ttrain-auc:0.791747\teval-auc:0.781323\n", "[75]\ttrain-auc:0.791801\teval-auc:0.781304\n", "[76]\ttrain-auc:0.791844\teval-auc:0.781313\n", "[77]\ttrain-auc:0.7919\teval-auc:0.781325\n", "[78]\ttrain-auc:0.792056\teval-auc:0.781448\n", "[79]\ttrain-auc:0.792088\teval-auc:0.781397\n", "[80]\ttrain-auc:0.792151\teval-auc:0.78138\n", "[81]\ttrain-auc:0.792173\teval-auc:0.781388\n", "[82]\ttrain-auc:0.792236\teval-auc:0.781301\n", "[83]\ttrain-auc:0.792327\teval-auc:0.781355\n", "[84]\ttrain-auc:0.792372\teval-auc:0.781356\n", "[85]\ttrain-auc:0.792402\teval-auc:0.781321\n", "[86]\ttrain-auc:0.79247\teval-auc:0.781286\n", "[87]\ttrain-auc:0.792521\teval-auc:0.781283\n", "[88]\ttrain-auc:0.792543\teval-auc:0.781265\n", "[89]\ttrain-auc:0.792595\teval-auc:0.781255\n", "[90]\ttrain-auc:0.792618\teval-auc:0.781242\n", "[91]\ttrain-auc:0.792673\teval-auc:0.781309\n", "[92]\ttrain-auc:0.792766\teval-auc:0.781357\n", "[93]\ttrain-auc:0.792826\teval-auc:0.781381\n", "[94]\ttrain-auc:0.792914\teval-auc:0.781387\n", "[95]\ttrain-auc:0.792967\teval-auc:0.781385\n", "[96]\ttrain-auc:0.793016\teval-auc:0.781375\n", "[97]\ttrain-auc:0.793053\teval-auc:0.781353\n", "[98]\ttrain-auc:0.79312\teval-auc:0.78137\n", "[99]\ttrain-auc:0.793145\teval-auc:0.781413\n", "[100]\ttrain-auc:0.793191\teval-auc:0.781456\n", "[101]\ttrain-auc:0.793256\teval-auc:0.781435\n", "[102]\ttrain-auc:0.793282\teval-auc:0.781382\n", "[103]\ttrain-auc:0.793322\teval-auc:0.781385\n", "[104]\ttrain-auc:0.793346\teval-auc:0.781361\n", "[105]\ttrain-auc:0.793399\teval-auc:0.781418\n", "[106]\ttrain-auc:0.793436\teval-auc:0.781398\n", "[107]\ttrain-auc:0.793511\teval-auc:0.781358\n", "[108]\ttrain-auc:0.793578\teval-auc:0.78137\n", "[109]\ttrain-auc:0.793655\teval-auc:0.781336\n", "[110]\ttrain-auc:0.793683\teval-auc:0.781299\n", "[111]\ttrain-auc:0.793701\teval-auc:0.781277\n", "[112]\ttrain-auc:0.79371\teval-auc:0.78128\n", "[113]\ttrain-auc:0.793745\teval-auc:0.781298\n", "[114]\ttrain-auc:0.793817\teval-auc:0.781307\n", "[115]\ttrain-auc:0.793838\teval-auc:0.781301\n", "[116]\ttrain-auc:0.793867\teval-auc:0.781308\n", "[117]\ttrain-auc:0.793877\teval-auc:0.781315\n", "[118]\ttrain-auc:0.793917\teval-auc:0.781246\n", "[119]\ttrain-auc:0.793947\teval-auc:0.781302\n", "[120]\ttrain-auc:0.794\teval-auc:0.781358\n", "[121]\ttrain-auc:0.794026\teval-auc:0.781357\n", "[122]\ttrain-auc:0.794053\teval-auc:0.781277\n", "[123]\ttrain-auc:0.794064\teval-auc:0.781257\n", "[124]\ttrain-auc:0.794209\teval-auc:0.781289\n", "[125]\ttrain-auc:0.794219\teval-auc:0.781288\n", "[126]\ttrain-auc:0.794287\teval-auc:0.781347\n", "[127]\ttrain-auc:0.79429\teval-auc:0.781347\n", "[128]\ttrain-auc:0.794327\teval-auc:0.781331\n", "[129]\ttrain-auc:0.794336\teval-auc:0.781349\n", "[130]\ttrain-auc:0.794367\teval-auc:0.781347\n", "[131]\ttrain-auc:0.794364\teval-auc:0.78134\n", "[132]\ttrain-auc:0.794385\teval-auc:0.781363\n", "[133]\ttrain-auc:0.794387\teval-auc:0.781326\n", "[134]\ttrain-auc:0.794472\teval-auc:0.78132\n", "[135]\ttrain-auc:0.794483\teval-auc:0.781326\n", "[136]\ttrain-auc:0.794495\teval-auc:0.781306\n", "[137]\ttrain-auc:0.794583\teval-auc:0.781293\n", "[138]\ttrain-auc:0.794589\teval-auc:0.781293\n", "[139]\ttrain-auc:0.7946\teval-auc:0.781286\n", "[140]\ttrain-auc:0.794642\teval-auc:0.7813\n", "[141]\ttrain-auc:0.794656\teval-auc:0.781303\n", "[142]\ttrain-auc:0.794654\teval-auc:0.781287\n", "[143]\ttrain-auc:0.794695\teval-auc:0.781273\n", "[144]\ttrain-auc:0.794705\teval-auc:0.781268\n", "[145]\ttrain-auc:0.794708\teval-auc:0.781268\n", "[146]\ttrain-auc:0.79471\teval-auc:0.781258\n", "[147]\ttrain-auc:0.794775\teval-auc:0.781284\n", "[148]\ttrain-auc:0.794782\teval-auc:0.781277\n", "[149]\ttrain-auc:0.794794\teval-auc:0.781283\n", "[150]\ttrain-auc:0.794815\teval-auc:0.78126\n", "Stopping. Best iteration:\n", "[100]\ttrain-auc:0.793191\teval-auc:0.781456\n", "\n" ] } ], "source": [ "# XGBoost uses SVMLight data structure, not Numpy arrays or Pandas DataFrames \n", "dtrain = xgb.DMatrix(train[X], train[y])\n", "dtest = xgb.DMatrix(test[X], test[y])\n", "\n", "# used to calibrate predictions to mean of y \n", "base_y = train[y].mean()\n", "\n", "# tuning parameters\n", "params = {\n", " 'objective': 'binary:logistic', # produces 0-1 probabilities for binary classification\n", " 'booster': 'gbtree', # base learner will be decision tree\n", " 'eval_metric': 'auc', # stop training based on maximum AUC, AUC always between 0-1\n", " 'eta': 0.08, # learning rate\n", " 'subsample': 0.9, # use 90% of rows in each decision tree\n", " 'colsample_bytree': 0.9, # use 90% of columns in each decision tree\n", " 'max_depth': 15, # allow decision trees to grow to depth of 15\n", " 'monotone_constraints':mono_constraints, # 1 = increasing relationship, -1 = decreasing relationship\n", " 'base_score': base_y, # calibrate predictions to mean of y \n", " 'seed': 12345 # set random seed for reproducibility\n", "}\n", "\n", "# watchlist is used for early stopping\n", "watchlist = [(dtrain, 'train'), (dtest, 'eval')]\n", "\n", "# train model\n", "xgb_model = xgb.train(params, # set tuning parameters from above \n", " dtrain, # training data\n", " 1000, # maximum of 1000 iterations (trees)\n", " evals=watchlist, # use watchlist for early stopping \n", " early_stopping_rounds=50, # stop after 50 iterations (trees) without increase in AUC\n", " verbose_eval=True) # display iteration progress\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Global Shapley variable importance\n", "By setting `pred_contribs=True`, XGBoost's `predict()` function will return Shapley values for each row of the test set. Instead of relying on traditional single-value variable importance measures, local Shapley values for each input will be ploted below to get a more holistic and consisent measurement for the global importance of each input variable. Shapley values are introduced in greater detail in Section 6 below, but for now notice the monotonicity of the input variable contributions displayed in the Shapley summary plot." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# dtest is DMatrix\n", "# shap_values is Numpy array\n", "shap_values = xgb_model.predict(dtest, pred_contribs=True, ntree_limit=xgb_model.best_ntree_limit)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot Shapley variable importance summary \n", "shap.summary_plot(shap_values[:, :-1], test[xgb_model.feature_names])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display Shapley variable importance summary\n", "The variable importance ranking should be parsimonious with human domain knowledge and reasonable expectations. In this case, `PAY_0` is by far the most important variable. As someone's most recent behavior is a very good indicator of future behavior, this checks out." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Calculating partial dependence and ICE to validate and explain monotonic behavior\n", "\n", "Partial dependence plots are used to view the global, average prediction behavior of a variable under the monotonic model. Partial dependence plots show the average prediction of the monotonic model as a function of specific values of an input variable of interest, indicating how the monotonic GBM predictions change based on the values of the input variable of interest, while taking nonlinearity into consideration and averaging out the effects of all other\n", "input variables. Partial dependence plots enable increased transparency into the monotonic GBM's mechanisms and enable validation and debugging of the monotonic GBM by comparing a variable's average predictions across its domain to known standards and reasonable expectations. Partial dependence plots are described in greater detail in *The Elements of Statistical Learning*, section 10.13: https://web.stanford.edu/~hastie/ElemStatLearn/printings/ESLII_print12.pdf.\n", "\n", "Individual conditional expectation (ICE) plots, a newer and less well-known adaptation of partial dependence plots, can be used to create more localized explanations for a single observation of data using the same basic ideas as partial dependence plots. ICE is also a type of nonlinear sensitivity analysis in which the model predictions for a single observation are measured while a feature of interest is varied over its domain. ICE increases understanding and transparency by displaying the nonlinear behavior of the monotonic GBM. ICE also enhances trust, accountability, and fairness by enabling comparisons of nonlinear behavior to human domain knowledge and reasonable expectations. ICE, as a type of sensitivity analysis, can also engender trust when model behavior on simulated or extreme data points is acceptable. A detailed description of ICE is available in this arXiv preprint: https://arxiv.org/abs/1309.6392.\n", "\n", "Because partial dependence and ICE are measured on the same scale, they can be displayed in the same line plot to compare the global, average prediction behavior for the entire model and the local prediction behavior for certain rows of data. Overlaying the two types of curves enables analysis of both global and local behavior simultaneously and provides an indication of the trustworthiness of the average behavior represented by partial dependence. (Partial dependence can be misleading in the presence of strong interactions or correlation. ICE curves diverging from the partial dependence curve can be indicative of such problems.) Histograms are also presented with the partial dependence and ICE curves, to enable a rough measure of epistemic uncertainty for model predictions: predictions based on small amounts of training data are likely less dependable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Function for calculating partial dependence\n", "Since partial dependence and ICE will be calculated for several important variables in the GBM model, it's convenient to have a function doing so. It's probably best to analyze partial dependence and ICE for all variables in a model, but only the top three most important input variables will be investigated here. It's also a good idea to analyze partial dependence and ICE on the test data, or other holdout datasets, to see how the model will perform on new data. \n", "This simple function is designed to return partial dependence when it is called for an entire dataset and ICE when it is called for a single row. The `bins` argument will be used later to calculate ICE values at the same places in an input variable domain that partial dependence is calculated directly below. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def par_dep(xs, frame, model, resolution=20, bins=None):\n", " \n", " \"\"\" Creates Pandas DataFrame containing partial dependence for a \n", " single variable.\n", " \n", " Args:\n", " xs: Variable for which to calculate partial dependence.\n", " frame: Pandas DataFrame for which to calculate partial dependence.\n", " model: XGBoost model for which to calculate partial dependence.\n", " resolution: The number of points across the domain of xs for which \n", " to calculate partial dependence, default 20.\n", " bins: List of values at which to set xs, default 20 equally-spaced \n", " points between column minimum and maximum.\n", " \n", " Returns:\n", " Pandas DataFrame containing partial dependence values.\n", " \n", " \"\"\"\n", " \n", " # turn off pesky Pandas copy warning\n", " pd.options.mode.chained_assignment = None\n", " \n", " # initialize empty Pandas DataFrame with correct column names\n", " par_dep_frame = pd.DataFrame(columns=[xs, 'partial_dependence'])\n", " \n", " # cache original column values \n", " col_cache = frame.loc[:, xs].copy(deep=True)\n", " \n", " # determine values at which to calculate partial dependence\n", " if bins == None:\n", " min_ = frame[xs].min()\n", " max_ = frame[xs].max()\n", " by = (max_ - min_)/resolution\n", " bins = np.arange(min_, max_, by)\n", " \n", " # calculate partial dependence \n", " # by setting column of interest to constant \n", " # and scoring the altered data and taking the mean of the predictions\n", " for j in bins:\n", " frame.loc[:, xs] = j\n", " dframe = xgb.DMatrix(frame)\n", " par_dep_i = pd.DataFrame(model.predict(dframe, ntree_limit=model.best_ntree_limit))\n", " par_dep_j = par_dep_i.mean()[0]\n", " par_dep_frame = par_dep_frame.append({xs:j,\n", " 'partial_dependence': par_dep_j}, \n", " ignore_index=True)\n", " \n", " # return input frame to original cached state \n", " frame.loc[:, xs] = col_cache\n", "\n", " return par_dep_frame\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculate partial dependence for the most important input variables in the GBM\n", "The partial dependence for `LIMIT_BAL` can be seen to decrease as credit balance limits increase. This finding is aligned with expectations that the model predictions will be monotonically decreasing with increasing `LIMIT_BAL` and parsimonious with well-known business practices in credit lending. Partial dependence for other important values is displayed in plots further below." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LIMIT_BALpartial_dependence
010000.00.266888
159500.00.243301
2109000.00.224759
3158500.00.216966
4208000.00.214594
5257500.00.209798
6307000.00.206901
7356500.00.198712
8406000.00.197883
9455500.00.197086
10505000.00.194449
11554500.00.194407
12604000.00.190247
13653500.00.190209
14703000.00.190188
15752500.00.190188
16802000.00.190188
17851500.00.190188
18901000.00.190188
19950500.00.190188
\n", "
" ], "text/plain": [ " LIMIT_BAL partial_dependence\n", "0 10000.0 0.266888\n", "1 59500.0 0.243301\n", "2 109000.0 0.224759\n", "3 158500.0 0.216966\n", "4 208000.0 0.214594\n", "5 257500.0 0.209798\n", "6 307000.0 0.206901\n", "7 356500.0 0.198712\n", "8 406000.0 0.197883\n", "9 455500.0 0.197086\n", "10 505000.0 0.194449\n", "11 554500.0 0.194407\n", "12 604000.0 0.190247\n", "13 653500.0 0.190209\n", "14 703000.0 0.190188\n", "15 752500.0 0.190188\n", "16 802000.0 0.190188\n", "17 851500.0 0.190188\n", "18 901000.0 0.190188\n", "19 950500.0 0.190188" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "par_dep_PAY_0 = par_dep('PAY_0', test[X], xgb_model) # calculate partial dependence for PAY_0\n", "par_dep_LIMIT_BAL = par_dep('LIMIT_BAL', test[X], xgb_model) # calculate partial dependence for LIMIT_BAL\n", "par_dep_BILL_AMT1 = par_dep('BILL_AMT1', test[X], xgb_model) # calculate partial dependence for BILL_AMT1\n", "\n", "# display partial dependence for LIMIT_BAL\n", "par_dep_LIMIT_BAL" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Helper function for finding percentiles of predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ICE can be calculated for any row in the training or test data, but without intimate knowledge of a data source it can be difficult to know where to apply ICE. Calculating and analyzing ICE curves for every row of training and test data set can be overwhelming, even for the example credit card default dataset. One place to start with ICE is to calculate ICE curves at every decile of predicted probabilities in a dataset, giving an indication of local prediction behavior across the dataset. The function below finds and returns the row indices for the maximum, minimum, and deciles of one column in terms of another -- in this case, the model predictions (`p_DEFAULT_NEXT_MONTH`) and the row identifier (`ID`), respectively. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def get_percentile_dict(yhat, id_, frame):\n", "\n", " \"\"\" Returns the percentiles of a column, yhat, as the indices based on \n", " another column id_.\n", " \n", " Args:\n", " yhat: Column in which to find percentiles.\n", " id_: Id column that stores indices for percentiles of yhat.\n", " frame: Pandas DataFrame containing yhat and id_. \n", " \n", " Returns:\n", " Dictionary of percentile values and index column values.\n", " \n", " \"\"\"\n", " \n", " # create a copy of frame and sort it by yhat\n", " sort_df = frame.copy(deep=True)\n", " sort_df.sort_values(yhat, inplace=True)\n", " sort_df.reset_index(inplace=True)\n", " \n", " # find top and bottom percentiles\n", " percentiles_dict = {}\n", " percentiles_dict[0] = sort_df.loc[0, id_]\n", " percentiles_dict[99] = sort_df.loc[sort_df.shape[0]-1, id_]\n", "\n", " # find 10th-90th percentiles\n", " inc = sort_df.shape[0]//10\n", " for i in range(1, 10):\n", " percentiles_dict[i * 10] = sort_df.loc[i * inc, id_]\n", "\n", " return percentiles_dict\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Find some percentiles of yhat in the test data\n", "The values for `ID` that correspond to the maximum, minimum, and deciles of `p_DEFAULT_NEXT_MONTH` are displayed below. ICE will be calculated for the rows of the test dataset associated with these `ID` values." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 23477,\n", " 10: 6226,\n", " 20: 25603,\n", " 30: 12890,\n", " 40: 715,\n", " 50: 14517,\n", " 60: 4908,\n", " 70: 7411,\n", " 80: 6219,\n", " 90: 18421,\n", " 99: 17757}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# merge GBM predictions onto test data\n", "yhat_test = pd.concat([test.reset_index(drop=True), pd.DataFrame(xgb_model.predict(dtest, \n", " ntree_limit=xgb_model.best_ntree_limit))],\n", " axis=1)\n", "yhat_test = yhat_test.rename(columns={0:'p_DEFAULT_NEXT_MONTH'})\n", "\n", "# find percentiles of predictions\n", "percentile_dict = get_percentile_dict('p_DEFAULT_NEXT_MONTH', 'ID', yhat_test)\n", "\n", "# display percentiles dictionary\n", "# ID values for rows\n", "# from lowest prediction \n", "# to highest prediction\n", "percentile_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculate ICE curve values\n", "ICE values represent a model's prediction for a row of data while an input variable of interest is varied across its domain. The values of the input variable are chosen to match the values at which partial dependence was calculated earlier, and ICE is calculated for the top three most important variables and for rows at each percentile of the test dataset. " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# retreive bins from original partial dependence calculation\n", "\n", "bins_PAY_0 = list(par_dep_PAY_0['PAY_0'])\n", "bins_LIMIT_BAL = list(par_dep_LIMIT_BAL['LIMIT_BAL'])\n", "bins_BILL_AMT1 = list(par_dep_BILL_AMT1['BILL_AMT1'])\n", "\n", "# for each percentile in percentile_dict\n", "# create a new column in the par_dep frame \n", "# representing the ICE curve for that percentile\n", "# and the variables of interest\n", "for i in sorted(percentile_dict.keys()):\n", " \n", " col_name = 'Percentile_' + str(i)\n", " \n", " # ICE curves for PAY_0 across percentiles at bins_PAY_0 intervals\n", " par_dep_PAY_0[col_name] = par_dep('PAY_0', \n", " test[test['ID'] == int(percentile_dict[i])][X], \n", " xgb_model, \n", " bins=bins_PAY_0)['partial_dependence']\n", " \n", " # ICE curves for LIMIT_BAL across percentiles at bins_LIMIT_BAL intervals\n", " par_dep_LIMIT_BAL[col_name] = par_dep('LIMIT_BAL', \n", " test[test['ID'] == int(percentile_dict[i])][X], \n", " xgb_model, \n", " bins=bins_LIMIT_BAL)['partial_dependence']\n", " \n", "\n", "\n", " # ICE curves for BILL_AMT1 across percentiles at bins_BILL_AMT1 intervals\n", " par_dep_BILL_AMT1[col_name] = par_dep('BILL_AMT1', \n", " test[test['ID'] == int(percentile_dict[i])][X], \n", " xgb_model, \n", " bins=bins_BILL_AMT1)['partial_dependence']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display partial dependence and ICE for `LIMIT_BAL`\n", "Partial dependence and ICE values for rows at the minimum, maximum and deciles (0%, 10%, 20%, ..., 90%, 99%) of predictions for `DEFAULT_NEXT_MONTH` and at the values of `LIMIT_BAL` used for partial dependence are shown here. Each column of ICE values will be a curve in the plots below. ICE values represent a prediction for a row of test data, at a percentile of interest noted in the column name above, and setting `LIMIT_BAL` to the value of `LIMIT_BAL` at right. Notice that monotonic decreasing prediction behavior for `LIMIT_BAL` holds at each displayed percentile of predicted `DEFAULT_NEXT_MONTH`, helping to validate that the trained GBM predictions are monotonic for `LIMIT_BAL`." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LIMIT_BALpartial_dependencePercentile_0Percentile_10Percentile_20Percentile_30Percentile_40Percentile_50Percentile_60Percentile_70Percentile_80Percentile_90Percentile_99
010000.00.2668880.0097520.1036100.1141300.1617170.1781960.1764230.2159320.2379700.4823090.6008730.963266
159500.00.2433010.0077650.0958050.1067690.1519720.1558970.1626220.1773660.2066510.3376910.5895420.943703
2109000.00.2247590.0045140.0737010.0898270.1222720.1479100.1299950.1644100.1861390.3303700.5752410.939937
3158500.00.2169660.0035100.0615870.0827310.1087020.1399230.1172460.1584310.1746850.3303700.5715360.939715
4208000.00.2145940.0032080.0590360.0818630.1068500.1371780.1152680.1584130.1714020.3303700.5695080.938817
5257500.00.2097980.0029250.0542680.0782280.1029880.1342960.1130190.1550230.1667660.3301300.5590600.936782
6307000.00.2069010.0028240.0532630.0757230.1006070.1339150.1097010.1527060.1621160.3301300.5554020.935711
7356500.00.1987120.0018720.0463650.0663660.0832480.1289320.0909320.1502450.1471200.3293560.5522230.935394
8406000.00.1978830.0018720.0457750.0648630.0814530.1276130.0908820.1495580.1468090.3245220.5501670.934938
9455500.00.1970860.0018720.0447610.0632190.0793120.1269550.0888180.1493720.1464950.3224380.5501670.934894
10505000.00.1944490.0018470.0442760.0625030.0783140.1214300.0874030.1434540.1438160.3144050.5492310.934661
11554500.00.1944070.0018470.0442620.0624840.0780400.1214300.0873770.1433940.1437750.3144050.5491950.934661
12604000.00.1902470.0017630.0422380.0594090.0742560.1190030.0830550.1398840.1372490.3077140.5475380.933731
13653500.00.1902090.0017630.0422310.0593840.0742290.1190030.0829680.1398120.1371860.3075880.5474950.933731
14703000.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
15752500.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
16802000.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
17851500.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
18901000.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
19950500.00.1901880.0017630.0422110.0593560.0741950.1190030.0829290.1398120.1371260.3075880.5473700.933731
\n", "
" ], "text/plain": [ " LIMIT_BAL partial_dependence Percentile_0 Percentile_10 Percentile_20 \\\n", "0 10000.0 0.266888 0.009752 0.103610 0.114130 \n", "1 59500.0 0.243301 0.007765 0.095805 0.106769 \n", "2 109000.0 0.224759 0.004514 0.073701 0.089827 \n", "3 158500.0 0.216966 0.003510 0.061587 0.082731 \n", "4 208000.0 0.214594 0.003208 0.059036 0.081863 \n", "5 257500.0 0.209798 0.002925 0.054268 0.078228 \n", "6 307000.0 0.206901 0.002824 0.053263 0.075723 \n", "7 356500.0 0.198712 0.001872 0.046365 0.066366 \n", "8 406000.0 0.197883 0.001872 0.045775 0.064863 \n", "9 455500.0 0.197086 0.001872 0.044761 0.063219 \n", "10 505000.0 0.194449 0.001847 0.044276 0.062503 \n", "11 554500.0 0.194407 0.001847 0.044262 0.062484 \n", "12 604000.0 0.190247 0.001763 0.042238 0.059409 \n", "13 653500.0 0.190209 0.001763 0.042231 0.059384 \n", "14 703000.0 0.190188 0.001763 0.042211 0.059356 \n", "15 752500.0 0.190188 0.001763 0.042211 0.059356 \n", "16 802000.0 0.190188 0.001763 0.042211 0.059356 \n", "17 851500.0 0.190188 0.001763 0.042211 0.059356 \n", "18 901000.0 0.190188 0.001763 0.042211 0.059356 \n", "19 950500.0 0.190188 0.001763 0.042211 0.059356 \n", "\n", " Percentile_30 Percentile_40 Percentile_50 Percentile_60 Percentile_70 \\\n", "0 0.161717 0.178196 0.176423 0.215932 0.237970 \n", "1 0.151972 0.155897 0.162622 0.177366 0.206651 \n", "2 0.122272 0.147910 0.129995 0.164410 0.186139 \n", "3 0.108702 0.139923 0.117246 0.158431 0.174685 \n", "4 0.106850 0.137178 0.115268 0.158413 0.171402 \n", "5 0.102988 0.134296 0.113019 0.155023 0.166766 \n", "6 0.100607 0.133915 0.109701 0.152706 0.162116 \n", "7 0.083248 0.128932 0.090932 0.150245 0.147120 \n", "8 0.081453 0.127613 0.090882 0.149558 0.146809 \n", "9 0.079312 0.126955 0.088818 0.149372 0.146495 \n", "10 0.078314 0.121430 0.087403 0.143454 0.143816 \n", "11 0.078040 0.121430 0.087377 0.143394 0.143775 \n", "12 0.074256 0.119003 0.083055 0.139884 0.137249 \n", "13 0.074229 0.119003 0.082968 0.139812 0.137186 \n", "14 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "15 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "16 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "17 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "18 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "19 0.074195 0.119003 0.082929 0.139812 0.137126 \n", "\n", " Percentile_80 Percentile_90 Percentile_99 \n", "0 0.482309 0.600873 0.963266 \n", "1 0.337691 0.589542 0.943703 \n", "2 0.330370 0.575241 0.939937 \n", "3 0.330370 0.571536 0.939715 \n", "4 0.330370 0.569508 0.938817 \n", "5 0.330130 0.559060 0.936782 \n", "6 0.330130 0.555402 0.935711 \n", "7 0.329356 0.552223 0.935394 \n", "8 0.324522 0.550167 0.934938 \n", "9 0.322438 0.550167 0.934894 \n", "10 0.314405 0.549231 0.934661 \n", "11 0.314405 0.549195 0.934661 \n", "12 0.307714 0.547538 0.933731 \n", "13 0.307588 0.547495 0.933731 \n", "14 0.307588 0.547370 0.933731 \n", "15 0.307588 0.547370 0.933731 \n", "16 0.307588 0.547370 0.933731 \n", "17 0.307588 0.547370 0.933731 \n", "18 0.307588 0.547370 0.933731 \n", "19 0.307588 0.547370 0.933731 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "par_dep_LIMIT_BAL" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Plotting partial dependence and ICE to validate and explain monotonic behavior\n", "\n", "Overlaying partial dependence onto ICE in a plot is a convenient way to validate and understand both global and local monotonic behavior. Plots of partial dependence curves overlayed onto ICE curves for several percentiles of predictions for `DEFAULT_NEXT_MONTH` are used to validate monotonic behavior, describe the GBM model mechanisms, and to compare the most extreme GBM behavior with the average GBM behavior in the test data. Partial dependence and ICE plots are displayed for the three most important variables in the GBM: `PAY_0`, `LIMIT_BAL`, and `BILL_AMT1`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "#### Function to plot partial dependence and ICE\n", "\n", "def plot_par_dep_ICE(xs, par_dep_frame):\n", "\n", " \n", " \"\"\" Plots ICE overlayed onto partial dependence for a single variable.\n", " \n", " Args: \n", " xs: Name of variable for which to plot ICE and partial dependence.\n", " par_dep_frame: Name of Pandas DataFrame containing ICE and partial\n", " dependence values.\n", " \n", " \"\"\"\n", " \n", " # initialize figure and axis\n", " fig, ax = plt.subplots()\n", " \n", " # plot ICE curves\n", " par_dep_frame.drop('partial_dependence', axis=1).plot(x=xs, \n", " colormap='gnuplot',\n", " ax=ax)\n", "\n", " # overlay partial dependence, annotate plot\n", " par_dep_frame.plot(title='Partial Dependence and ICE for ' + str(xs),\n", " x=xs, \n", " y='partial_dependence',\n", " style='r-', \n", " linewidth=3, \n", " ax=ax)\n", "\n", " # add legend\n", " _ = plt.legend(bbox_to_anchor=(1.05, 0),\n", " loc=3, \n", " borderaxespad=0.)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Partial dependence and ICE plot for `LIMIT_BAL`\n", "The monotonic prediction behavior displayed in the partial dependence, and ICE tables for `LIMIT_BAL` is also visible in this plot. Monotonic decreasing behavior is evident at every percentile of predictions for `DEFAULT_NEXT_MONTH`. Most percentiles of predictions show that sharper decreases in probability of default occur when `LIMIT_BAL` increases just slightly from its lowest values in the test set. However, for the custumers that are most likely to default according to the GBM model, no increase in `LIMIT_BAL` has a strong impact on probabilitiy of default. As mentioned previously, the displayed relationship between credit balance limits and probablility of default is not uncommon in credit lending. As can be seen from the displayed histogram, above ~$NT 500,000 prediction behavior may have been learned from extremely small samples of data. " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_par_dep_ICE('LIMIT_BAL', par_dep_LIMIT_BAL) # plot partial dependence and ICE for LIMIT_BAL" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_ = train['LIMIT_BAL'].plot(kind='hist', bins=20, title='Histogram: LIMIT_BAL')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Partial dependence and ICE plot for `PAY_0`\n", "Monotonic increasing prediction behavior for `PAY_0` is displayed for all percentiles of model predictions. Predition behavior is different at different deciles, but not abnormal or vastly different from the average prediction behavior represented by the red partial dependence curve. The largest jump in predicted probability appears to occur at `PAY_0 = 2`, or when a customer becomes two months late on their most recent payment. Above `PAY_0 = 3` there are few examples from which the model could learn." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_par_dep_ICE('PAY_0', par_dep_PAY_0) # plot partial dependence and ICE for PAY_0" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_ = train['PAY_0'].plot(kind='hist', bins=20, title='Histogram: PAY_0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Partial dependence and ICE plot for `BILL_AMT1`\n", "Monotonic decreasing prediction behavior for `BILL_AMT1` is also displayed for all percentiles. This mild decrease in probability of default as most recent bill amount increases could be related to wealthier, big-spending customers taking on more debt but also being able to pay it off reliably. Also, customers with negative bills are more likely to default, potentially indicating charge-offs are being recorded as negative bills. In a mission-critical situation, this issue would require more debugging. Also predictions below \\$ NT 0 and above \\$ NT 400,000 are based on very little training data." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_par_dep_ICE('BILL_AMT1', par_dep_BILL_AMT1) # plot partial dependence and ICE for BILL_AMT1" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_ = train['BILL_AMT1'].plot(kind='hist', bins=20, title='Histogram: BILL_AMT1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Generate reason codes using the Shapley method \n", "Now that the monotonic behavior of the GBM has been verified and compared against domain knowledge and reasonable expectations, a method called Shapley explanations will be used to calculate the local variable importance for any one prediction: http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions. Shapley explanations are the only possible consistent local variable importance values. (Here consistency means that if a variable is more important than another variable in a given prediction, the more important variable's Shapley value will not be smaller in magnitude than the less important variable's Shapley value.) Very crucially Shapley values also *always* sum to the actual prediction of the XGBoost model. When used in a model-specific context for decision tree models, Shapley values are likely the most accurate known local variable importance method available today. In this notebook, XGBoost itself is used to create Shapley values with the `pred_contribs` parameter to `predict()`, but the `shap` package is also available for other types of models: https://github.com/slundberg/shap. \n", "\n", "The numeric Shapley values in each column are an estimate of how much each variable contributed to each prediction. Shapley contributions can indicate how a variable and its values were weighted in any given decision by the model. These values are crucially important for machine learning interpretability and are related to \"local feature importance\", \"reason codes\", or \"turn-down codes.\" The latter phrases are borrowed from credit scoring. Credit lenders in the U.S. must provide reasons for automatically rejecting a credit application. Reason codes can be easily extracted from Shapley local variable contribution values by ranking the variables that played the largest role in any given decision." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To find the index corresponding to a particular row of interest later, the index of the `test` DataFrame is reset to begin at 0 and increase sequentially. Without resetting the index, the `test` DataFrame row indices still correspond to the original raw data from which the test set was sampled." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "test.reset_index(drop=True, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Select most risky customer in test data\n", "One person who might be of immediate interest is the most likely to default customer in the test data. This customer's row will be selected and local variable importance for the corresponding prediction will be analyzed." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "decile = 99\n", "row = test[test['ID'] == percentile_dict[decile]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a Pandas DataFrame of Shapley values for riskiest customer\n", "The most interesting Shapley values are probably those that push this customer's probability of default higher, i.e. the highest positive Shapley values. Those values are plotted below." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# reset test data index to find riskiest customer in shap_values \n", "# sort to find largest positive contributions\n", "s_df = pd.DataFrame(shap_values[row.index[0], :][:-1].reshape(23, 1), columns=['Reason Codes'], index=X)\n", "s_df.sort_values(by='Reason Codes', inplace=True, ascending=False)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Reason Codes
PAY_01.526435
PAY_50.511360
PAY_60.328183
PAY_20.289542
LIMIT_BAL0.287896
PAY_40.227838
AGE0.225641
BILL_AMT10.197412
PAY_30.178575
MARRIAGE0.171177
PAY_AMT30.133252
PAY_AMT10.107046
PAY_AMT20.082654
BILL_AMT30.068008
EDUCATION0.060491
PAY_AMT40.060251
BILL_AMT20.057245
BILL_AMT40.039342
PAY_AMT50.036137
PAY_AMT60.025210
BILL_AMT60.005666
BILL_AMT50.002853
SEX-0.096203
\n", "
" ], "text/plain": [ " Reason Codes\n", "PAY_0 1.526435\n", "PAY_5 0.511360\n", "PAY_6 0.328183\n", "PAY_2 0.289542\n", "LIMIT_BAL 0.287896\n", "PAY_4 0.227838\n", "AGE 0.225641\n", "BILL_AMT1 0.197412\n", "PAY_3 0.178575\n", "MARRIAGE 0.171177\n", "PAY_AMT3 0.133252\n", "PAY_AMT1 0.107046\n", "PAY_AMT2 0.082654\n", "BILL_AMT3 0.068008\n", "EDUCATION 0.060491\n", "PAY_AMT4 0.060251\n", "BILL_AMT2 0.057245\n", "BILL_AMT4 0.039342\n", "PAY_AMT5 0.036137\n", "PAY_AMT6 0.025210\n", "BILL_AMT6 0.005666\n", "BILL_AMT5 0.002853\n", "SEX -0.096203" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot top local contributions as reason codes" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_ = s_df[:5].plot(kind='bar', \n", " title='Top Five Reason Codes for a Risky Customer\\n', \n", " legend=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the customer in the test dataset that the GBM predicts as most likely to default, the most important input variables in the prediction are, in descending order, `PAY_0`, `PAY_5`, `PAY_6`, `PAY_2`, and `LIMIT_BAL`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display customer in question " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local contributions for this customer appear reasonable, especially when considering her payment information. Her most recent payment was 3 months late and her payment for 6 months and 5 months previous were 7 months late. Also her credit limit was extremely low, so it's logical that these factors would weigh heavily into the model's prediction for default for this customer." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDLIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_0PAY_2PAY_3PAY_4PAY_5PAY_6BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6DEFAULT_NEXT_MONTH
53991775710000231513227772400240024002400240024000000001
\n", "
" ], "text/plain": [ " ID LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 \\\n", "5399 17757 10000 2 3 1 51 3 2 2 \n", "\n", " PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 \\\n", "5399 7 7 7 2400 2400 2400 2400 \n", "\n", " BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 \\\n", "5399 2400 2400 0 0 0 0 0 \n", "\n", " PAY_AMT6 DEFAULT_NEXT_MONTH \n", "5399 0 1 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row # helps understand reason codes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To generate reason codes for the model's decision, the locally important variable and its value are used together. If this customer was denied future credit based on this model and data, the top five Shapley-based reason codes for the automated decision would be:\n", "\n", "1. Most recent payment is 3 months delayed.\n", "2. 5th most recent payment is 7 months delayed.\n", "3. 6th most recent payment is 7 months delayed.\n", "4. 2nd most recent payment is 2 months delayed.\n", "5. Credit limit is too low: 10,000 $NT.\n", "\n", "(Of course, credit limits are set by the lender and are used to price-in risk to credit decisions, so using credit limits as reason codes or even in a probability of default model is likely questionable. However, in this small, example data set all input columns were used to generate a better model fit. For a slightly more careful treatment of gradient boosting in the context of credit scoring, please see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Summary\n", "\n", "In this notebook, a highly transparent, nonlinear, monotonic GBM classifier was trained to predict credit card defaults and the monotonic behavior of the classifier was analyzed and validated. To do so, Pearson correlation between each input and the target was used to determine the direction for monotonicity constraints for each input variable in the XGBoost classifier. GBM variable importance, partial dependence, and ICE were calculated, plotted, and compared to one another, domain knowledge, and reasonable expectations. Shapley values were then used to explain the model predictions for the single most risky customer in the test set. These techniques should generalize well for many types of business and research problems, enabling you to train a monotonic GBM model and analyze, validate, and explain it to your colleagues, bosses, and potentially, external regulators. " ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }