{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Protodash: NHANES (CDC) data example\n", "- This notebook shows an example of how to use the ProtodashExplainer defined in [AIX360](https://github.com/IBM/AIX360/) to generate prototypes from (training/test) data. The notebook uses one of the [NHANES CDC questionnaire dataset](https://wwwn.cdc.gov/nchs/nhanes/search/datapage.aspx?Component=Questionnaire&CycleBeginYear=2013) related to incomes of individuals.\n", "- ProtodashExplainer is an implementation of the [Protodash algorithm](https://arxiv.org/abs/1707.01212)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Protodash Explainer examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import statements" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", "from aix360.algorithms.protodash import ProtodashExplainer, get_Gaussian_Data\n", "from aix360.datasets import CDCDataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load NHANES dataset from CDC " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "nhanes = CDCDataset()\n", "nhanes_files = nhanes.get_csv_file_names()\n", "(nhanesinfo, _, _) = nhanes._cdc_files_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "#### Explore NHANES Income questionnaire dataset\n", "\n", "Now let us explore the income questionnaire dataset and find out the types of responses received in the survey. Each column in this dataset corresponds to a question and each row denotes the answers given by a respondent to those questions. Both column names and answers by respondents are encoded. For example, 'SEQN' denotes the sequence number assigned to a respondent and 'IND235' corresponds to a question about monthly family income. As seen below, in most cases a value of 1 implies \"Yes\" to the question, while a value of 2 implies \"No\". More details about the income questionaire and how questions and answers are encoded can be seen [here](https://wwwn.cdc.gov/Nchs/Nhanes/2013-2014/INQ_H.htm)\n", "\n", "|Column |Description | Values and Meaning|\n", "|-------|----------------------------|---------|\n", "|SEQN | Respondent sequence number |\n", "|INQ020 | Income from wages/salaries |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ012 | Income from self employment|1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ030 | Income from Social Security or RR |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ060 | Income from other disability pension |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ080 | Income from retirement/survivor pension |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ090 | Income from Supplemental Security Income |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ132 | Income from state/county cash assistance |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ140 | Income from interest/dividends or rental |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|INQ150 | Income from other sources |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|IND235 | Monthly family income |1-12->Increasing income brackets, 77->Refused, 99->Don't know|\n", "|INDFMMPI | Family monthly poverty level index |0-5->Higher value more affluent|\n", "|INDFMMPC | Family monthly poverty level category |1-3->Increasing INDFMMPI brackets, 7->Refused, 9->Don't know|\n", "|INQ244 | Family has savings more than $5000 |1->Yes, 2->No, 7->Refused, 9->Don't know|\n", "|IND247 | Total savings/cash assets for the family |1-6->Increasing savings brackets, 77->Refused, 99->Don't know|" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answers given by some respondents to the income questionnaire:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Respondent sequence numberIncome from wages/salariesIncome from self employmentIncome from Social Security or RRIncome from other disability pensionIncome from retirement/survivor pensionIncome from Supplemental Security IncomeIncome from state/county cash assistanceIncome from interest/dividends or rentalIncome from other sourcesMonthly family incomeFamily monthly poverty level indexFamily monthly poverty level categoryFamily has savings more than $5000Total savings/cash assets for the family
073557.02.02.01.02.02.02.02.02.02.04.00.861.09.0NaN
173558.01.01.01.02.02.02.02.01.02.05.00.921.01.0NaN
273559.02.02.01.02.01.02.02.01.02.010.04.373.0NaNNaN
373560.01.02.02.02.02.02.02.02.01.09.02.523.0NaNNaN
473561.02.02.01.02.02.02.02.02.02.011.05.003.0NaNNaN
\n", "
" ], "text/plain": [ " Respondent sequence number Income from wages/salaries \\\n", "0 73557.0 2.0 \n", "1 73558.0 1.0 \n", "2 73559.0 2.0 \n", "3 73560.0 1.0 \n", "4 73561.0 2.0 \n", "\n", " Income from self employment Income from Social Security or RR \\\n", "0 2.0 1.0 \n", "1 1.0 1.0 \n", "2 2.0 1.0 \n", "3 2.0 2.0 \n", "4 2.0 1.0 \n", "\n", " Income from other disability pension \\\n", "0 2.0 \n", "1 2.0 \n", "2 2.0 \n", "3 2.0 \n", "4 2.0 \n", "\n", " Income from retirement/survivor pension \\\n", "0 2.0 \n", "1 2.0 \n", "2 1.0 \n", "3 2.0 \n", "4 2.0 \n", "\n", " Income from Supplemental Security Income \\\n", "0 2.0 \n", "1 2.0 \n", "2 2.0 \n", "3 2.0 \n", "4 2.0 \n", "\n", " Income from state/county cash assistance \\\n", "0 2.0 \n", "1 2.0 \n", "2 2.0 \n", "3 2.0 \n", "4 2.0 \n", "\n", " Income from interest/dividends or rental Income from other sources \\\n", "0 2.0 2.0 \n", "1 1.0 2.0 \n", "2 1.0 2.0 \n", "3 2.0 1.0 \n", "4 2.0 2.0 \n", "\n", " Monthly family income Family monthly poverty level index \\\n", "0 4.0 0.86 \n", "1 5.0 0.92 \n", "2 10.0 4.37 \n", "3 9.0 2.52 \n", "4 11.0 5.00 \n", "\n", " Family monthly poverty level category Family has savings more than $5000 \\\n", "0 1.0 9.0 \n", "1 1.0 1.0 \n", "2 3.0 NaN \n", "3 3.0 NaN \n", "4 3.0 NaN \n", "\n", " Total savings/cash assets for the family \n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 NaN \n", "4 NaN " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# replace encoded column names by the associated question text. \n", "df_inc = nhanes.get_csv_file('INQ_H.csv')\n", "df_inc.columns[0]\n", "dict_inc = {\n", "'SEQN': 'Respondent sequence number', \n", "'INQ020': 'Income from wages/salaries',\n", "'INQ012': 'Income from self employment',\n", "'INQ030':'Income from Social Security or RR',\n", "'INQ060': 'Income from other disability pension', \n", "'INQ080': 'Income from retirement/survivor pension',\n", "'INQ090': 'Income from Supplemental Security Income',\n", "'INQ132': 'Income from state/county cash assistance', \n", "'INQ140': 'Income from interest/dividends or rental', \n", "'INQ150': 'Income from other sources',\n", "'IND235': 'Monthly family income',\n", "'INDFMMPI': 'Family monthly poverty level index', \n", "'INDFMMPC': 'Family monthly poverty level category',\n", "'INQ244': 'Family has savings more than $5000',\n", "'IND247': 'Total savings/cash assets for the family'\n", "}\n", "qlist = []\n", "for i in range(len(df_inc.columns)):\n", " qlist.append(dict_inc[df_inc.columns[i]])\n", "df_inc.columns = qlist\n", "print(\"Answers given by some respondents to the income questionnaire:\")\n", "df_inc.head(5)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of respondents to Income questionnaire: 10175\n", "Distribution of answers to 'monthly family income' and 'Family savings' questions:\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(\"Number of respondents to Income questionnaire:\", df_inc.shape[0])\n", "print(\"Distribution of answers to \\'monthly family income\\' and \\'Family savings\\' questions:\")\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(10,5))\n", "fig.subplots_adjust(wspace=0.5)\n", "hist1 = df_inc['Monthly family income'].value_counts().plot(kind='bar', ax=axes[0])\n", "hist2 = df_inc['Family has savings more than $5000'].value_counts().plot(kind='bar', ax=axes[1])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "#### Summarize NHANES Income Questionnaire dataset using Prototypes\n", "\n", "Consider a social scientist who would like to quickly obtain a summary report of this dataset in terms of types of people that span this dataset. Is it possible to summarize this dataset by looking at answers given by a few representative/prototypical respondents? \n", "\n", "We now show how the ProtodashExplainer can be used to obtain a few prototypical respondents (about 10 in this example) that span the diverse set of individuals answering the income questionnaire making it easy for the social scientist to summarize the dataset." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# convert pandas dataframe to numpy\n", "data = df_inc.to_numpy()\n", "\n", "#sort the rows by sequence numbers in 1st column \n", "idx = np.argsort(data[:, 0]) \n", "data = data[idx, :]\n", "\n", "# replace nan's (missing values) with 0's\n", "original = data\n", "original[np.isnan(original)] = 0\n", "\n", "# delete 1st column (sequence numbers)\n", "original = original[:, 1:]\n", "\n", "# one hot encode all features as they are categorical\n", "onehot_encoder = OneHotEncoder(sparse=False)\n", "onehot_encoded = onehot_encoder.fit_transform(original)\n", "\n", "explainer = ProtodashExplainer()\n", "\n", "# call protodash explainer\n", "# S contains indices of the selected prototypes\n", "# W contains importance weights associated with the selected prototypes \n", "(W, S, _) = explainer.explain(onehot_encoded, onehot_encoded, m=10) " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Respondent sequence numberIncome from wages/salariesIncome from self employmentIncome from Social Security or RRIncome from other disability pensionIncome from retirement/survivor pensionIncome from Supplemental Security IncomeIncome from state/county cash assistanceIncome from interest/dividends or rentalIncome from other sourcesMonthly family incomeFamily monthly poverty level indexFamily monthly poverty level categoryFamily has savings more than $5000Total savings/cash assets for the familyWeights of Prototypes
873565.01.02.02.02.02.02.02.02.02.012.05.003.0NaNNaN0.18
747581032.02.02.01.02.02.02.02.01.02.03.01.081.02.01.00.10
244976006.01.01.02.02.02.02.02.02.01.01.00.001.02.01.00.15
291276469.01.02.02.01.01.02.02.02.02.06.01.322.01.0NaN0.07
689580452.01.02.02.02.02.01.01.02.02.05.00.861.02.01.00.09
389977456.01.02.01.02.02.02.02.02.02.07.02.713.0NaNNaN0.12
147575032.01.02.02.02.02.02.02.01.02.04.01.652.02.03.00.06
69074247.02.02.02.02.02.02.02.02.02.08.03.053.0NaNNaN0.09
507778634.01.02.02.02.01.02.02.02.02.02.00.441.02.01.00.07
13273689.01.02.02.02.02.02.02.01.02.011.04.303.0NaNNaN0.07
\n", "
" ], "text/plain": [ " Respondent sequence number Income from wages/salaries \\\n", "8 73565.0 1.0 \n", "7475 81032.0 2.0 \n", "2449 76006.0 1.0 \n", "2912 76469.0 1.0 \n", "6895 80452.0 1.0 \n", "3899 77456.0 1.0 \n", "1475 75032.0 1.0 \n", "690 74247.0 2.0 \n", "5077 78634.0 1.0 \n", "132 73689.0 1.0 \n", "\n", " Income from self employment Income from Social Security or RR \\\n", "8 2.0 2.0 \n", "7475 2.0 1.0 \n", "2449 1.0 2.0 \n", "2912 2.0 2.0 \n", "6895 2.0 2.0 \n", "3899 2.0 1.0 \n", "1475 2.0 2.0 \n", "690 2.0 2.0 \n", "5077 2.0 2.0 \n", "132 2.0 2.0 \n", "\n", " Income from other disability pension \\\n", "8 2.0 \n", "7475 2.0 \n", "2449 2.0 \n", "2912 1.0 \n", "6895 2.0 \n", "3899 2.0 \n", "1475 2.0 \n", "690 2.0 \n", "5077 2.0 \n", "132 2.0 \n", "\n", " Income from retirement/survivor pension \\\n", "8 2.0 \n", "7475 2.0 \n", "2449 2.0 \n", "2912 1.0 \n", "6895 2.0 \n", "3899 2.0 \n", "1475 2.0 \n", "690 2.0 \n", "5077 1.0 \n", "132 2.0 \n", "\n", " Income from Supplemental Security Income \\\n", "8 2.0 \n", "7475 2.0 \n", "2449 2.0 \n", "2912 2.0 \n", "6895 1.0 \n", "3899 2.0 \n", "1475 2.0 \n", "690 2.0 \n", "5077 2.0 \n", "132 2.0 \n", "\n", " Income from state/county cash assistance \\\n", "8 2.0 \n", "7475 2.0 \n", "2449 2.0 \n", "2912 2.0 \n", "6895 1.0 \n", "3899 2.0 \n", "1475 2.0 \n", "690 2.0 \n", "5077 2.0 \n", "132 2.0 \n", "\n", " Income from interest/dividends or rental Income from other sources \\\n", "8 2.0 2.0 \n", "7475 1.0 2.0 \n", "2449 2.0 1.0 \n", "2912 2.0 2.0 \n", "6895 2.0 2.0 \n", "3899 2.0 2.0 \n", "1475 1.0 2.0 \n", "690 2.0 2.0 \n", "5077 2.0 2.0 \n", "132 1.0 2.0 \n", "\n", " Monthly family income Family monthly poverty level index \\\n", "8 12.0 5.00 \n", "7475 3.0 1.08 \n", "2449 1.0 0.00 \n", "2912 6.0 1.32 \n", "6895 5.0 0.86 \n", "3899 7.0 2.71 \n", "1475 4.0 1.65 \n", "690 8.0 3.05 \n", "5077 2.0 0.44 \n", "132 11.0 4.30 \n", "\n", " Family monthly poverty level category \\\n", "8 3.0 \n", "7475 1.0 \n", "2449 1.0 \n", "2912 2.0 \n", "6895 1.0 \n", "3899 3.0 \n", "1475 2.0 \n", "690 3.0 \n", "5077 1.0 \n", "132 3.0 \n", "\n", " Family has savings more than $5000 \\\n", "8 NaN \n", "7475 2.0 \n", "2449 2.0 \n", "2912 1.0 \n", "6895 2.0 \n", "3899 NaN \n", "1475 2.0 \n", "690 NaN \n", "5077 2.0 \n", "132 NaN \n", "\n", " Total savings/cash assets for the family Weights of Prototypes \n", "8 NaN 0.18 \n", "7475 1.0 0.10 \n", "2449 1.0 0.15 \n", "2912 NaN 0.07 \n", "6895 1.0 0.09 \n", "3899 NaN 0.12 \n", "1475 3.0 0.06 \n", "690 NaN 0.09 \n", "5077 1.0 0.07 \n", "132 NaN 0.07 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Display the prototypes along with their computed weights\n", "inc_prototypes = df_inc.iloc[S, :].copy()\n", "# Compute normalized importance weights for prototypes\n", "inc_prototypes[\"Weights of Prototypes\"] = np.around(W/np.sum(W), 2) \n", "inc_prototypes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explanation:\n", "The 10 people shown above (i.e. 5 prototypes) are representative of the income questionnaire according to Protodash. Firstly, in the distribution plot for family finance related questions we saw that there roughly were 5 times as many people not having savings in excess of $5000 compared with others. Our prototypes also have a similar spread which is reassuring. Also for monthly family income we get a more even spread over the more commonly occuring categories. This is kind of a spot check to see if our prototypes actually match the distribution of values in the dataset.\n", "\n", "Looking at the other questions in the questionnaire and the corresponding answers given by the prototypical people above the social scientist realizes that most people are employeed (3rd question) and work for an organization earning through salary/wages (1st two questions). Most of them are also young (5th question) and fit to work (4th question). However, they don't seem to have much savings (last question). These insights that the social scientist has acquired from studying the prototypes could be conveyed also to the appropriate government authorities that affect future public policy decisions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Summarize Gaussian (simulated) data using prototypes" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(300, 100) (4000, 100)\n" ] } ], "source": [ "# generate normalized gaussian data X, Y with 100 features and 300 & 4000 observations respectively\n", "(X, Y) = get_Gaussian_Data(100, 300, 4000)\n", "print(X.shape, Y.shape)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "(W, S, setValues) = explainer.explain(X, Y, m=5, kernelType='Gaussian', sigma=2)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[3940 2539 2168 2189 1170] [0.20611992 0.24524021 0.19131892 0.17175183 0.16123833]\n" ] } ], "source": [ "print(S, W)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "#Print prototypes\n", "#print(Y[S, :]) " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.15" } }, "nbformat": 4, "nbformat_minor": 2 }