{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Ok1vxsLqqw3w" }, "source": [ "# Estimating Treatment Effect Using Machine Learning" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "B16h5bb8eFmw" }, "source": [ "Welcome to the first assignment of **AI for Medical Treatment**!\n", "\n", "You will be using different methods to evaluate the results of a [randomized control trial](https://en.wikipedia.org/wiki/Randomized_controlled_trial) (RCT).\n", "\n", "**You will learn:**\n", "- How to analyze data from a randomized control trial using both:\n", " - traditional statistical methods\n", " - and the more recent machine learning techniques\n", "- Interpreting Multivariate Models\n", " - Quantifying treatment effect\n", " - Calculating baseline risk\n", " - Calculating predicted risk reduction\n", "- Evaluating Treatment Effect Models\n", " - Comparing predicted and empirical risk reductions\n", " - Computing C-statistic-for-benefit\n", "- Interpreting ML models for Treatment Effect Estimation\n", " - Implement T-learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### This assignment covers the folowing topics:\n", "\n", "- [1. Dataset](#1)\n", " - [1.1 Why RCT?](#1-1)\n", " - [1.2 Data Processing](#1-2)\n", " - [Exercise 1](#ex-01)\n", " - [Exercise 2](#ex-02)\n", "- [2. Modeling Treatment Effect](#2)\n", " - [2.1 Constant Treatment Effect](#2-1)\n", " - [Exercise 3](#ex-03)\n", " - [2.2 Absolute Risk Reduction](#2-2)\n", " - [Exercise 4](#ex-04)\n", " - [2.3 Model Limitations](#2-3)\n", " - [Exercise 5](#ex-05)\n", " - [Exercise 6](#ex-06)\n", "- [3. Evaluation Metric](#3)\n", " - [3.1 C-statistic-for-benefit](#3-1)\n", " - [Exercise 7](#ex-07)\n", " - [Exercise 8](#ex-08)\n", "- [4. Machine Learning Approaches](#4)\n", " - [4.1 T-Learner](#4-1)\n", " - [Exercise 9](#ex-09)\n", " - [Exercise 10](#ex-10)\n", " - [Exercise 11](#ex-11)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Tklnk8tneq2U" }, "source": [ "## Packages\n", "\n", "We'll first import all the packages that we need for this assignment. \n", "\n", "\n", "- `pandas` is what we'll use to manipulate our data\n", "- `numpy` is a library for mathematical and scientific operations\n", "- `matplotlib` is a plotting library\n", "- `sklearn` contains a lot of efficient tools for machine learning and statistical modeling\n", "- `random` allows us to generate random numbers in python\n", "- `lifelines` is an open-source library that implements c-statistic\n", "- `itertools` will help us with hyperparameters searching\n", "\n", "## Import Packages\n", "\n", "Run the next cell to import all the necessary packages, dependencies and custom util functions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-04-04T15:29:41.602385Z", "start_time": "2020-04-04T15:29:39.274097Z" }, "colab": {}, "colab_type": "code", "id": "Z5zOXfAIH-41" }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import sklearn\n", "import random\n", "import lifelines\n", "import itertools\n", "\n", "plt.rcParams['figure.figsize'] = [10, 7]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pVEHJZ79mvQx" }, "source": [ "\n", "## 1 Dataset\n", "\n", "### 1.1 Why RCT?\n", "\n", "In this assignment, we'll be examining data from an RCT, measuring the effect of a particular drug combination on colon cancer. Specifically, we'll be looking the effect of [Levamisole](https://en.wikipedia.org/wiki/Levamisole) and [Fluorouracil](https://en.wikipedia.org/wiki/Fluorouracil) on patients who have had surgery to remove their colon cancer. After surgery, the curability of the patient depends on the remaining residual cancer. In this study, it was found that this particular drug combination had a clear beneficial effect, when compared with [Chemotherapy](https://en.wikipedia.org/wiki/Chemotherapy). \n", "\n", "### 1.2 Data Processing\n", "In this first section, we will load in the dataset and calculate basic statistics. Run the next cell to load the dataset. We also do some preprocessing to convert categorical features to one-hot representations." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-04-04T15:29:41.612018Z", "start_time": "2020-04-04T15:29:41.602385Z" }, "colab": {}, "colab_type": "code", "id": "QOV_BJGyLtjR" }, "outputs": [], "source": [ "data = pd.read_csv(\"levamisole_data.csv\", index_col=0)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RlqE8036sj3y" }, "source": [ "Let's look at our data to familiarize ourselves with the various fields. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-04-04T15:29:45.698204Z", "start_time": "2020-04-04T15:29:45.677460Z" }, "colab": { "base_uri": "https://localhost:8080/", "height": 221 }, "colab_type": "code", "id": "RPS1stb7si4N", "outputId": "a64b50c6-5df2-467a-abee-0d73f82d7825" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data Dimensions: (607, 14)\n" ] }, { "data": { "text/html": [ "
\n", " | sex | \n", "age | \n", "obstruct | \n", "perfor | \n", "adhere | \n", "nodes | \n", "node4 | \n", "outcome | \n", "TRTMT | \n", "differ_2.0 | \n", "differ_3.0 | \n", "extent_2 | \n", "extent_3 | \n", "extent_4 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | \n", "1 | \n", "43 | \n", "0 | \n", "0 | \n", "0 | \n", "5.0 | \n", "1 | \n", "1 | \n", "True | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
2 | \n", "1 | \n", "63 | \n", "0 | \n", "0 | \n", "0 | \n", "1.0 | \n", "0 | \n", "0 | \n", "True | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
3 | \n", "0 | \n", "71 | \n", "0 | \n", "0 | \n", "1 | \n", "7.0 | \n", "1 | \n", "1 | \n", "False | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "
4 | \n", "0 | \n", "66 | \n", "1 | \n", "0 | \n", "0 | \n", "6.0 | \n", "1 | \n", "1 | \n", "True | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
5 | \n", "1 | \n", "69 | \n", "0 | \n", "0 | \n", "0 | \n", "22.0 | \n", "1 | \n", "1 | \n", "False | \n", "1 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
\n", "
\n", "