{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## License \n", "\n", "Copyright 2017 - 2020 Patrick Hall and the H2O.ai team\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " http://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**DISCLAIMER:** This notebook is not legal compliance advice." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Increase Transparency and Accountability in Your Machine Learning Project with Python and H2O\n", "#### Explain your complex models with decision tree surrogates, GBM feature importance, and reason codes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Decision trees and decision tree ensembles are some of the most popular machine learning models used in commercial practice. They can train and make predictions on data containing character values and missing values - both common in large commercial data stores. Single decision trees are easily represented as directed graphs, which can drastically increase their interpretability and transparency. Decision tree ensembles (i.e., random forests and gradient boosting machines (GBMs)), can be used to increase the accuracy and stability of single decision tree models, but are far less intepretable than single trees. These characteristics of decision trees will be leveraged here to increase transparency and accountability in complex, nonlinear, machine learning models.\n", "\n", "This notebook starts by training a GBM on the UCI credit card default data using the popular open source library, h2o. A single decision tree *surrogate* model will then be trained on the original UCI credit card default data and the predictions from the h2o GBM, to create an approximate flow chart for the GBM's global decision-making processes. A technique known as leave-one-covariate/column-out (LOCO) will then be used to generate local explanations for any row-wise prediction made by the GBM model. Finally, local explanations are ensembled together from multiple similar models to increase explanation stability. \n", "\n", "**Note**: As of the h2o 3.24 \"Yates\" release, Shapley values are supported in h2o and LOCO is no longer recommended. To see Shapley values for an h2o GBM in action please see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Python imports\n", "In general, NumPy and Pandas will be used for data manipulation purposes and h2o will be used for modeling tasks. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "# h2o Python API with specific classes\n", "import h2o \n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator # for GBM\n", "from h2o.estimators.random_forest import H2ORandomForestEstimator # for single tree\n", "from h2o.backend import H2OLocalServer # for plotting local tree in-notebook\n", "\n", "import numpy as np # array, vector, matrix calculations\n", "import pandas as pd # DataFrame handling\n", "\n", "# system packages for calling external graphviz processes\n", "import os\n", "import re\n", "import subprocess\n", "\n", "# in-notebook display\n", "from IPython.display import Image\n", "from IPython.display import display\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start h2o\n", "H2o is both a library and a server. The machine learning algorithms in the library take advantage of the multithreaded and distributed architecture provided by the server to train machine learning algorithms extremely efficiently. The API for the library was imported above in cell 1, but the server still needs to be started." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking whether there is an H2O instance running at http://localhost:54321 ..... not found.\n", "Attempting to start a local H2O server...\n", " Java Version: openjdk version \"1.8.0_232\"; OpenJDK Runtime Environment (build 1.8.0_232-8u232-b09-0ubuntu1~16.04.1-b09); OpenJDK 64-Bit Server VM (build 25.232-b09, mixed mode)\n", " Starting server from /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar\n", " Ice root: /tmp/tmpjpy5faje\n", " JVM stdout: /tmp/tmpjpy5faje/h2o_patrickh_started_from_python.out\n", " JVM stderr: /tmp/tmpjpy5faje/h2o_patrickh_started_from_python.err\n", " Server is running at http://127.0.0.1:54321\n", "Connecting to H2O server at http://127.0.0.1:54321 ... successful.\n", "Warning: Your H2O cluster version is too old (4 months and 14 days)! Please download and install the latest version from http://h2o.ai/download/\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
H2O cluster uptime:01 secs
H2O cluster timezone:America/New_York
H2O data parsing timezone:UTC
H2O cluster version:3.26.0.3
H2O cluster version age:4 months and 14 days !!!
H2O cluster name:H2O_from_python_patrickh_6414b6
H2O cluster total nodes:1
H2O cluster free memory:1.778 Gb
H2O cluster total cores:8
H2O cluster allowed cores:8
H2O cluster status:accepting new members, healthy
H2O connection url:http://127.0.0.1:54321
H2O connection proxy:None
H2O internal security:False
H2O API Extensions:Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4
Python version:3.6.3 final
" ], "text/plain": [ "-------------------------- ---------------------------------------------------\n", "H2O cluster uptime: 01 secs\n", "H2O cluster timezone: America/New_York\n", "H2O data parsing timezone: UTC\n", "H2O cluster version: 3.26.0.3\n", "H2O cluster version age: 4 months and 14 days !!!\n", "H2O cluster name: H2O_from_python_patrickh_6414b6\n", "H2O cluster total nodes: 1\n", "H2O cluster free memory: 1.778 Gb\n", "H2O cluster total cores: 8\n", "H2O cluster allowed cores: 8\n", "H2O cluster status: accepting new members, healthy\n", "H2O connection url: http://127.0.0.1:54321\n", "H2O connection proxy:\n", "H2O internal security: False\n", "H2O API Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4\n", "Python version: 3.6.3 final\n", "-------------------------- ---------------------------------------------------" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "h2o.init(max_mem_size='2G') # start h2o\n", "h2o.remove_all() # remove any existing data structures from h2o memory" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Download, explore, and prepare UCI credit card default data\n", "\n", "UCI credit card default data: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients\n", "\n", "The UCI credit card default data contains demographic and payment information about credit card customers in Taiwan in the year 2005. The data set contains 23 input variables: \n", "\n", "* **`LIMIT_BAL`**: Amount of given credit (NT dollar)\n", "* **`SEX`**: 1 = male; 2 = female\n", "* **`EDUCATION`**: 1 = graduate school; 2 = university; 3 = high school; 4 = others \n", "* **`MARRIAGE`**: 1 = married; 2 = single; 3 = others\n", "* **`AGE`**: Age in years \n", "* **`PAY_0`, `PAY_2` - `PAY_6`**: History of past payment; `PAY_0` = the repayment status in September, 2005; `PAY_2` = the repayment status in August, 2005; ...; `PAY_6` = the repayment status in April, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; ...; 8 = payment delay for eight months; 9 = payment delay for nine months and above. \n", "* **`BILL_AMT1` - `BILL_AMT6`**: Amount of bill statement (NT dollar). `BILL_AMNT1` = amount of bill statement in September, 2005; `BILL_AMT2` = amount of bill statement in August, 2005; ...; `BILL_AMT6` = amount of bill statement in April, 2005. \n", "* **`PAY_AMT1` - `PAY_AMT6`**: Amount of previous payment (NT dollar). `PAY_AMT1` = amount paid in September, 2005; `PAY_AMT2` = amount paid in August, 2005; ...; `PAY_AMT6` = amount paid in April, 2005. \n", "\n", "These 23 input variables are used to predict the target variable, whether or not a customer defaulted on their credit card bill in late 2005.\n", "\n", "Because h2o accepts both numeric and character inputs, some variables will be recoded into more transparent character values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import data and clean\n", "The credit card default data is available as an `.xls` file. Pandas reads `.xls` files automatically, so it's used to load the credit card default data and give the prediction target a shorter name: `DEFAULT_NEXT_MONTH`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# import XLS file\n", "path = 'default_of_credit_card_clients.xls'\n", "data = pd.read_excel(path,\n", " skiprows=1)\n", "\n", "# remove spaces from target column name \n", "data = data.rename(columns={'default payment next month': 'DEFAULT_NEXT_MONTH'}) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Assign modeling roles\n", "The shorthand name `y` is assigned to the prediction target. `X` is assigned to all other input variables in the credit card default data except the row indentifier, `ID`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = DEFAULT_NEXT_MONTH\n", "X = ['LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']\n" ] } ], "source": [ "# assign target and inputs for GBM\n", "y = 'DEFAULT_NEXT_MONTH'\n", "X = [name for name in data.columns if name not in [y, 'ID']]\n", "print('y =', y)\n", "print('X =', X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Helper function for recoding values in the UCI credict card default data\n", "This simple function maps longer, more understandable character string values from the UCI credit card default data dictionary to the original integer values of the input variables found in the dataset. These character values can be used directly in h2o decision tree models, and the function returns the original Pandas DataFrame as an h2o object, an H2OFrame. H2o models cannot run on Pandas DataFrames. They require H2OFrames." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parse progress: |█████████████████████████████████████████████████████████| 100%\n" ] } ], "source": [ "def recode_cc_data(frame):\n", " \n", " \"\"\" Recodes numeric categorical variables into categorical character variables\n", " with more transparent values. \n", " \n", " Args:\n", " frame: Pandas DataFrame version of UCI credit card default data.\n", " \n", " Returns: \n", " H2OFrame with recoded values.\n", " \n", " \"\"\"\n", " \n", " # define recoded values\n", " sex_dict = {1:'male', 2:'female'}\n", " education_dict = {0:'other', 1:'graduate school', 2:'university', 3:'high school', \n", " 4:'other', 5:'other', 6:'other'}\n", " marriage_dict = {0:'other', 1:'married', 2:'single', 3:'divorced'}\n", " pay_dict = {-2:'no consumption', -1:'pay duly', 0:'use of revolving credit', 1:'1 month delay', \n", " 2:'2 month delay', 3:'3 month delay', 4:'4 month delay', 5:'5 month delay', 6:'6 month delay', \n", " 7:'7 month delay', 8:'8 month delay', 9:'9+ month delay'}\n", " \n", " # recode values using Pandas apply() and anonymous function\n", " frame['SEX'] = frame['SEX'].apply(lambda i: sex_dict[i])\n", " frame['EDUCATION'] = frame['EDUCATION'].apply(lambda i: education_dict[i]) \n", " frame['MARRIAGE'] = frame['MARRIAGE'].apply(lambda i: marriage_dict[i]) \n", " for name in frame.columns:\n", " if name in ['PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6']:\n", " frame[name] = frame[name].apply(lambda i: pay_dict[i]) \n", " \n", " return h2o.H2OFrame(frame)\n", "\n", "data = recode_cc_data(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Ensure target is handled as a categorical variable" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In h2o, a numeric variable can be treated as numeric or categorical. The target variable `DEFAULT_NEXT_MONTH` takes on values of `0` or `1`. To ensure this numeric variable is treated as a categorical variable, the `asfactor()` function is used to explicitly declare that it is a categorical variable. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data[y] = data[y].asfactor() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display descriptive statistics\n", "The h2o `describe()` function displays a brief description of the credit card default data. For the categorical input variables `LIMIT_BAL`, `SEX`, `EDUCATION`, `MARRIAGE`, and `PAY_0`-`PAY_6`, the new character values created above in cell 5 are visible. Basic descriptive statistics are displayed for numeric inputs. Also, it's easy to see there are no missing values in this dataset, which will be an important consideration for calculating LOCO values in section 5 and 6." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rows:30000\n", "Cols:24\n", "\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 DEFAULT_NEXT_MONTH
type int enum enum enum int enum enum enum enum enum enum int int int int int int int int int int int int enum
mins 10000.0 21.0 -165580.0 -69777.0 -157264.0 -170000.0 -81334.0 -339603.0 0.0 0.0 0.0 0.0 0.0 0.0
mean 167484.32266666688 35.48549999999994 51223.3309000000949179.0751666666847013.1547999997143262.9489666666 40311.4009666665338871.760399999915663.580500000014 5921.16350000001 5225.681500000005 4826.076866666661 4799.387633333302 5215.502566666664
maxs 1000000.0 79.0 964511.0 983931.0 1664089.0 891586.0 927171.0 961664.0 873552.0 1684259.0 896040.0 621000.0 426529.0 528666.0
sigma 129747.66156720225 9.21790406809016 73635.8605755295971173.7687825283669349.3874270368164332.8561339164160797.1557702648 59554.1075367457416563.28035402576323040.87040205722617606.96146980311515666.15974403199315278.30567914479317777.465775435332
zeros 0 0 2008 2506 2870 3195 3506 4020 5249 5396 5968 6408 6703 7173
missing0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 20000.0 femaleuniversity married 24.0 2 month delay 2 month delay pay duly pay duly no consumption no consumption 3913.0 3102.0 689.0 0.0 0.0 0.0 0.0 689.0 0.0 0.0 0.0 0.0 1
1 120000.0 femaleuniversity single 26.0 pay duly 2 month delay use of revolving credituse of revolving credituse of revolving credit2 month delay 2682.0 1725.0 2682.0 3272.0 3455.0 3261.0 0.0 1000.0 1000.0 1000.0 0.0 2000.0 1
2 90000.0 femaleuniversity single 34.0 use of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credit29239.0 14027.0 13559.0 14331.0 14948.0 15549.0 1518.0 1500.0 1000.0 1000.0 1000.0 5000.0 0
3 50000.0 femaleuniversity married 37.0 use of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credit46990.0 48233.0 49291.0 28314.0 28959.0 29547.0 2000.0 2019.0 1200.0 1100.0 1069.0 1000.0 0
4 50000.0 male university married 57.0 pay duly use of revolving creditpay duly use of revolving credituse of revolving credituse of revolving credit8617.0 5670.0 35835.0 20940.0 19146.0 19131.0 2000.0 36681.0 10000.0 9000.0 689.0 679.0 0
5 50000.0 male graduate schoolsingle 37.0 use of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credit64400.0 57069.0 57608.0 19394.0 19619.0 20024.0 2500.0 1815.0 657.0 1000.0 1000.0 800.0 0
6 500000.0 male graduate schoolsingle 29.0 use of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credituse of revolving credit367965.0 412023.0 445007.0 542653.0 483003.0 473944.0 55000.0 40000.0 38000.0 20239.0 13750.0 13770.0 0
7 100000.0 femaleuniversity single 23.0 use of revolving creditpay duly pay duly use of revolving credituse of revolving creditpay duly 11876.0 380.0 601.0 221.0 -159.0 567.0 380.0 601.0 0.0 581.0 1687.0 1542.0 0
8 140000.0 femalehigh school married 28.0 use of revolving credituse of revolving credit2 month delay use of revolving credituse of revolving credituse of revolving credit11285.0 14096.0 12108.0 12211.0 11793.0 3719.0 3329.0 0.0 432.0 1000.0 1000.0 1000.0 0
9 20000.0 male high school single 35.0 no consumption no consumption no consumption no consumption pay duly pay duly 0.0 0.0 0.0 0.0 13007.0 13912.0 0.0 0.0 0.0 13007.0 1122.0 0.0 0
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data[X + [y]].describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Train an H2O GBM classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split data into training and test sets for early stopping\n", "The credit card default data is split into training and test sets to monitor and prevent overtraining. Reproducibility is also an important factor in creating trustworthy models, and randomly splitting datasets can introduce randomness in model predictions and other results. A random seed is used here to ensure the data split is reproducible." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train data rows = 21060, columns = 25\n", "Test data rows = 8940, columns = 25\n" ] } ], "source": [ "# split into training and validation\n", "train, test = data.split_frame([0.7], seed=12345)\n", "\n", "# summarize split\n", "print('Train data rows = %d, columns = %d' % (train.shape[0], train.shape[1]))\n", "print('Test data rows = %d, columns = %d' % (test.shape[0], test.shape[1]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train h2o GBM classifier\n", "Many tuning parameters must be specified to train a GBM using h2o. Typically a grid search would be performed to identify the best parameters for a given modeling task using the `H2OGridSearch` class. For brevity's sake, a previously-discovered set of good tuning parameters are specified here. Because gradient boosting methods typically resample training data, an additional random seed is also specified for the h2o GBM using the `seed` parameter to create reproducible predictions, error rates, and variable importance values. To avoid overfitting, the `stopping_rounds` parameter is used to stop the training process after the test error fails to decrease for 5 iterations.\n", "\n", "The `balance_classes` parameter ensures the positive and negative classes of the target variable are seen by the model in equal proportions during training. This can be very important for the LOCO calculations in section 5 and 6 for unbalanced data. From experiments across several data sets, explanations for rows with a majority class label for the target variable (e.g., 0) generated by LOCO are more likely to match those generated by another popular explanatory technique, LIME, when the target class is rebalanced during training. `balance_classes` is commented below because the row explained in this notebook has a minority class label (e.g., 1)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Model Build progress: |███████████████████████████████████████████████| 100%\n", "GBM Test AUC = 0.78\n" ] } ], "source": [ "# initialize GBM model\n", "model = H2OGradientBoostingEstimator(ntrees=150, # maximum 150 trees in GBM\n", " max_depth=4, # trees can have maximum depth of 4\n", " sample_rate=0.9, # use 90% of rows in each iteration (tree)\n", " col_sample_rate=0.9, # use 90% of variables in each iteration (tree)\n", " #balance_classes=True, # sample to balance 0/1 distribution of target - can help LOCO\n", " stopping_rounds=5, # stop if validation error does not decrease for 5 iterations (trees)\n", " score_tree_interval=1, # for reproducibility, set higher for bigger data\n", " seed=12345) # for reproducibility\n", "\n", "# train a GBM model\n", "model.train(y=y, x=X, training_frame=train, validation_frame=test)\n", "\n", "# print AUC\n", "print('GBM Test AUC = %.2f' % model.auc(valid=True))\n", "\n", "# uncomment to see model details\n", "# print(model) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display variable importance\n", "During training, the h2o GBM aggregates the improvement in error caused by each split in each decision tree across all the decision trees in the ensemble classifier. These values are attributed to the input variable used in each split and give an indication of the contribution each input variable makes toward the model's predictions. The variable importance ranking should be parsimonious with human domain knowledge and reasonable expectations. In this case, a customer's most recent payment behavior, `PAY_0`, is by far the most important variable followed by their second most recent payment, `PAY_2`, and third most recent payment, `PAY_3`, behavior. This result is well-aligned with business practices in credit lending: people who miss their most recent payments are likely to default soon." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.varimp_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Train a decision tree surrogate model to describe GBM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A surrogate model is a simple model that is used to explain a complex model. One of the original references for surrogate models is available here: https://papers.nips.cc/paper/1152-extracting-tree-structured-representations-of-trained-networks.pdf. In this example, a single decision tree will be trained on the original inputs and predictions of the h2o GBM model and the tree will be visualized using special functionality in h2o and GraphViz. The variable importance, interactions, and decision paths displayed in the directed graph of the trained decision tree surrogate model are then assumed to be indicative of the internal mechanisms of the more complex GBM model, creating an approximate, overall flowchart for the GBM. There are few mathematical guarantees that the simple surrogate model is highly representative of the more complex GBM, but a recent preprint article has put forward ideas on strenghthening the theoretical relationship between surrogate models and more complex models: https://arxiv.org/pdf/1705.08504.pdf. Since surrogate models alone do not gaurantee accurate transparency, they will be used along with GBM variable importance and LOCO to build a cohesive narrative about the mechansims within the GBM. **Because many currently-available explanatory techniques are approximate, it is recommended that users employ several different explanatory techniques and trust only consisent results across techniques. Also, as of h2o 3.24, Shapley values are supported for h2o GBM. Use them instead of LOCO for any high-stakes application.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create dataset for surrogate model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train a surrogate model, the predictions and original inputs of the complex model to be explained need to be in the same dataset. The test data is used here to see how the model behaves on holdout data, which should be closer to its behavior on new data than analyzing the surrogate model for the training inputs and predictions." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm prediction progress: |████████████████████████████████████████████████| 100%\n" ] } ], "source": [ "# cbind predictions to training frame\n", "# give them a nice name\n", "yhat = 'p_DEFAULT_NEXT_MONTH'\n", "preds1 = test['ID'].cbind(model.predict(test).drop(['predict', 'p0']))\n", "preds1.columns = ['ID', yhat]\n", "test_yhat = test.cbind(preds1[yhat])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train single h2o decision tree\n", "A single decision tree is trained on the test inputs and predictions. To simulate a single decision tree in h2o, the `H2ORandomForestEstimator` class is used, but only one tree is trained instead of a forest of decision trees. Setting the `mtry` parameter to `-2` tells the `H2ORandomForestEstimator` to consider all variables in all splits of a tree, instead of considering a random subset of columns. It is also recommended to set a random seed for reproducibility and to set `max_depth` to a lower number, say less than 6, so that the surrogate model will not become overly complex and hard to explain and understand. Once the tree is trained, a model optimized java object (MOJO) representation of the tree is saved. H2o provides a way to visualize the trained tree in detail using the MOJO and Graphviz." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", "Generated MOJO path:\n", " /home/patrickh/workspace/interpretable_machine_learning_with_python/dt_surrogate_mojo.zip\n" ] } ], "source": [ "model_id = 'dt_surrogate_mojo' # gives MOJO artifact a recognizable name\n", "\n", "# initialize single tree surrogate model\n", "surrogate = H2ORandomForestEstimator(ntrees=1, # use only one tree\n", " sample_rate=1, # use all rows in that tree\n", " mtries=-2, # use all columns in that tree\n", " max_depth=3, # shallow trees are easier to understand\n", " seed=12345, # random seed for reproducibility\n", " model_id=model_id) # gives MOJO artifact a recognizable name\n", "\n", "# train single tree surrogate model\n", "surrogate.train(x=X, y=yhat, training_frame=test_yhat)\n", "\n", "# persist MOJO (compiled, representation of trained model)\n", "# from which to generate plot of surrogate\n", "mojo_path = surrogate.download_mojo(path='.')\n", "print('Generated MOJO path:\\n', mojo_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create GraphViz dot file\n", "GraphViz is an open source graph visualization tool. It is freely available from this url: http://www.graphviz.org/. To plot the trained decision tree surrogate model, a special h2o class, `PrintMojo`, is executed against the MOJO to create a GraphViz dot file representation of the tree." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Discovered H2O jar path:\n", " /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar\n", "\n", "Calling external process ...\n", "java -cp /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar hex.genmodel.tools.PrintMojo --tree 0 -i /home/patrickh/workspace/interpretable_machine_learning_with_python/dt_surrogate_mojo.zip -o dt_surrogate_mojo.gv --title Credit Card Default Decision Tree Surrogate\n" ] } ], "source": [ "# title for plot\n", "title = 'Credit Card Default Decision Tree Surrogate' \n", "\n", "# locate h2o jar\n", "hs = H2OLocalServer()\n", "h2o_jar_path = hs._find_jar()\n", "print('Discovered H2O jar path:\\n', h2o_jar_path)\n", "\n", "# construct command line call to generate graphviz version of \n", "# surrogate tree see for more information: \n", "# http://docs.h2o.ai/h2o/latest-stable/h2o-genmodel/javadoc/index.html\n", "gv_file_name = model_id + '.gv'\n", "gv_args = str('-cp ' + h2o_jar_path +\n", " ' hex.genmodel.tools.PrintMojo --tree 0 -i '\n", " + mojo_path + ' -o').split()\n", "gv_args.insert(0, 'java')\n", "gv_args.append(gv_file_name)\n", "if title is not None:\n", " gv_args = gv_args + ['--title', title]\n", " \n", "# call \n", "print()\n", "print('Calling external process ...')\n", "print(' '.join(gv_args))\n", "# if the line below is failing for you, try instead:\n", "# _ = subprocess.call(gv_args, shell=True) \n", "_ = subprocess.call(gv_args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create PNG from GraphViz dot file and display\n", "Then a GraphViz command line tool is used to create a static PNG image from the dot file ... " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calling external process ...\n", "dot -Tpng dt_surrogate_mojo.gv -o dt_surrogate_mojo.png\n" ] } ], "source": [ "# construct call to generate PNG from \n", "# graphviz representation of the tree\n", "png_file_name = model_id + '.png'\n", "png_args = str('dot -Tpng ' + gv_file_name + ' -o ' + png_file_name)\n", "png_args = png_args.split()\n", "\n", "# call\n", "print('Calling external process ...')\n", "print(' '.join(png_args))\n", "# if the line below is failing for you, try instead:\n", "# _ = subprocess.call(png_args, shell=True) \n", "_ = subprocess.call(png_args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display surrogate decision tree in notebook\n", "... and the image is displayed in the notebook." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# display in-notebook\n", "display(Image((png_file_name)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Analyze surrogate model and compare to global GBM variable importance\n", "\n", "The displayed tree is comparable with the global GBM variable importance. A simple heuristic rule for variable importance in a decision tree relates to the depth and frequency at which a variable is split on in a tree: variables used higher in the tree and more frequently in the tree are more important. Most of the variables pictured in this tree also appear as highly important in the GBM variable importance plot. In both cases, `PAY_0` is appearing as crucially important, with other payment behavior variables following close behind. The surrogate decision tree enables users to understand and confirm not only what input variables are important, but also how their values contribute to model decisions. For instance, to fall into the lowest probability of default leaf node in the surrogate decision tree a customer must make their first and second payments in a timely fashion and then pay more than 1515.5 New Tiawanese Dollars for their fifth payment. Conversely, customers who miss their first, fifth, and third payments fall into the highest probability of default leaf node of the surrogate decision tree. It is also imperative to compare these results to domain knowledge and reasonable expectations. In this case, the global explanatory methods applied thus far tell a consisent and reasonable story about the GBM's behavior. If this was not so, steps should be taken to either reconcile or remove inconsistencies and unreasonable prediction behavior." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Generate reason codes using the LOCO method \n", "\n", "Now that a solid understanding of global model behavior has been attained, local behavior for any given row of data and prediction can be analyzed and validated using LOCO. The LOCO method presented here is adapted from *Distribution-Free Predictive Inference for Regression* by Jing Lei et al., http://www.stat.cmu.edu/~ryantibs/papers/conformal.pdf. Here the local contribution of an input variable to a prediction for a single row of data is estimated by rescoring the GBM on that row one time for each input variable, each time leaving out one input variable (e.g., \"covariate\") by setting it to missing, and then subtracting the new score from the original score. By default, h2o scores missing data in decision trees by running them through the majority decision path. This means LOCO will be a numeric measure of how different the local contribution of an input variable is from the most common local contribution of that variable in the model. This variant of LOCO differs from the original method, in which one input variable is dropped from the model and the model is retrained without that variable. For nonlinear models, nonlinear dependencies can allow variables to nearly completely replace one another when a variable is dropped and the model is retrained. Hence, the approach of injecting missing values is used to estimate local contributions of input variables for nonlinear models here, as opposed to dropping a variable and retraining the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculate LOCO reason values for each row of the test set\n", "To implement LOCO, GBM model predicitions are calculated once for the test data and then again for each input variable, setting the entire input variable column to missing. Once the prediction without the variable is found for every row of data in the test set, that column vector of predictions on corrupted data can be subtracted from the column vector of predictions on the original, non-corrupted data to estimate the local contribution of that variable for each prediction in the test data. For better local accuracy and explainability, LOCO contributions are scaled such that contributions for each prediction plus the overall average of `DEFAULT_NEXT_MONTH` always sum to the model predictions." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating LOCO contributions ...\n", "LOCO Progress: LIMIT_BAL (1/23) ...\n", "LOCO Progress: SEX (2/23) ...\n", "LOCO Progress: EDUCATION (3/23) ...\n", "LOCO Progress: MARRIAGE (4/23) ...\n", "LOCO Progress: AGE (5/23) ...\n", "LOCO Progress: PAY_0 (6/23) ...\n", "LOCO Progress: PAY_2 (7/23) ...\n", "LOCO Progress: PAY_3 (8/23) ...\n", "LOCO Progress: PAY_4 (9/23) ...\n", "LOCO Progress: PAY_5 (10/23) ...\n", "LOCO Progress: PAY_6 (11/23) ...\n", "LOCO Progress: BILL_AMT1 (12/23) ...\n", "LOCO Progress: BILL_AMT2 (13/23) ...\n", "LOCO Progress: BILL_AMT3 (14/23) ...\n", "LOCO Progress: BILL_AMT4 (15/23) ...\n", "LOCO Progress: BILL_AMT5 (16/23) ...\n", "LOCO Progress: BILL_AMT6 (17/23) ...\n", "LOCO Progress: PAY_AMT1 (18/23) ...\n", "LOCO Progress: PAY_AMT2 (19/23) ...\n", "LOCO Progress: PAY_AMT3 (20/23) ...\n", "LOCO Progress: PAY_AMT4 (21/23) ...\n", "LOCO Progress: PAY_AMT5 (22/23) ...\n", "LOCO Progress: PAY_AMT6 (23/23) ...\n", "\n", "Scaling contributions ...\n", "Done.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDp_DEFAULT_NEXT_MONTHLIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_0PAY_2PAY_3...BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6
040.144991-0.079758-0.000000-0.00000-0.009143-0.001728-0.000000-0.000000-0.000000...-0.0021130.005768-0.000000-0.000000-0.000000-0.000000-0.000000-0.000000-0.000000-0.005340
180.128193-0.020007-0.000000-0.00000-0.000000-0.0000000.011403-0.0000000.045036...0.010062-0.000000-0.0594670.015406-0.028304-0.036129-0.067713-0.0000000.057314-0.000000
2100.179911-0.024094-0.002945-0.00000-0.000000-0.0098500.003778-0.0000000.005310...-0.000000-0.000000-0.000000-0.000000-0.005015-0.003077-0.0096570.019491-0.000000-0.004973
3160.3252050.0126170.0000000.000000.0040430.0000000.0270900.0552670.000000...0.000000-0.0032650.0000000.0000000.0054030.0000000.0000000.0000000.0000000.000000
4170.4088210.0331880.0000000.001860.0000000.0000000.0000000.0000000.051523...-0.0019660.0000000.0000000.0000000.0000000.0000000.0000000.0012470.0000000.008397
\n", "

5 rows × 25 columns

\n", "
" ], "text/plain": [ " ID p_DEFAULT_NEXT_MONTH LIMIT_BAL SEX EDUCATION MARRIAGE \\\n", "0 4 0.144991 -0.079758 -0.000000 -0.00000 -0.009143 \n", "1 8 0.128193 -0.020007 -0.000000 -0.00000 -0.000000 \n", "2 10 0.179911 -0.024094 -0.002945 -0.00000 -0.000000 \n", "3 16 0.325205 0.012617 0.000000 0.00000 0.004043 \n", "4 17 0.408821 0.033188 0.000000 0.00186 0.000000 \n", "\n", " AGE PAY_0 PAY_2 PAY_3 ... BILL_AMT3 BILL_AMT4 \\\n", "0 -0.001728 -0.000000 -0.000000 -0.000000 ... -0.002113 0.005768 \n", "1 -0.000000 0.011403 -0.000000 0.045036 ... 0.010062 -0.000000 \n", "2 -0.009850 0.003778 -0.000000 0.005310 ... -0.000000 -0.000000 \n", "3 0.000000 0.027090 0.055267 0.000000 ... 0.000000 -0.003265 \n", "4 0.000000 0.000000 0.000000 0.051523 ... -0.001966 0.000000 \n", "\n", " BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 \\\n", "0 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 -0.000000 \n", "1 -0.059467 0.015406 -0.028304 -0.036129 -0.067713 -0.000000 0.057314 \n", "2 -0.000000 -0.000000 -0.005015 -0.003077 -0.009657 0.019491 -0.000000 \n", "3 0.000000 0.000000 0.005403 0.000000 0.000000 0.000000 0.000000 \n", "4 0.000000 0.000000 0.000000 0.000000 0.000000 0.001247 0.000000 \n", "\n", " PAY_AMT6 \n", "0 -0.005340 \n", "1 -0.000000 \n", "2 -0.004973 \n", "3 0.000000 \n", "4 0.008397 \n", "\n", "[5 rows x 25 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h2o.no_progress() # turn off h2o gratuitous progress bars\n", "\n", "# create set of original predictions and row ID\n", "preds2 = test['ID'].cbind(model.predict(test).drop(['predict', 'p0']))\n", "preds2.columns = ['ID', yhat]\n", "\n", "# calculate LOCO for each variable\n", "print('Calculating LOCO contributions ...')\n", "for k, i in enumerate(X):\n", "\n", " # train and predict with x_i set to missing\n", " test_loco = h2o.deep_copy(test, 'test_loco')\n", " test_loco[i] = np.nan\n", " preds_loco = model.predict(test_loco).drop(['predict','p0'])\n", " \n", " # create a new, named column for the LOCO prediction\n", " preds_loco.columns = [i]\n", " preds2 = preds2.cbind(preds_loco)\n", " \n", " # subtract the LOCO prediction from the original prediction\n", " preds2[i] = preds2[yhat] - preds2[i]\n", " \n", " # update progress\n", " print('LOCO Progress: ' + i + ' (' + str(k+1) + '/' + str(len(X)) + ') ...')\n", " \n", "# scale contributions to sum to yhat - y_0\n", "print('\\nScaling contributions ...')\n", "\n", "y_0 = test[y].mean()[0]\n", "preds2_pd = preds2.as_data_frame()\n", "pred_ = preds2_pd[yhat]\n", "scaler = (pred_ - y_0) / preds2_pd[X].sum(axis=1)\n", "preds2_pd[X] = preds2_pd[X].multiply(scaler, axis=0) \n", "\n", "print('Done.') \n", "\n", "preds2_pd.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The numeric LOCO values in each column are an estimate of how much each variable contributed to each prediction. LOCO can indicate how a variable and its values were weighted in any given decision by the model. These values are crucially important for machine learning interpretability and are related to \"local feature importance\", \"reason codes\", or \"turn-down codes.\" The latter phrases are borrowed from credit scoring. Credit lenders in the U.S. must provide reasons for automatically rejecting a credit application. Reason codes can be easily extracted from LOCO local variable contribution values by simply ranking the variables that played the largest role in any given decision." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Helper function for finding percentile indices\n", "The function below finds and returns the row indices for the minimum, the maximum, and the deciles of one column in terms of another, in this case the model predictions (`p_DEFAULT_NEXT_MONTH`) and the row identifier (`ID`), respectively. These indices are used as a starting point for finding potentially interesting predictions. Outlying predictions found through residual analysis is another group of potentially interesting local predictions to analyze with LOCO." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 28716,\n", " 99: 29116,\n", " 10: 8942,\n", " 20: 28257,\n", " 30: 4074,\n", " 40: 13411,\n", " 50: 16633,\n", " 60: 2402,\n", " 70: 19769,\n", " 80: 25069,\n", " 90: 21372}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_percentile_dict(yhat, id_, frame):\n", "\n", " \"\"\" Returns the minimum, maximum, and percentiles of a column, yhat, \n", " as the indices based on another column id_.\n", " \n", " Args:\n", " yhat: Column in which to find percentiles.\n", " id_: Id column that stores indices for percentiles of yhat.\n", " frame: H2OFrame containing yhat and id_. \n", " \n", " Returns:\n", " Dictionary of percentile values and index column values.\n", " \n", " \"\"\"\n", " \n", " # convert to Pandas and sort \n", " sort_df = preds2_pd.copy(deep=True)\n", " sort_df.sort_values(yhat, inplace=True)\n", " sort_df.reset_index(inplace=True)\n", "\n", " # find top and bottom percentiles\n", " percentiles_dict = {}\n", " percentiles_dict[0] = sort_df.loc[0, id_]\n", " percentiles_dict[99] = sort_df.loc[sort_df.shape[0]-1, id_]\n", " inc = sort_df.shape[0]//10\n", " \n", " # find 10th-90th percentiles \n", " for i in range(1, 10):\n", " percentiles_dict[i * 10] = sort_df.loc[i * inc, id_]\n", "\n", " return percentiles_dict\n", "\n", "# display percentiles dictionary\n", "# ID values for rows\n", "# from lowest prediction \n", "# to highest prediction\n", "percentile_dict = get_percentile_dict(yhat, 'ID', preds2_pd)\n", "percentile_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot some reason codes for a risky customer\n", "Investigating customers with very high or low predicted probabilities to determine if their local explanations justify their extreme predictions is typically a productive exercise in boundary testing, model debugging, and validation. Reason codes are generated for the customer with the highest probability of default in the test data set below in cell 18, but LOCO can create local explanations for any or all rows in the training or test datasets, and on new data." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# select single customer\n", "# convert to Pandas\n", "# drop prediction and row ID\n", "\n", "risky_loco = preds2_pd[preds2_pd['ID'] == int(percentile_dict[99])].drop(['ID', yhat], axis=1) \n", "\n", "# transpose into column vector and sort \n", "risky_loco = risky_loco.T.sort_values(by=8674, ascending=False)[:5]\n", "\n", "# plot\n", "_ = risky_loco.plot(kind='bar', \n", " title='Top Five Reason Codes for a Risky Customer\\n', \n", " legend=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the customer in the test dataset that the GBM predicts as most likely to default, the most important input variables in the prediction are, in descending order, `PAY_0`, `PAY_6`, `PAY_3`, `PAY_5`, and `AGE`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Display customer in question " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local contributions for this customer appear reasonable, especially when considering her payment information. Her most recent payment was 3 months late and her payment for 6 months previous was 4 months late, so it's logical that these would weigh heavily into the model's prediction for default for this customer." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
ID LIMIT_BALSEX EDUCATION MARRIAGE AGEPAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 DEFAULT_NEXT_MONTH p_DEFAULT_NEXT_MONTH
29116 20000femaleuniversity married 593 month delay2 month delay3 month delay2 month delay2 month delay4 month delay 8803 11137 10672 11201 12721 11946 2800 0 1000 2000 0 0 1 0.895285
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_yhat[test_yhat['ID'] == int(percentile_dict[99]), :] # helps understand reason codes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To generate reason codes for the model's decision, the locally important variable and its value are used together. If this customer was denied future credit based on this model and data, the top five LOCO-based reason codes for the automated decision would be:\n", "\n", "1. Most recent payment is 3 months delayed.\n", "2. 6th most recent payment is 4 months delayed.\n", "3. 3rd most recent payment is 3 months delayed.\n", "4. 5th most recent payment is 2 months delayed.\n", "5. Customer age is 59. \n", "\n", "(Of course, in many places, variables like `AGE` and `SEX` cannot, and should not, be used in credit lending or other high-stakes decisions. For a slightly more careful treatment of GBM in a fair lending context, see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Bonus: Generate ensemble LOCO reason codes for greater explanation stability\n", "Just like predictions from high variance, nonlinear models, *explanations* derived from machine learning models can be unstable. One general way to decrease variance is to ensemble the results of many models. The last section of this notebook puts forward a simple approach to creating ensemble explanations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train multiple models\n", "To create ensemble explanations, several accurate models are trained. The models and their predictions on the test data are stored in Python lists." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Progress: model 1/10, AUC = 0.7813 ...\n", "Training Progress: model 2/10, AUC = 0.7803 ...\n", "Training Progress: model 3/10, AUC = 0.7787 ...\n", "Training Progress: model 4/10, AUC = 0.7826 ...\n", "Training Progress: model 5/10, AUC = 0.7804 ...\n", "Training Progress: model 6/10, AUC = 0.7800 ...\n", "Training Progress: model 7/10, AUC = 0.7802 ...\n", "Training Progress: model 8/10, AUC = 0.7799 ...\n", "Training Progress: model 9/10, AUC = 0.7796 ...\n", "Training Progress: model 10/10, AUC = 0.7811 ...\n", "Done.\n" ] } ], "source": [ "n_models = 10 # select number of models\n", "\n", "# lists for holding models and predictions\n", "models = []\n", "pred_frames = []\n", "\n", "for i in range(0, n_models):\n", "\n", " # initialize and store models\n", " models.append(H2OGradientBoostingEstimator(ntrees=150,\n", " max_depth=4,\n", " sample_rate=0.9 - ((i + 1)*0.01), # perturb sample rate\n", " col_sample_rate=0.9 - ((i + 1)*0.01), # perturb column sample rate\n", " #balance_classes=True, # sample to balance 0/1 distribution of target - helps LOCO\n", " stopping_rounds=5, # stop if validation error does not decrease for 5 iterations (trees)\n", " seed=i + 1)) # new random seed for each model\n", " \n", " # train models\n", " models[i].train(y=y, x=X, training_frame=train, validation_frame=test)\n", " \n", " # store predictions\n", " pred_frames.append(test['ID'].cbind(models[i].predict(test).drop(['predict','p0'])))\n", " pred_frames[i].columns = ['ID', yhat]\n", " \n", " # update progress\n", " print('Training Progress: model %d/%d, AUC = %.4f ...' % (i + 1, n_models, models[i].auc(valid=True)))\n", "\n", "print('Done.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Calculate LOCO for each model\n", "LOCO is calculated on the test data for each model, each input, and each row of data in the test set using the stored models and predictions." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCO Progress: model 1/10 ...\n", "LOCO Progress: model 2/10 ...\n", "LOCO Progress: model 3/10 ...\n", "LOCO Progress: model 4/10 ...\n", "LOCO Progress: model 5/10 ...\n", "LOCO Progress: model 6/10 ...\n", "LOCO Progress: model 7/10 ...\n", "LOCO Progress: model 8/10 ...\n", "LOCO Progress: model 9/10 ...\n", "LOCO Progress: model 10/10 ...\n", "Done.\n" ] } ], "source": [ "# for each new model ...\n", "for k, model in enumerate(models):\n", "\n", " # calculate LOCO for each input variable \n", " for i in X:\n", "\n", " # train and predict with Xi set to missing\n", " test_loco = h2o.deep_copy(test, 'test_loco')\n", " test_loco[i] = np.nan\n", " preds_loco = model.predict(test_loco).drop(['predict','p0'])\n", "\n", " # create a new, named column for the LOCO prediction\n", " preds_loco.columns = [i]\n", " pred_frames[k] = pred_frames[k].cbind(preds_loco)\n", "\n", " # subtract the LOCO prediction from the original prediction\n", " pred_frames[k][i] = pred_frames[k][yhat] - pred_frames[k][i]\n", " \n", " # update progress \n", " print('LOCO Progress: model %d/%d ...' % (k + 1, n_models))\n", "\n", "print('Done.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Collect LOCO values for each model for a risky customer\n", "To create ensemble explanations for a single row, the LOCO values for each variable in the row are averaged across all models. Single-model and mean LOCO values for the most risky person in the test set are displayed below. Notice that even slight changes in model specifications can result in different explanations. For example, the local contribution of `PAY_0` for the riskiest customer ranges from 0.13 to 0.23 across the 10 models in the table below. " ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Loco 1Loco 2Loco 3Loco 4Loco 5Loco 6Loco 7Loco 8Loco 9Loco 10Mean Local ImportanceScaled Mean Local ImportanceStd. Dev. Local Importance
LIMIT_BAL0.0132050.0110400.0124830.0013450.015150-0.006688-0.002158-0.009428-0.002005-0.0055030.0027440.0041160.008836
SEX0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
EDUCATION0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
MARRIAGE0.0511860.0233040.0495740.0440720.0349870.0835130.0534440.0348090.0516110.0978800.0524380.0786520.021392
AGE0.0226460.0840280.0153650.0194500.0472770.0084110.0193110.0191940.0649690.0479810.0348630.0522920.023673
PAY_00.1992850.1369710.1553680.1239130.1039770.0945630.1298080.1665230.1633030.1189500.1392660.2088870.030281
PAY_20.0027210.0232280.0299130.0641360.0478390.0033870.0442960.0027520.0218870.0282160.0268370.0402540.019742
PAY_30.0684940.0931030.0476300.0464080.0448000.0169140.0386560.0463570.0881660.0540550.0544580.0816830.021809
PAY_40.0303880.0532980.039892-0.0090980.0214450.0226280.0430560.0381970.0637590.0218410.0325410.0488080.019174
PAY_50.0645080.0510640.0336040.0698550.0244480.0380360.0312120.0206760.0630640.0729980.0469470.0704160.018681
PAY_60.0300940.0335540.027220-0.0230890.0054300.0052040.0225030.0291800.0280950.0309590.0189150.0283710.017000
BILL_AMT10.0009330.0361900.0021330.0253750.0010710.0046280.0180290.0046580.0215240.0271000.0141640.0212450.012328
BILL_AMT20.0088040.0016310.0000000.0000000.0317670.0000000.0086150.0000000.0000000.0000000.0050820.0076220.009515
BILL_AMT30.000000-0.0010080.0000000.0000000.0000000.0000000.0000000.000000-0.0016480.000000-0.000266-0.0003980.000550
BILL_AMT4-0.0013700.0000000.0000000.0000000.000000-0.0100700.000000-0.0055860.000000-0.001738-0.001876-0.0028140.003198
BILL_AMT50.0000000.0000000.0000000.0033390.0000000.0000000.0000000.004020-0.0109650.0092040.0005600.0008400.004787
BILL_AMT60.0000000.0000000.0000000.0000000.0064780.0000000.0000000.0000000.0000000.0000000.0006480.0009720.001943
PAY_AMT10.0000000.0043070.0000000.0000000.0000000.0039190.0000000.0000000.0000000.0000000.0008230.0012340.001648
PAY_AMT20.0000000.0015410.000000-0.0013500.0045030.0024500.0000000.0019220.0000000.0009850.0010050.0015080.001582
PAY_AMT30.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0050400.0005040.0007560.001512
PAY_AMT40.0000000.0000000.000000-0.0065280.0000000.0000000.0000000.0000000.0000000.000000-0.000653-0.0009790.001959
PAY_AMT50.0000000.0146050.0169060.0163360.0014400.0049750.0297930.0458100.0194390.0474910.0196790.0295170.015936
PAY_AMT6-0.0009250.0000000.000000-0.009342-0.000800-0.0005560.0000000.0000000.0000000.000000-0.001162-0.0017430.002749
\n", "
" ], "text/plain": [ " Loco 1 Loco 2 Loco 3 Loco 4 Loco 5 Loco 6 \\\n", "LIMIT_BAL 0.013205 0.011040 0.012483 0.001345 0.015150 -0.006688 \n", "SEX 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "EDUCATION 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "MARRIAGE 0.051186 0.023304 0.049574 0.044072 0.034987 0.083513 \n", "AGE 0.022646 0.084028 0.015365 0.019450 0.047277 0.008411 \n", "PAY_0 0.199285 0.136971 0.155368 0.123913 0.103977 0.094563 \n", "PAY_2 0.002721 0.023228 0.029913 0.064136 0.047839 0.003387 \n", "PAY_3 0.068494 0.093103 0.047630 0.046408 0.044800 0.016914 \n", "PAY_4 0.030388 0.053298 0.039892 -0.009098 0.021445 0.022628 \n", "PAY_5 0.064508 0.051064 0.033604 0.069855 0.024448 0.038036 \n", "PAY_6 0.030094 0.033554 0.027220 -0.023089 0.005430 0.005204 \n", "BILL_AMT1 0.000933 0.036190 0.002133 0.025375 0.001071 0.004628 \n", "BILL_AMT2 0.008804 0.001631 0.000000 0.000000 0.031767 0.000000 \n", "BILL_AMT3 0.000000 -0.001008 0.000000 0.000000 0.000000 0.000000 \n", "BILL_AMT4 -0.001370 0.000000 0.000000 0.000000 0.000000 -0.010070 \n", "BILL_AMT5 0.000000 0.000000 0.000000 0.003339 0.000000 0.000000 \n", "BILL_AMT6 0.000000 0.000000 0.000000 0.000000 0.006478 0.000000 \n", "PAY_AMT1 0.000000 0.004307 0.000000 0.000000 0.000000 0.003919 \n", "PAY_AMT2 0.000000 0.001541 0.000000 -0.001350 0.004503 0.002450 \n", "PAY_AMT3 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "PAY_AMT4 0.000000 0.000000 0.000000 -0.006528 0.000000 0.000000 \n", "PAY_AMT5 0.000000 0.014605 0.016906 0.016336 0.001440 0.004975 \n", "PAY_AMT6 -0.000925 0.000000 0.000000 -0.009342 -0.000800 -0.000556 \n", "\n", " Loco 7 Loco 8 Loco 9 Loco 10 Mean Local Importance \\\n", "LIMIT_BAL -0.002158 -0.009428 -0.002005 -0.005503 0.002744 \n", "SEX 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "EDUCATION 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "MARRIAGE 0.053444 0.034809 0.051611 0.097880 0.052438 \n", "AGE 0.019311 0.019194 0.064969 0.047981 0.034863 \n", "PAY_0 0.129808 0.166523 0.163303 0.118950 0.139266 \n", "PAY_2 0.044296 0.002752 0.021887 0.028216 0.026837 \n", "PAY_3 0.038656 0.046357 0.088166 0.054055 0.054458 \n", "PAY_4 0.043056 0.038197 0.063759 0.021841 0.032541 \n", "PAY_5 0.031212 0.020676 0.063064 0.072998 0.046947 \n", "PAY_6 0.022503 0.029180 0.028095 0.030959 0.018915 \n", "BILL_AMT1 0.018029 0.004658 0.021524 0.027100 0.014164 \n", "BILL_AMT2 0.008615 0.000000 0.000000 0.000000 0.005082 \n", "BILL_AMT3 0.000000 0.000000 -0.001648 0.000000 -0.000266 \n", "BILL_AMT4 0.000000 -0.005586 0.000000 -0.001738 -0.001876 \n", "BILL_AMT5 0.000000 0.004020 -0.010965 0.009204 0.000560 \n", "BILL_AMT6 0.000000 0.000000 0.000000 0.000000 0.000648 \n", "PAY_AMT1 0.000000 0.000000 0.000000 0.000000 0.000823 \n", "PAY_AMT2 0.000000 0.001922 0.000000 0.000985 0.001005 \n", "PAY_AMT3 0.000000 0.000000 0.000000 0.005040 0.000504 \n", "PAY_AMT4 0.000000 0.000000 0.000000 0.000000 -0.000653 \n", "PAY_AMT5 0.029793 0.045810 0.019439 0.047491 0.019679 \n", "PAY_AMT6 0.000000 0.000000 0.000000 0.000000 -0.001162 \n", "\n", " Scaled Mean Local Importance Std. Dev. Local Importance \n", "LIMIT_BAL 0.004116 0.008836 \n", "SEX 0.000000 0.000000 \n", "EDUCATION 0.000000 0.000000 \n", "MARRIAGE 0.078652 0.021392 \n", "AGE 0.052292 0.023673 \n", "PAY_0 0.208887 0.030281 \n", "PAY_2 0.040254 0.019742 \n", "PAY_3 0.081683 0.021809 \n", "PAY_4 0.048808 0.019174 \n", "PAY_5 0.070416 0.018681 \n", "PAY_6 0.028371 0.017000 \n", "BILL_AMT1 0.021245 0.012328 \n", "BILL_AMT2 0.007622 0.009515 \n", "BILL_AMT3 -0.000398 0.000550 \n", "BILL_AMT4 -0.002814 0.003198 \n", "BILL_AMT5 0.000840 0.004787 \n", "BILL_AMT6 0.000972 0.001943 \n", "PAY_AMT1 0.001234 0.001648 \n", "PAY_AMT2 0.001508 0.001582 \n", "PAY_AMT3 0.000756 0.001512 \n", "PAY_AMT4 -0.000979 0.001959 \n", "PAY_AMT5 0.029517 0.015936 \n", "PAY_AMT6 -0.001743 0.002749 " ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# holds predictions for a specific row \n", "risky_loco_frames = []\n", "\n", "# column names for Pandas DataFrame of combined LOCO prediction\n", "col_names = ['Loco ' + str(i) for i in range(1, n_models + 1)]\n", "\n", "# for each new model ...\n", "for i in range(0, n_models):\n", " \n", " # collect LOCO for that model and a specific row \n", " # as a column vector in a Pandas DataFrame\n", " preds = pred_frames[i]\n", " risky_loco_frames.append(preds[preds['ID'] == int(percentile_dict[99]), :] # row for risky person\n", " .as_data_frame() # convert to Pandas\n", " .drop(['ID', yhat], axis=1) # drop predictions and row ID\n", " .T) # Transpose into column vector\n", "\n", "# bind LOCO for each row as column vectors \n", "# into the same Pandas DataFrame\n", "loco_ensemble = pd.concat(risky_loco_frames, axis=1) \n", "\n", "# update column names\n", "loco_ensemble.columns = col_names\n", "\n", "# mean local importance across models\n", "loco_ensemble['Mean Local Importance'] = loco_ensemble.mean(axis=1)\n", "\n", "# scale contribs\n", "scaler = (test_yhat[test_yhat['ID'] == int(percentile_dict[99]), yhat] - y_0) /\\\n", " (loco_ensemble['Mean Local Importance'].sum())\n", "loco_ensemble['Scaled Mean Local Importance'] = loco_ensemble['Mean Local Importance'] * scaler[0, 0]\n", "\n", "# std deviation\n", "loco_ensemble['Std. Dev. Local Importance'] = loco_ensemble\\\n", " .drop('Scaled Mean Local Importance', axis=1)\\\n", " .std(axis=1)\n", " \n", "# display\n", "loco_ensemble" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot some mean reason codes for a risky customer\n", "Taking mean explanations across multiple models leads to reason codes somewhat different from the reason codes produced by a single model. Mean reason codes may be more stable, they represent explanations from several models, and they may take practicioners a step closer to using machine learning models to make inferential conclusions about phenomena represented in the training or test data, instead of simply providing an approximate explanation of a single model's decision processes. " ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "risky_mean_loco = loco_ensemble['Mean Local Importance'].sort_values(ascending=False)[:5]\n", "_ = risky_mean_loco.plot(kind='bar', \n", " title='Top Five Reason Codes for a Risky Customer\\n', \n", " color='b',\n", " legend=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Shutdown H2O\n", "After using h2o, it's typically best to shut it down. However, before doing so, users should ensure that they have saved any h2o data structures, such as models and H2OFrames, or scoring artifacts, such as POJOs and MOJOs." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Are you sure you want to shutdown the H2O instance running at http://127.0.0.1:54321 (Y/N)? y\n", "H2O session _sid_ae25 closed.\n" ] } ], "source": [ "# be careful, this can erase your work!\n", "h2o.cluster().shutdown(prompt=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Summary\n", "\n", "In this notebook, a complex GBM classifier was trained to predict credit card defaults and explained at a global scale with a decision tree surrogate model and explained at a local scale with LOCO. An ensemble LOCO approach was also introduced to stabilize approximate explanations. The decision tree surrogate creates an overall approximate flowchart for the GBM's decision processes and LOCO can be used to create reason codes for each model prediction. All of these techniques enhance the transparency of the complex model, which in turn enables greater accountability for the model's predictions. These techniques should generalize well for many types of business and research problems, enabling you to train a complex GBM model and explain it to your colleagues, bosses, and potentially, external regulators. " ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }