{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Causal Inference in `ktrain`\n", "\n", "## What is the causal impact of having a PhD on making over 50K/year?\n", "\n", "As of v0.27.x, ktrain supports causal inference using [meta-learners](https://arxiv.org/abs/1706.03461). We will use the well-studied [Adults Census](https://raw.githubusercontent.com/amaiya/ktrain/master/ktrain/tests/tabular_data/adults.csv) dataset from the UCI ML repository, which is census data from the early to mid 1990s. The objective is to estimate how much earning a PhD increases the probability of making over $50K in salary. (Note that this dataset is simply being used as a simple demonstration example of estimation. In a real-world scenario, you would spend more time on identifying which variables you should control for and which variables you should not control for.) Unlike conventional supervised machine learning, there is typically no ground truth for causal inference models, unless you're using a simulated dataset. So, we will simply check our estimates to see if they agree with intuition for illustration purposes in addition to inspecting robustness.\n", "\n", "Let's begin by loading the dataset and creating a binary treatment (1 for PhD and 0 for no PhD)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2021-07-20 14:17:32-- https://raw.githubusercontent.com/amaiya/ktrain/master/tests/resources/tabular_data/adults.csv\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 4573758 (4.4M) [text/plain]\n", "Saving to: ‘/tmp/adults.csv’\n", "\n", "/tmp/adults.csv 100%[===================>] 4.36M 26.3MB/s in 0.2s \n", "\n", "2021-07-20 14:17:32 (26.3 MB/s) - ‘/tmp/adults.csv’ saved [4573758/4573758]\n", "\n" ] } ], "source": [ "!wget https://raw.githubusercontent.com/amaiya/ktrain/master/tests/resources/tabular_data/adults.csv -O /tmp/adults.csv" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclasstreatment
025Private178478Bachelors13Never-marriedTech-supportOwn-childWhiteFemale0040United-States<=50K0
123State-gov617435th-6th3Never-marriedTransport-movingNot-in-familyWhiteMale0035United-States<=50K0
246Private376789HS-grad9Never-marriedOther-serviceNot-in-familyWhiteMale0015United-States<=50K0
355?200235HS-grad9Married-civ-spouse?HusbandWhiteMale0050United-States>50K0
436Private2245417th-8th4Married-civ-spouseHandlers-cleanersHusbandWhiteMale0040El-Salvador<=50K0
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num marital-status \\\n", "0 25 Private 178478 Bachelors 13 Never-married \n", "1 23 State-gov 61743 5th-6th 3 Never-married \n", "2 46 Private 376789 HS-grad 9 Never-married \n", "3 55 ? 200235 HS-grad 9 Married-civ-spouse \n", "4 36 Private 224541 7th-8th 4 Married-civ-spouse \n", "\n", " occupation relationship race sex capital-gain \\\n", "0 Tech-support Own-child White Female 0 \n", "1 Transport-moving Not-in-family White Male 0 \n", "2 Other-service Not-in-family White Male 0 \n", "3 ? Husband White Male 0 \n", "4 Handlers-cleaners Husband White Male 0 \n", "\n", " capital-loss hours-per-week native-country class treatment \n", "0 0 40 United-States <=50K 0 \n", "1 0 35 United-States <=50K 0 \n", "2 0 15 United-States <=50K 0 \n", "3 0 50 United-States >50K 0 \n", "4 0 40 El-Salvador <=50K 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('/tmp/adults.csv')\n", "df = df.rename(columns=lambda x: x.strip())\n", "df = df.applymap(lambda x: x.strip() if isinstance(x, str) else x) \n", "filter_set = 'Doctorate'\n", "df['treatment'] = df['education'].apply(lambda x: 1 if x in filter_set else 0)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's invoke the `causal_inference_model` function to create a `CausalInferenceModel` instance and invoke `fit` to estimate the individualized treatment effect for each row in this dataset. By default, a [T-Learner](https://arxiv.org/abs/1706.03461) metalearner is used with LightGBM models as base learners. These can be adjusted using the `method` and `learner` parameters. Since this example is simply being used for illustration purposes, we will ignore the `fnlwgt` column, which represents the number of people the census believes the entry represents. In practice, one might incorporate domain knowledge when choosing which variables to include and ignore. For instance, variables thought to be common effects of both the treatment and outcome might be excluded as [colliders](https://en.wikipedia.org/wiki/Collider_(statistics)). Finally, we will also exclude the education-related columns as they are already captured in the treatment. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "replaced ['<=50K', '>50K'] in column \"class\" with [0, 1]\n", "outcome column (categorical): class\n", "treatment column: treatment\n", "numerical/categorical covariates: ['age', 'workclass', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n", "preprocess time: 0.5897183418273926 sec\n", "start fitting causal inference model\n", "time to fit causal inference model: 0.9125957489013672 sec\n" ] } ], "source": [ "from ktrain.tabular.causalinference import causal_inference_model\n", "cm = causal_inference_model(df,\n", " treatment_col='treatment', \n", " outcome_col='class',\n", " ignore_cols=['fnlwgt', 'education','education-num']).fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown above, the dataset is automatically preprocessed and fitted very quickly.\n", "\n", "### Average Treatment Effect (ATE)\n", "\n", "The overall average treatment effect for all examples is 0.20. That is, having a PhD increases your probability of making over $50K by 20 percentage points." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ate': 0.20340645077516034}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cm.estimate_ate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conditional Average Treatment Effects (CATE)\n", "\n", "We also compute treatment effects after conditioning on attributes. \n", "\n", "For those with Master's degrees, we find that it is lower than the overall population as expected but still positive (which is qualitatively [consistent with studies by the Census Bureau](https://www.wes.org/advisor-blog/salary-difference-masters-phd)):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ate': 0.17672418257642838}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cm.estimate_ate(cm.df['education'] == 'Masters')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For those that dropped out of school, we find that it is higher (also as expected):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ate': 0.2586697863578173}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cm.estimate_ate(cm.df['education'].isin(['Preschool', '1st-4th', '5th-6th', '7th-8th', '9th', '10th', '12th']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Invidividualized Treatment Effects (ITE)\n", "\n", "The CATEs above illustrate how causal effects vary across different subpopulations in the dataset. In fact, `CausalInferenceModel.df` stores a DataFrame representation of your dataset that has been augmented with a column called `treatment_effect` that stores the **individualized** treatment effect for each row in your dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For instance, these individuals are predicted to benefit the most from a PhD with an increase of nearly 100 percentage points in the probability (see the **treatment_effect** column). " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclasseducationmarital-statusoccupationrelationshipracesexhours-per-weeknative-countryclasstreatmenttreatment_effect
1928340PrivateHS-gradNever-marriedAdm-clericalNot-in-familyWhiteFemale38United-States000.991928
1650035PrivateAssoc-vocDivorcedAdm-clericalNot-in-familyWhiteFemale40United-States000.991656
3059772PrivateAssoc-vocSeparatedOther-serviceUnmarriedWhiteFemale25United-States000.991625
988827PrivateHS-gradDivorcedMachine-op-inspctNot-in-familyWhiteMale40United-States000.989816
2934139PrivateHS-gradDivorcedOther-serviceUnmarriedAmer-Indian-EskimoFemale40United-States000.989737
\n", "
" ], "text/plain": [ " age workclass education marital-status occupation \\\n", "19283 40 Private HS-grad Never-married Adm-clerical \n", "16500 35 Private Assoc-voc Divorced Adm-clerical \n", "30597 72 Private Assoc-voc Separated Other-service \n", "9888 27 Private HS-grad Divorced Machine-op-inspct \n", "29341 39 Private HS-grad Divorced Other-service \n", "\n", " relationship race sex hours-per-week \\\n", "19283 Not-in-family White Female 38 \n", "16500 Not-in-family White Female 40 \n", "30597 Unmarried White Female 25 \n", "9888 Not-in-family White Male 40 \n", "29341 Unmarried Amer-Indian-Eskimo Female 40 \n", "\n", " native-country class treatment treatment_effect \n", "19283 United-States 0 0 0.991928 \n", "16500 United-States 0 0 0.991656 \n", "30597 United-States 0 0 0.991625 \n", "9888 United-States 0 0 0.989816 \n", "29341 United-States 0 0 0.989737 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "drop_cols = ['fnlwgt', 'education-num', 'capital-gain', 'capital-loss'] # omitted for readability\n", "cm.df.sort_values('treatment_effect', ascending=False).drop(drop_cols, axis=1).head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Examining how the treatment effect varies across units in the population can be useful for variety of different applications. [Uplift modeling](https://en.wikipedia.org/wiki/Uplift_modelling) is often used by companies for targeted promotions by identifying those individuals with the highest estimated treatment effects. Assessing the impact after such campaigns is yet another way to assess the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making Predictions on New Examples\n", "\n", "Finally, we can predict treatment effects on new examples, as long as they are formatted like the original DataFrame. For instance, let's make a prediction for one of the rows we already examined:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclasstreatmenttreatment_effect
1928340Private207025HS-grad9Never-marriedAdm-clericalNot-in-familyWhiteFemale6849038United-States000.991928
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num marital-status \\\n", "19283 40 Private 207025 HS-grad 9 Never-married \n", "\n", " occupation relationship race sex capital-gain capital-loss \\\n", "19283 Adm-clerical Not-in-family White Female 6849 0 \n", "\n", " hours-per-week native-country class treatment treatment_effect \n", "19283 38 United-States 0 0 0.991928 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_example = cm.df.sort_values('treatment_effect', ascending=False).iloc[[0]]\n", "df_example" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.99192821]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cm.predict(df_example)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluating Robustness\n", "\n", "As mentioned above, there is no ground truth for this problem to validate our estimates. In the cells above, we simply inspected the estimates to see if they correspond to our intuition on what should happen. Another approach to validating causal estimates is to evaluate robustness to various data manipulations (i.e., sensitivity analysis). For instance, the Placebo Treatment test replaces the treatment with a random covariate. We see below that this causes our estimate to drop to near zero, which is expected and exactly what we want. Such tests might be used to compare different models." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MethodATENew ATENew ATE LBNew ATE UBDistance from Desired (should be near 0)
0Placebo Treatment0.2034060.00164019-0.004083860.007364240.00164019
0Random Cause0.2034060.2303160.2165850.2440460.0269094
0Subset Data(sample size @0.8)0.2034060.1946870.1734650.215908-0.00871967
0Random Replace0.2034060.2145060.2012080.2278040.0110997
\n", "
" ], "text/plain": [ " Method ATE New ATE New ATE LB \\\n", "0 Placebo Treatment 0.203406 0.00164019 -0.00408386 \n", "0 Random Cause 0.203406 0.230316 0.216585 \n", "0 Subset Data(sample size @0.8) 0.203406 0.194687 0.173465 \n", "0 Random Replace 0.203406 0.214506 0.201208 \n", "\n", " New ATE UB Distance from Desired (should be near 0) \n", "0 0.00736424 0.00164019 \n", "0 0.244046 0.0269094 \n", "0 0.215908 -0.00871967 \n", "0 0.227804 0.0110997 " ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cm.evaluate_robustness()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**ktrain** uses the **CausalNLP** package for inferring causality. For more information, see the [CausalNLP documentation](https://amaiya.github.io/causalnlp)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }