{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Selection for Uplift Modeling\n", " \n", " \n", "This notebook includes two sections: \n", "- **Feature selection**: demonstrate how to use Filter methods to select the most important numeric features\n", "- **Performance evaluation**: evaluate the AUUC performance with top features dataset\n", " \n", "*(Paper reference: [Zhao, Zhenyu, et al. \"Feature Selection Methods for Uplift Modeling.\" arXiv preprint arXiv:2005.03447 (2020).](https://arxiv.org/abs/2005.03447))*" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n" ] } ], "source": [ "from causalml.dataset import make_uplift_classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import FilterSelect class for Filter methods" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from causalml.feature_selection.filters import FilterSelect" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from causalml.inference.tree import UpliftRandomForestClassifier\n", "from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor\n", "from causalml.metrics import plot_gain, auuc_score" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestRegressor" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import logging\n", "\n", "logger = logging.getLogger('causalml')\n", "logging.basicConfig(level=logging.INFO)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate dataset\n", "\n", "Generate synthetic data using the built-in function." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# define parameters for simulation\n", "\n", "y_name = 'conversion'\n", "treatment_group_keys = ['control', 'treatment1']\n", "n = 100000\n", "n_classification_features = 50\n", "n_classification_informative = 10\n", "n_classification_repeated = 0\n", "n_uplift_increase_dict = {'treatment1': 8}\n", "n_uplift_decrease_dict = {'treatment1': 4}\n", "delta_uplift_increase_dict = {'treatment1': 0.1}\n", "delta_uplift_decrease_dict = {'treatment1': -0.1}\n", "\n", "random_seed = 20200808" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "df, X_names = make_uplift_classification(\n", " treatment_name=treatment_group_keys,\n", " y_name=y_name,\n", " n_samples=n,\n", " n_classification_features=n_classification_features,\n", " n_classification_informative=n_classification_informative,\n", " n_classification_repeated=n_classification_repeated,\n", " n_uplift_increase_dict=n_uplift_increase_dict,\n", " n_uplift_decrease_dict=n_uplift_decrease_dict,\n", " delta_uplift_increase_dict = delta_uplift_increase_dict, \n", " delta_uplift_decrease_dict = delta_uplift_decrease_dict,\n", " random_seed=random_seed\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | treatment_group_key | \n", "x1_informative | \n", "x2_informative | \n", "x3_informative | \n", "x4_informative | \n", "x5_informative | \n", "x6_informative | \n", "x7_informative | \n", "x8_informative | \n", "x9_informative | \n", "... | \n", "x56_uplift_increase | \n", "x57_uplift_increase | \n", "x58_uplift_increase | \n", "x59_increase_mix | \n", "x60_uplift_decrease | \n", "x61_uplift_decrease | \n", "x62_uplift_decrease | \n", "x63_uplift_decrease | \n", "conversion | \n", "treatment_effect | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "control | \n", "0.653960 | \n", "-0.217603 | \n", "1.856916 | \n", "-0.075662 | \n", "0.080971 | \n", "-0.338374 | \n", "-1.011470 | \n", "0.528000 | \n", "0.115418 | \n", "... | \n", "1.533832 | \n", "-2.183001 | \n", "1.839608 | \n", "0.755302 | \n", "1.835047 | \n", "-0.458431 | \n", "-1.927525 | \n", "2.765331 | \n", "0 | \n", "0 | \n", "
| 1 | \n", "control | \n", "3.439658 | \n", "0.477855 | \n", "-0.377658 | \n", "-1.317121 | \n", "0.861815 | \n", "-0.393180 | \n", "0.503727 | \n", "2.323846 | \n", "1.229948 | \n", "... | \n", "-1.192333 | \n", "-1.581815 | \n", "2.423700 | \n", "2.396904 | \n", "0.296043 | \n", "-1.961940 | \n", "-1.444725 | \n", "1.469213 | \n", "1 | \n", "0 | \n", "
| 2 | \n", "treatment1 | \n", "0.130907 | \n", "-0.333536 | \n", "0.474847 | \n", "-0.352067 | \n", "-0.024502 | \n", "1.437105 | \n", "0.566178 | \n", "-0.232508 | \n", "0.866236 | \n", "... | \n", "-0.301982 | \n", "-0.933816 | \n", "0.475274 | \n", "1.540994 | \n", "0.698066 | \n", "0.545091 | \n", "-0.084405 | \n", "-2.337347 | \n", "1 | \n", "0 | \n", "
| 3 | \n", "treatment1 | \n", "-2.156683 | \n", "1.120198 | \n", "0.174293 | \n", "-1.741426 | \n", "0.488993 | \n", "0.638340 | \n", "-0.721928 | \n", "1.802134 | \n", "1.097178 | \n", "... | \n", "-2.129098 | \n", "-1.183581 | \n", "0.000318 | \n", "1.105735 | \n", "-0.629281 | \n", "-0.737041 | \n", "-1.525081 | \n", "1.416042 | \n", "0 | \n", "0 | \n", "
| 4 | \n", "control | \n", "-2.708572 | \n", "-0.799698 | \n", "-2.199595 | \n", "0.574077 | \n", "0.083142 | \n", "-0.389140 | \n", "1.492101 | \n", "1.725202 | \n", "1.194315 | \n", "... | \n", "1.582041 | \n", "-1.176077 | \n", "1.686322 | \n", "0.480035 | \n", "1.780710 | \n", "0.862094 | \n", "0.128872 | \n", "-2.851344 | \n", "0 | \n", "0 | \n", "
5 rows × 66 columns
\n", "| \n", " | mean | \n", "size | \n", "
|---|---|---|
| \n", " | conversion | \n", "conversion | \n", "
| treatment_group_key | \n", "\n", " | \n", " |
| control | \n", "0.499050 | \n", "100000 | \n", "
| treatment1 | \n", "0.599680 | \n", "100000 | \n", "
| All | \n", "0.549365 | \n", "200000 | \n", "