{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "CausalML is a Python package that provides a suite of uplift modeling and causal inference methods using machine learning algorithms based on recent research. The package currently supports the following methods:\n", "\n", "Tree-based algorithms\n", "* Uplift tree/random forests on KL divergence, Euclidean Distance, and Chi-Square\n", "* Uplift tree/random forests on Contextual Treatment Selection\n", "\n", "Meta-learner algorithms\n", "* S-learner\n", "* T-learner\n", "* X-learner\n", "* R-learner\n", "\n", "In this notebook, we use synthetic data to demonstrate the use of the tree-based algorithms." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:10.960638Z", "start_time": "2019-12-20T17:20:10.957253Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from causalml.dataset import make_uplift_classification\n", "from causalml.inference.tree import UpliftRandomForestClassifier\n", "from causalml.metrics import plot_gain\n", "\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate synthetic dataset\n", "\n", "The CausalML package contains various functions to generate synthetic datasets for uplift modeling. Here we generate a classification dataset using the make_uplift_classification() function." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:10.998192Z", "start_time": "2019-12-20T17:20:10.962810Z" } }, "outputs": [], "source": [ "df, x_names = make_uplift_classification()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:11.024259Z", "start_time": "2019-12-20T17:20:11.000415Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
treatment_group_keyx1_informativex2_informativex3_informativex4_informativex5_informativex6_irrelevantx7_irrelevantx8_irrelevantx9_irrelevant...x11_uplift_increasex12_uplift_increasex13_increase_mixx14_uplift_increasex15_uplift_increasex16_increase_mixx17_uplift_increasex18_uplift_increasex19_increase_mixconversion
0control-0.5428881.976361-0.531359-2.354211-0.380629-2.614321-0.1288930.448689-2.275192...0.656869-1.3153040.7426541.891699-2.4283951.541875-0.817705-0.610194-0.5915810
1treatment30.2586540.5524121.434239-1.4223110.0891310.7902931.1595131.5788680.166540...1.050526-1.391878-0.6232432.443972-2.8892532.018585-1.109296-0.380362-1.6676060
2treatment11.697012-2.762600-0.662874-1.6823401.2174430.8379821.0429810.177398-0.112409...1.072329-1.1324971.0501791.573054-1.7884271.341609-0.749227-2.091521-0.4713860
3treatment2-1.4416441.8236480.789423-0.2953980.718509-0.4929930.947824-1.3078870.123340...1.398966-2.0846190.0584811.3694390.4225381.087176-0.966666-1.785592-1.2683791
4control-0.6250743.002388-0.0962881.9382353.392424-0.465860-0.919897-1.072592-1.331181...1.398327-1.4039840.7604301.917635-2.3476751.560946-0.833067-1.407884-0.7813430
\n", "

5 rows × 21 columns

\n", "
" ], "text/plain": [ " treatment_group_key x1_informative x2_informative x3_informative \\\n", "0 control -0.542888 1.976361 -0.531359 \n", "1 treatment3 0.258654 0.552412 1.434239 \n", "2 treatment1 1.697012 -2.762600 -0.662874 \n", "3 treatment2 -1.441644 1.823648 0.789423 \n", "4 control -0.625074 3.002388 -0.096288 \n", "\n", " x4_informative x5_informative x6_irrelevant x7_irrelevant \\\n", "0 -2.354211 -0.380629 -2.614321 -0.128893 \n", "1 -1.422311 0.089131 0.790293 1.159513 \n", "2 -1.682340 1.217443 0.837982 1.042981 \n", "3 -0.295398 0.718509 -0.492993 0.947824 \n", "4 1.938235 3.392424 -0.465860 -0.919897 \n", "\n", " x8_irrelevant x9_irrelevant ... x11_uplift_increase \\\n", "0 0.448689 -2.275192 ... 0.656869 \n", "1 1.578868 0.166540 ... 1.050526 \n", "2 0.177398 -0.112409 ... 1.072329 \n", "3 -1.307887 0.123340 ... 1.398966 \n", "4 -1.072592 -1.331181 ... 1.398327 \n", "\n", " x12_uplift_increase x13_increase_mix x14_uplift_increase \\\n", "0 -1.315304 0.742654 1.891699 \n", "1 -1.391878 -0.623243 2.443972 \n", "2 -1.132497 1.050179 1.573054 \n", "3 -2.084619 0.058481 1.369439 \n", "4 -1.403984 0.760430 1.917635 \n", "\n", " x15_uplift_increase x16_increase_mix x17_uplift_increase \\\n", "0 -2.428395 1.541875 -0.817705 \n", "1 -2.889253 2.018585 -1.109296 \n", "2 -1.788427 1.341609 -0.749227 \n", "3 0.422538 1.087176 -0.966666 \n", "4 -2.347675 1.560946 -0.833067 \n", "\n", " x18_uplift_increase x19_increase_mix conversion \n", "0 -0.610194 -0.591581 0 \n", "1 -0.380362 -1.667606 0 \n", "2 -2.091521 -0.471386 0 \n", "3 -1.785592 -1.268379 1 \n", "4 -1.407884 -0.781343 0 \n", "\n", "[5 rows x 21 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:11.069258Z", "start_time": "2019-12-20T17:20:11.026453Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
meansize
conversionconversion
treatment_group_key
control0.5111000
treatment10.5141000
treatment20.5591000
treatment30.6001000
All0.5464000
\n", "
" ], "text/plain": [ " mean size\n", " conversion conversion\n", "treatment_group_key \n", "control 0.511 1000\n", "treatment1 0.514 1000\n", "treatment2 0.559 1000\n", "treatment3 0.600 1000\n", "All 0.546 4000" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Look at the conversion rate and sample size in each group\n", "df.pivot_table(values='conversion',\n", " index='treatment_group_key',\n", " aggfunc=[np.mean, np.size],\n", " margins=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Run the uplift random forest classifier\n", "\n", "In this section, we first fit the uplift random forest classifier using training data. We then use the fitted model to make a prediction using testing data. The prediction returns an ndarray in which each column contains the predicted uplift if the unit was in the corresponding treatment group." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:11.076379Z", "start_time": "2019-12-20T17:20:11.071002Z" } }, "outputs": [], "source": [ "# Split data to training and testing samples for model validation (next section)\n", "df_train, df_test = train_test_split(df, test_size=0.2, random_state=111)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:11.081266Z", "start_time": "2019-12-20T17:20:11.078310Z" } }, "outputs": [], "source": [ "uplift_model = UpliftRandomForestClassifier(control_name='control')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.741946Z", "start_time": "2019-12-20T17:20:11.083854Z" } }, "outputs": [], "source": [ "uplift_model.fit(df_train[x_names].values,\n", " treatment=df_train['treatment_group_key'].values,\n", " y=df_train['conversion'].values)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.887350Z", "start_time": "2019-12-20T17:20:18.743659Z" } }, "outputs": [], "source": [ "y_pred = uplift_model.predict(df_test[x_names].values)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.892035Z", "start_time": "2019-12-20T17:20:18.888977Z" } }, "outputs": [], "source": [ "# Put the predictions to a DataFrame for a neater presentation\n", "result = pd.DataFrame(y_pred,\n", " columns=uplift_model.classes_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create the uplift curve\n", "\n", "The performance of the model can be evaluated with the help of the [uplift curve](http://proceedings.mlr.press/v67/gutierrez17a/gutierrez17a.pdfc). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a synthetic population\n", "\n", "The uplift curve is calculated on a synthetic population that consists of those that were in the control group and those who happened to be in the treatment group recommended by the model. We use the synthetic population to calculate the _actual_ treatment effect within _predicted_ treatment effect quantiles. Because the data is randomized, we have a roughly equal number of treatment and control observations in the predicted quantiles and there is no self selection to treatment groups." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.902754Z", "start_time": "2019-12-20T17:20:18.894201Z" } }, "outputs": [], "source": [ "# If all deltas are negative, assing to control; otherwise assign to the treatment\n", "# with the highest delta\n", "best_treatment = np.where((result < 0).all(axis=1),\n", " 'control',\n", " result.idxmax(axis=1))\n", "\n", "# Create indicator variables for whether a unit happened to have the\n", "# recommended treatment or was in the control group\n", "actual_is_best = np.where(df_test['treatment_group_key'] == best_treatment, 1, 0)\n", "actual_is_control = np.where(df_test['treatment_group_key'] == 'control', 1, 0)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.907817Z", "start_time": "2019-12-20T17:20:18.904664Z" } }, "outputs": [], "source": [ "synthetic = (actual_is_best == 1) | (actual_is_control == 1)\n", "synth = result[synthetic]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calculate the observed treatment effect per predicted treatment effect quantile\n", "\n", "We use the observed treatment effect to calculate the uplift curve, which answers the question: how much of the total cumulative uplift could we have captured by targeting a subset of the population sorted according to the predicted uplift, from highest to lowest?\n", "\n", "CausalML has the plot_gain() function which calculates the uplift curve given a DataFrame containing the treatment assignment, observed outcome and the predicted treatment effect." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:18.916662Z", "start_time": "2019-12-20T17:20:18.909528Z" } }, "outputs": [], "source": [ "auuc_metrics = (synth.assign(is_treated = 1 - actual_is_control[synthetic],\n", " conversion = df_test.loc[synthetic, 'conversion'].values,\n", " uplift_tree = synth.max(axis=1))\n", " .drop(columns=list(uplift_model.classes_)))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:19.216604Z", "start_time": "2019-12-20T17:20:18.918481Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_gain(auuc_metrics, outcome_col='conversion', treatment_col='is_treated')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }