{ "cells": [ { "cell_type": "markdown", "id": "restricted-republic", "metadata": {}, "source": [ "# Assess income level predictions on adult census data" ] }, { "cell_type": "markdown", "id": "adolescent-fusion", "metadata": {}, "source": [ "This notebook demonstrates the use of the `responsibleai` API to assess a model trained on census data. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model." ] }, { "cell_type": "markdown", "id": "exempt-cartoon", "metadata": {}, "source": [ "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", " * [Train a Model](#Train-a-Model)\n", " * [Create Model and Data Insights](#Create-Model-and-Data-Insights)\n", "* [Assess Your Model](#Assess-Your-Model)\n", " * [Aggregate Analysis](#Aggregate-Analysis)\n", " * [Individual Analysis](#Individual-Analysis)" ] }, { "cell_type": "markdown", "id": "continent-dream", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" ] }, { "cell_type": "markdown", "id": "welsh-crisis", "metadata": {}, "source": [ "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed." ] }, { "cell_type": "markdown", "id": "sophisticated-bryan", "metadata": {}, "source": [ "### Train a Model\n", "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*" ] }, { "cell_type": "code", "execution_count": null, "id": "indie-message", "metadata": {}, "outputs": [], "source": [ "import zipfile\n", "\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", "import pandas as pd\n", "from lightgbm import LGBMClassifier" ] }, { "cell_type": "markdown", "id": "clinical-henry", "metadata": {}, "source": [ "First, load the census dataset and specify the different types of features. Compose a pipeline which contains a preprocessor and estimator." ] }, { "cell_type": "code", "execution_count": null, "id": "resistant-consequence", "metadata": {}, "outputs": [], "source": [ "from raiutils.dataset import fetch_dataset\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "\n", "def split_label(dataset, target_feature):\n", " X = dataset.drop([target_feature], axis=1)\n", " y = dataset[[target_feature]]\n", " return X, y\n", "\n", "def create_classification_pipeline(X):\n", " pipe_cfg = {\n", " 'num_cols': X.dtypes[X.dtypes == 'int64'].index.values.tolist(),\n", " 'cat_cols': X.dtypes[X.dtypes == 'object'].index.values.tolist(),\n", " }\n", " num_pipe = Pipeline([\n", " ('num_imputer', SimpleImputer(strategy='median')),\n", " ('num_scaler', StandardScaler())\n", " ])\n", " cat_pipe = Pipeline([\n", " ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n", " ('cat_encoder', OneHotEncoder(handle_unknown='ignore', sparse=False))\n", " ])\n", " feat_pipe = ColumnTransformer([\n", " ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n", " ('cat_pipe', cat_pipe, pipe_cfg['cat_cols'])\n", " ])\n", "\n", " # Append classifier to preprocessing pipeline.\n", " # Now we have a full prediction pipeline.\n", " pipeline = Pipeline(steps=[('preprocessor', feat_pipe),\n", " ('model', LGBMClassifier(random_state=0))])\n", "\n", " return pipeline\n", "\n", "outdirname = 'responsibleai.12.28.21'\n", "zipfilename = outdirname + '.zip'\n", "\n", "fetch_dataset('https://publictestdatasets.blob.core.windows.net/data/' + zipfilename, zipfilename)\n", "\n", "with zipfile.ZipFile(zipfilename, 'r') as unzip:\n", " unzip.extractall('.')\n", "\n", "target_feature = 'income'\n", "categorical_features = ['workclass', 'education', 'marital-status',\n", " 'occupation', 'relationship', 'race', 'gender', 'native-country']\n", "\n", "\n", "train_data = pd.read_csv('adult-train.csv', skipinitialspace=True)\n", "test_data = pd.read_csv('adult-test.csv', skipinitialspace=True)\n", "\n", "X_train_original, y_train = split_label(train_data, target_feature)\n", "X_test_original, y_test = split_label(test_data, target_feature)\n", "\n", "pipeline = create_classification_pipeline(X_train_original)\n", "\n", "y_train = y_train[target_feature].to_numpy()\n", "y_test = y_test[target_feature].to_numpy()\n", "\n", "\n", "# Take 500 samples from the test data\n", "test_data_sample = test_data.sample(n=500, random_state=5)" ] }, { "cell_type": "markdown", "id": "potential-proportion", "metadata": {}, "source": [ "Train the classification pipeline composed in the previous cell on the training data." ] }, { "cell_type": "code", "execution_count": null, "id": "biological-album", "metadata": {}, "outputs": [], "source": [ "model = pipeline.fit(X_train_original, y_train)" ] }, { "cell_type": "markdown", "id": "continued-praise", "metadata": {}, "source": [ "### Create Model and Data Insights" ] }, { "cell_type": "code", "execution_count": null, "id": "residential-identification", "metadata": {}, "outputs": [], "source": [ "from raiwidgets import ResponsibleAIDashboard\n", "from responsibleai import RAIInsights" ] }, { "cell_type": "markdown", "id": "cheap-juice", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", "\n", "RAIInsights accepts the model, the full dataset, the test dataset, the target feature string and the task type string as its arguments.\n", "\n", "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, specify a list of strings of categorical feature names via the `categorical_features` parameter, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`." ] }, { "cell_type": "code", "execution_count": null, "id": "bulgarian-hepatitis", "metadata": {}, "outputs": [], "source": [ "from responsibleai.feature_metadata import FeatureMetadata\n", "feature_metadata = FeatureMetadata(categorical_features=categorical_features, dropped_features=[])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bulgarian-hepatitis", "metadata": {}, "outputs": [], "source": [ "rai_insights = RAIInsights(model, train_data, test_data_sample, target_feature, 'classification',\n", " feature_metadata=feature_metadata)" ] }, { "cell_type": "markdown", "id": "original-rolling", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on model assessment." ] }, { "cell_type": "code", "execution_count": null, "id": "governing-antique", "metadata": {}, "outputs": [], "source": [ "# Interpretability\n", "rai_insights.explainer.add()\n", "# Error Analysis\n", "rai_insights.error_analysis.add()\n", "# Counterfactuals: accepts total number of counterfactuals to generate, the label that they should have, and a list of \n", " # strings of categorical feature names\n", "rai_insights.counterfactual.add(total_CFs=10, desired_class='opposite')" ] }, { "cell_type": "markdown", "id": "unexpected-bicycle", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." ] }, { "cell_type": "code", "execution_count": null, "id": "average-calibration", "metadata": {}, "outputs": [], "source": [ "rai_insights.compute()" ] }, { "cell_type": "markdown", "id": "b84c6c0d", "metadata": {}, "source": [ "Compose some cohorts which can be injected into the `ResponsibleAIDashboard`." ] }, { "cell_type": "code", "execution_count": null, "id": "0994b7d6", "metadata": {}, "outputs": [], "source": [ "from raiutils.cohort import Cohort, CohortFilter, CohortFilterMethods\n", "\n", "# Cohort on age and hours-per-week features in the dataset\n", "cohort_filter_age = CohortFilter(\n", " method=CohortFilterMethods.METHOD_LESS,\n", " arg=[65],\n", " column='age')\n", "cohort_filter_hours_per_week = CohortFilter(\n", " method=CohortFilterMethods.METHOD_GREATER,\n", " arg=[40],\n", " column='hours-per-week')\n", "\n", "user_cohort_age_and_hours_per_week = Cohort(name='Cohort Age and Hours-Per-Week')\n", "user_cohort_age_and_hours_per_week.add_cohort_filter(cohort_filter_age)\n", "user_cohort_age_and_hours_per_week.add_cohort_filter(cohort_filter_hours_per_week)\n", "\n", "# Cohort on marital-status feature in the dataset\n", "cohort_filter_marital_status = CohortFilter(\n", " method=CohortFilterMethods.METHOD_INCLUDES,\n", " arg=[\"Never-married\", \"Divorced\"],\n", " column='marital-status')\n", "\n", "user_cohort_marital_status = Cohort(name='Cohort Marital-Status')\n", "user_cohort_marital_status.add_cohort_filter(cohort_filter_marital_status)\n", "\n", "# Cohort on index of the row in the dataset\n", "cohort_filter_index = CohortFilter(\n", " method=CohortFilterMethods.METHOD_LESS,\n", " arg=[20],\n", " column='Index')\n", "\n", "user_cohort_index = Cohort(name='Cohort Index')\n", "user_cohort_index.add_cohort_filter(cohort_filter_index)\n", "\n", "# Cohort on predicted target value\n", "cohort_filter_predicted_y = CohortFilter(\n", " method=CohortFilterMethods.METHOD_INCLUDES,\n", " arg=['>50K'],\n", " column='Predicted Y')\n", "\n", "user_cohort_predicted_y = Cohort(name='Cohort Predicted Y')\n", "user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y)\n", "\n", "# Cohort on predicted target value\n", "cohort_filter_true_y = CohortFilter(\n", " method=CohortFilterMethods.METHOD_INCLUDES,\n", " arg=['>50K'],\n", " column='True Y')\n", "\n", "user_cohort_true_y = Cohort(name='Cohort True Y')\n", "user_cohort_true_y.add_cohort_filter(cohort_filter_true_y)\n", "\n", "cohort_list = [user_cohort_age_and_hours_per_week,\n", " user_cohort_marital_status,\n", " user_cohort_index,\n", " user_cohort_predicted_y,\n", " user_cohort_true_y]" ] }, { "cell_type": "markdown", "id": "elder-fleet", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." ] }, { "cell_type": "code", "execution_count": null, "id": "thousand-louis", "metadata": {}, "outputs": [], "source": [ "ResponsibleAIDashboard(rai_insights, cohort_list=cohort_list)" ] }, { "cell_type": "markdown", "id": "empty-parallel", "metadata": {}, "source": [ "## Assess Your Model" ] }, { "cell_type": "markdown", "id": "valued-victoria", "metadata": {}, "source": [ "### Aggregate Analysis" ] }, { "cell_type": "markdown", "id": "rubber-drove", "metadata": {}, "source": [ "The Error Analysis component is displayed at the top of the dashboard widget. To visualize how error is broken down across cohorts, use the tree map view to understand how it filters through the nodes." ] }, { "cell_type": "markdown", "id": "specified-painting", "metadata": {}, "source": [ "![Error Analysis tree map with \"Marital Status == 2,\" \"Capital Gain <= 1287.5,\" \"Capital Loss <= 1494.5\" path selected](./img/classification-assessment-1.png)" ] }, { "cell_type": "markdown", "id": "sitting-heritage", "metadata": {}, "source": [ "Over 40% of the error in this model is concentrated in datapoints of people who are married, have higher education and minimal capital gain. \n", "\n", "Let's see what else we can discover about this cohort.\n", "\n", "First, save the cohort by clicking \"Save as a new cohort\" on the right side panel of the Error Analysis component." ] }, { "cell_type": "markdown", "id": "average-chocolate", "metadata": {}, "source": [ "![Cohort creation sidebar and tree map cohort creation popup](./img/classification-assessment-2.png)" ] }, { "cell_type": "markdown", "id": "mature-reception", "metadata": {}, "source": [ "To switch to this cohort for analysis, click \"Switch global cohort\" and select the recently saved cohort from the dropdown." ] }, { "cell_type": "markdown", "id": "regulated-shame", "metadata": {}, "source": [ "![Popup with dropdown to shift cohort from \"All data\" to \"Married, Low Capital Loss/Gain\" accompanied by cohort statistics](./img/classification-assessment-3.png)" ] }, { "cell_type": "markdown", "id": "surrounded-alpha", "metadata": {}, "source": [ "The Model Overview component allows the comparison of statistics across multiple saved cohorts.\n", "\n", "The diagram indicates that the model is misclassifying datapoints of married individuals with low capital gains and high education as lower income (false negative)." ] }, { "cell_type": "markdown", "id": "married-wheel", "metadata": {}, "source": [ "![Bar chart of classification outcomes (true negative, true positive, false negative, false positive) compared across cohorts](./img/classification-assessment-4.png)" ] }, { "cell_type": "markdown", "id": "decreased-commission", "metadata": {}, "source": [ "Looking at the ground truth statistics of the overall data and the erroneous cohort, we realize there are opposite patterns in terms of high income representation in ground truth. While the overall data is representing more individuals with actual income of <= 50K, the married individuals with low capital gains and high education represent more individuals with actual income of > 50K. Given the small size of the dataset and this reverse pattern, the model makes more mistakes in predicting high income individuals. One action item is to collect a lot more data in both cohorts and retrain the model." ] }, { "cell_type": "markdown", "id": "legitimate-specific", "metadata": {}, "source": [ "![image-3.png](./img/classification-assessment-5.png)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "published-dancing", "metadata": {}, "source": [ "![image.png](./img/classification-assessment-6.png)" ] }, { "cell_type": "markdown", "id": "certain-huntington", "metadata": {}, "source": [ "The Interpretability component displays feature importances for model predictions at an individual and aggregate level. The plot below indicates that the `marital-status` attribute influence model predictions the most on average." ] }, { "attachments": {}, "cell_type": "markdown", "id": "danish-stadium", "metadata": {}, "source": [ "![Top 5 features of the cohort, in descending importance: relationship, age, capital gain, education-num, hours per week](./img/classification-assessment-7.png)" ] }, { "cell_type": "markdown", "id": "devoted-management", "metadata": {}, "source": [ "The lower half of this tab specifies how marita status affects model prediction. Being a husband or wife (married-civ-spouse) is more likely to pull the prediction away from <=50k, possibly because couples have a higher cumulative income." ] }, { "attachments": {}, "cell_type": "markdown", "id": "nervous-confusion", "metadata": {}, "source": [ "![Feature importance stratified by relationship](./img/classification-assessment-8.png)" ] }, { "cell_type": "markdown", "id": "absolute-bench", "metadata": {}, "source": [ "### Individual Analysis" ] }, { "cell_type": "markdown", "id": "lesbian-market", "metadata": {}, "source": [ "Let's revisit Data Explorer. In the \"Individual datapoints\" view, we can see the prediction probabilities of each point. Point 510 is one that was just above the threshold to be classified as income of > 50K." ] }, { "attachments": {}, "cell_type": "markdown", "id": "under-seminar", "metadata": {}, "source": [ "![Scatter plot of prediction probabilities (rounded to 0.2) on the y-axis and index on the x-axis](./img/classification-assessment-9.png)" ] }, { "cell_type": "markdown", "id": "intermediate-maximum", "metadata": {}, "source": [ "What factors led the model to make this decision?\n", "\n", "The \"Individual feature importance\" tab in the Interpretability component's Feature Importances section let you select points for further analysis." ] }, { "attachments": {}, "cell_type": "markdown", "id": "medium-broadcasting", "metadata": {}, "source": [ "![Table of datapoints with row 510 selected](./img/classification-assessment-10.png)" ] }, { "cell_type": "markdown", "id": "different-frank", "metadata": {}, "source": [ "Under this, the feature importance plot shows `capital-gain` and `native-country` as the most significant factors leading to the <= 50K classification. Changing these may cause the threshold to be crossed and the model to predict the opposite class. Please note that depending on the context, the high importance of `native-country` might be considered as a fairness issue." ] }, { "attachments": {}, "cell_type": "markdown", "id": "super-cambodia", "metadata": {}, "source": [ "![Feature importance plot for classification of 0 (descending, positive to negative): age, hours per week, capital gain, race, education-num, workclass, sex, country, occupation, marital status, relationship, capital loss](./img/classification-assessment-11.png)" ] }, { "cell_type": "markdown", "id": "ruled-quantum", "metadata": {}, "source": [ "The What-If Counterfactuals component focuses on how to change features slightly in order to change model predictions. As seen in its top ranked features bar plot, changing this person's marital-status, capital-loss, and education-num have the highest impact on flipping the prediction to > 50K." ] }, { "attachments": {}, "cell_type": "markdown", "id": "rough-uncle", "metadata": {}, "source": [ "![Top-ranked features (descending) for datapoint 510 to perturb to flip model prediction: age, hours per week, capital gain, capital loss, marital status, occupation, education-num, workclass, relationship, race, sex, country](./img/classification-assessment-12.png)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.12" } }, "nbformat": 4, "nbformat_minor": 5 }