{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Assess disease progression predictions on diabetes data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates the use of the `responsibleai` API to assess a regression trained on diabetes progression data. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", " * [Train a Model](#Train-a-Model)\n", " * [Create Model and Data Insights](#Create-Model-and-Data-Insights)\n", "* [Assess Your Model](#Assess-Your-Model)\n", " * [Aggregate Analysis](#Aggregate-Analysis)\n", " * [Individual Analysis](#Individual-Analysis)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train a Model\n", "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import shap\n", "import sklearn\n", "import pandas as pd\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestRegressor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, load the diabetes dataset and specify the different types of features. Then, clean it and put it into a DataFrame with named columns." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = sklearn.datasets.load_diabetes()\n", "target_feature = 'y'\n", "continuous_features = data.feature_names\n", "data_df = pd.DataFrame(data.data, columns=data.feature_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After loading and cleaning the data, split the datapoints into training and test sets. Assemble separate datasets for the full sample and the test data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(data_df, data.target, test_size=0.2, random_state=7)\n", "\n", "train_data = X_train.copy()\n", "test_data = X_test.copy()\n", "train_data[target_feature] = y_train\n", "test_data[target_feature] = y_test\n", "data.feature_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train a nearest-neighbors classifier on the training data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = RandomForestRegressor()\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Model and Data Insights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from raiwidgets import ResponsibleAIDashboard\n", "from responsibleai import RAIInsights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", "RAIInsights accepts the model, the train dataset, the test dataset, the target feature string, the task type string, and a list of strings of categorical feature names as its arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rai_insights = RAIInsights(model, train_data, test_data, target_feature, 'regression',\n", " categorical_features=[])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on model assessment." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Interpretability\n", "rai_insights.explainer.add()\n", "# Error Analysis\n", "rai_insights.error_analysis.add()\n", "# Counterfactuals: accepts total number of counterfactuals to generate, the range that their label should fall under, \n", "# and a list of strings of categorical feature names\n", "rai_insights.counterfactual.add(total_CFs=20, desired_range=[50, 120])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rai_insights.compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ResponsibleAIDashboard(rai_insights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Assess Your Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Aggregate Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Upon opening the dashboard widget, the Error Analysis component is displayed at the top. The tree map view of this component visualizes the cohort breakdown of error in nodes." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Error Analysis tree map with \"age <= 0.02\" path selected](./img/regression-assessment-1.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this model, over 70% of the error is concentrated in datapoints whose age feature is less than 0.02. Note that this value has been mean-centered and scaled by the number of samples * standard deviation.\n", "\n", "Let's explore this cohort further.\n", "\n", "First, save the cohort of interest. This can be done by clicking \"Save as a new cohort\" on the ride side panel." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Cohort creation sidebar and tree map cohort creation popup](./img/regression-assessment-2.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, click \"Switch global cohort\" and select the recently saved cohort to focus analysis on." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Popup with dropdown to shift cohort from \"All data\" to \"Age <= 0.02\" accompanied by cohort statistics](./img/regression-assessment-3.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The heatmap can be used to understand this cohort further by visualizing the error rates of subcohorts. Investigating feature `s5`, [thought to represent the log of the serum triglycerides level](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset), shows a marked difference in error rates between bins of data." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Heatmap of error stratified by s5](./img/regression-assessment-4.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use the Model Overview component to examine how the distribution of error in the overall dataset, stratified by `s5`, compares." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Model Overview error boxplots stratified by s5](./img/regression-assessment-5.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It appears that the three middle bins of the five have the largest interquartile ranges of error. How might the error patterns relate to the underlying distribution of the data?\n", "\n", "The Data Explorer component can help to visualize how data is spread across these bins. Interestingly, the bins with the most varied error also have the most datapoints. Intuitively, it makes sense that the more points in the bin, the more variation of error within it." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Data Explorer distribution of s5 values across dataset](./img/regression-assessment-6.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Moving on to understanding how the model makes decisions, the \"Aggregate feature importance\" plot on the Interpretability component lists the top features used across all datapoints in making predictions.\n", "\n", "`s5` and `bmi`, body mass index, are the features that most influence the model's decisions." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Top 8 features of the cohort, in descending importance: s5, bmi, sex, s3, bp, s6, age, s2](./img/regression-assessment-7.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Individual Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The view of the Data Explorer component can be modified to display a scatterplot of various metrics on individual datapoints. Here, we see that the prediction with the largest magnitude negative error rate belongs to patient 13." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Data explorer scatterplot of error vs. index with datapoint 13 selected](./img/regression-assessment-8.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Patient 13's actual diabetes progression score is 52, barely a quarter of the model's prediction. How did the model go so wrong? Use the \"Individual feature importance\" to select the datapoint corresponding to this patient for further analysis." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Table of datapoints with row 13 selected](./img/regression-assessment-9.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scrolling lower, we can see that the model's prediction depended most on `s5`, like the average feature importances seen earlier." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Feature importance plot (descending): s5, sex, s3, bmi](./img/regression-assessment-10.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How would the model's prediction change if the `s5` value of this data point differed? There are two ways to explore the answer to this question.\n", "\n", "The \"Individual conditional expectation\", or ICE plot, illustrates how the model's prediction for this point would have changed along a range of `s5` values. The point's `s5` value is approximately 0.03. This visualization shows that a lower `s5` value awould result in a significant decrease in the model's prediction." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Individual conditional expectation plot for s5](./img/regression-assessment-11.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A similar analysis can be conducted elsewhere. The What-If Counterfactuals component is used to generate artificial points similar to a real point in the dataset, with the values of one or a few features set slightly differently. This can be used to understand how different feature combinations create different model outputs.\n", "\n", "The top-ranked features plot of this component illustrates what changes in feature values would be most effective in moving the model prediction for patient 13 into the range specified when adding the counterfactual component to the dashboard (see code above)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Top ranked features for datapoint 13 (descending): s5, bmi, s1, s3, bp, s4, s6, s2, sex, age](./img/regression-assessment-12.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the ICE plot, we know that decreasing `s5` will also decrease the model prediction. Let's see how changing `s5` and `bmi` together will shift the prediction. \n", "\n", "Select point 13 on the What-If Counterfactuals scatterplot and click \"Create what-if counterfactual\" to the right of the scatterplot. This brings up a side panel with sample counterfactuals and an option to create a new one." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Panel to create counterfactuals of row 13](./img/regression-assessment-13.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating a what-if counterfactual of point 13 with the value -0.1 for `s5` and `bmi` has the following result. We see that the magnitude of error drops from about 150 to about 50." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![Counterfactuals scatterplot of error vs. index showing a 100pt decrease in error with the counterfactual](./img/regression-assessment-14.png)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }