{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "fb121c58-651d-4fd6-b37e-362058107486", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.metrics import roc_auc_score\n", "from lightgbm import LGBMClassifier\n", "from interpret.glassbox import ExplainableBoostingClassifier\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.model_selection import train_test_split\n", "from interpret.glassbox import LogisticRegression\n", "from sklearn.preprocessing import StandardScaler\n", "from interpret import perf\n", "from interpret.blackbox import ShapKernel\n", "from interpret.blackbox import LimeTabular\n", "from interpret import show\n", "from interpret.provider import InlineProvider\n", "from interpret import set_visualize_provider\n", "\n", "set_visualize_provider(InlineProvider())\n", "from interpret.data import ClassHistogram" ] }, { "cell_type": "markdown", "id": "305e3d47-a705-4c0a-9b5a-aef619aff43a", "metadata": {}, "source": [ "When vizualization doesn't work (https://github.com/interpretml/interpret/issues/259)" ] }, { "cell_type": "code", "execution_count": 3, "id": "d2d4ee8f-45c2-47c0-b556-b89eb5581fb1", "metadata": {}, "outputs": [], "source": [ "# from interpret.data import Marginal\n", "# from interpret import preserve" ] }, { "cell_type": "markdown", "id": "1b0edeca-f91f-46aa-a45f-929c9e8721be", "metadata": {}, "source": [ "# Import Data" ] }, { "cell_type": "code", "execution_count": 4, "id": "c7a7cf42-0321-412c-bbec-f57be98297af", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv(\"heart_failure_clinical_records_dataset.csv\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "dc1b367d-4038-4160-a257-393c2c7c79ce", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageanaemiacreatinine_phosphokinasediabetesejection_fractionhigh_blood_pressureplateletsserum_creatinineserum_sodiumsexsmokingtimeDEATH_EVENT
075.005820201265000.001.91301041
155.0078610380263358.031.11361061
265.001460200162000.001.31291171
350.011110200210000.001.91371071
465.011601200327000.002.71160081
\n", "
" ], "text/plain": [ " age anaemia creatinine_phosphokinase diabetes ejection_fraction \\\n", "0 75.0 0 582 0 20 \n", "1 55.0 0 7861 0 38 \n", "2 65.0 0 146 0 20 \n", "3 50.0 1 111 0 20 \n", "4 65.0 1 160 1 20 \n", "\n", " high_blood_pressure platelets serum_creatinine serum_sodium sex \\\n", "0 1 265000.00 1.9 130 1 \n", "1 0 263358.03 1.1 136 1 \n", "2 0 162000.00 1.3 129 1 \n", "3 0 210000.00 1.9 137 1 \n", "4 0 327000.00 2.7 116 0 \n", "\n", " smoking time DEATH_EVENT \n", "0 0 4 1 \n", "1 0 6 1 \n", "2 1 7 1 \n", "3 0 7 1 \n", "4 0 8 1 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 6, "id": "5c33c836-1c3c-48e0-b52e-36e32f5d00d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 299 entries, 0 to 298\n", "Data columns (total 13 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 age 299 non-null float64\n", " 1 anaemia 299 non-null int64 \n", " 2 creatinine_phosphokinase 299 non-null int64 \n", " 3 diabetes 299 non-null int64 \n", " 4 ejection_fraction 299 non-null int64 \n", " 5 high_blood_pressure 299 non-null int64 \n", " 6 platelets 299 non-null float64\n", " 7 serum_creatinine 299 non-null float64\n", " 8 serum_sodium 299 non-null int64 \n", " 9 sex 299 non-null int64 \n", " 10 smoking 299 non-null int64 \n", " 11 time 299 non-null int64 \n", " 12 DEATH_EVENT 299 non-null int64 \n", "dtypes: float64(3), int64(10)\n", "memory usage: 30.5 KB\n" ] } ], "source": [ "data.info()" ] }, { "cell_type": "markdown", "id": "f76272e7-740e-45e8-9628-20addea3c932", "metadata": { "tags": [] }, "source": [ "# Train Test Split" ] }, { "cell_type": "code", "execution_count": 7, "id": "72443f7a-c53f-46fd-9d13-dedce85a872e", "metadata": {}, "outputs": [], "source": [ "X = data.drop([\"time\", \"DEATH_EVENT\"], axis=1)\n", "y = data[\"DEATH_EVENT\"]" ] }, { "cell_type": "code", "execution_count": 8, "id": "082abf55-08a1-4495-b2a0-cc1c7fb89be0", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.33, random_state=42, stratify=y\n", ")" ] }, { "cell_type": "markdown", "id": "fee8f187-65f5-4865-8c4c-0a3c2bb44b84", "metadata": { "tags": [] }, "source": [ "# Explore the data" ] }, { "cell_type": "markdown", "id": "1efc965a-5630-45ff-9a5c-ea94c02fbbe4", "metadata": {}, "source": [ "InterpretML also support data exploration " ] }, { "cell_type": "code", "execution_count": 9, "id": "315a7d34-9c92-4196-89ec-ff89fab11c32", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hist = ClassHistogram().explain_data(X_train, y_train, name=\"Train Data\")\n", "show(hist)" ] }, { "cell_type": "markdown", "id": "6db74474-f68b-45bb-a3a7-256582222954", "metadata": { "tags": [] }, "source": [ "# Modeling & Hyperparameter Tuning" ] }, { "cell_type": "markdown", "id": "61c49590-a816-4c08-900b-d3940bc4b0c9", "metadata": { "tags": [] }, "source": [ "## Logistic Regression" ] }, { "cell_type": "code", "execution_count": 10, "id": "46dfbfeb-997e-4997-9451-dc13ab310156", "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "features_to_scale = [\n", " \"age\",\n", " \"creatinine_phosphokinase\",\n", " \"ejection_fraction\",\n", " \"platelets\",\n", " \"serum_creatinine\",\n", " \"serum_sodium\",\n", "]\n", "\n", "X_train_scaled, X_test_scaled = X_train.copy(), X_test.copy()\n", "X_train_scaled[features_to_scale] = scaler.fit_transform(\n", " X_train_scaled[features_to_scale]\n", ")\n", "X_test_scaled[features_to_scale] = scaler.transform(X_test_scaled[features_to_scale])" ] }, { "cell_type": "markdown", "id": "86c23a3c-7d99-4ad2-85dd-95ad3155b77a", "metadata": {}, "source": [ "Note: InterpretML's Logistic Regression currently cannot run with RandomizedSearchCV since it does not implement a 'get_params' method." ] }, { "cell_type": "code", "execution_count": 11, "id": "99257e87-04d0-4fa8-8a00-9eeab0e71f94", "metadata": {}, "outputs": [], "source": [ "LR_clf = LogisticRegression(solver=\"liblinear\", class_weight=\"balanced\", penalty=\"l1\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "a6e31057-2041-40cf-a73a-23030dcbdef4", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LR_clf.fit(X_train_scaled, y_train)" ] }, { "cell_type": "code", "execution_count": 13, "id": "bc489fcc-0c5f-4053-bbf6-f67dc2acde6e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr_global = LR_clf.explain_global()\n", "show(lr_global)" ] }, { "cell_type": "code", "execution_count": 14, "id": "4827855e-1b72-43db-abfa-ac40dce03b2d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr_local = LR_clf.explain_local(X_test_scaled[10:15], y_test[10:15])\n", "show(lr_local)" ] }, { "cell_type": "markdown", "id": "92472cba-cd98-4d0a-a43a-1e06038d2575", "metadata": { "tags": [] }, "source": [ "## EBM" ] }, { "cell_type": "markdown", "id": "9aaadf47-24fd-4334-847c-2a7c31b8f2e4", "metadata": {}, "source": [ "### Hyperparameter Optimization" ] }, { "cell_type": "code", "execution_count": 15, "id": "acbc80e5-ef83-4cd8-9d81-e930213a8890", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomizedSearchCV(cv=3,\n", " estimator=ExplainableBoostingClassifier(feature_names=Index(['age', 'anaemia', 'creatinine_phosphokinase', 'diabetes',\n", " 'ejection_fraction', 'high_blood_pressure', 'platelets',\n", " 'serum_creatinine', 'serum_sodium', 'sex', 'smoking'],\n", " dtype='object')),\n", " param_distributions={'interactions': [5, 10, 15],\n", " 'learning_rate': [0.001, 0.005, 0.01,\n", " 0.03],\n", " 'max_interaction_bins': [10, 15, 20],\n", " 'max_leaves': [3, 5, 10],\n", " 'max_rounds': [5000, 10000, 15000,\n", " 20000],\n", " 'min_samples_leaf': [2, 3, 5]},\n", " random_state=314, scoring='roc_auc', verbose=False)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param_test = {\n", " \"learning_rate\": [0.001, 0.005, 0.01, 0.03],\n", " \"interactions\": [5, 10, 15],\n", " \"max_interaction_bins\": [10, 15, 20],\n", " \"max_rounds\": [5000, 10000, 15000, 20000],\n", " \"min_samples_leaf\": [2, 3, 5],\n", " \"max_leaves\": [3, 5, 10],\n", "}\n", "\n", "n_HP_points_to_test = 10\n", "EBM_clf = ExplainableBoostingClassifier(feature_names=X_train.columns)\n", "EBM_gs = RandomizedSearchCV(\n", " estimator=EBM_clf,\n", " param_distributions=param_test,\n", " n_iter=n_HP_points_to_test,\n", " scoring=\"roc_auc\",\n", " cv=3,\n", " refit=True,\n", " random_state=314,\n", " verbose=False,\n", ")\n", "\n", "EBM_gs.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "id": "72cf0a3f-6afa-43c4-afdd-fed40039209d", "metadata": {}, "source": [ "### Model Performance " ] }, { "cell_type": "code", "execution_count": 16, "id": "698bbcee-c12c-4d98-8a32-bf2ff14b4edd", "metadata": {}, "outputs": [], "source": [ "roc = perf.ROC(EBM_gs.best_estimator_.predict_proba, feature_names=X_train.columns)\n", "roc_explanation = roc.explain_perf(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 17, "id": "0b9828f5-58f9-4837-9a16-82bcbb14f5d2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show(roc_explanation)" ] }, { "cell_type": "markdown", "id": "2901721e-ecea-4708-b2e9-6d25a71b8656", "metadata": {}, "source": [ "### Global Explanation " ] }, { "cell_type": "code", "execution_count": 18, "id": "f9c0aeb8-eee9-4661-ad6a-93dc7f554768", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ebm_global = EBM_gs.best_estimator_.explain_global()\n", "show(ebm_global)" ] }, { "cell_type": "markdown", "id": "f538147e-481f-448a-911c-9974a586f0dd", "metadata": {}, "source": [ "### Local Explanation" ] }, { "cell_type": "code", "execution_count": 19, "id": "082bfcf7-b9e4-43e4-a5ed-3679bd4decd3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ebm_local = EBM_gs.best_estimator_.explain_local(X_test[10:15], y_test[10:15])\n", "show(ebm_local)" ] }, { "cell_type": "markdown", "id": "1d14986b-aa22-4c6a-8ee5-6b76bbc169b1", "metadata": { "tags": [] }, "source": [ "## LightGBM" ] }, { "cell_type": "markdown", "id": "601514bf-01af-4ad3-b270-4ad8a46731ec", "metadata": {}, "source": [ "### Hyperparameter Optimization" ] }, { "cell_type": "code", "execution_count": 20, "id": "0ad31b47-4006-4322-90b8-32178f255be0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomizedSearchCV(cv=3, estimator=LGBMClassifier(random_state=314),\n", " param_distributions={'is_unbalance': [True, False],\n", " 'learning_rate': [0.01, 0.03, 0.05,\n", " 0.07],\n", " 'max_depth': [-1, 5, 10, 15],\n", " 'n_estimators': [200, 500, 700, 1000],\n", " 'num_leaves': [20, 30, 40, 50, 60]},\n", " random_state=314, scoring='roc_auc', verbose=False)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param_test = {\n", " \"num_leaves\": [20, 30, 40, 50, 60],\n", " \"max_depth\": [-1, 5, 10, 15],\n", " \"learning_rate\": [0.01, 0.03, 0.05, 0.07],\n", " \"n_estimators\": [200, 500, 700, 1000],\n", " \"is_unbalance\": [True, False],\n", "}\n", "\n", "n_HP_points_to_test = 10\n", "LGBM_clf = LGBMClassifier(random_state=314, n_jobs=-1)\n", "LGBM_gs = RandomizedSearchCV(\n", " estimator=LGBM_clf,\n", " param_distributions=param_test,\n", " n_iter=n_HP_points_to_test,\n", " scoring=\"roc_auc\",\n", " cv=3,\n", " refit=True,\n", " random_state=314,\n", " verbose=False,\n", ")\n", "\n", "LGBM_gs.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "id": "145cac3b-2820-4228-a712-1e6302c74046", "metadata": {}, "source": [ "### Model Performance " ] }, { "cell_type": "code", "execution_count": 21, "id": "0007d060-322e-492f-9336-87d7a74f2c99", "metadata": {}, "outputs": [], "source": [ "roc = perf.ROC(LGBM_gs.best_estimator_.predict_proba, feature_names=X_train.columns)\n", "roc_explanation = roc.explain_perf(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 22, "id": "ab7c5552-14a7-469c-9de9-f4756bd54314", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show(roc_explanation)" ] }, { "cell_type": "markdown", "id": "2d93af9f-f1a8-42a9-8a63-8ddd1215e70b", "metadata": {}, "source": [ "### Local Explanations" ] }, { "cell_type": "code", "execution_count": 23, "id": "57d4600a-d835-4910-85f4-17807d1c0029", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9cbcbbf7f6a439884d4a391c040ccdf", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5 [00:00\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "shap = ShapKernel(predict_fn=LGBM_gs.predict_proba, data=X_train)\n", "shap_local = shap.explain_local(X_test[10:15], y_test[10:15])\n", "\n", "show(shap_local)" ] }, { "cell_type": "code", "execution_count": 24, "id": "931e64af-120b-49f8-b360-f05bcfb00c4b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lime = LimeTabular(predict_fn=LGBM_gs.predict_proba, data=X_train)\n", "lime_local = lime.explain_local(X_test[10:15], y_test[10:15])\n", "\n", "show(lime_local)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }