{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpretable or Accurate? Why not both?\n", "\n", "## Case Study: Predicting Employee Attrition Using Machine Learning\n", "\n", "The notebook contains the code for the accompanying blogpost titled [Interpretable or Accurate? Why not both?](https://towardsdatascience.com/interpretable-or-accurate-why-not-both-4d9c73512192?sk=2f44377541a2f49939c921e54eb3cde7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "\n", "Interpret is supported across Windows, Mac and Linux on Python 3.5+. Please refer the [documentation](https://interpret.ml/docs/getting-started.html) for more details.\n", "\n", "### pip\n", "pip install interpret\n", "\n", "### conda\n", "conda install -c interpretml interpret\n", "\n", "### source\n", "git clone https://github.com/interpretml/interpret.git && cd interpret/scripts && make install\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing necessary libraries\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import f1_score, accuracy_score\n", "\n", "from interpret import show\n", "from interpret import set_visualize_provider\n", "from interpret.provider import InlineProvider\n", "from interpret.data import ClassHistogram\n", "set_visualize_provider(InlineProvider())\n", "from interpret.glassbox import (\n", " LogisticRegression,\n", " ClassificationTree,\n", " ExplainableBoostingClassifier,\n", ")\n", "\n", "\n", "seed = 42" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeAttritionBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEmployeeCountEmployeeNumber...RelationshipSatisfactionStandardHoursStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
041YesTravel_Rarely1102Sales12Life Sciences11...18008016405
149NoTravel_Frequently279Research & Development81Life Sciences12...4801103310717
237YesTravel_Rarely1373Research & Development22Other14...28007330000
333NoTravel_Frequently1392Research & Development34Life Sciences15...38008338730
427NoTravel_Rarely591Research & Development21Medical17...48016332222
\n", "

5 rows × 35 columns

\n", "
" ], "text/plain": [ " Age Attrition BusinessTravel DailyRate Department \\\n", "0 41 Yes Travel_Rarely 1102 Sales \n", "1 49 No Travel_Frequently 279 Research & Development \n", "2 37 Yes Travel_Rarely 1373 Research & Development \n", "3 33 No Travel_Frequently 1392 Research & Development \n", "4 27 No Travel_Rarely 591 Research & Development \n", "\n", " DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \\\n", "0 1 2 Life Sciences 1 1 \n", "1 8 1 Life Sciences 1 2 \n", "2 2 2 Other 1 4 \n", "3 3 4 Life Sciences 1 5 \n", "4 2 1 Medical 1 7 \n", "\n", " ... RelationshipSatisfaction StandardHours StockOptionLevel \\\n", "0 ... 1 80 0 \n", "1 ... 4 80 1 \n", "2 ... 2 80 0 \n", "3 ... 3 80 0 \n", "4 ... 4 80 1 \n", "\n", " TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany \\\n", "0 8 0 1 6 \n", "1 10 3 3 10 \n", "2 7 3 3 0 \n", "3 8 3 3 8 \n", "4 6 3 3 2 \n", "\n", " YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager \n", "0 4 0 5 \n", "1 7 1 7 \n", "2 0 0 0 \n", "3 7 3 0 \n", "4 2 2 2 \n", "\n", "[5 rows x 35 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "df = pd.read_csv(\"WA_Fn-UseC_-HR-Employee-Attrition.csv\")\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 1\n", "1 0\n", "2 1\n", "3 0\n", "4 0\n", "5 0\n", "6 0\n", "7 0\n", "8 0\n", "9 0\n", "Name: Attrition, dtype: int64\n" ] } ], "source": [ "#Encoding the target variable i.e Attrition\n", "\n", "target_map = {'Yes': 1, 'No': 0}\n", "target = df[\"Attrition\"].apply(lambda x: target_map[x])\n", "print(target[:10])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Deleting columns that are not useful for the predicitons\n", "\n", "df.drop(['EmployeeCount', 'EmployeeNumber', 'Over18', 'StandardHours','Attrition'], axis=\"columns\", inplace=True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split data into train and test\n", "X_train, X_test, y_train, y_test = train_test_split(df, \n", " target, \n", " test_size=0.2,\n", " random_state=seed,\n", " stratify=target)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exploring the Dataset with histogram visualizations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hist = ClassHistogram().explain_data(X_train, y_train, name = 'Train Data')\n", "show(hist)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training GlassBox Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Explainable Boosting Machine (EBM)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ExplainableBoostingClassifier(feature_names=['Age', 'BusinessTravel',\n", " 'DailyRate', 'Department',\n", " 'DistanceFromHome', 'Education',\n", " 'EducationField',\n", " 'EnvironmentSatisfaction',\n", " 'Gender', 'HourlyRate',\n", " 'JobInvolvement', 'JobLevel',\n", " 'JobRole', 'JobSatisfaction',\n", " 'MaritalStatus', 'MonthlyIncome',\n", " 'MonthlyRate',\n", " 'NumCompaniesWorked', 'OverTime',\n", " 'PercentSalaryHike',\n", " 'Perfor...\n", " 'categorical', 'continuous',\n", " 'categorical', 'continuous',\n", " 'continuous', 'continuous',\n", " 'categorical', 'continuous',\n", " 'categorical', 'continuous',\n", " 'continuous', 'continuous',\n", " 'categorical', 'continuous',\n", " 'continuous', 'continuous',\n", " 'continuous', 'continuous',\n", " 'continuous', 'continuous',\n", " 'continuous', 'continuous',\n", " 'continuous', 'continuous', ...],\n", " inner_bags=100, n_jobs=-1, outer_bags=100)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ebm = ExplainableBoostingClassifier(random_state=seed, n_jobs=-1,inner_bags=100,outer_bags=100)\n", "ebm.fit(X_train, y_train)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Global Explanations\n", "\n", "Global Explanations help to gain a better understanding of the model's overall behavior and what the model learnt overall." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ebm_global = ebm.explain_global(name='EBM')\n", "show(ebm_global)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Local Explanations: \n", "\n", "Local Explanations helps us understand the reasons behind individual predictionsHow an why individual prediction was made" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ebm_local = ebm.explain_local(X_test[:5], y_test[:5], name='EBM')\n", "show(ebm_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating EBM performance" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "from interpret.perf import ROC\n", "\n", "ebm_perf = ROC(ebm.predict_proba).explain_perf(X_test, y_test, name='EBM')\n", "show(ebm_perf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparing the performance with other GlassBox models \n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.Logistic Regression and Decision Tree " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We have to transform categorical variables to use Logistic Regression and Decision Tree \n", "X_enc = pd.get_dummies(df, prefix_sep='.')\n", "feature_names = list(X_enc.columns)\n", "X_train_enc, X_test_enc, y_train, y_test = train_test_split(X_enc, target, test_size=0.20, random_state=seed)\n", "\n", "lr = LogisticRegression(random_state=seed, feature_names=feature_names, penalty='l1', solver='liblinear')\n", "lr.fit(X_train_enc, y_train)\n", "\n", "tree = ClassificationTree()\n", "tree.fit(X_train_enc, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparing the performance of all the models" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "lr_perf = ROC(lr.predict_proba).explain_perf(X_test_enc, y_test, name='Logistic Regression')\n", "tree_perf = ROC(tree.predict_proba).explain_perf(X_test_enc, y_test, name='Classification Tree')\n", "\n", "show(lr_perf)\n", "show(tree_perf)\n", "show(ebm_perf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Blackbox Models\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Random Forest Classifier" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pipeline(steps=[('pca', PCA()), ('rf', RandomForestClassifier(n_jobs=-1))])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.decomposition import PCA\n", "from sklearn.pipeline import Pipeline\n", "\n", "#Blackbox system can include preprocessing, not just a classifier!\n", "pca = PCA()\n", "rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)\n", "\n", "X_enc = pd.get_dummies(df, prefix_sep='.')\n", "feature_names = list(X_enc.columns)\n", "X_train_enc, X_test_enc, y_train, y_test = train_test_split(X_enc, target, test_size=0.20, random_state=seed)\n", "\n", "\n", "\n", "blackbox_model = Pipeline([('pca', pca), ('rf', rf)])\n", "blackbox_model.fit(X_train_enc, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluating BlackBox models" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from interpret import show\n", "from interpret.perf import ROC\n", "\n", "blackbox_perf = ROC(blackbox_model.predict_proba).explain_perf(X_test_enc, y_test, name='Blackbox')\n", "show(blackbox_perf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explaining local BlackBox predictions with [LIME](https://arxiv.org/abs/1602.04938v3)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from interpret.blackbox import LimeTabular\n", "from interpret import show\n", "\n", "#Blackbox explainers need a predict function, and optionally a dataset\n", "lime = LimeTabular(predict_fn=blackbox_model.predict_proba, data=X_train_enc, random_state=1)\n", "\n", "#Pick the instances to explain, optionally pass in labels if you have them\n", "lime_local = lime.explain_local(X_test_enc[:5], y_test[:5], name='LIME')\n", "\n", "show(lime_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explaining global BlackBox predictions with [PDP](https://christophm.github.io/interpretable-ml-book/pdp.html)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from interpret.blackbox import PartialDependence\n", "\n", "pdp = PartialDependence(predict_fn=blackbox_model.predict_proba, data=X_train_enc)\n", "pdp_global = pdp.explain_global(name='Partial Dependence')\n", "\n", "show(pdp_global)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }