{ "cells": [ { "cell_type": "markdown", "id": "75018d5c", "metadata": {}, "source": [ "# Plan real-world action using counterfactual example analysis and causal analysis" ] }, { "cell_type": "markdown", "id": "d4939847", "metadata": {}, "source": [ "This notebook demonstrates the use of the Responsible AI Toolbox to make decisions from diabetes progression data. It walks through the API calls necessary to create a widget with causal inferencing insights, then guides a visual analysis of the data." ] }, { "cell_type": "markdown", "id": "231caa35", "metadata": {}, "source": [ "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", " * [Train a Model](#Train-a-Model)\n", " * [Create Model and Data Insights](#Create-Model-and-Data-Insights)\n", "* [Take Real-World Action](#Take-Real-World-Action)\n", " * [What-If Counterfactuals Analysis](#What-If-Counterfactuals-Analysis)\n", " * [Causal Analysis](#Causal-Analysis)" ] }, { "cell_type": "markdown", "id": "8cfa82d1", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" ] }, { "cell_type": "markdown", "id": "789b30d1", "metadata": {}, "source": [ "The following section examines the code necessary to create the dataset. It then generates insights using the `responsibleai` API that can be visually analyzed." ] }, { "cell_type": "markdown", "id": "3e43e464", "metadata": {}, "source": [ "### Train a Model\n", "*The following section can be skipped. It loads a dataset for illustrative purposes.*" ] }, { "cell_type": "code", "execution_count": null, "id": "a670ba8c", "metadata": {}, "outputs": [], "source": [ "import shap\n", "import sklearn\n", "import pandas as pd\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestRegressor" ] }, { "cell_type": "markdown", "id": "a4f53194", "metadata": {}, "source": [ "First, load the diabetes dataset and specify the different types of features. Then, clean it and put it into a DataFrame with named columns." ] }, { "cell_type": "code", "execution_count": null, "id": "479ad4f8", "metadata": {}, "outputs": [], "source": [ "data = sklearn.datasets.load_diabetes()\n", "target_feature = 'y'\n", "continuous_features = data.feature_names\n", "data_df = pd.DataFrame(data.data, columns=data.feature_names)" ] }, { "cell_type": "markdown", "id": "c7cdd8ae", "metadata": {}, "source": [ "After loading and cleaning the data, split the datapoints into training and test sets. Assemble separate datasets for the training and test data." ] }, { "cell_type": "code", "execution_count": null, "id": "4e02d132", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(data_df, data.target, test_size=0.2, random_state=7)\n", "\n", "train_data = X_train.copy()\n", "test_data = X_test.copy()\n", "train_data[target_feature] = y_train\n", "test_data[target_feature] = y_test\n", "data.feature_names" ] }, { "cell_type": "markdown", "id": "46119282", "metadata": {}, "source": [ "You may define `features_to_drop` and drop any features from `X_train`. The model will be trained without `features_to_drop`. If `features_to_drop` is not set, `X_train_after_drop` will be the same as `X_train`." ] }, { "cell_type": "code", "execution_count": null, "id": "3db1ef3b", "metadata": {}, "outputs": [], "source": [ "features_to_drop = []\n", "X_train_after_drop = X_train.drop(features_to_drop, axis=1)" ] }, { "cell_type": "markdown", "id": "59853607", "metadata": {}, "source": [ "Train a nearest-neighbors classifier on the training data." ] }, { "cell_type": "code", "execution_count": null, "id": "6612038f", "metadata": {}, "outputs": [], "source": [ "model = RandomForestRegressor(random_state=0)\n", "model.fit(X_train_after_drop, y_train)" ] }, { "cell_type": "markdown", "id": "29805164", "metadata": {}, "source": [ "### Create Model and Data Insights" ] }, { "cell_type": "code", "execution_count": null, "id": "c65f788f", "metadata": {}, "outputs": [], "source": [ "from raiwidgets import ResponsibleAIDashboard\n", "from responsibleai import RAIInsights" ] }, { "cell_type": "markdown", "id": "400de1d9", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", "RAIInsights accepts the model, the train dataset, the test dataset, the target feature string and the task type string as its arguments.\n", "\n", "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, specify a list of strings of categorical feature names via the `categorical_features` parameter, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`." ] }, { "cell_type": "code", "execution_count": null, "id": "d965f768", "metadata": {}, "outputs": [], "source": [ "from responsibleai.feature_metadata import FeatureMetadata\n", "# Add 's1' as an identity feature, set features_to_drop as dropped features\n", "feature_metadata = FeatureMetadata(identity_feature_name='s1', categorical_features=[], dropped_features=features_to_drop)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d965f769", "metadata": {}, "outputs": [], "source": [ "rai_insights = RAIInsights(model, train_data, test_data, target_feature, 'regression',\n", " feature_metadata=feature_metadata)" ] }, { "cell_type": "markdown", "id": "38fbbe06", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on decision-making." ] }, { "cell_type": "code", "execution_count": null, "id": "24567d8d", "metadata": {}, "outputs": [], "source": [ "# Counterfactuals: accepts total number of counterfactuals to generate, the range that their label should fall under, \n", "# and a list of strings of categorical feature names\n", "rai_insights.counterfactual.add(total_CFs=20, desired_range=[50, 120])\n", "# Causal Inference: determines causation between features\n", "rai_insights.causal.add(treatment_features=['bmi', 'bp', 's2'])" ] }, { "cell_type": "markdown", "id": "571b2235", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." ] }, { "cell_type": "code", "execution_count": null, "id": "a7dec636", "metadata": {}, "outputs": [], "source": [ "rai_insights.compute()" ] }, { "cell_type": "markdown", "id": "0ad206fd", "metadata": {}, "source": [ "Compose some cohorts which can be injected into the `ResponsibleAIDashboard`." ] }, { "cell_type": "code", "execution_count": null, "id": "7a039b34", "metadata": {}, "outputs": [], "source": [ "from raiutils.cohort import Cohort, CohortFilter, CohortFilterMethods\n", "\n", "# Cohort on age and bmi features in the dataset\n", "cohort_filter_age = CohortFilter(\n", " method=CohortFilterMethods.METHOD_LESS,\n", " arg=[40],\n", " column='age')\n", "cohort_filter_bmi = CohortFilter(\n", " method=CohortFilterMethods.METHOD_GREATER,\n", " arg=[0],\n", " column='bmi')\n", " \n", "user_cohort_age_and_bmi= Cohort(name='Cohort Age and BMI')\n", "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_age)\n", "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_bmi)\n", "\n", "# Cohort on index\n", "cohort_filter_index = CohortFilter(\n", " method=CohortFilterMethods.METHOD_LESS,\n", " arg=[20],\n", " column='Index')\n", "\n", "user_cohort_index = Cohort(name='Cohort Index')\n", "user_cohort_index.add_cohort_filter(cohort_filter_index)\n", "\n", "# Cohort on predicted y values\n", "cohort_filter_predicted_y = CohortFilter(\n", " method=CohortFilterMethods.METHOD_LESS,\n", " arg=[165.0],\n", " column='Predicted Y')\n", "\n", "user_cohort_predicted_y = Cohort(name='Cohort Predicted Y')\n", "user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y)\n", "\n", "# Cohort on true y values\n", "cohort_filter_true_y = CohortFilter(\n", " method=CohortFilterMethods.METHOD_GREATER,\n", " arg=[45.0],\n", " column='True Y')\n", "\n", "user_cohort_true_y = Cohort(name='Cohort True Y')\n", "user_cohort_true_y.add_cohort_filter(cohort_filter_true_y)\n", "\n", "# Cohort on true y values\n", "cohort_filter_regression_error = CohortFilter(\n", " method=CohortFilterMethods.METHOD_GREATER,\n", " arg=[20.0],\n", " column='Error')\n", "\n", "user_cohort_regression_error = Cohort(name='Cohort Regression Error')\n", "user_cohort_regression_error.add_cohort_filter(cohort_filter_regression_error)\n", "\n", "cohort_list = [user_cohort_age_and_bmi,\n", " user_cohort_index,\n", " user_cohort_predicted_y,\n", " user_cohort_true_y,\n", " user_cohort_regression_error]" ] }, { "cell_type": "markdown", "id": "54a43b5c", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." ] }, { "cell_type": "code", "execution_count": null, "id": "ad84c884", "metadata": {}, "outputs": [], "source": [ "ResponsibleAIDashboard(rai_insights, cohort_list=cohort_list)" ] }, { "cell_type": "markdown", "id": "fb2ab57e", "metadata": {}, "source": [ "## Take Real-World Action" ] }, { "cell_type": "markdown", "id": "84325421", "metadata": {}, "source": [ "### What-If Counterfactuals Analysis" ] }, { "cell_type": "markdown", "id": "d292d247", "metadata": {}, "source": [ "Let's imagine that the diabetes progression scores predicted by the model are used to determine medical insurance rates. If the score is greater than 120, there is a higher rate. Patient 43's model score of 268.08 results in this increased rate, and they want to know how they should change their health to get a lower rate prediction from the model (leading to lower insurance price).\n", "\n", "The What-If counterfactuals component shows how slightly different feature values affect model predictions. This can be used to solve Patient 43's problem." ] }, { "attachments": {}, "cell_type": "markdown", "id": "d459156b", "metadata": {}, "source": [ "![What-If Counterfactuals component with datapoint 43 selected on the scatter plot with axes \"Predicted Y\" and \"Index\"](./img/regression-decision-making-1.png)" ] }, { "cell_type": "markdown", "id": "d7b86696", "metadata": {}, "source": [ "What can Patient 43 do to create the desired change? The top ranked features bar plot shows that `bmi` and `s5` are the best to perturb to bring the model score within 120." ] }, { "attachments": {}, "cell_type": "markdown", "id": "b16d1a6c", "metadata": {}, "source": [ "![Top-ranked features (descending) for datapoint 43 to perturb to reduce model prediction below 120: bmi, s5, s4, s3, age, bp, sex, s1, s2, s6](./img/regression-decision-making-2.png)" ] }, { "cell_type": "markdown", "id": "709c3019", "metadata": {}, "source": [ "Let's see how that can be achieved. Change `bmi` to -0.04 and `s5` to -0.042 and see what the result is." ] }, { "attachments": {}, "cell_type": "markdown", "id": "5faa62ea", "metadata": {}, "source": [ "![Counterfactual creation panel. BMI has been changed to -0.04 and s5 has been changed to -0.042](./img/regression-decision-making-3.png)" ] }, { "cell_type": "markdown", "id": "a9f67339", "metadata": {}, "source": [ "As we can see, the model's prediction has dropped to 131.22. Thus, Patient 43 should work on reducing their [body mass index and serum triglycerides level](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) to bring the model score under the insurance threshold." ] }, { "attachments": {}, "cell_type": "markdown", "id": "22d445d7", "metadata": {}, "source": [ "![Counterfactual of datapoint 43 selected on the counterfactuals scatter plot with axes \"Predicted Y\" and \"Index\". Predicted Y is 115.4](./img/regression-decision-making-4.png)" ] }, { "cell_type": "markdown", "id": "b4f78fd8", "metadata": {}, "source": [ "Note that this result does not mean that reducing `bmi` and `s5` *causes* the diabetes progression score to go down. It simply decreases the model prediction. To investigate causal relationships, continue reading:" ] }, { "cell_type": "markdown", "id": "b134cdb5", "metadata": {}, "source": [ "### Causal Analysis" ] }, { "cell_type": "markdown", "id": "da76466d", "metadata": {}, "source": [ "Now suppose that a doctor wishes to know how to reduce the progression of diabetes in her patients. This can be explored in the Causal Inference component of the Responsible AI Toolbox.\n", "\n", "In the \"Aggregate causal effects\" tab, it is possible to see how perturbing features causes lower disease progression. It appears that increasing s2 (LDL) by one unit, would increase diabetes progression by 606.89 units." ] }, { "attachments": {}, "cell_type": "markdown", "id": "90b838d8", "metadata": {}, "source": [ "![Overall causal analysis table](./img/regression-decision-making-5.png)\n", "![Causal analysis diagram](./img/regression-decision-making-6.png)" ] }, { "cell_type": "markdown", "id": "f6078481", "metadata": {}, "source": [ "Let's revisit Patient 43. Instead of simply reducing the model score, they've decided to focus on actually improving their health to manage their diabetes better. In the \"Individual causal what-if\" tab, it shows that decreasing his/her bmi to 0.05 reduces diabetes progression from 242 to 237.982." ] }, { "attachments": {}, "cell_type": "markdown", "id": "93105414", "metadata": {}, "source": [ "![individual causal analysis table](./img/regression-decision-making-7.png)" ] }, { "cell_type": "markdown", "id": "a6fa7384", "metadata": {}, "source": [ "To put that into a formal intervention policy, switch to the \"Treatment policy\" tab. This view helps build policies for future interventions. You can identify what parts of your sample experience the largest responses to changes in causal features, or treatments, and construct rules to define which future populations should be targeted for particular interventions." ] }, { "attachments": {}, "cell_type": "markdown", "id": "d1af0772", "metadata": {}, "source": [ "![treatment_policy](./img/regression-decision-making-8.png)" ] }, { "cell_type": "markdown", "id": "ac8025e4", "metadata": {}, "source": [ "Is that change the best overall treatment for them? Let's investigate different policies. Going back to the \"Treatment policy\" tab, we see that going with the above intervention of s2 feature outperforms perturbing that with a \"always increase\" intervention." ] }, { "attachments": {}, "cell_type": "markdown", "id": "ce677d35", "metadata": {}, "source": [ "![image.png](./img/regression-decision-making-9.png)" ] }, { "cell_type": "markdown", "id": "3355ea1c", "metadata": {}, "source": [ "Finally, you can see a list demonstrating which datapoints (patients) in the current data sample have the largest causal response to the selected treatment (s2 feature change), based on all features included in the estimated causal model." ] }, { "attachments": {}, "cell_type": "markdown", "id": "3cb02322", "metadata": {}, "source": [ "![causal-table](./img/regression-decision-making-10.png)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }