{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "uplift_metrics_tutorial_advanced.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "sklift-env", "language": "python", "name": "sklift-env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } }, "widgets": { "application/vnd.jupyter.widget-state+json": { "12a2acaf31694c63a2813f34d37300dd": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_view_name": "HBoxView", "_dom_classes": [], "_model_name": "HBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_7553ee7087e245e694f7f85837e088ec", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_60d81df4c86240999f4be223efec1533", "IPY_MODEL_db7f043fbefb4dcda7d3486cffd28faf", "IPY_MODEL_f7418cb987f9437c9627dceea703ecd6" ] } }, "7553ee7087e245e694f7f85837e088ec": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "60d81df4c86240999f4be223efec1533": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_9f9467cab1f2430fa89e2fc396ced9f8", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": "100%", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_ecf054c1199c45389e7dbf94c5ac41df" } }, "db7f043fbefb4dcda7d3486cffd28faf": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_view_name": "ProgressView", "style": "IPY_MODEL_f309aeffa2824070afc241d3477c6bcd", "_dom_classes": [], "description": "", "_model_name": "FloatProgressModel", "bar_style": "", "max": 144735744, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 144735744, "_view_count": null, "_view_module_version": "1.5.0", "orientation": "horizontal", "min": 0, "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_353821f9d47f4434b71b956c3946ad01" } }, "f7418cb987f9437c9627dceea703ecd6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_view_name": "HTMLView", "style": "IPY_MODEL_60a724915a0a4c9f86197e8b7066b513", "_dom_classes": [], "description": "", "_model_name": "HTMLModel", "placeholder": "​", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": " 145M/145M [00:20<00:00, 29.6MiB/s]", "_view_count": null, "_view_module_version": "1.5.0", "description_tooltip": null, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_16efe430cd6a4b608dd690598ece62cb" } }, "9f9467cab1f2430fa89e2fc396ced9f8": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "ecf054c1199c45389e7dbf94c5ac41df": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "f309aeffa2824070afc241d3477c6bcd": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_view_name": "StyleView", "_model_name": "ProgressStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "bar_color": null, "_model_module": "@jupyter-widgets/controls" } }, "353821f9d47f4434b71b956c3946ad01": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "60a724915a0a4c9f86197e8b7066b513": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_view_name": "StyleView", "_model_name": "DescriptionStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "16efe430cd6a4b608dd690598ece62cb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } } } } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ATqE_EOoymEd" }, "source": [ "# 🎯 Uplift modeling `metrics` advanced\n", "\n", "
\n", "
\n", " \n", " \n", " \n", "
\n", " SCIKIT-UPLIFT REPO | \n", " SCIKIT-UPLIFT DOCS | \n", " USER GUIDE\n", "
\n", "
" ] }, { "cell_type": "code", "metadata": { "id": "jmg3AprtymEg" }, "source": [ "import sys\n", "\n", "# install uplift library scikit-uplift and other libraries \n", "!{sys.executable} -m pip install scikit-uplift dill catboost\n", "from IPython.display import clear_output\n", "clear_output()" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "cHFiqvnrymEh" }, "source": [ "# 📝 Load data\n", "\n", "We are going to use a `Lenta dataset` from the BigTarget Hackathon hosted in summer 2020 by Lenta and Microsoft.\n", "\n", "Lenta is a russian food retailer. \n", "\n", "### Data description\n", "\n", "✏️ Dataset can be loaded from `sklift.datasets` module using `fetch_lenta` function.\n", "\n", "Read more about dataset in the api docs. \n", "\n", "This is an uplift modeling dataset containing data about Lenta's customers grociery shopping, marketing campaigns communications as `treatment` and store visits as `target`.\n", "\n", "#### ✏️ Major columns:\n", "\n", "- `group` - treatment / control flag\n", "- `response_att` - binary target\n", "- `CardHolder` - customer id\n", "- `gender` - customer gender \n", "- `age` - customer age" ] }, { "cell_type": "code", "metadata": { "id": "5o0Hm-iqymEi", "scrolled": true, "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "12a2acaf31694c63a2813f34d37300dd", "7553ee7087e245e694f7f85837e088ec", "60d81df4c86240999f4be223efec1533", "db7f043fbefb4dcda7d3486cffd28faf", "f7418cb987f9437c9627dceea703ecd6", "9f9467cab1f2430fa89e2fc396ced9f8", "ecf054c1199c45389e7dbf94c5ac41df", "f309aeffa2824070afc241d3477c6bcd", "353821f9d47f4434b71b956c3946ad01", "60a724915a0a4c9f86197e8b7066b513", "16efe430cd6a4b608dd690598ece62cb" ] }, "outputId": "233891ac-d09c-440f-8d54-50ea3ebfc3cd" }, "source": [ "from sklift.datasets import fetch_lenta\n", "\n", "# returns sklearn Bunch object\n", "# with data, target, treatment keys\n", "# data features (pd.DataFrame), target (pd.Series), treatment (pd.Series) values \n", "dataset = fetch_lenta()" ], "execution_count": 2, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "12a2acaf31694c63a2813f34d37300dd", "version_minor": 0, "version_major": 2 }, "text/plain": [ " 0%| | 0.00/145M [00:00\n", "\n", "Dataset features shape: (687029, 193)\n", "Dataset target shape: (687029,)\n", "Dataset treatment shape: (687029,)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "F4-jlzDbymEk" }, "source": [ "# 📝 EDA" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 406 }, "id": "t7kA0MxxymEk", "outputId": "d5d5637e-ce85-4fcb-eb8c-695733919b99" }, "source": [ "dataset.data.head().append(dataset.data.tail())" ], "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agecheque_count_12m_g20cheque_count_12m_g21cheque_count_12m_g25cheque_count_12m_g32cheque_count_12m_g33cheque_count_12m_g38cheque_count_12m_g39cheque_count_12m_g41cheque_count_12m_g42cheque_count_12m_g45cheque_count_12m_g46cheque_count_12m_g48cheque_count_12m_g52cheque_count_12m_g56cheque_count_12m_g57cheque_count_12m_g58cheque_count_12m_g79cheque_count_3m_g20cheque_count_3m_g21cheque_count_3m_g25cheque_count_3m_g42cheque_count_3m_g45cheque_count_3m_g52cheque_count_3m_g56cheque_count_3m_g57cheque_count_3m_g79cheque_count_6m_g20cheque_count_6m_g21cheque_count_6m_g25cheque_count_6m_g32cheque_count_6m_g33cheque_count_6m_g38cheque_count_6m_g39cheque_count_6m_g40cheque_count_6m_g41cheque_count_6m_g42cheque_count_6m_g45cheque_count_6m_g46cheque_count_6m_g48...perdelta_days_between_visits_15_30dpromo_share_15dresponse_smsresponse_vibersale_count_12m_g32sale_count_12m_g33sale_count_12m_g49sale_count_12m_g54sale_count_12m_g57sale_count_3m_g24sale_count_3m_g33sale_count_3m_g57sale_count_6m_g24sale_count_6m_g25sale_count_6m_g32sale_count_6m_g33sale_count_6m_g44sale_count_6m_g54sale_count_6m_g57sale_sum_12m_g24sale_sum_12m_g25sale_sum_12m_g26sale_sum_12m_g27sale_sum_12m_g32sale_sum_12m_g44sale_sum_12m_g54sale_sum_3m_g24sale_sum_3m_g26sale_sum_3m_g32sale_sum_3m_g33sale_sum_6m_g24sale_sum_6m_g25sale_sum_6m_g26sale_sum_6m_g32sale_sum_6m_g33sale_sum_6m_g44sale_sum_6m_g54stdev_days_between_visits_15dstdev_discount_depth_15dstdev_discount_depth_1m
047.03.022.019.03.028.08.07.06.01.013.012.016.03.015.011.00.04.00.07.08.00.05.01.06.06.01.00.012.09.01.06.04.02.05.01.00.05.05.06.0...1.33930.58210.9230770.07142910.084.31498.016.011.0137.28228.7766.0169.65810.6807.028.77621.08.09.04469.86658.851286.327736.05418.803233.31811.732321.61182.82283.843648.233141.25356.67237.25283.843648.231195.37535.421.70780.27980.3008
157.01.00.02.01.01.01.00.01.00.01.00.01.00.00.00.00.01.00.00.02.00.01.00.00.00.01.01.00.02.01.01.01.00.03.01.00.01.00.01.0...0.00000.00001.0000000.0000001.01.0002.02.00.00.0001.0000.01.7442.0001.01.0000.02.00.0113.3962.6958.7193.3587.010.00122.980.0058.7187.01179.83113.3962.6958.7187.01179.830.00122.980.00000.00000.0000
238.07.00.015.04.09.05.09.014.07.06.010.014.05.011.00.03.02.02.00.03.02.01.01.00.00.02.06.00.09.02.05.01.07.07.08.03.02.06.06.0...0.00000.72561.0000000.2500005.021.10250.0109.00.00.0007.5940.025.29411.0843.011.15831.059.00.01564.91971.09177.933257.49975.212555.276351.290.000.000.00783.871239.19533.4683.37593.131217.431336.833709.820.0000NaN0.0803
365.06.03.025.02.010.014.011.08.01.00.02.06.07.02.00.00.00.01.00.05.00.00.01.00.00.00.02.01.011.02.03.05.05.04.02.01.00.01.03.0...0.00000.00000.9090910.0000002.012.54449.039.00.00.0002.7780.02.00034.2122.03.7782.013.00.0358.223798.18680.931425.07175.73602.813544.760.00119.9973.24346.74139.681849.91360.40175.73496.73172.581246.210.00000.00000.0000
461.00.01.02.00.02.01.00.03.02.01.01.05.05.00.00.00.01.00.01.01.00.00.02.00.00.01.00.01.02.00.02.01.00.08.02.02.01.01.04.0...0.00000.78651.0000000.1000000.01.45425.025.00.00.0000.4540.03.03612.0000.01.4548.023.00.0226.98168.05960.371560.210.00342.451039.850.0066.180.0087.94226.98168.05461.370.00237.93225.51995.271.41420.34950.3495
68702435.00.00.04.00.02.00.01.00.03.02.02.03.02.01.00.01.00.00.00.03.02.01.02.01.00.00.00.00.03.00.02.00.00.05.00.02.02.02.02.0...1.33330.40020.0000000.1666670.03.00014.02.00.019.8563.0000.019.85629.0000.03.00015.01.00.0550.09695.32111.87114.210.001173.84147.68550.09111.870.00330.96550.09669.33111.870.00330.961173.84119.992.64580.36460.3282
68702533.00.00.00.00.00.00.00.00.00.00.00.02.00.00.00.00.00.0NaNNaNNaNNaNNaNNaNNaNNaNNaN0.00.00.00.00.00.00.01.00.00.00.00.02.0...0.00000.00001.0000000.0000000.00.0001.01.00.0NaNNaNNaN0.0000.0000.00.0000.01.00.00.000.000.000.000.000.0028.01NaNNaNNaNNaN0.000.000.000.000.000.0028.010.00000.00000.0000
68702636.00.00.03.00.00.00.00.01.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0...0.00000.98471.0000000.0000000.00.0005.03.00.00.0000.0000.00.0000.0000.00.00015.00.00.00.00155.9723.9941.510.00615.7787.470.000.000.000.000.000.000.000.000.00449.010.000.0000NaNNaN
68702737.00.01.02.00.00.00.00.00.01.00.01.00.01.00.00.00.00.00.00.01.00.00.01.00.00.00.00.00.01.00.00.00.00.01.00.00.00.00.00.0...0.00000.83181.0000000.0000000.00.0001.00.00.00.0000.0000.00.0000.4760.00.0000.00.00.00.0081.9029.820.000.000.000.000.000.000.000.000.0046.720.000.000.000.000.000.0000NaNNaN
68702840.00.01.00.00.02.00.00.02.02.02.02.03.01.01.02.01.04.00.01.00.01.00.00.01.01.03.00.01.00.00.01.00.00.00.00.01.00.02.02.0...0.00000.00001.0000000.1000000.06.45225.017.03.06.6601.3441.06.6600.0000.01.34418.04.01.0531.250.000.00916.440.002407.561304.03290.010.000.00228.47290.010.000.000.00228.47752.32596.860.00000.00000.0000
\n", "

10 rows × 193 columns

\n", "
" ], "text/plain": [ " age ... stdev_discount_depth_1m\n", "0 47.0 ... 0.3008\n", "1 57.0 ... 0.0000\n", "2 38.0 ... 0.0803\n", "3 65.0 ... 0.0000\n", "4 61.0 ... 0.3495\n", "687024 35.0 ... 0.3282\n", "687025 33.0 ... 0.0000\n", "687026 36.0 ... NaN\n", "687027 37.0 ... NaN\n", "687028 40.0 ... 0.0000\n", "\n", "[10 rows x 193 columns]" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "metadata": { "id": "cNSQsJcqymEk" }, "source": [ "### 🤔 target share for `treatment / control` " ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 142 }, "id": "d0BPrhjnymEl", "outputId": "a57f8423-1dd4-42ed-f2d4-63b939077770" }, "source": [ "import pandas as pd \n", "\n", "pd.crosstab(dataset.treatment, dataset.target, normalize='index')" ], "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
response_att01
group
control0.8974210.102579
test0.8898740.110126
\n", "
" ], "text/plain": [ "response_att 0 1\n", "group \n", "control 0.897421 0.102579\n", "test 0.889874 0.110126" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "metadata": { "id": "q48zr_exymEl" }, "source": [ "# make treatment binary\n", "treat_dict = {\n", " 'test': 1,\n", " 'control': 0\n", "}\n", "\n", "dataset.treatment = dataset.treatment.map(treat_dict)" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9oGypq_JymEm", "outputId": "90909780-1bb9-4166-90a2-1f343803959e" }, "source": [ "# fill NaNs in the categorical feature `gender` \n", "# for CatBoostClassifier\n", "dataset.data['gender'] = dataset.data['gender'].fillna(value='Не определен')\n", "\n", "print(dataset.data['gender'].value_counts(dropna=False))" ], "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ "Ж 433448\n", "М 243910\n", "Не определен 9671\n", "Name: gender, dtype: int64\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "MKHT1JzsymEm" }, "source": [ "### ✂️ train test split\n", "\n", "- stratify by two columns: treatment and target. \n", "\n", "`Intuition:` In a binary classification problem definition we stratify train set by splitting target `0/1` column. In uplift modeling we have two columns instead of one. " ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u4lM58UMymEm", "outputId": "84cc2abf-b854-41a4-8bb8-bbddeb97baab" }, "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "stratify_cols = pd.concat([dataset.treatment, dataset.target], axis=1)\n", "\n", "X_train, X_val, trmnt_train, trmnt_val, y_train, y_val = train_test_split(\n", " dataset.data,\n", " dataset.treatment,\n", " dataset.target,\n", " stratify=stratify_cols,\n", " test_size=0.3,\n", " random_state=42\n", ")\n", "\n", "print(f\"Train shape: {X_train.shape}\")\n", "print(f\"Validation shape: {X_val.shape}\")" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "Train shape: (480920, 193)\n", "Validation shape: (206109, 193)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "BYzpcKwuymEn" }, "source": [ "# 👾 Class Transformation uplift model and Two Models\n", "\n", "### For example, let's take the models [ Class Transformation ](https://github.com/maks-sh/scikit-uplift/blob/c9dd56aa0277e81ef7c4be62bf2fd33432e46f36/sklift/models/models.py#L181) and [Two Models](https://github.com/maks-sh/scikit-uplift/blob/c9dd56aa0277e81ef7c4be62bf2fd33432e46f36/sklift/models/models.py#L271). Let's display their uplift scores on one graph" ] }, { "cell_type": "code", "metadata": { "id": "PBwZVdIEymEn" }, "source": [ "from catboost import CatBoostClassifier\n", "from sklearn.base import clone\n", "\n", "from sklift.models import TwoModels\n", "from sklift.models import ClassTransformation\n", "\n", "first_estimator = CatBoostClassifier(verbose=100,\n", " task_type=\"GPU\",\n", " devices='0:1',\n", " cat_features=['gender'],\n", " random_state=42,\n", " thread_count=1)\n", "second_estimator = clone(first_estimator)\n", "\n", "transform_model = ClassTransformation(estimator=first_estimator)\n", "two_model = TwoModels(estimator_trmnt=first_estimator, estimator_ctrl=second_estimator)" ], "execution_count": 9, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6IcMWundymEn", "outputId": "5f59657c-2a7d-47aa-cb21-89a900e47e23" }, "source": [ "transform_model = transform_model.fit(\n", " X=X_train, \n", " y=y_train, \n", " treatment=trmnt_train\n", ")\n", "\n", "two_model = two_model.fit(\n", " X=X_train, \n", " y=y_train, \n", " treatment=trmnt_train\n", ")" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "Learning rate set to 0.024003\n", "0:\tlearn: 0.6893849\ttotal: 59ms\tremaining: 58.9s\n", "100:\tlearn: 0.6100331\ttotal: 5.39s\tremaining: 48s\n", "200:\tlearn: 0.6019326\ttotal: 12s\tremaining: 47.8s\n", "300:\tlearn: 0.6000429\ttotal: 18.7s\tremaining: 43.4s\n", "400:\tlearn: 0.5992161\ttotal: 25.4s\tremaining: 37.9s\n", "500:\tlearn: 0.5986674\ttotal: 32s\tremaining: 31.8s\n", "600:\tlearn: 0.5982996\ttotal: 38.5s\tremaining: 25.5s\n", "700:\tlearn: 0.5980941\ttotal: 44.8s\tremaining: 19.1s\n", "800:\tlearn: 0.5979237\ttotal: 51.2s\tremaining: 12.7s\n", "900:\tlearn: 0.5976503\ttotal: 57.5s\tremaining: 6.32s\n", "999:\tlearn: 0.5975015\ttotal: 1m 3s\tremaining: 0us\n", "Learning rate set to 0.02591\n", "0:\tlearn: 0.6711650\ttotal: 23.1ms\tremaining: 23.1s\n", "100:\tlearn: 0.2887976\ttotal: 3.03s\tremaining: 27s\n", "200:\tlearn: 0.2763838\ttotal: 6.53s\tremaining: 26s\n", "300:\tlearn: 0.2729584\ttotal: 10.2s\tremaining: 23.7s\n", "400:\tlearn: 0.2713649\ttotal: 13.9s\tremaining: 20.8s\n", "500:\tlearn: 0.2703728\ttotal: 17.6s\tremaining: 17.6s\n", "600:\tlearn: 0.2696703\ttotal: 21.3s\tremaining: 14.1s\n", "700:\tlearn: 0.2691328\ttotal: 24.9s\tremaining: 10.6s\n", "800:\tlearn: 0.2686616\ttotal: 28.6s\tremaining: 7.11s\n", "900:\tlearn: 0.2682632\ttotal: 32.3s\tremaining: 3.55s\n", "999:\tlearn: 0.2678762\ttotal: 36s\tremaining: 0us\n", "Learning rate set to 0.024384\n", "0:\tlearn: 0.6735712\ttotal: 44.9ms\tremaining: 44.9s\n", "100:\tlearn: 0.3063022\ttotal: 4.82s\tremaining: 42.9s\n", "200:\tlearn: 0.2925770\ttotal: 10.2s\tremaining: 40.4s\n", "300:\tlearn: 0.2895685\ttotal: 15.6s\tremaining: 36.3s\n", "400:\tlearn: 0.2880540\ttotal: 21.3s\tremaining: 31.9s\n", "500:\tlearn: 0.2872389\ttotal: 26.9s\tremaining: 26.8s\n", "600:\tlearn: 0.2866951\ttotal: 32.6s\tremaining: 21.6s\n", "700:\tlearn: 0.2863474\ttotal: 38.1s\tremaining: 16.3s\n", "800:\tlearn: 0.2860138\ttotal: 43.6s\tremaining: 10.8s\n", "900:\tlearn: 0.2857359\ttotal: 49.2s\tremaining: 5.41s\n", "999:\tlearn: 0.2854954\ttotal: 54.8s\tremaining: 0us\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "POrn2kgwymEo" }, "source": [ "### Uplift prediction" ] }, { "cell_type": "code", "metadata": { "id": "Xx_hHajjymEo" }, "source": [ "uplift_transform_model_val = transform_model.predict(X_val)\n", "uplift_transform_model_train = transform_model.predict(X_train)\n", "\n", "uplift_two_model = two_model.predict(X_val)" ], "execution_count": 11, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-5PofV6aymEp" }, "source": [ "# 🚀🚀🚀 Uplift metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "SmvFxIALymEp" }, "source": [ "### 🚀 `uplift@k`\n", "\n", "- uplift at first k%\n", "- usually falls between [0; 1] depending on k, model quality and data\n", "\n", "\n", "### `uplift@k` = `target mean at k% in the treatment group` - `target mean at k% in the control group`\n", "\n", "___\n", "\n", "How to count `uplift@k`:\n", "\n", "1. sort by predicted uplift\n", "2. select first k%\n", "3. count target mean in the treatment group\n", "4. count target mean in the control group\n", "5. substract the mean in the control group from the mean in the treatment group\n", "\n", "---\n", "\n", "Code parameter options:\n", "\n", "- `strategy='overall'` - sort by uplift treatment and control together\n", "- `strategy='by_group'` - sort by uplift treatment and control separately" ] }, { "cell_type": "markdown", "metadata": { "id": "KWBOEv0Z6daH" }, "source": [ "## `🚀uplift@k with a small step ot the k parameter`\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "ZWjC06aQymEp" }, "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from sklift.metrics import uplift_at_k\n", "\n", "values_uplift_k_transform = []\n", "values_uplift_k_two = []\n", "values_k = []\n", "for k in np.arange(0.01,1,0.01):\n", " values_uplift_k_transform.append(uplift_at_k(y_val, uplift_transform_model_val, trmnt_val, strategy='overall', k=k))\n", " values_uplift_k_two.append(uplift_at_k(y_val, uplift_two_model, trmnt_val, strategy='overall', k=k))\n", " values_k.append(k)" ], "execution_count": 12, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "oshHc_VWlKmw" }, "source": [ "### `For ClassTransformation model`" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 295 }, "id": "O6tfXwaLlHJV", "outputId": "03c001f3-b386-4838-a6dd-e7a215b235f2" }, "source": [ "plt.plot(values_k, values_uplift_k_transform)\n", "plt.title('Dependence of uplift@k on k')\n", "plt.xlabel('The value of k')\n", "plt.ylabel('The value of uplift@k')\n", "plt.show()" ], "execution_count": 13, "outputs": [ { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "7eRHptiLlXpb" }, "source": [ "### `For TwoModels`" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 295 }, "id": "lQGkD3dTlEQn", "outputId": "ff2debcb-030e-4158-ed64-24216a341985" }, "source": [ "plt.plot(values_k, values_uplift_k_two)\n", "plt.title('Dependence of uplift@k on k')\n", "plt.xlabel('The value of k')\n", "plt.ylabel('The value of uplift@k')\n", "plt.show()" ], "execution_count": 14, "outputs": [ { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "KD5ZUzlEJN1b" }, "source": [ "# 🚀 `ASD metric`\n", "### `The average squared deviation (ASD) is a model stability metric that shows how much the model overfits the training data. Larger values of ASD mean greater overfit.`\n", "\n", "## Code parameter options:\n", "\n", "- `strategy='overall'` - The first step is taking the first k observations of all test data ordered by uplift prediction (overall both groups - control and treatment) and conversions in treatment and control groups calculated only on them. Then the difference between these conversions is calculated.\n", "- `strategy='by_group'` - Separately calculates conversions in top k observations in each group (control and treatment) sorted by uplift predictions. Then the difference between these conversions is calculated\n", "- `bins=10` - Determines the number of bins (and the relative percentile) in the data." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FnqffihzymEv", "outputId": "30a329cc-387b-4500-8c3e-f8f6ab043f37" }, "source": [ "from sklift.metrics import average_squared_deviation\n", "\n", "asd_overall = average_squared_deviation(y_train, uplift_transform_model_train, trmnt_train, y_val,\n", " uplift_transform_model_val, trmnt_val, strategy='overall')\n", "asd_by_group = average_squared_deviation(y_train, uplift_transform_model_train, trmnt_train, y_val, \n", " uplift_transform_model_val, trmnt_val, strategy='by_group')\n", "\n", "print(f\"average squared deviation by overall strategy for the ClassTransformation model: {asd_overall:.6f}\")\n", "print(f\"average squared deviation by group strategy for the ClassTransformation model: {asd_by_group:.6f}\")" ], "execution_count": 15, "outputs": [ { "output_type": "stream", "text": [ "average squared deviation by overall strategy for the ClassTransformation model: 0.000007\n", "average squared deviation by group strategy for the ClassTransformation model: 0.000011\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "VSg7zXHGG_76" }, "source": [ "# `↗️Display 2 different model uplift scores on one qini plot`\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BrRjY_zlYThJ" }, "source": [ "### `Only qiwi curves`" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 296 }, "id": "-WFTIgynAI28", "outputId": "b37fb698-7522-44ac-c786-21ea6f8e21a7" }, "source": [ "from sklift.viz import plot_qini_curve\n", "\n", "fig, ax_roc = plt.subplots(1, 1)\n", "plot_qini_curve(y_val, uplift_transform_model_val, trmnt_val, name='Transform model', random=False, perfect=False, ax=ax_roc)\n", "plot_qini_curve(y_val, uplift_two_model, trmnt_val, name='Two models', random=False, perfect=False, ax=ax_roc)" ], "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 16 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "QI7Rn8viY2eF" }, "source": [ "### `Qini curves with a random curve and with a perfect curve`" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 296 }, "id": "plgwP2Srq7ed", "outputId": "3cc745f4-b95d-4238-89b7-69127b234339" }, "source": [ "fig, ax_roc = plt.subplots(1, 1)\n", "plot_qini_curve(y_val, uplift_transform_model_val, trmnt_val, name='Transform model', random=True, perfect=True, ax=ax_roc)\n", "plot_qini_curve(y_val, uplift_two_model, trmnt_val, name='Two models', random=True, perfect=True, ax=ax_roc)" ], "execution_count": 17, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 17 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "7vCf0C89Oito" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }