{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Interpretability Dashboard with Employee Attrition Data\n",
    "\n",
    "This notebook illustrates creating explanations for a binary classification model, employee attrition classification, that uses one to one and one to many feature transformations from raw data to engineered features. It will showcase raw feature transformations with TabularExplainer from SHAP.\n",
    "\n",
    "Problem: Employee attrition classification with scikit-learn (run model explainer locally)\n",
    "\n",
    "1. Transform raw features to engineered features.\n",
    "2. Train a classification model using Scikit-learn.\n",
    "3. Run 'explain_model' globally and locally with full dataset.\n",
    "4. Visualize the global and local explanations with the interpretability visualization dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Required Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install --upgrade interpret-community"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After installing packages, you must close and reopen the notebook as well as restarting the kernel."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explain\n",
    "\n",
    "### Run model explainer at training time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from lightgbm import LGBMClassifier\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from urllib.request import urlretrieve\n",
    "import zipfile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the employee attrition data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outdirname = 'dataset.6.21.19'\n",
    "zipfilename = outdirname + '.zip'\n",
    "urlretrieve('https://publictestdatasets.blob.core.windows.net/data/' + zipfilename, zipfilename)\n",
    "with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "    unzip.extractall('.')\n",
    "attritionData = pd.read_csv('./WA_Fn-UseC_-HR-Employee-Attrition.csv')\n",
    "\n",
    "# Dropping Employee count as all values are 1 and hence attrition is independent of this feature\n",
    "attritionData = attritionData.drop(['EmployeeCount'], axis=1)\n",
    "# Dropping Employee Number since it is merely an identifier\n",
    "attritionData = attritionData.drop(['EmployeeNumber'], axis=1)\n",
    "attritionData = attritionData.drop(['Over18'], axis=1)\n",
    "\n",
    "# Since all values are 80\n",
    "attritionData = attritionData.drop(['StandardHours'], axis=1)\n",
    "\n",
    "# Converting target variables from string to numerical values\n",
    "target_map = {'Yes': 0, 'No': 1}\n",
    "attritionData[\"Attrition_numerical\"] = attritionData[\"Attrition\"].apply(lambda x: target_map[x])\n",
    "target = attritionData[\"Attrition_numerical\"]\n",
    "\n",
    "attritionXData = attritionData.drop(['Attrition_numerical', 'Attrition'], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train and test\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(attritionXData, \n",
    "                                                    target, \n",
    "                                                    test_size=0.2,\n",
    "                                                    random_state=0,\n",
    "                                                    stratify=target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating dummy columns for each categorical feature\n",
    "categorical = []\n",
    "for col, value in attritionXData.iteritems():\n",
    "    if value.dtype == 'object':\n",
    "        categorical.append(col)\n",
    "        \n",
    "# Store the numerical columns in a list numerical\n",
    "numerical = attritionXData.columns.difference(categorical)        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transform raw features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can explain raw features by either using a `sklearn.compose.ColumnTransformer` or a list of fitted transformer tuples. The cell below uses `sklearn.compose.ColumnTransformer`. In case you want to run the example with the list of fitted transformer tuples, comment the cell below and uncomment the cell that follows after. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.compose import ColumnTransformer\n",
    "\n",
    "# We create the preprocessing pipelines for both numeric and categorical data.\n",
    "numeric_transformer = Pipeline(steps=[\n",
    "    ('imputer', SimpleImputer(strategy='median')),\n",
    "    ('scaler', StandardScaler())])\n",
    "\n",
    "categorical_transformer = Pipeline(steps=[\n",
    "    ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),\n",
    "    ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n",
    "\n",
    "transformations = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', numeric_transformer, numerical),\n",
    "        ('cat', categorical_transformer, categorical)])\n",
    "\n",
    "# Append classifier to preprocessing pipeline.\n",
    "# Now we have a full prediction pipeline.\n",
    "clf = Pipeline(steps=[('preprocessor', transformations),\n",
    "                      ('classifier', LGBMClassifier())])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a LightGBM classification model, which you want to explain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explain your model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clf.steps[-1][1] returns the trained classification model\n",
    "\n",
    "# Using SHAP TabularExplainer\n",
    "from interpret.ext.blackbox import TabularExplainer\n",
    "explainer = TabularExplainer(clf.steps[-1][1], \n",
    "                             initialization_examples=X_train, \n",
    "                             features=attritionXData.columns, \n",
    "                             classes=['Leaving', 'Staying'], \n",
    "                             transformations=transformations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate global explanations\n",
    "Explain overall model predictions (global explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
    "# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
    "global_explanation = explainer.explain_global(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print out a dictionary that holds the sorted feature importance names and values\n",
    "print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate local explanations\n",
    "Explain local data points (individual instances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can pass a specific data point or a group of data points to the explain_local function\n",
    "# E.g., Explain the first data point in the test set\n",
    "instance_num = 1\n",
    "local_explanation = explainer.explain_local(X_test[:instance_num])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Get the prediction for the first member of the test set and explain why model made that prediction\n",
    "prediction_value = clf.predict(X_test)[instance_num]\n",
    "\n",
    "sorted_local_importance_values = local_explanation.get_ranked_local_values()[prediction_value]\n",
    "sorted_local_importance_names = local_explanation.get_ranked_local_names()[prediction_value]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('local importance values: {}'.format(sorted_local_importance_values))\n",
    "print('local importance names: {}'.format(sorted_local_importance_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize\n",
    "Load the interpretability visualization dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from raiwidgets import ExplanationDashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ExplanationDashboard(global_explanation, model, dataset=X_test, true_y=y_test.values)"
   ]
  }
 ],
 "metadata": {
  "authors": [
   {
    "name": "mesameki"
   }
  ],
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}