{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously I've shown how to create a [linear model and neural net from scratch](https://www.kaggle.com/code/jhoward/linear-model-and-neural-net-from-scratch), and used it to create a solid submission to Kaggle's [Titanic](https://www.kaggle.com/competitions/titanic/) competition. However, for *tabular* data (i.e data that looks like spreadsheet or database tables, such as the data for the Titanic competition) it's more common to see good results by using ensembles of decision trees, such as Random Forests and Gradient Boosting Machines.\n", "\n", "In this notebook, we're going to learn all about Random Forests, by building one from scratch, and using it to submit to the Titanic competition! That might sound like a pretty big stretch, but I think you'll be surprised to discover how straightforward it actually is.\n", "\n", "We'll start by importing the basic set of libraries we normally need for data science work, and setting numpy to use our display space more efficiently:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2022-05-23T04:37:24.640765Z", "iopub.status.busy": "2022-05-23T04:37:24.640339Z", "iopub.status.idle": "2022-05-23T04:37:25.174055Z", "shell.execute_reply": "2022-05-23T04:37:25.172992Z", "shell.execute_reply.started": "2022-05-23T04:37:24.640663Z" } }, "outputs": [], "source": [ "from fastai.imports import *\n", "np.set_printoptions(linewidth=130)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll create `DataFrame`s from the CSV files just like we did in the \"*linear model and neural net from scratch*\" notebook, and do much the same preprocessing (so go back and check that out if you're not already familiar with the dataset):" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:26.152841Z", "iopub.status.busy": "2022-05-23T04:37:26.152518Z", "iopub.status.idle": "2022-05-23T04:37:26.205303Z", "shell.execute_reply": "2022-05-23T04:37:26.204112Z", "shell.execute_reply.started": "2022-05-23T04:37:26.152806Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "titanic.zip: Skipping, found more recently modified local copy (use --force to force download)\n" ] } ], "source": [ "import os\n", "iskaggle = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')\n", "\n", "if iskaggle: path = Path('../input/titanic')\n", "else:\n", " import zipfile,kaggle\n", " path = Path('titanic')\n", " kaggle.api.competition_download_cli(str(path))\n", " zipfile.ZipFile(f'{path}.zip').extractall(path)\n", "\n", "df = pd.read_csv(path/'train.csv')\n", "tst_df = pd.read_csv(path/'test.csv')\n", "modes = df.mode().iloc[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One difference with Random Forests however is that we don't generally have to create *dummy variables* like we did for non-numeric columns in the linear models and neural network. Instead, we can just convert those fields to *categorical variables*, which internally in Pandas makes a list of all the unique values in the column, and replaces each value with a number. The number is just an index for looking up the value in the list of all unique values." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:26.941371Z", "iopub.status.busy": "2022-05-23T04:37:26.941013Z", "iopub.status.idle": "2022-05-23T04:37:26.9692Z", "shell.execute_reply": "2022-05-23T04:37:26.968016Z", "shell.execute_reply.started": "2022-05-23T04:37:26.941337Z" } }, "outputs": [], "source": [ "def proc_data(df):\n", " df['Fare'] = df.Fare.fillna(0)\n", " df.fillna(modes, inplace=True)\n", " df['LogFare'] = np.log1p(df['Fare'])\n", " df['Embarked'] = pd.Categorical(df.Embarked)\n", " df['Sex'] = pd.Categorical(df.Sex)\n", "\n", "proc_data(df)\n", "proc_data(tst_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll make a list of the continuous, categorical, and dependent variables. Note that we no longer consider `Pclass` a categorical variable. That's because it's *ordered* (i.e 1st, 2nd, and 3rd class have an order), and decision trees, as we'll see, only care about order, not about absolute value." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:27.728358Z", "iopub.status.busy": "2022-05-23T04:37:27.728077Z", "iopub.status.idle": "2022-05-23T04:37:27.733013Z", "shell.execute_reply": "2022-05-23T04:37:27.732245Z", "shell.execute_reply.started": "2022-05-23T04:37:27.728328Z" } }, "outputs": [], "source": [ "cats=[\"Sex\",\"Embarked\"]\n", "conts=['Age', 'SibSp', 'Parch', 'LogFare',\"Pclass\"]\n", "dep=\"Survived\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Even although we've made the `cats` columns categorical, they are still shown by Pandas as their original values:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:28.64687Z", "iopub.status.busy": "2022-05-23T04:37:28.645995Z", "iopub.status.idle": "2022-05-23T04:37:28.657163Z", "shell.execute_reply": "2022-05-23T04:37:28.656275Z", "shell.execute_reply.started": "2022-05-23T04:37:28.646831Z" } }, "outputs": [ { "data": { "text/plain": [ "0 male\n", "1 female\n", "2 female\n", "3 female\n", "4 male\n", "Name: Sex, dtype: category\n", "Categories (2, object): ['female', 'male']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.Sex.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However behind the scenes they're now stored as integers, with indices that are looked up in the `Categories` list shown in the output above. We can view the stored values by looking in the `cat.codes` attribute:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:29.606313Z", "iopub.status.busy": "2022-05-23T04:37:29.605974Z", "iopub.status.idle": "2022-05-23T04:37:29.613586Z", "shell.execute_reply": "2022-05-23T04:37:29.612892Z", "shell.execute_reply.started": "2022-05-23T04:37:29.606276Z" } }, "outputs": [ { "data": { "text/plain": [ "0 1\n", "1 0\n", "2 0\n", "3 0\n", "4 1\n", "dtype: int8" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.Sex.cat.codes.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Binary splits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we create a Random Forest or Gradient Boosting Machine, we'll first need to learn how to create a *decision tree*, from which both of these models are built.\n", "\n", "And to create a decision tree, we'll first need to create a *binary split*, since that's what a decision tree is built from.\n", "\n", "A binary split is where all rows are placed into one of two groups, based on whether they're above or below some threshold of some column. For example, we could split the rows of our dataset into males and females, by using the threshold `0.5` and the column `Sex` (since the values in the column are `0` for `female` and `1` for `male`). We can use a plot to see how that would split up our data -- we'll use the [Seaborn](https://seaborn.pydata.org/) library, which is a layer on top of [matplotlib](https://matplotlib.org/) that makes some useful charts easier to create, and more aesthetically pleasing by default:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:31.229486Z", "iopub.status.busy": "2022-05-23T04:37:31.228601Z", "iopub.status.idle": "2022-05-23T04:37:32.45072Z", "shell.execute_reply": "2022-05-23T04:37:32.449794Z", "shell.execute_reply.started": "2022-05-23T04:37:31.229441Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAApkAAAFNCAYAAABL6HT2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmj0lEQVR4nO3de7xdZX3n8c+XcLHcREq4lJBCNUrB8RqpttZiUUFHjW1F46VCy0yGKWjbqSI4rbXVVMSOU6rQNOMgeCOmaiU6qUixeEHQhBaBcKkZQAjXACJgHWjCb/5YK7A52UlOkrXOPif5vF+v8zp7PevZa/+Onjx8z3rWWk+qCkmSJKlLO4y6AEmSJG17DJmSJEnqnCFTkiRJnTNkSpIkqXOGTEmSJHXOkClJkqTOGTI15SRZkORPOjjOuUk+0EVNkjTRkqxIcuSo65A2xJCpTiR5cZLvJPlxkvuSXJrkBX18VlWdWFXv7+PYWyPJzUleNuo6JG0bho0pSY5P8m2Aqjq8qi7ZxDEOTlJJduyxVGkof+m01ZLsCXwF+K/AYmBn4FeBh7fgWAFSVY92WuRWSrJjVa0ZdR2SNJk4NmpjPJOpLjwdoKrOr6q1VfXTqvpaVV0FkOR9ST69rvPYv6yTXJJkfpJLgX8D3pNk+eAHJPnDJEva149Ncye5LsmrB/rtmOSeJM9rt/8uyZ3tGdZvJjl8PD9Qe7bg0iT/M8l9wPuSPDXJ15Pc237GZ5Ls1fb/FDAT+HKSh5Kc0ra/sD3De3+S7zu1Jakrg2c6kxyRZHmSB5LcleQjbbdvtt/vb8emFyXZIckfJ/lhkruTfDLJkweO+7Z2371J/mTM57wvyeeTfDrJA8Dx7Wdf1o5zdyT5WJKdB45XSX4vyQ+SPJjk/e14ellb7+LB/tp2GDLVhX8F1iY5L8krkzxlC47x28A8YA/go8Azkswa2P9m4LND3nc+8KaB7aOBe6rqn9vtfwBmAfsC/wx8ZjNq+iXgxva984EAHwR+DvhF4CDgfQBV9dvALcBrqmr3qjojyYHA/wE+AOwNvBP4QpLpm1GDJI3HmcCZVbUn8FSaWSWAl7Tf92rHpsuA49uvlwK/AOwOfAwgyWHA2cBbgAOAJwMHjvmsOcDngb1oxtS1wB8C+wAvAo4Cfm/Me44Bng+8EDgFWNh+xkHAM3niOK5thCFTW62qHgBeDBTwv4DVSZYk2W8zDnNuVa2oqjVV9WPgAtpBpw2bhwJLhrzvs8Brk+zabj8hjFbVOVX1YFU9TBMInz34F/sm3F5VH21r+mlVrayqi6rq4apaDXwE+LWNvP+twNKqWlpVj1bVRcBy4FXj/HxJ+lJ7hvD+JPfTBMBh/h14WpJ9quqhqrp8I8d8C/CRqrqxqh4CTgPmtrNLrwe+XFXfrqpHgPfSjO2DLquqL7Xj2k+r6oqqurwdK28G/pb1x8YPVdUDVbUCuAb4Wvv5P6Y5GfDccf8voinDkKlOVNV1VXV8Vc2g+av054C/2oxD3Dpm+7M8/pftm4EvVdW/DfnclcB1wGvaoPna9r0kmZbk9CT/t53Wubl92z5bUlOSfZMsSnJbe7xPb+JYPw8cO+Y/EC+mOTsgSePxuqraa90X658hXOcEmkuXrk+ybPAyoiF+DvjhwPYPae7R2K/d99jY14679455/9ix8elJvtJemvQA8BesPzbeNfD6p0O2d99IvZqiDJnqXFVdD5xLEzYBfgLsOtBl/2FvG7P9NWCfJM+hCZvDpsrXWTdlPge4tg2e0ITTOcDLaKZ8Dm7bM44fY1hNH2zbntVOSb11zLHG9r8V+NTgfyCqareqOn2cny9J41JVP6iqN9Fc3vMh4PNJdmP9cQngdpo/gteZCayhCX53ADPW7UjyM8DPjv24Mdt/A1wPzGrHxvcw/nFW2zBDprZakkOT/FGSGe32QTShb910zZXAS5LMbKeqT9vUMdu7FT8PfJjmesaLNtJ9EfAKmrvbB8PoHjR3uN9LE3L/YjN+rGH2AB6iuYD+QOBdY/bfRXN90zqfpjnDenR7VvVJSY5c97+TJHUlyVuTTG+fzHF/27wWWA08yhPHpvOBP0xySJLdacbGzw2Mu69J8svtzTh/xqYD4x7AA8BDSQ6lGYslQ6Y68SDNTTLfTfITmnB5DfBHAO21iJ8DrgKuoHnc0Xh8luYs5N9t7BEZVXUHcBnwy+3nrPNJmmmg24BreTz0bqk/A54H/Jjmhp4vjtn/QeCP26nxd1bVrTRnUt9DM9DfShNM/XcnqWvHACuSPERzE9Dcqvp/7XT3fODSdmx6IXAO8CmaO89vAv4f8HaA9prJt9P88X4Hzfh+Nxt/JN07aWaOHqS5Lv9zG+mr7Uiqhp1JlyRJ27v2TOf9NFPhN424HE0xnlGRJEmPSfKaJLu213T+JXA1j984KY2bIVOSJA2aQ3Nz0O00zxmeW057ags4XS5JkqTOeSZTkiRJnTNkStIkkGSvdk3o65Nc164xvXeSi9o1ny8aXLI1yWlJVia5IcnRo6xdkoaZctPl++yzTx188MGjLkPSNuaKK664p6pGtq58kvOAb1XVx9vnE+5K8/ir+6rq9CSnAk+pqne360ufDxxBs0LLPwJPr6q1Gzq+Y6ekPmxs7NxxoovZWgcffDDLly8fdRmStjFJfrjpXr199p7AS4DjAdo1ox9JMgc4su12HnAJ8G6aGzMWVdXDwE1JVtIEzss29BmOnZL6sLGx0+lySRq9X6B5YP8nkvxLko+3j4/Zr11sYN2iA/u2/Q/kietHr2rbJGnSMGRK0ujtSLOa1N9U1XOBnwCnbqT/sGX+1rv2Kcm8JMuTLF+9enU3lUrSOBkyJWn0VgGrquq77fbnaULnXUkOAGi/3z3Q/6CB98+geabhE1TVwqqaXVWzp08f2eWmkrZThkxJGrGquhO4Nckz2qajgGuBJcBxbdtxwAXt6yXA3CS7JDmE5oHZ35vAkiVpk3q98SfJMcCZwDTg41V1+pj9TwY+Dcxsa/nLqvpEnzVJ0iT1duAz7Z3lNwK/Q3MiYHGSE4BbgGMBqmpFksU0QXQNcNLG7iyXpFHoLWQmmQacBbycZmpnWZIlVXXtQLeTgGur6jVJpgM3JPlMe2elJG03qupKYPaQXUdtoP98YH6fNUnS1uhzuvwIYGVV3diGxkU0j90YVMAeSQLsDtxH81e5JEmSprA+Q+Z4HrHxMeAXaS5Yvxr4/ap6tMeaJEmSNAH6DJnjecTG0cCVNCtWPAf4WPtQ4iceyMdwSJIkTSl9hszxPGLjd4AvVmMlcBNw6NgD+RgOSZKkqaXPu8uXAbPax2vcBswF3jymzy00F7V/K8l+wDNo7qrUZjjllFO488472X///TnjjDNGXY4kSVJ/IbOq1iQ5GbiQ5hFG57SP3Tix3b8AeD9wbpKraabX311V9/RV07bqzjvv5Lbbbht1GZKk7dgtf/4fRl2CtsLM917d+TF7fU5mVS0Flo5pWzDw+nbgFX3WIEmSpInnij+SJEnqnCFTkiRJnTNkSpIkqXOGTEmSJHXOkClJkqTOGTIlSZLUOUOmJEmSOmfIlCRJUucMmZIkSeqcIVOSJEmdM2RKkiSpc4ZMSZIkdc6QKUmSpM4ZMiVJktS5HUddwER5/rs+OeoSerPHPQ8yDbjlnge32Z/zig+/bdQlSJKkzeCZTEmSJHXOkClJkqTOGTIlSZLUOUOmJEmSOmfIlCRJUucMmZIkSeqcIVOSJEmdM2RKkiSpc4ZMSZIkda7XkJnkmCQ3JFmZ5NQh+9+V5Mr265oka5Ps3WdNkiRJ6l9vITPJNOAs4JXAYcCbkhw22KeqPlxVz6mq5wCnAd+oqvv6qkmSJEkTo88zmUcAK6vqxqp6BFgEzNlI/zcB5/dYjyRJkiZInyHzQODWge1Vbdt6kuwKHAN8YQP75yVZnmT56tWrOy9UkiRJ3eozZGZIW22g72uASzc0VV5VC6tqdlXNnj59emcFSpIkqR99hsxVwEED2zOA2zfQdy5OlUvajiW5OcnV7Y2Qy9u2vZNclOQH7fenDPQ/rb2p8oYkR4+uckkars+QuQyYleSQJDvTBMklYzsleTLwa8AFPdayTXt0591Yu8uePLrzbqMuRdLWeWl7M+TsdvtU4OKqmgVc3G7T3kQ5Fzic5lKjs9ubLSVp0tixrwNX1ZokJwMXAtOAc6pqRZIT2/0L2q6/AXytqn7SVy3bup/MesWoS5DUjznAke3r84BLgHe37Yuq6mHgpiQraW62vGwENUrSUL2FTICqWgosHdO2YMz2ucC5fdYhSVNAAV9LUsDfVtVCYL+qugOgqu5Ism/b90Dg8oH3Dr2xMsk8YB7AzJkz+6xdktbTa8iUJI3br1TV7W2QvCjJ9RvpO64bK9uguhBg9uzZG7rxUpJ64bKSkjQJVNXt7fe7gb+nmf6+K8kBAO33u9vum3NjpSSNhCFTkkYsyW5J9lj3GngFcA3NzZLHtd2O4/EbJJcAc5PskuQQYBbwvYmtWpI2zulySRq9/YC/TwLNuPzZqvpqkmXA4iQnALcAxwK0N1EuBq4F1gAnVdXa0ZQuScMZMiVpxKrqRuDZQ9rvBY7awHvmA/N7Lk2StpjT5ZIkSeqcIVOSJEmdM2RKkiSpc4ZMSZIkdc6QKUmSpM4ZMiVJktQ5Q6YkSZI6Z8iUJElS5wyZkiRJ6pwhU5IkSZ0zZEqSJKlzhkxJkiR1zpApSZKkzhkyJUmS1DlDpiRJkjpnyJQkSVLnDJmSJEnqnCFTkiRJnes1ZCY5JskNSVYmOXUDfY5McmWSFUm+0Wc9kiRJmhg79nXgJNOAs4CXA6uAZUmWVNW1A332As4GjqmqW5Ls21c9kiRJmjh9nsk8AlhZVTdW1SPAImDOmD5vBr5YVbcAVNXdPdYjSZKkCdJnyDwQuHVge1XbNujpwFOSXJLkiiRv67EeSZIkTZDepsuBDGmrIZ//fOAo4GeAy5JcXlX/+oQDJfOAeQAzZ87soVRJkiR1qc8zmauAgwa2ZwC3D+nz1ar6SVXdA3wTePbYA1XVwqqaXVWzp0+f3lvBkiRJ6kafIXMZMCvJIUl2BuYCS8b0uQD41SQ7JtkV+CXguh5rkiRJ0gTobbq8qtYkORm4EJgGnFNVK5Kc2O5fUFXXJfkqcBXwKPDxqrqmr5okSZI0Mfq8JpOqWgosHdO2YMz2h4EP91mHJEmSJpYr/kiSJKlzhkxJkiR1zpApSZKkzhkyJUmS1DlDpiRJkjpnyJQkSVLnDJmSJEnqnCFTkiRJnTNkSpIkqXOGTEmaBJJMS/IvSb7Sbu+d5KIkP2i/P2Wg72lJVia5IcnRo6takjbMkClJk8PvA9cNbJ8KXFxVs4CL222SHAbMBQ4HjgHOTjJtgmuVpE0yZErSiCWZAfxH4OMDzXOA89rX5wGvG2hfVFUPV9VNwErgiAkqVZLGzZApSaP3V8ApwKMDbftV1R0A7fd92/YDgVsH+q1q29aTZF6S5UmWr169uvOiJWljDJmSNEJJXg3cXVVXjPctQ9pqWMeqWlhVs6tq9vTp07e4RknaEjuOugBJ2s79CvDaJK8CngTsmeTTwF1JDqiqO5IcANzd9l8FHDTw/hnA7RNasSSNg2cyJWmEquq0qppRVQfT3NDz9ap6K7AEOK7tdhxwQft6CTA3yS5JDgFmAd+b4LIlaZM8kylJk9PpwOIkJwC3AMcCVNWKJIuBa4E1wElVtXZ0ZUrScIZMSZokquoS4JL29b3AURvoNx+YP2GFSdIWcLpckiRJnTNkSpIkqXOGTEmSJHXOkClJkqTOGTIlSZLUOUOmJEmSOtdryExyTJIbkqxMcuqQ/Ucm+XGSK9uv9/ZZjyRJkiZGb8/JTDINOAt4Oc0yaMuSLKmqa8d0/VZVvbqvOiRJkjTx+jyTeQSwsqpurKpHgEXAnB4/T5IkSZNEnyHzQODWge1VbdtYL0ry/ST/kOTwHuuRJEnSBOlzWckMaasx2/8M/HxVPZTkVcCXgFnrHSiZB8wDmDlzZsdlSpIkqWt9nslcBRw0sD0DuH2wQ1U9UFUPta+XAjsl2WfsgapqYVXNrqrZ06dP77FkSZIkdaHPkLkMmJXkkCQ7A3OBJYMdkuyfJO3rI9p67u2xJkmSJE2A3qbLq2pNkpOBC4FpwDlVtSLJie3+BcDrgf+aZA3wU2BuVY2dUpckSdIU0+c1meumwJeOaVsw8PpjwMf6rEGSJEkTzxV/JEmS1DlDpiRJkjpnyJQkSVLnDJmSJEnq3EZv/EnyIOs/QP0xVbVn5xVJkiRpyttoyKyqPQCS/DlwJ/ApmpV83gLs0Xt1kiRJmpLGO11+dFWdXVUPtqv0/A3wW30WJkmSpKlrvCFzbZK3JJmWZIckbwHW9lmYJEmSpq7xhsw3A28A7mq/jm3bJEmSpPWMa8WfqroZmNNvKZIkSdpWjOtMZpKnJ7k4yTXt9rOS/HG/pUmSJGmqGu90+f8CTgP+HaCqrgLm9lWUJE1FSS4eT5skbQ/GNV0O7FpV30sy2Lamh3okacpJ8iRgV2CfJE+hedQbwJ7Az42sMEkaofGGzHuSPJX2wexJXg/c0VtVkjS1/BfgD2gC5RU8HjIfAM4aUU2SNFLjDZknAQuBQ5PcBtxE80B2SdruVdWZwJlJ3l5VHx11PZI0GYw3ZP6wql6WZDdgh6p6sM+iJGkqqqqPJvll4GAGxteq+uTIipKkERlvyLwpyVeBzwFf77EeSZqyknwKeCpwJY8vWFGAIVPSdme8IfMZwGtops3/d5KvAIuq6tu9VSZJU89s4LCqqlEXIkmjNt6Hsf8UWAwsbu+cPBP4BjCtx9okaaq5BtifbejGyOe/y5OwU9kVH37bqEvQdmy8ZzJJ8mvAG4FXAstolpmUJD1uH+DaJN8DHl7XWFWvHV1JkjQa4wqZSW6iucZoMfCuqvpJn0VJ0hT1vi15U/uczW8Cu9CMy5+vqj9NsjfNtfAHAzcDb6iqH7XvOQ04gebaz3dU1YVbW7wkdWm8ZzKfXVUP9FqJJE1xVfWNLXzrw8CvV9VDSXYCvp3kH4DfBC6uqtOTnAqcCrw7yWE0q64dTvNszn9M8vSqWruhD5CkibbRkJnklKo6A5ifZL0L2avqHb1VJklTTJIHaRetAHYGdgJ+UlV7bux97Y1CD7WbO7VfBcwBjmzbzwMuAd7dti+qqodpnv6xEjgCuKyrn0WSttamzmRe135f3nchkjTVVdUeg9tJXkcT/jYpyTSa1YKeBpxVVd9Nsl9V3dEe+44k+7bdDwQuH3j7qrZNkiaNjYbMqvpy+/KqqvqXzT14kmNo7kSfBny8qk7fQL8X0AyYb6yqz2/u50jSZFRVX2qnucfTdy3wnCR7AX+f5Jkb6Z4hbevNNiWZB8wDmDlz5njKkKTOjPeazI8kOQD4O5opmhWbekP7V/lZwMtp/spelmRJVV07pN+HAC9alzSlJfnNgc0daJ6buVnPzKyq+5NcAhwD3JXkgPYs5gHA3W23VcBBA2+bAdw+5FgLaZYEZvbs2T67U9KE2mE8narqpTTXBa0GFia5Oskfb+JtRwArq+rGqnoEWERzHdFYbwe+wOODpyRNVa8Z+DoaeJDh494TJJnensEkyc8ALwOuB5YAx7XdjgMuaF8vAeYm2SXJIcAs4Hvd/RiStPXG/ZzMqroT+Osk/wScArwX+MBG3nIgcOvA9irglwY7JDkQ+A3g14EXjLcWSZqMqup3tvCtBwDntTM7OwCLq+orSS6jWQTjBOAW4Nj2c1YkWQxcC6wBTvLOckmTzXifk/mLNA9ifz1wL81ZyT/a1NuGtI2drvkr4N1VtTYZ1v2xz/e6IkmTXpIZwEeBX6EZ774N/H5VrdrY+6rqKuC5Q9rvBY7awHvmA/O3tmZJ6st4z2R+AjgfeEVVrXfdzwaM55qh2cCiNmDuA7wqyZqq+tJgJ68rkjRFfAL4LO0ZR+CtbdvLR1aRJI3IJkNmO33zf6vqzM089jJgVnu90G00Dw5+82CHqjpk4HPOBb4yNmBK0hQyvao+MbB9bpI/GFUxkjRKm7zxp73O52eT7Lw5B66qNcDJNHeNX0dzjdGKJCcmOXGLqpWkye2eJG9NMq39eivNJUaStN0Z73T5D4FLkywBHlu3vKo+srE3VdVSYOmYtgUb6Hv8OGuRpMnqd4GPAf+T5prM7wBbejOQJE1p4w2Zt7dfOwB7bKKvJG2v3g8cV1U/AkiyN/CXNOFTkrYr4wqZVfVnfRciSduAZ60LmABVdV+S9e4al6TtwXgfYfRPDFm1oqp+vfOKJGnq2iHJU8acyRz384glaVsy3sHvnQOvnwT8Fs0DgCVJj/sfwHeSfJ7mD/M34LMsJW2nxjtdfsWYpkuTfKOHeiRpyqqqTyZZTrOKWYDfrKprR1yWJI3EeKfL9x7Y3IHmIer791KRJE1hbag0WEra7o13uvwKHr8mcw1wM3BCHwVJkiRp6ttoyEzyAuDWdSvzJDmO5nrMm/EvdUmSJG3Aplb8+VvgEYAkLwE+CJwH/Jh2LXFJkiRprE1Nl0+rqvva128EFlbVF4AvJLmy18okSZI0ZW3qTOa0JOuC6FHA1wf2+ew3SZIkDbWpoHg+8I0k9wA/Bb4FkORpNFPmkiRJ0no2GjKran6Si4EDgK9V1bo7zHcA3t53cZIkSZqaNjnlXVWXD2n7137KkSRJ0rZgU9dkSpIkSZvNkClJkqTOGTIlSZLUOUOmJEmSOmfIlCRJUucMmZIkSeqcIVOSJEmdM2RKkiSpc4ZMSZIkdc6QKUmSpM71GjKTHJPkhiQrk5w6ZP+cJFcluTLJ8iQv7rMeSZIkTYxNrl2+pZJMA84CXg6sApYlWVJV1w50uxhYUlWV5FnAYuDQvmqSJEnSxOjzTOYRwMqqurGqHgEWAXMGO1TVQ1VV7eZuQCFJkqQpr8+QeSBw68D2qrbtCZL8RpLrgf8D/O6wAyWZ106nL1+9enUvxUqSJKk7fYbMDGlb70xlVf19VR0KvA54/7ADVdXCqppdVbOnT5/ebZWSJEnqXJ8hcxVw0MD2DOD2DXWuqm8CT02yT481SZIkaQL0GTKXAbOSHJJkZ2AusGSwQ5KnJUn7+nnAzsC9PdYkSZKkCdDb3eVVtSbJycCFwDTgnKpakeTEdv8C4LeAtyX5d+CnwBsHbgSSJEnSFNVbyASoqqXA0jFtCwZefwj4UJ81SNJkl+Qg4JPA/sCjwMKqOjPJ3sDngIOBm4E3VNWP2vecBpwArAXeUVUXjqB0SdogV/yRpNFbA/xRVf0i8ELgpCSHAacCF1fVLJrnCp8K0O6bCxwOHAOc3T6bWJImDUOmJI1YVd1RVf/cvn4QuI7mkW9zgPPabufRPIWDtn1RVT1cVTcBK2meTSxJk4YhU5ImkSQHA88FvgvsV1V3QBNEgX3bbuN6DrEkjZIhU5ImiSS7A18A/qCqHthY1yFt69006UIWkkbJkClJk0CSnWgC5meq6ott811JDmj3HwDc3baP6znELmQhaZQMmZI0Yu3zgv83cF1VfWRg1xLguPb1ccAFA+1zk+yS5BBgFvC9iapXksaj10cYSZLG5VeA3wauTnJl2/Ye4HRgcZITgFuAYwHaZw4vBq6luTP9pKpaO+FVS9JGGDIlacSq6tsMv84S4KgNvGc+ML+3oiRpKzldLkmSpM4ZMiVJktQ5Q6YkSZI6Z8iUJElS5wyZkiRJ6pwhU5IkSZ0zZEqSJKlzhkxJkiR1zpApSZKkzhkyJUmS1DlDpiRJkjrn2uXSduSUU07hzjvvZP/99+eMM84YdTmSpG2YIVPajtx5553cdtttoy5DkrQdcLpckiRJnTNkSpIkqXO9hswkxyS5IcnKJKcO2f+WJFe1X99J8uw+65EkSdLE6C1kJpkGnAW8EjgMeFOSw8Z0uwn4tap6FvB+YGFf9UiSJGni9Hkm8whgZVXdWFWPAIuAOYMdquo7VfWjdvNyYEaP9UiSJGmC9BkyDwRuHdhe1bZtyAnAP/RYjyRJkiZIn48wypC2GtoxeSlNyHzxBvbPA+YBzJw5s6v6JEmS1JM+z2SuAg4a2J4B3D62U5JnAR8H5lTVvcMOVFULq2p2Vc2ePn16L8VKkiSpO32GzGXArCSHJNkZmAssGeyQZCbwReC3q+pfe6xFkiRJE6i36fKqWpPkZOBCYBpwTlWtSHJiu38B8F7gZ4GzkwCsqarZfdUkSZKkidHrspJVtRRYOqZtwcDr/wT8pz5rkLbELX/+H0ZdQi/W3Lc3sCNr7vvhNvszznzv1aMuQZKEK/5IkiSpB4ZMSZIkdc6QKUmSpM4ZMiVJktQ5Q6YkSZI6Z8iUJElS5wyZkiRJ6pwhU5IkSZ0zZEqSJKlzhkxJkiR1zpApSZKkzvW6drmkyWWfJz0KrGm/S5LUH0OmtB1557PuH3UJGiLJOcCrgbur6plt297A54CDgZuBN1TVj9p9pwEnAGuBd1TVhSMoW5I2yulySRq9c4FjxrSdClxcVbOAi9ttkhwGzAUOb99zdpJpE1eqJI2PIVOSRqyqvgncN6Z5DnBe+/o84HUD7Yuq6uGquglYCRwxEXVK0uYwZErS5LRfVd0B0H7ft20/ELh1oN+qtk2SJhVDpiRNLRnSVkM7JvOSLE+yfPXq1T2XJUlPZMiUpMnpriQHALTf727bVwEHDfSbAdw+7ABVtbCqZlfV7OnTp/darCSNZciUpMlpCXBc+/o44IKB9rlJdklyCDAL+N4I6pOkjfIRRpI0YknOB44E9kmyCvhT4HRgcZITgFuAYwGqakWSxcC1wBrgpKpaO5LCJWkjDJmSNGJV9aYN7DpqA/3nA/P7q0iStp7T5ZIkSeqcIVOSJEmdM2RKkiSpc72GzCTHJLkhycokpw7Zf2iSy5I8nOSdfdYiSZKkidPbjT/tWrpnAS+nea7bsiRLquragW73Ae/g8eXSJEmStA3o80zmEcDKqrqxqh4BFtGsufuYqrq7qpYB/95jHZIkSZpgfYZM19eVJEnaTvUZMse9vu4mD+T6u5IkSVNKnyFz3Ovrborr70qSJE0tfYbMZcCsJIck2RmYS7PmriRJkrZxvd1dXlVrkpwMXAhMA85p19w9sd2/IMn+wHJgT+DRJH8AHFZVD/RVlyRJkvrX69rlVbUUWDqmbcHA6ztpptElSZK0DXHFH0mSJHXOkClJkqTOGTIlSZLUOUOmJEmSOmfIlCRJUucMmZIkSeqcIVOSJEmdM2RKkiSpc4ZMSZIkdc6QKUmSpM4ZMiVJktQ5Q6YkSZI6Z8iUJElS5wyZkiRJ6pwhU5IkSZ0zZEqSJKlzhkxJkiR1zpApSZKkzhkyJUmS1DlDpiRJkjpnyJQkSVLnDJmSJEnqnCFTkiRJnTNkSpIkqXO9hswkxyS5IcnKJKcO2Z8kf93uvyrJ8/qsR5K2FZsaXyVp1HoLmUmmAWcBrwQOA96U5LAx3V4JzGq/5gF/01c9krStGOf4Kkkj1eeZzCOAlVV1Y1U9AiwC5ozpMwf4ZDUuB/ZKckCPNUnStmA846skjVSfIfNA4NaB7VVt2+b2kSQ9kWOnpElvxx6PnSFttQV9SDKPZjod4KEkN2xlbduifYB7Rl1EX/KXx426hG3JNv27wp8OG1bG5ee7LKNnjp3d2ab/PTh2dmqb/l3pY+zsM2SuAg4a2J4B3L4FfaiqhcDCrgvcliRZXlWzR12HJj9/V7YJjp0d8d+Dxsvflc3X53T5MmBWkkOS7AzMBZaM6bMEeFt7l/kLgR9X1R091iRJ24LxjK+SNFK9ncmsqjVJTgYuBKYB51TViiQntvsXAEuBVwErgX8DfqeveiRpW7Gh8XXEZUnSE6Rqvct4NAUlmddOjUkb5e+K9Dj/PWi8/F3ZfIZMSZIkdc5lJSVJktQ5Q+YkkeQdSa5L8pmejv++JO/s49ia2pIcmeQro65D2hKOnRoVx85N6/MRRto8vwe8sqpuGnUhkjSFOHZKk5RnMieBJAuAXwCWJPnvSc5JsizJvySZ0/Y5PsmXknw5yU1JTk7y39o+lyfZu+33n9v3fj/JF5LsOuTznprkq0muSPKtJIdO7E+sriU5OMn1ST6e5Jokn0nysiSXJvlBkiPar++0vzPfSfKMIcfZbdjvnzQZOXZqazl29suQOQlU1Yk0D1J+KbAb8PWqekG7/eEku7Vdnwm8mWbd4vnAv1XVc4HLgLe1fb5YVS+oqmcD1wEnDPnIhcDbq+r5wDuBs/v5yTTBngacCTwLOJTmd+XFNP8fvwe4HnhJ+zvzXuAvhhzjv7Ph3z9pUnHsVEccO3vidPnk8wrgtQPXAD0JmNm+/qeqehB4MMmPgS+37VfT/OMAeGaSDwB7AbvTPEfvMUl2B34Z+LvksSWkdunh59DEu6mqrgZIsgK4uKoqydXAwcCTgfOSzKJZgnCnIcfY0O/fdX0XL20lx05tKcfOnhgyJ58Av1VVT1hjOMkvAQ8PND06sP0oj/9/eS7wuqr6fpLjgSPHHH8H4P6qek6nVWsy2NTvx/tp/mP7G0kOBi4Zcoyhv3/SFODYqS3l2NkTp8snnwuBt6f9UznJczfz/XsAdyTZCXjL2J1V9QBwU5Jj2+MnybO3smZNDU8GbmtfH7+BPlv7+yeNimOn+uLYuYUMmZPP+2lOxV+V5Jp2e3P8CfBd4CKa60iGeQtwQpLvAysAL1DePpwBfDDJpTRLEQ6ztb9/0qg4dqovjp1byBV/JEmS1DnPZEqSJKlzhkxJkiR1zpApSZKkzhkyJUmS1DlDpiRJkjpnyNSU1q5XvCLJVUmubB+8LEnaCMdOTQRX/NGUleRFwKuB51XVw0n2AXYecVmSNKk5dmqieCZTU9kBwD1V9TBAVd1TVbcneX6SbyS5IsmFSQ5I8uQkNyR5BkCS85P855FWL0mj4dipCeHD2DVlJdkd+DawK/CPwOeA7wDfAOZU1eokbwSOrqrfTfJy4M+BM4Hjq+qYEZUuSSPj2KmJ4nS5pqyqeijJ84FfBV5KM1B+AHgmcFG7hOw04I62/0XtusNnAa45LGm75NipieKZTG0zkrweOAl4UlW9aMj+HWj+Uj8EeFVVXTXBJUrSpOPYqb54TaamrCTPSDJroOk5wHXA9PbCdpLslOTwdv8ftvvfBJyTZKeJrFeSJgPHTk0Uz2Rqymqnez4K7AWsAVYC84AZwF8DT6a5JOSvaP4KvwA4oqoeTPIR4MGq+tOJr1ySRsexUxPFkClJkqTOOV0uSZKkzhkyJUmS1DlDpiRJkjpnyJQkSVLnDJmSJEnqnCFTkiRJnTNkSpIkqXOGTEmSJHXu/wNqeVkkNDm5xAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "\n", "fig,axs = plt.subplots(1,2, figsize=(11,5))\n", "sns.barplot(data=df, y=dep, x=\"Sex\", ax=axs[0]).set(title=\"Survival rate\")\n", "sns.countplot(data=df, x=\"Sex\", ax=axs[1]).set(title=\"Histogram\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we see that (on the left) if we split the data into males and females, we'd have groups that have very different survival rates: >70% for females, and <20% for males. We can also see (on the right) that the split would be reasonably even, with over 300 passengers (out of around 900) in each group.\n", "\n", "We could create a very simple \"model\" which simply says that all females survive, and no males do. To do so, we better first split our data into a training and validation set, to see how accurate this approach turns out to be:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:32.688753Z", "iopub.status.busy": "2022-05-23T04:37:32.688257Z", "iopub.status.idle": "2022-05-23T04:37:32.908453Z", "shell.execute_reply": "2022-05-23T04:37:32.907399Z", "shell.execute_reply.started": "2022-05-23T04:37:32.688718Z" } }, "outputs": [], "source": [ "from numpy import random\n", "from sklearn.model_selection import train_test_split\n", "\n", "random.seed(42)\n", "trn_df,val_df = train_test_split(df, test_size=0.25)\n", "trn_df[cats] = trn_df[cats].apply(lambda x: x.cat.codes)\n", "val_df[cats] = val_df[cats].apply(lambda x: x.cat.codes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(In the previous step we also replaced the categorical variables with their integer codes, since some of the models we'll be building in a moment require that.)\n", "\n", "Now we can create our independent variables (the `x` variables) and dependent (the `y` variable):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:33.154552Z", "iopub.status.busy": "2022-05-23T04:37:33.153932Z", "iopub.status.idle": "2022-05-23T04:37:33.164771Z", "shell.execute_reply": "2022-05-23T04:37:33.163736Z", "shell.execute_reply.started": "2022-05-23T04:37:33.154505Z" } }, "outputs": [], "source": [ "def xs_y(df):\n", " xs = df[cats+conts].copy()\n", " return xs,df[dep] if dep in df else None\n", "\n", "trn_xs,trn_y = xs_y(trn_df)\n", "val_xs,val_y = xs_y(val_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the predictions for our extremely simple model, where `female` is coded as `0`:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:33.932259Z", "iopub.status.busy": "2022-05-23T04:37:33.931549Z", "iopub.status.idle": "2022-05-23T04:37:33.940681Z", "shell.execute_reply": "2022-05-23T04:37:33.93969Z", "shell.execute_reply.started": "2022-05-23T04:37:33.932218Z" } }, "outputs": [], "source": [ "preds = val_xs.Sex==0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use mean absolute error to measure how good this model is:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:34.580131Z", "iopub.status.busy": "2022-05-23T04:37:34.579605Z", "iopub.status.idle": "2022-05-23T04:37:34.588669Z", "shell.execute_reply": "2022-05-23T04:37:34.58794Z", "shell.execute_reply.started": "2022-05-23T04:37:34.580088Z" } }, "outputs": [ { "data": { "text/plain": [ "0.21524663677130046" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import mean_absolute_error\n", "mean_absolute_error(val_y, preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, we could try splitting on a continuous column. We have to use a somewhat different chart to see how this might work -- here's an example of how we could look at `LogFare`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:36.079566Z", "iopub.status.busy": "2022-05-23T04:37:36.078932Z", "iopub.status.idle": "2022-05-23T04:37:36.428338Z", "shell.execute_reply": "2022-05-23T04:37:36.427385Z", "shell.execute_reply.started": "2022-05-23T04:37:36.079515Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df_fare = trn_df[trn_df.LogFare>0]\n", "fig,axs = plt.subplots(1,2, figsize=(11,5))\n", "sns.boxenplot(data=df_fare, x=dep, y=\"LogFare\", ax=axs[0])\n", "sns.kdeplot(data=df_fare, x=\"LogFare\", ax=axs[1]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [boxenplot](https://seaborn.pydata.org/generated/seaborn.boxenplot.html) above shows quantiles of `LogFare` for each group of `Survived==0` and `Survived==1`. It shows that the average `LogFare` for passengers that didn't survive is around `2.5`, and for those that did it's around `3.2`. So it seems that people that paid more for their tickets were more likely to get put on a lifeboat.\n", "\n", "Let's create a simple model based on this observation:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:38.68033Z", "iopub.status.busy": "2022-05-23T04:37:38.680007Z", "iopub.status.idle": "2022-05-23T04:37:38.68637Z", "shell.execute_reply": "2022-05-23T04:37:38.685259Z", "shell.execute_reply.started": "2022-05-23T04:37:38.680295Z" } }, "outputs": [], "source": [ "preds = val_xs.LogFare>2.7" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and test it out:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:40.107202Z", "iopub.status.busy": "2022-05-23T04:37:40.106891Z", "iopub.status.idle": "2022-05-23T04:37:40.114544Z", "shell.execute_reply": "2022-05-23T04:37:40.113805Z", "shell.execute_reply.started": "2022-05-23T04:37:40.107167Z" } }, "outputs": [ { "data": { "text/plain": [ "0.336322869955157" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_absolute_error(val_y, preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is quite a bit less accurate than our model that used `Sex` as the single binary split.\n", "\n", "Ideally, we'd like some way to try more columns and breakpoints more easily. We could create a function that returns how good our model is, in order to more quickly try out a few different splits. We'll create a `score` function to do this. Instead of returning the mean absolute error, we'll calculate a measure of *impurity* -- that is, how much the binary split creates two groups where the rows in a group are each similar to each other, or dissimilar.\n", "\n", "We can measure the similarity of rows inside a group by taking the standard deviation of the dependent variable. If it's higher, then it means the rows are more different to each other. We'll then multiply this by the number of rows, since a bigger group has more impact than a smaller group:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:40.630562Z", "iopub.status.busy": "2022-05-23T04:37:40.62976Z", "iopub.status.idle": "2022-05-23T04:37:40.636571Z", "shell.execute_reply": "2022-05-23T04:37:40.635377Z", "shell.execute_reply.started": "2022-05-23T04:37:40.630514Z" } }, "outputs": [], "source": [ "def _side_score(side, y):\n", " tot = side.sum()\n", " if tot<=1: return 0\n", " return y[side].std()*tot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we've got that written, we can calculate the score for a split by adding up the scores for the \"left hand side\" (lhs) and \"right hand side\" (rhs):" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:41.160896Z", "iopub.status.busy": "2022-05-23T04:37:41.160555Z", "iopub.status.idle": "2022-05-23T04:37:41.166645Z", "shell.execute_reply": "2022-05-23T04:37:41.165803Z", "shell.execute_reply.started": "2022-05-23T04:37:41.160861Z" } }, "outputs": [], "source": [ " \n", "def score(col, y, split):\n", " lhs = col<=split\n", " return (_side_score(lhs,y) + _side_score(~lhs,y))/len(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For instance, here's the impurity score for the split on `Sex`:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:42.530883Z", "iopub.status.busy": "2022-05-23T04:37:42.529881Z", "iopub.status.idle": "2022-05-23T04:37:42.542981Z", "shell.execute_reply": "2022-05-23T04:37:42.541924Z", "shell.execute_reply.started": "2022-05-23T04:37:42.530826Z" } }, "outputs": [ { "data": { "text/plain": [ "0.40787530982063946" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score(trn_xs[\"Sex\"], trn_y, 0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and for `LogFare`:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:43.657577Z", "iopub.status.busy": "2022-05-23T04:37:43.656642Z", "iopub.status.idle": "2022-05-23T04:37:43.665306Z", "shell.execute_reply": "2022-05-23T04:37:43.66471Z", "shell.execute_reply.started": "2022-05-23T04:37:43.657534Z" } }, "outputs": [ { "data": { "text/plain": [ "0.47180873952099694" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "score(trn_xs[\"LogFare\"], trn_y, 2.7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we'd expect from our earlier tests, `Sex` appears to be a better split.\n", "\n", "To make it easier to find the best binary split, we can create a simple interactive tool (note that this only works in Kaggle if you click \"Copy and Edit\" in the top right to open the notebook editor):" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:44.159662Z", "iopub.status.busy": "2022-05-23T04:37:44.159112Z", "iopub.status.idle": "2022-05-23T04:37:44.20316Z", "shell.execute_reply": "2022-05-23T04:37:44.202647Z", "shell.execute_reply.started": "2022-05-23T04:37:44.159612Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cca27efa8e1643e990264062f595df43", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(Dropdown(description='nm', options=('Age', 'SibSp', 'Parch', 'LogFare', 'Pclass'), value…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def iscore(nm, split):\n", " col = trn_xs[nm]\n", " return score(col, trn_y, split)\n", "\n", "from ipywidgets import interact\n", "interact(nm=conts, split=15.5)(iscore);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try selecting different columns and split points using the dropdown and slider above. What splits can you find that increase the purity of the data?\n", "\n", "We can do the same thing for the categorical variables:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:45.542149Z", "iopub.status.busy": "2022-05-23T04:37:45.541373Z", "iopub.status.idle": "2022-05-23T04:37:45.587574Z", "shell.execute_reply": "2022-05-23T04:37:45.58671Z", "shell.execute_reply.started": "2022-05-23T04:37:45.542101Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6078c0b725e464998e37b5c00c5de08", "version_major": 2, "version_minor": 0 }, "text/plain": [ "interactive(children=(Dropdown(description='nm', options=('Sex', 'Embarked'), value='Sex'), IntSlider(value=2,…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "interact(nm=cats, split=2)(iscore);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That works well enough, but it's rather slow and fiddly. Perhaps we could get the computer to automatically find the best split point for a column for us? For example, to find the best split point for `age` we'd first need to make a list of all the possible split points (i.e all the unique values of that field)...:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:47.225685Z", "iopub.status.busy": "2022-05-23T04:37:47.225336Z", "iopub.status.idle": "2022-05-23T04:37:47.234728Z", "shell.execute_reply": "2022-05-23T04:37:47.233526Z", "shell.execute_reply.started": "2022-05-23T04:37:47.225646Z" } }, "outputs": [ { "data": { "text/plain": [ "array([ 0.42, 0.67, 0.75, 0.83, 0.92, 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. , 10. , 11. , 12. ,\n", " 13. , 14. , 14.5 , 15. , 16. , 17. , 18. , 19. , 20. , 21. , 22. , 23. , 24. , 24.5 , 25. , 26. , 27. ,\n", " 28. , 28.5 , 29. , 30. , 31. , 32. , 32.5 , 33. , 34. , 34.5 , 35. , 36. , 36.5 , 37. , 38. , 39. , 40. ,\n", " 40.5 , 41. , 42. , 43. , 44. , 45. , 45.5 , 46. , 47. , 48. , 49. , 50. , 51. , 52. , 53. , 54. , 55. ,\n", " 55.5 , 56. , 57. , 58. , 59. , 60. , 61. , 62. , 64. , 65. , 70. , 70.5 , 74. , 80. ])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nm = \"Age\"\n", "col = trn_xs[nm]\n", "unq = col.unique()\n", "unq.sort()\n", "unq" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and find which index of those values is where `score()` is the lowest:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:49.070221Z", "iopub.status.busy": "2022-05-23T04:37:49.069942Z", "iopub.status.idle": "2022-05-23T04:37:49.150766Z", "shell.execute_reply": "2022-05-23T04:37:49.14992Z", "shell.execute_reply.started": "2022-05-23T04:37:49.070191Z" } }, "outputs": [ { "data": { "text/plain": [ "6.0" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores = np.array([score(col, trn_y, o) for o in unq if not np.isnan(o)])\n", "unq[scores.argmin()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on this, it looks like, for instance, that for the `Age` column, `6` is the optimal cutoff according to our training set.\n", "\n", "We can write a little function that implements this idea:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:50.094341Z", "iopub.status.busy": "2022-05-23T04:37:50.093867Z", "iopub.status.idle": "2022-05-23T04:37:50.175498Z", "shell.execute_reply": "2022-05-23T04:37:50.174679Z", "shell.execute_reply.started": "2022-05-23T04:37:50.094305Z" } }, "outputs": [ { "data": { "text/plain": [ "(6.0, 0.478316717508991)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def min_col(df, nm):\n", " col,y = df[nm],df[dep]\n", " unq = col.dropna().unique()\n", " scores = np.array([score(col, y, o) for o in unq if not np.isnan(o)])\n", " idx = scores.argmin()\n", " return unq[idx],scores[idx]\n", "\n", "min_col(trn_df, \"Age\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try all the columns:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:51.295064Z", "iopub.status.busy": "2022-05-23T04:37:51.294528Z", "iopub.status.idle": "2022-05-23T04:37:51.577522Z", "shell.execute_reply": "2022-05-23T04:37:51.57682Z", "shell.execute_reply.started": "2022-05-23T04:37:51.29503Z" } }, "outputs": [ { "data": { "text/plain": [ "{'Sex': (0, 0.40787530982063946),\n", " 'Embarked': (0, 0.47883342573147836),\n", " 'Age': (6.0, 0.478316717508991),\n", " 'SibSp': (4, 0.4783740258817434),\n", " 'Parch': (0, 0.4805296527841601),\n", " 'LogFare': (2.4390808375825834, 0.4620823937736597),\n", " 'Pclass': (2, 0.46048261885806596)}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = cats+conts\n", "{o:min_col(trn_df, o) for o in cols}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "According to this, `Sex<=0` is the best split we can use.\n", "\n", "We've just re-invented the [OneR](https://link.springer.com/article/10.1023/A:1022631118932) classifier (or at least, a minor variant of it), which was found to be one of the most effective classifiers in real-world datasets, compared to the algorithms in use in 1993. Since it's so simple and surprisingly effective, it makes for a great *baseline* -- that is, a starting point that you can use to compare your more sophisticated models to.\n", "\n", "We found earlier that our OneR rule had an error of around `0.215`, so we'll keep that in mind as we try out more sophisticated approaches." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a decision tree" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How can we improve our OneR classifier, which predicts survival based only on `Sex`?\n", "\n", "How about we take each of our two groups, `female` and `male`, and create one more binary split for each of them. That is: find the single best split for females, and the single best split for males. To do this, all we have to do is repeat the previous section's steps, once for males, and once for females.\n", "\n", "First, we'll remove `Sex` from the list of possible splits (since we've already used it, and there's only one possible split for that binary column), and create our two groups:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:52.822448Z", "iopub.status.busy": "2022-05-23T04:37:52.821912Z", "iopub.status.idle": "2022-05-23T04:37:52.829494Z", "shell.execute_reply": "2022-05-23T04:37:52.82831Z", "shell.execute_reply.started": "2022-05-23T04:37:52.822405Z" } }, "outputs": [], "source": [ "cols.remove(\"Sex\")\n", "ismale = trn_df.Sex==1\n", "males,females = trn_df[ismale],trn_df[~ismale]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's find the single best binary split for males...:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:54.372844Z", "iopub.status.busy": "2022-05-23T04:37:54.372319Z", "iopub.status.idle": "2022-05-23T04:37:54.585566Z", "shell.execute_reply": "2022-05-23T04:37:54.584693Z", "shell.execute_reply.started": "2022-05-23T04:37:54.372801Z" } }, "outputs": [ { "data": { "text/plain": [ "{'Embarked': (0, 0.3875581870410906),\n", " 'Age': (6.0, 0.3739828371010595),\n", " 'SibSp': (4, 0.3875864227586273),\n", " 'Parch': (0, 0.3874704821461959),\n", " 'LogFare': (2.803360380906535, 0.3804856231758151),\n", " 'Pclass': (1, 0.38155442004360934)}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "{o:min_col(males, o) for o in cols}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and for females:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:55.970401Z", "iopub.status.busy": "2022-05-23T04:37:55.969784Z", "iopub.status.idle": "2022-05-23T04:37:56.171762Z", "shell.execute_reply": "2022-05-23T04:37:56.170711Z", "shell.execute_reply.started": "2022-05-23T04:37:55.970348Z" } }, "outputs": [ { "data": { "text/plain": [ "{'Embarked': (0, 0.4295252982857327),\n", " 'Age': (50.0, 0.4225927658431649),\n", " 'SibSp': (4, 0.42319212059713535),\n", " 'Parch': (3, 0.4193314500446158),\n", " 'LogFare': (4.256321678298823, 0.41350598332911376),\n", " 'Pclass': (2, 0.3335388911567601)}" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "{o:min_col(females, o) for o in cols}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the best next binary split for males is `Age<=6`, and for females is `Pclass<=2`.\n", "\n", "By adding these rules, we have created a *decision tree*, where our model will first check whether `Sex` is female or male, and depending on the result will then check either the above `Age` or `Pclass` rules, as appropriate. We could then repeat the process, creating new additional rules for each of the four groups we've now created.\n", "\n", "Rather than writing that code manually, we can use `DecisionTreeClassifier`, from *sklearn*, which does exactly that for us:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:57.022761Z", "iopub.status.busy": "2022-05-23T04:37:57.022405Z", "iopub.status.idle": "2022-05-23T04:37:57.201325Z", "shell.execute_reply": "2022-05-23T04:37:57.200214Z", "shell.execute_reply.started": "2022-05-23T04:37:57.022724Z" } }, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier, export_graphviz\n", "\n", "m = DecisionTreeClassifier(max_leaf_nodes=4).fit(trn_xs, trn_y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One handy feature or this class is that it provides a function for drawing a tree representing the rules:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:57.253844Z", "iopub.status.busy": "2022-05-23T04:37:57.253469Z", "iopub.status.idle": "2022-05-23T04:37:57.274531Z", "shell.execute_reply": "2022-05-23T04:37:57.27374Z", "shell.execute_reply.started": "2022-05-23T04:37:57.253805Z" } }, "outputs": [], "source": [ "import graphviz\n", "\n", "def draw_tree(t, df, size=10, ratio=0.6, precision=2, **kwargs):\n", " s=export_graphviz(t, out_file=None, feature_names=df.columns, filled=True, rounded=True,\n", " special_characters=True, rotate=False, precision=precision, **kwargs)\n", " return graphviz.Source(re.sub('Tree {', f'Tree {{ size={size}; ratio={ratio}', s))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:58.207169Z", "iopub.status.busy": "2022-05-23T04:37:58.206306Z", "iopub.status.idle": "2022-05-23T04:37:59.286087Z", "shell.execute_reply": "2022-05-23T04:37:59.284841Z", "shell.execute_reply.started": "2022-05-23T04:37:58.207112Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "Sex ≤ 0.5\n", "gini = 0.47\n", "samples = 668\n", "value = [415, 253]\n", "\n", "\n", "\n", "1\n", "\n", "Pclass ≤ 2.5\n", "gini = 0.38\n", "samples = 229\n", "value = [59, 170]\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "2\n", "\n", "Age ≤ 6.5\n", "gini = 0.31\n", "samples = 439\n", "value = [356, 83]\n", "\n", "\n", "\n", "0->2\n", "\n", "\n", "False\n", "\n", "\n", "\n", "3\n", "\n", "gini = 0.06\n", "samples = 120\n", "value = [4, 116]\n", "\n", "\n", "\n", "1->3\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "gini = 0.5\n", "samples = 109\n", "value = [55, 54]\n", "\n", "\n", "\n", "1->4\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "gini = 0.41\n", "samples = 21\n", "value = [6, 15]\n", "\n", "\n", "\n", "2->5\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "gini = 0.27\n", "samples = 418\n", "value = [350, 68]\n", "\n", "\n", "\n", "2->6\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_tree(m, trn_xs, size=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that it's found exactly the same splits as we did!\n", "\n", "In this picture, the more orange nodes have a lower survival rate, and blue have higher survival. Each node shows how many rows (\"*samples*\") match that set of rules, and shows how many perish or survive (\"*values*\"). There's also something called \"*gini*\". That's another measure of impurity, and it's very similar to the `score()` we created earlier. It's defined as follows:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:37:59.76912Z", "iopub.status.busy": "2022-05-23T04:37:59.76879Z", "iopub.status.idle": "2022-05-23T04:37:59.776055Z", "shell.execute_reply": "2022-05-23T04:37:59.774896Z", "shell.execute_reply.started": "2022-05-23T04:37:59.769086Z" } }, "outputs": [], "source": [ "def gini(cond):\n", " act = df.loc[cond, dep]\n", " return 1 - act.mean()**2 - (1-act).mean()**2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What this calculates is the probability that, if you pick two rows from a group, you'll get the same `Survived` result each time. If the group is all the same, the probability is `1.0`, and `0.0` if they're all different:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:38:01.166595Z", "iopub.status.busy": "2022-05-23T04:38:01.166139Z", "iopub.status.idle": "2022-05-23T04:38:01.177753Z", "shell.execute_reply": "2022-05-23T04:38:01.176747Z", "shell.execute_reply.started": "2022-05-23T04:38:01.166562Z" } }, "outputs": [ { "data": { "text/plain": [ "(0.3828350034484158, 0.3064437162277842)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gini(df.Sex=='female'), gini(df.Sex=='male')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how this model compares to our OneR version:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:38:02.057241Z", "iopub.status.busy": "2022-05-23T04:38:02.056429Z", "iopub.status.idle": "2022-05-23T04:38:02.068392Z", "shell.execute_reply": "2022-05-23T04:38:02.067471Z", "shell.execute_reply.started": "2022-05-23T04:38:02.057192Z" } }, "outputs": [ { "data": { "text/plain": [ "0.2242152466367713" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_absolute_error(val_y, m.predict(val_xs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's a tiny bit worse. Since this is such a small dataset (we've only got around 200 rows in our validation set) this small difference isn't really meaningful. Perhaps we'll see better results if we create a bigger tree:" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:38:02.926983Z", "iopub.status.busy": "2022-05-23T04:38:02.926216Z", "iopub.status.idle": "2022-05-23T04:38:02.981361Z", "shell.execute_reply": "2022-05-23T04:38:02.98003Z", "shell.execute_reply.started": "2022-05-23T04:38:02.926945Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "Sex ≤ 0.5\n", "gini = 0.47\n", "samples = 668\n", "value = [415, 253]\n", "\n", "\n", "\n", "1\n", "\n", "Pclass ≤ 2.5\n", "gini = 0.38\n", "samples = 229\n", "value = [59, 170]\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "8\n", "\n", "LogFare ≤ 3.31\n", "gini = 0.31\n", "samples = 439\n", "value = [356, 83]\n", "\n", "\n", "\n", "0->8\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "Age ≤ 28.5\n", "gini = 0.06\n", "samples = 120\n", "value = [4, 116]\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "LogFare ≤ 2.7\n", "gini = 0.5\n", "samples = 109\n", "value = [55, 54]\n", "\n", "\n", "\n", "1->5\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "gini = 0.11\n", "samples = 53\n", "value = [3, 50]\n", "\n", "\n", "\n", "2->3\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "gini = 0.03\n", "samples = 67\n", "value = [1, 66]\n", "\n", "\n", "\n", "2->4\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "gini = 0.49\n", "samples = 59\n", "value = [25, 34]\n", "\n", "\n", "\n", "5->6\n", "\n", "\n", "\n", "\n", "\n", "7\n", "\n", "gini = 0.48\n", "samples = 50\n", "value = [30, 20]\n", "\n", "\n", "\n", "5->7\n", "\n", "\n", "\n", "\n", "\n", "9\n", "\n", "Age ≤ 20.5\n", "gini = 0.24\n", "samples = 320\n", "value = [275, 45]\n", "\n", "\n", "\n", "8->9\n", "\n", "\n", "\n", "\n", "\n", "18\n", "\n", "SibSp ≤ 0.5\n", "gini = 0.43\n", "samples = 119\n", "value = [81, 38]\n", "\n", "\n", "\n", "8->18\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "gini = 0.43\n", "samples = 55\n", "value = [38, 17]\n", "\n", "\n", "\n", "9->10\n", "\n", "\n", "\n", "\n", "\n", "11\n", "\n", "Age ≤ 32.5\n", "gini = 0.19\n", "samples = 265\n", "value = [237, 28]\n", "\n", "\n", "\n", "9->11\n", "\n", "\n", "\n", "\n", "\n", "12\n", "\n", "Age ≤ 24.75\n", "gini = 0.22\n", "samples = 181\n", "value = [158, 23]\n", "\n", "\n", "\n", "11->12\n", "\n", "\n", "\n", "\n", "\n", "17\n", "\n", "gini = 0.11\n", "samples = 84\n", "value = [79, 5]\n", "\n", "\n", "\n", "11->17\n", "\n", "\n", "\n", "\n", "\n", "13\n", "\n", "LogFare ≤ 2.18\n", "gini = 0.16\n", "samples = 114\n", "value = [104, 10]\n", "\n", "\n", "\n", "12->13\n", "\n", "\n", "\n", "\n", "\n", "16\n", "\n", "gini = 0.31\n", "samples = 67\n", "value = [54, 13]\n", "\n", "\n", "\n", "12->16\n", "\n", "\n", "\n", "\n", "\n", "14\n", "\n", "gini = 0.21\n", "samples = 50\n", "value = [44, 6]\n", "\n", "\n", "\n", "13->14\n", "\n", "\n", "\n", "\n", "\n", "15\n", "\n", "gini = 0.12\n", "samples = 64\n", "value = [60, 4]\n", "\n", "\n", "\n", "13->15\n", "\n", "\n", "\n", "\n", "\n", "19\n", "\n", "gini = 0.48\n", "samples = 60\n", "value = [36, 24]\n", "\n", "\n", "\n", "18->19\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "gini = 0.36\n", "samples = 59\n", "value = [45, 14]\n", "\n", "\n", "\n", "18->20\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = DecisionTreeClassifier(min_samples_leaf=50)\n", "m.fit(trn_xs, trn_y)\n", "draw_tree(m, trn_xs, size=25)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:38:03.953946Z", "iopub.status.busy": "2022-05-23T04:38:03.953608Z", "iopub.status.idle": "2022-05-23T04:38:03.965301Z", "shell.execute_reply": "2022-05-23T04:38:03.964474Z", "shell.execute_reply.started": "2022-05-23T04:38:03.953912Z" } }, "outputs": [ { "data": { "text/plain": [ "0.18385650224215247" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_absolute_error(val_y, m.predict(val_xs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It looks like this is an improvement, although again it's a bit hard to tell with small datasets like this. Let's try submitting it to Kaggle:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:38:04.158937Z", "iopub.status.busy": "2022-05-23T04:38:04.158557Z", "iopub.status.idle": "2022-05-23T04:38:04.183949Z", "shell.execute_reply": "2022-05-23T04:38:04.182841Z", "shell.execute_reply.started": "2022-05-23T04:38:04.158902Z" } }, "outputs": [], "source": [ "tst_df[cats] = tst_df[cats].apply(lambda x: x.cat.codes)\n", "tst_xs,_ = xs_y(tst_df)\n", "\n", "def subm(preds, suff):\n", " tst_df['Survived'] = preds\n", " sub_df = tst_df[['PassengerId','Survived']]\n", " sub_df.to_csv(f'sub-{suff}.csv', index=False)\n", "\n", "subm(m.predict(tst_xs), 'tree')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When I submitted this, I got a score of 0.765, which isn't as good as our linear models or most of our neural nets, but it's pretty close to those results.\n", "\n", "Hopefully you can now see why we didn't really need to create dummy variables, but instead just converted the labels into numbers using some (potentially arbitary) ordering of categories. For instance, here's how the first few items of `Embarked` are labeled:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:40:33.30121Z", "iopub.status.busy": "2022-05-23T04:40:33.300898Z", "iopub.status.idle": "2022-05-23T04:40:33.310952Z", "shell.execute_reply": "2022-05-23T04:40:33.309927Z", "shell.execute_reply.started": "2022-05-23T04:40:33.301175Z" } }, "outputs": [ { "data": { "text/plain": [ "0 S\n", "1 C\n", "2 S\n", "3 S\n", "4 S\n", "Name: Embarked, dtype: category\n", "Categories (3, object): ['C', 'Q', 'S']" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.Embarked.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...resulting in these integer codes:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:41:09.343513Z", "iopub.status.busy": "2022-05-23T04:41:09.34246Z", "iopub.status.idle": "2022-05-23T04:41:09.351674Z", "shell.execute_reply": "2022-05-23T04:41:09.350699Z", "shell.execute_reply.started": "2022-05-23T04:41:09.343456Z" } }, "outputs": [ { "data": { "text/plain": [ "0 2\n", "1 0\n", "2 2\n", "3 2\n", "4 2\n", "dtype: int8" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.Embarked.cat.codes.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So let's say we wanted to split into \"C\" in one group, vs \"Q\" or \"S\" in the other group. Then we just have to split on codes `<=0` (since `C` is mapped to category `0`). Note that if we wanted to split into \"Q\" in one group, we'd need to use two binary splits, first to separate \"C\" from \"Q\" and \"S\", and then a second split to separate \"Q\" from \"S\". For this reason, sometimes it can still be helpful to use dummy variables for categorical variables with few levels (like this one).\n", "\n", "In practice, I often use dummy variables for <4 levels, and numeric codes for >=4 levels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The random forest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can't make the decision tree much bigger than the example above, since some leaf nodes already have only 50 rows in them. That's not a lot of data to make a prediction.\n", "\n", "So how could we use bigger trees? One big insight came from Leo Breiman: what if we create lots of bigger trees, and take the average of their predictions? Taking the average prediction of a bunch of models in this way is known as [bagging](https://link.springer.com/article/10.1007/BF00058655).\n", "\n", "The idea is that we want each model's predictions in the averaged ensemble to be uncorrelated with each other model. That way, if we average the predictions, the average will be equal to the true target value -- that's because the average of lots of uncorrelated random errors is zero. That's quite an amazing insight!\n", "\n", "One way we can create a bunch of uncorrelated models is to train each of them on a different random subset of the data. Here's how we can create a tree on a random subset of the data:" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:51:20.332256Z", "iopub.status.busy": "2022-05-23T04:51:20.331981Z", "iopub.status.idle": "2022-05-23T04:51:20.338451Z", "shell.execute_reply": "2022-05-23T04:51:20.337562Z", "shell.execute_reply.started": "2022-05-23T04:51:20.332229Z" } }, "outputs": [], "source": [ "def get_tree(prop=0.75):\n", " n = len(trn_y)\n", " idxs = random.choice(n, int(n*prop))\n", " return DecisionTreeClassifier(min_samples_leaf=5).fit(trn_xs.iloc[idxs], trn_y.iloc[idxs])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can create as many trees as we want:" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:51:23.113575Z", "iopub.status.busy": "2022-05-23T04:51:23.113271Z", "iopub.status.idle": "2022-05-23T04:51:23.392216Z", "shell.execute_reply": "2022-05-23T04:51:23.391179Z", "shell.execute_reply.started": "2022-05-23T04:51:23.113547Z" } }, "outputs": [], "source": [ "trees = [get_tree() for t in range(100)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our prediction will be the average of these trees' predictions:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:51:24.13079Z", "iopub.status.busy": "2022-05-23T04:51:24.130425Z", "iopub.status.idle": "2022-05-23T04:51:24.27515Z", "shell.execute_reply": "2022-05-23T04:51:24.274245Z", "shell.execute_reply.started": "2022-05-23T04:51:24.130754Z" } }, "outputs": [ { "data": { "text/plain": [ "0.2272645739910314" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_probs = [t.predict(val_xs) for t in trees]\n", "avg_probs = np.stack(all_probs).mean(0)\n", "\n", "mean_absolute_error(val_y, avg_probs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is nearly identical to what `sklearn`'s `RandomForestClassifier` does. The main extra piece in a \"real\" random forest is that as well as choosing a random sample of data for each tree, it also picks a random subset of columns for each split. Here's how we repeat the above process with a random forest:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:51:24.302789Z", "iopub.status.busy": "2022-05-23T04:51:24.302442Z", "iopub.status.idle": "2022-05-23T04:51:24.586127Z", "shell.execute_reply": "2022-05-23T04:51:24.584999Z", "shell.execute_reply.started": "2022-05-23T04:51:24.302754Z" } }, "outputs": [ { "data": { "text/plain": [ "0.18834080717488788" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "rf = RandomForestClassifier(100, min_samples_leaf=5)\n", "rf.fit(trn_xs, trn_y);\n", "mean_absolute_error(val_y, rf.predict(val_xs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can submit that to Kaggle too:" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:51:26.868302Z", "iopub.status.busy": "2022-05-23T04:51:26.867535Z", "iopub.status.idle": "2022-05-23T04:51:26.902301Z", "shell.execute_reply": "2022-05-23T04:51:26.901088Z", "shell.execute_reply.started": "2022-05-23T04:51:26.868252Z" } }, "outputs": [], "source": [ "subm(rf.predict(tst_xs), 'rf')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I found that gave nearly an identical result as our single tree (which, in turn, was slightly lower than our linear and neural net models in the previous notebook)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One particularly nice feature of random forests is they can tell us which independent variables were the most important in the model, using `feature_importances_`:" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "execution": { "iopub.execute_input": "2022-05-23T04:53:08.654508Z", "iopub.status.busy": "2022-05-23T04:53:08.654163Z", "iopub.status.idle": "2022-05-23T04:53:08.906439Z", "shell.execute_reply": "2022-05-23T04:53:08.904724Z", "shell.execute_reply.started": "2022-05-23T04:53:08.654445Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pd.DataFrame(dict(cols=trn_xs.columns, imp=m.feature_importances_)).plot('cols', 'imp', 'barh');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that `Sex` is by far the most important predictor, with `Pclass` a distant second, and `LogFare` and `Age` behind that. In datasets with many columns, I generally recommend creating a feature importance plot as soon as possible, in order to find which columns are worth studying more closely. (Note also that we didn't really need to take the `log()` of `Fare`, since random forests only care about order, and `log()` doesn't change the order -- we only did it to make our graphs earlier easier to read.)\n", "\n", "For details about deriving and understanding feature importances, and the many other important diagnostic tools provided by random forests, take a look at [chapter 8](https://github.com/fastai/fastbook/blob/master/08_collab.ipynb) of [our book](https://www.amazon.com/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So what can we take away from all this?\n", "\n", "I think the first thing I'd note from this is that, clearly, more complex models aren't always better. Our \"OneR\" model, consisting of a single binary split, was nearly as good as our more complex models. Perhaps in practice a simple model like this might be much easier to use, and could be worth considering. Our random forest wasn't an improvement on the single decision tree at all.\n", "\n", "So we should always be careful to benchmark simple models, as see if they're good enough for our needs. In practice, you will often find that simple models will have trouble providing adequate accuracy for more complex tasks, such as recommendation systems, NLP, computer vision, or multivariate time series. But there's no need to guess -- it's so easy to try a few different models, there's no reason not to give the simpler ones a go too!\n", "\n", "Another thing I think we can take away is that random forests aren't actually that complicated at all. We were able to implement the key features of them in a notebook quite quickly. And they aren't sensitive to issues like normalization, interactions, or non-linear transformations, which make them extremely easy to work with, and hard to mess up!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you found this notebook useful, please remember to click the little up-arrow at the top to upvote it, since I like to know when people have found my work useful, and it helps others find it too. (BTW, be sure you're looking at my [original notebook here](https://www.kaggle.com/jhoward/how-random-forests-work) when you do that, and are not on your own copy of it, otherwise your upvote won't get counted!) And if you have any questions or comments, please pop them below -- I read every comment I receive!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "toc": { "base_numbering": 1, "nav_menu": { "height": "133.002px", "width": "196.553px" }, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }