{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "CausalML is a Python package that provides a suite of uplift modeling and causal inference methods using machine learning algorithms based on recent research. The package currently supports the following methods:\n", "\n", "Tree-based algorithms\n", "* Uplift tree/random forests on KL divergence, Euclidean Distance, and Chi-Square\n", "* Uplift tree/random forests on Contextual Treatment Selection\n", "\n", "Meta-learner algorithms\n", "* S-learner\n", "* T-learner\n", "* X-learner\n", "* R-learner\n", "\n", "In this notebook, we use synthetic data to demonstrate the use of the tree-based algorithms." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:10.960638Z", "start_time": "2019-12-20T17:20:10.957253Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from causalml.dataset import make_uplift_classification\n", "from causalml.inference.tree import UpliftRandomForestClassifier\n", "from causalml.metrics import plot_gain\n", "\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate synthetic dataset\n", "\n", "The CausalML package contains various functions to generate synthetic datasets for uplift modeling. Here we generate a classification dataset using the make_uplift_classification() function." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:10.998192Z", "start_time": "2019-12-20T17:20:10.962810Z" } }, "outputs": [], "source": [ "df, x_names = make_uplift_classification()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2019-12-20T17:20:11.024259Z", "start_time": "2019-12-20T17:20:11.000415Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | treatment_group_key | \n", "x1_informative | \n", "x2_informative | \n", "x3_informative | \n", "x4_informative | \n", "x5_informative | \n", "x6_irrelevant | \n", "x7_irrelevant | \n", "x8_irrelevant | \n", "x9_irrelevant | \n", "... | \n", "x11_uplift_increase | \n", "x12_uplift_increase | \n", "x13_increase_mix | \n", "x14_uplift_increase | \n", "x15_uplift_increase | \n", "x16_increase_mix | \n", "x17_uplift_increase | \n", "x18_uplift_increase | \n", "x19_increase_mix | \n", "conversion | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "control | \n", "-0.542888 | \n", "1.976361 | \n", "-0.531359 | \n", "-2.354211 | \n", "-0.380629 | \n", "-2.614321 | \n", "-0.128893 | \n", "0.448689 | \n", "-2.275192 | \n", "... | \n", "0.656869 | \n", "-1.315304 | \n", "0.742654 | \n", "1.891699 | \n", "-2.428395 | \n", "1.541875 | \n", "-0.817705 | \n", "-0.610194 | \n", "-0.591581 | \n", "0 | \n", "
1 | \n", "treatment3 | \n", "0.258654 | \n", "0.552412 | \n", "1.434239 | \n", "-1.422311 | \n", "0.089131 | \n", "0.790293 | \n", "1.159513 | \n", "1.578868 | \n", "0.166540 | \n", "... | \n", "1.050526 | \n", "-1.391878 | \n", "-0.623243 | \n", "2.443972 | \n", "-2.889253 | \n", "2.018585 | \n", "-1.109296 | \n", "-0.380362 | \n", "-1.667606 | \n", "0 | \n", "
2 | \n", "treatment1 | \n", "1.697012 | \n", "-2.762600 | \n", "-0.662874 | \n", "-1.682340 | \n", "1.217443 | \n", "0.837982 | \n", "1.042981 | \n", "0.177398 | \n", "-0.112409 | \n", "... | \n", "1.072329 | \n", "-1.132497 | \n", "1.050179 | \n", "1.573054 | \n", "-1.788427 | \n", "1.341609 | \n", "-0.749227 | \n", "-2.091521 | \n", "-0.471386 | \n", "0 | \n", "
3 | \n", "treatment2 | \n", "-1.441644 | \n", "1.823648 | \n", "0.789423 | \n", "-0.295398 | \n", "0.718509 | \n", "-0.492993 | \n", "0.947824 | \n", "-1.307887 | \n", "0.123340 | \n", "... | \n", "1.398966 | \n", "-2.084619 | \n", "0.058481 | \n", "1.369439 | \n", "0.422538 | \n", "1.087176 | \n", "-0.966666 | \n", "-1.785592 | \n", "-1.268379 | \n", "1 | \n", "
4 | \n", "control | \n", "-0.625074 | \n", "3.002388 | \n", "-0.096288 | \n", "1.938235 | \n", "3.392424 | \n", "-0.465860 | \n", "-0.919897 | \n", "-1.072592 | \n", "-1.331181 | \n", "... | \n", "1.398327 | \n", "-1.403984 | \n", "0.760430 | \n", "1.917635 | \n", "-2.347675 | \n", "1.560946 | \n", "-0.833067 | \n", "-1.407884 | \n", "-0.781343 | \n", "0 | \n", "
5 rows × 21 columns
\n", "\n", " | mean | \n", "size | \n", "
---|---|---|
\n", " | conversion | \n", "conversion | \n", "
treatment_group_key | \n", "\n", " | \n", " |
control | \n", "0.511 | \n", "1000 | \n", "
treatment1 | \n", "0.514 | \n", "1000 | \n", "
treatment2 | \n", "0.559 | \n", "1000 | \n", "
treatment3 | \n", "0.600 | \n", "1000 | \n", "
All | \n", "0.546 | \n", "4000 | \n", "