{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This section demonstrates how to fit bagging, random forest, and boosting models using `scikit-learn`. We will again use the {doc}`penguins ` dataset for classification and the {doc}`tips ` dataset for regression." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "## Import packages\n", "import numpy as np \n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Bagging and Random Forests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Recall that bagging and random forests can handle both classification and regression tasks. For this example we will do classification on the `penguins` dataset. Recall that `scikit-learn` trees do not currently support categorical predictors, so we must first convert those to dummy variables" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "## Load penguins data\n", "penguins = sns.load_dataset('penguins')\n", "penguins = penguins.dropna().reset_index(drop = True)\n", "X = penguins.drop(columns = 'species')\n", "y = penguins['species']\n", "\n", "## Train-test split\n", "np.random.seed(1)\n", "test_frac = 0.25\n", "test_size = int(len(y)*test_frac)\n", "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n", "X_train = X.drop(test_idxs)\n", "y_train = y.drop(test_idxs)\n", "X_test = X.loc[test_idxs]\n", "y_test = y.loc[test_idxs]\n", "\n", "## Get dummies\n", "X_train = pd.get_dummies(X_train, drop_first = True)\n", "X_test = pd.get_dummies(X_test, drop_first = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bagging" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A simple bagging classifier is fit below. The most important arguments are `n_estimators` and `base_estimator`, which determine the number and type of weak learners the bagging model should use. The default `base_estimator` is a decision tree, though this can be changed as in the second example below, which uses Naive Bayes estimators. " ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.963855421686747\n", "0.9156626506024096\n" ] } ], "source": [ "from sklearn.ensemble import BaggingClassifier\n", "from sklearn.naive_bayes import GaussianNB\n", "\n", "## Decision Tree bagger\n", "bagger1 = BaggingClassifier(n_estimators = 50, random_state = 123)\n", "bagger1.fit(X_train, y_train)\n", "\n", "## Naive Bayes bagger\n", "bagger2 = BaggingClassifier(base_estimator = GaussianNB(), random_state = 123)\n", "bagger2.fit(X_train, y_train)\n", "\n", "## Evaluate\n", "print(np.mean(bagger1.predict(X_test) == y_test))\n", "print(np.mean(bagger2.predict(X_test) == y_test))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Random Forests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An example of a random forest in `scikit-learn` is given below. The most important arguments to the random forest are the number of estimators (decision trees), `max_features` (the number of predictors to consider at each split), and any chosen parameters for the decision trees (such as the maximum depth). Guidelines for setting each of these parameters are given below. \n", "\n", "- `n_estimators`: In general, the more base estimators the better, though there are diminishing marginal returns. While increasing the number of base estimators does not risk overfitting, it eventually provides no benefit. \n", "- `max_features`: This argument is set by default to the square root of the number of total features (which is made explicit in the example below). If this value equals the number of total features, we are left with a bagging model. Lowering this value lowers the amount of correlation between trees but also prevents the base estimators from learning potentially valuable information. \n", "- Decision tree parameters: These parameters are generally left untouched. This allows the individual decision trees to grow deep, increasing variance but decreasing bias. The variance is then decreased by the ensemble of individual trees.\n" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9879518072289156\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "rf = RandomForestClassifier(n_estimators = 100, max_features = int(np.sqrt(X_test.shape[1])), random_state = 123)\n", "rf.fit(X_train, y_train)\n", "print(np.mean(rf.predict(X_test) == y_test))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Boosting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{note}\n", "Note that the `AdaBoostClassifier` from `scikit-learn` uses a slightly different algorithm than the one introduced in the {doc}`concept section ` though results should be similar. The `AdaBoostRegressor` class in `scikit-learn` uses the same algorithm we introduced: *AdaBoost.R2*\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### AdaBoost Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `AdaBoostClassifier` in `scikit-learn` is actually able to handle multiclass target variables, but for consistency, let's use the same binary target we did in our AdaBoost construction: whether the penguin's species is *Adelie*." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "## Make binary\n", "y_train = (y_train == 'Adelie')\n", "y_test = (y_test == 'Adelie')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then fit the classifier with the `AdaBoostClassifier` class as below. Again, we first convert categorical predictors to dummy variables. The classifier will by default use 50 decision trees, each with a max depth of 1, for the weak learners. \n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9759036144578314" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "\n", "## Get dummies\n", "X_train = pd.get_dummies(X_train, drop_first = True)\n", "X_test = pd.get_dummies(X_test, drop_first = True)\n", "\n", "## Build model\n", "abc = AdaBoostClassifier(n_estimators = 50)\n", "abc.fit(X_train, y_train)\n", "y_test_hat = abc.predict(X_test)\n", "\n", "## Evaluate \n", "np.mean(y_test_hat == y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A different weak learner can easily be used in place of a decision tree. The below shows an example using logistic regression. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "abc = AdaBoostClassifier(base_estimator = LogisticRegression(max_iter = 1000))\n", "abc.fit(X_train, y_train);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### AdaBoost Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "AdaBoost regression is implemented almost identically in `scikit-learn`. An example with the `tips` dataset is shown below." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "## Load penguins data\n", "tips = sns.load_dataset('tips')\n", "tips = tips.dropna().reset_index(drop = True)\n", "X = tips.drop(columns = 'tip')\n", "y = tips['tip']\n", "\n", "## Train-test split\n", "np.random.seed(1)\n", "test_frac = 0.25\n", "test_size = int(len(y)*test_frac)\n", "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n", "X_train = X.drop(test_idxs)\n", "y_train = y.drop(test_idxs)\n", "X_test = X.loc[test_idxs]\n", "y_test = y.loc[test_idxs]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAFTCAYAAACu19yeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3df5RcZZ3n8c+nu9OxCYHE0MSQH7DOMHHQQYFWceLMIrAuIgPmwLjsDqIelywTRNzDDCo7DrPseBxGVx2HwciPUfAnimRkGGVQgXHMCk6H3xA4AoIJxKQJCXSSNk2nvvtH3e6pVKq6q7qr6t6q+36d0yf14+beb9+C+uR57vM81xEhAADypivtAgAASAMBCADIJQIQAJBLBCAAIJcIQABALhGAAIBcIgABALlEAAJoG7ZPtX1q2nWgM5iJ8ADage1DJN2ePP1PEbEtzXrQ/ghAAG3B9t9JWiupW9LpEXFByiWhzRGAQE7YflrSf4+IH6ZdC5AFXANEbtjeWfJTsD1S8vyPprnPp22fPMU2b7X9/2y/aPsF2+tsv3F6vwWARulJuwCgVSLiwPHHrWoN2T5I0q2S/ljStyT1Svo9SXuaeVwAU6MFCCRsH2b7O7aHbP/C9odK3vuI7WdtD9t+3PZJtr8iaZmkf0xakZdU2O1vSVJEfCMi9kbESETcHhEPluz7o7afTPb9qO2VJe89bftPbT9oe5ft62wvtP39ZPsf2p5ftv3Hkv1st/0l26+o9/ct2+5A23ttLyp57XW2N9ueW7btR23fVPba39j+fLXzWPHDmObxgbpEBD/85O5H0tOSTi553iVpvaQ/V7GV9mpJT0n6z5KWS9oo6bBk2yMk/Ual/VQ4zkGStkm6XtI7JM2vsM0fSjosqeG/SNolaVHJ/u+WtFDSYklbJd0r6RhJsyXdIemyst/rYUlLJb1S0jpJf1le62S/b5Xf4xFJ7yx5fqukCytsd7ik3ZIOSp53S9os6fjJzmMNn1dNx+eHn3p+aAECRW+U1B8Rl0fEaEQ8JekaSWdL2qti2Bxle1ZEPB0RT9ay04h4SdJbJUWyvyHbt9heWLLNtyPiuYgoRMSNkn4u6U0lu/nbiNgSEc9K+ldJ90TEfRGxR8VRkceUHfbKiNgYES9I+oSk/1rn71vJv0k6VpJs/76koyR9scLv+4yKAf2u5KUTJe2OiLs1g/NY6/GBehCAQNHhkg6zvWP8R9KlkhZGxBOSPizpLyRttf1N24fVuuOI2BAR74uIJZJep2Jr73Pj79s+1/b9Jcd9naRDSnaxpeTxSIXnB2pfG0seP5Mcr+bft8qvMRFAkv5a0scjYrTKtl/Xv4fuf0uea4bnsZ7jAzUhAIGijZJ+ERHzSn7mRsSpkhQRX4+It6oYHCHpiuTv1TWPKCIek/RlFUNOtg9XseX1QUkLImKeil2YnsHvsrTk8TJJz1XYZtLft4J/k3Ss7TMl9Un6xiTH/7akE2wvkbRSSQBKk57HqdRzfKAmBCBQ9DNJLyWDNPpsdycDLd5oe7ntE23PlvRrFVtde5O/t0XF62cV2X6N7YuTMJDtpSq2ju5ONpmjYhAMJe+/X0k4zsAFtpfYfqWKrbob6/l9q+zzAUmvkvR/JX00IgrVDh4RQ5LukvQlFUN2gyRNcR6nUvPxgVoRgICkiNgr6Q8kvUHSLyQ9L+laSQereN3qr5LXfiXpUBWDRZI+KenPkm7EP6mw62FJb5Z0j+1dKgbfw5IuTo77qIpf6j9VMUx/R8WBKzPxdRWXDHsq+fnLOn/f/STXGx+S9HREfL/GGk5WSetPk5zHZFTrpeU7mcHxgSmxEgzQQZo1v9F2r6QnJL07GdDSUmkfH52JFiCAWlwmaV2K4ZP28dGBCEAAVdk+1vaLkn5f0oV5Oz46G12gAIBcogUIAMglAhAAkEsdczeIU045JW677ba0ywAAZEvVRSU6pgX4/PPPp10CAKCNdEwAAgBQDwIQAJBLBCAAIJdSGQSTLNc0rOJCuGMRMVD2/gmSvqviGoWSdHNEXN7KGgEAnS3NUaBvi4jJRq78a0Sc1rJqAAC5QhcoACCX0grAkHS77fW2V1XZ5i22H0huk/LaVhYHAOh8aXWBroiI52wfKukHth+LiB+XvH+vpMMjYqftUyX9g6Qjy3eShOcqSVq2bFkr6gYAdIhUWoAR8Vzy51ZJayW9qez9lyJiZ/L4e5Jm2T6kwn6ujoiBiBjo7+9vQeVA6xQKoaHhPXp2+24NDe9RocDC9UAjtbwFaHuOpK6IGE4ev13S5WXbvErSlogI229SMai3tbpWIC2FQujxLcM674ZBbdo+oiXz+3TNuQNavnCuurqqruwEoA5ptAAXSvqJ7Qck/UzSP0XEbbbPt31+ss1Zkh5Otvm8pLOD+zYhR7btGp0IP0natH1E590wqG27RlOuDOgcLW8BRsRTkl5f4fU1JY+vlHRlK+sCsmR0bO9E+I3btH1Eo2N7U6oI6DxMgwAyqLenW0vm9+3z2pL5fert6U6pIqDzEIBABi2Y06trzh2YCMHxa4AL5vSmXBnQOTrmfoBAJ+nqspYvnKu1q1dodGyvenu6tWBOLwNggAYiAIGM6uqy+ufOTrsMoGPRBQoAyCUCEACQSwQgACCXCEAAQC4RgACAXCIAAQC5RAACAHKJAAQA5BIBCADIJQIQAJBLLIUGNEihENq2a5S1O4E2QQACDcAd3IH2Qxco0ADcwR1oP7QAgQaY7h3c6TYF0kMAAg0wfgf30hCc6g7udJsC6aILFGiA6dzBnW5TIF20AIEGmM4d3KfbbQqgMQhAoEHqvYP7dLpNATQOXaBASqbTbQqgcWgBAimZTrcpgMYhAIEU1dttCqBxUukCtf207Yds3297sML7tv1520/YftD2sWnUCQDoXGm2AN8WEc9Xee8dko5Mft4s6QvJnwAANERWB8GcIemGKLpb0jzbi9IuCgDQOdIKwJB0u+31tldVeH+xpI0lzzclrwEA0BBpdYGuiIjnbB8q6Qe2H4uIH5e8X2kYXJS/kITnKklatmxZcyoFAHSkVFqAEfFc8udWSWslvalsk02SlpY8XyLpuQr7uToiBiJioL+/v1nlAgA6UMsD0PYc23PHH0t6u6SHyza7RdK5yWjQ4yW9GBGbW1wqAKCDpdEFulDSWtvjx/96RNxm+3xJiog1kr4n6VRJT0jaLen9KdQJAOhgLQ/AiHhK0usrvL6m5HFIuqCVdQEA8iWr0yAAAGgqAhAAkEsEIAAglwhAAEAucTcIAEBmFAqhbbtGW3KLMAIQAJAJhULo8S3DOu+GQW3aPjJxk+jlC+c2JQTpAgUAZMK2XaMT4SdJm7aP6LwbBrVt12hTjkcAAgAyYXRs70T4jdu0fUSjY3ubcjwCEACQCb093Voyv2+f15bM71NvT3dTjkcAAgAyYcGcXl1z7sBECI5fA1wwp7cpx2MQDAAgE7q6rOUL52rt6hWMAgUA5EtXl9U/d3ZrjtWSowAAkDG0AIEUtXLSL4B9EYBASlo96RfAvugCBVLS6km/APZFAAIpafWkXwD7IgCBlLR60i+AfRGAQEpaPekXwL4YBAPUoZGjNls96RfAvghAoEbNGLXZykm/APZFFyhQI0ZtAp2FAARqxKhNoLMQgECNGLUJdBYCEKgRozaBzpLaIBjb3ZIGJT0bEaeVvfc+SZ+S9Gzy0pURcW1rKwT2xahNoLOkOQr0IkkbJB1U5f0bI+KDLawHmBKjNoHOkUoXqO0lkt4piVYdACAVaV0D/JykSyQVJtnmTNsP2r7J9tIW1QUAyImWB6Dt0yRtjYj1k2z2j5KOiIijJf1Q0vVV9rXK9qDtwaGhoSZUCwDoVI6I1h7Q/qSk90gak/QKFa8B3hwR51TZvlvSCxFx8GT7HRgYiMHBwUaXCwBob1VHqbW8BRgRH4uIJRFxhKSzJd1RHn62F5U8PV3FwTIAADRMZtYCtX25pMGIuEXSh2yfrmIr8QVJ70uzNgBA52l5F2iz0AUKAKggO12gAABkAQEIAMglAhAAkEsEIAAglwhAAEAuEYAAgFwiAAEAuUQAAgByiQAEAOQSAQgAyCUCEACQSwQgACCXCEAAQC4RgACAXCIAAQC5RAACAHKJAAQA5BIBCADIJQIQAJBLBCAAIJcIQABALhGAAIBcIgABALlEAAIAcokABADkUmoBaLvb9n22b63w3mzbN9p+wvY9to9ofYUAgE6WZgvwIkkbqrz3AUnbI+I3JX1W0hUtqwoAkAupBKDtJZLeKenaKpucIen65PFNkk6y7VbUBgDIh7RagJ+TdImkQpX3F0vaKEkRMSbpRUkLWlMaACAPWh6Atk+TtDUi1k+2WYXXosK+VtketD04NDTUsBoBAJ0vjRbgCkmn235a0jclnWj7q2XbbJK0VJJs90g6WNIL5TuKiKsjYiAiBvr7+5tbNQCgo7Q8ACPiYxGxJCKOkHS2pDsi4pyyzW6R9N7k8VnJNvu1AAEAmK6etAsYZ/tySYMRcYuk6yR9xfYTKrb8zk61ONSlUAht2zWq0bG96u3p1oI5verqYgwTgGxxpzSsBgYGYnBwMO0ycq9QCD2+ZVjn3TCoTdtHtGR+n645d0DLF84lBAGkoeoXDyvBoKG27RqdCD9J2rR9ROfdMKhtu0ZTrgwA9kUAoqFGx/ZOhN+4TdtHNDq2N6WKAKAyAhAN1dvTrSXz+/Z5bcn8PvX2dKdUEQBURgCioRbM6dU15w5MhOD4NcAFc3pTrgwA9pWZUaDoDF1d1vKFc7V29QpGgQLINAIQDdfVZfXPnZ12GQAwKbpAAQC5RAACAHKJAAQA5BIBCADIJQIQAJBLBCAAIJeYBtEmsnyHhSzX1k6acR75bIDqCMA2kOU7LGS5tnbSjPPIZwNMji7QNpDlOyxkubZ20ozzyGcDTI4AbANZvsNClmtrJ804j3w2wOQIwBoUCqGh4T16dvtuDQ3vUaHQ2psIZ/kOC1murZ004zzy2QCTIwCnMH4dZeVV67Tiiju18qp1enzLcEtDMMt3WJjfN0trzjlun9rWnHOc5vfNSrmy9tKMzzjL/90AWeCI2r/Ibf9Q0sUR8UDzSpqegYGBGBwcbPh+h4b3aOVV6/bpSloyv09rV69o6YLPY2MFbd25R2N7C+rp7tKhB85WT0/6/34ZGt6j/7X2QZ153FLN65ulHSMv6zvrN+oTK49uywWx0xw1yShQoCmq/gc/6ShQ20dJujQizkleukTSZ20/k7y+uXE1ZlMWrqMUCqGfD+3M5Gi+0bG9uv3Rrbr90a37vH7ZH9R+frLyJZ32qMlm3EWDO3MA1U3VhPiRpD8bfxIR90bEiZJulXSb7cts91X92x0gC9dRsjyab6bnJwtdzOOyfJ4BNN5UAfh2SZ8ofcG2JT0u6QuSLpT0c9vvaU556cvCdZQstEKrmen5yVLoZPk8A2i8SbtAI+IhSX80/tz2TyS9WtIjku6W9D5Jj0m6yPbvRcSq5pWajizc4Xy8lVV+HTILo/lmen6yFDpZPs8AGq/elWDOl/RI7D9y5kLbGxpUU+akfR1lvJVVfm0qK6P5ZnJ+shQ6aZznrFz/BPKorlGgk+7IfnVEPNWQnU1Ds0aBZkWnflGmPfCkUj2tOs9Z+92BDlX1f6aGBWCtbL9C0o8lzVaxBXpTRFxWts37JH1K0rPJS1dGxLWT7bfTA7CTdWq4TyUrU2yADje9aRBNskfSiRGx0/YsST+x/f2IuLtsuxsj4oMp1IcWS7uLOS1Zuv4J5FHLZ1JH0c7k6azkp/Vj3oGUZWGKDZBnqSwlYrvb9v2Stkr6QUTcU2GzM20/aPsm20tbXCLQdFmYYgPkWcuvAe5zcHuepLWSLoyIh0teXyBpZ0TssX2+pHcnE/DL//4qSaskadmyZcc988wzLaocaIy8Xv8EWig7g2D2K8C+TNKuiPh0lfe7Jb0QEQdPth8GwQAAKqgagC3vArXdn7T8lCyjdrKKk+lLt1lU8vR0SR07xxAAkI40RoEuknR90rLrkvStiLjV9uWSBiPiFkkfsn26pDFJL6i44gwAAA2Tehdoo9AFCgCoIDtdoAAAZAEBCADIJQIQAJBLBCAAIJcIQABALhGAAIBcIgABALmUxkR4AC3GmqPA/ghAZA5f1o2VlTvP87kiawjANpHlL49G1pb2l3WWz/N0bds1OnE+peJNd8+7YbCld55P+3MFKuEaYBsY//JYedU6rbjiTq28ap0e3zKsQiH9ZewaXVu1L+tfvfRrPbt9t4aG9zTt987yeZ6JLNx5vtrnum3XaMtqAMoRgG0gy18eja6t2pf1cztGmh5KWT7PM5GFO89nIYSBcgRgG8jyl0eja6v2ZT0eQs0MpSyf55nIwp3nsxDCQDkCsA1k+cuj0bVV+rK+4syjteauJye2aVYoZfk8z0RXl7V84VytXb1C6z7yNq1dvaLl196yEMJAOW6H1AayPICgGbWVDkSxrb+45WHd/ujWifeXzO9rygCOLJ/nTtCJA4zQFqr+R0YAtoksf3k0urbS/fX1dmvH7pf1zLbdOqC3W7tH9+rwBQfoiAVz6j5GLXW2+jxn+XMFOkTV/6GYBtEmurrcsiHr9WpkbeWtsLcfdag+dNJv6ePffXifVtlM91utddfK80yLE0gX1wCRKeUjMc88bqnO/+r6GY/MzOIIzyzWBOQJLcA2kZeusvKRmPP6ZjVkZGYWR3hmsSYgT2gBtoFOnaBdSflIzB0jLzdkZGYWR3hmsSYgTwjANpCnrrLy4fLfWb9Ra845bsbD57M4DD+LNQF5wijQNvDs9t1accWd+72+7iNv0+L5B6RQUXOVd/fO75ul7SMvz7j7N4vdyFmsCegwjAJtZ+NdZaXXizq5q6zSSMxGjMzM4kjaLNYE5AVdoG0gi11lhUJoaHhP0xeoBoBmoQXYBkqXsspCVxnz1wB0gpa3AG2/wvbPbD9g+xHb/7vCNrNt32j7Cdv32D6i1XVmzXhX2eL5B6h/7uxUgyZPg3IAdK40ukD3SDoxIl4v6Q2STrF9fNk2H5C0PSJ+U9JnJV3R4hoxCeavAegELQ/AKNqZPJ2V/JRfQDpD0vXJ45sknWSbvrWMYP4agE6QyiAY292275e0VdIPIuKesk0WS9ooSRExJulFSQtaWyWqyeKgHACoVyqDYCJir6Q32J4naa3t10XEwyWbVGrt7TfM0PYqSaskadmyZU2pFfvL2qAcAJiOVKdBRMQOSXdJOqXsrU2SlkqS7R5JB0t6ocLfvzoiBiJioL+/v8nVolSWBuUAwHSkMQq0P2n5yXafpJMlPVa22S2S3ps8PkvSHdEpS9YAADIhjS7QRZKut92tYgB/KyJutX25pMGIuEXSdZK+YvsJFVt+Z6dQJ5qMZcAApIm1QJEKJtMDaJGqXygshYZUMJkeQNpYCi2n0u5+zOpk+rTPC4DWIQBzKAvdj1m8w0UWzguA1qELtESW73DQqNoKhdCvXvp1Q7sfp1PbZJPp0/oc6JYF8oUWYCLL//pvVG3j+9m1Z6xh3Y/11lbaxbjwoNm6efXv6uWxwkR3o6TUPoesdssCaA5agIks/+u/UbWN72fbrtGGreVZT23jYbnyqnVaccWdOv3Kddq2c1SLDu6bmEyf5ufAGqdAvhCAiSz/6386tVXqRhzfz5q7ntQVZx7dkLU866mtlnBL83NgjVMgX+gCTWRxUMa4emur1i254MBeLZnfp/s27tCn//lxffy0o7RgTq/mHdCr6fYu1lNbLeGW5ufAGqdAvtACTGT5X//11latpdXT5Yn93Ldxh/7PrY9qz1hBf/rtB3Tu3/9sWt2M9dRWSxdj2p8Da5wC+cFKMCWyPAesntqe3b5bK664c7/X133kbVp0cJ82vziiTdtHtGPkZa2560ndt3HHxPuL5x/QtNpqHTCT5c8BQNup+uVBF2iJ8X/9Z1E9tU3WjdjVZfX2dOvibz/QsG7GWmurtYsxy58DgM5BF2gHmqobMc1uRroYAWQFXaAt0OguvVr2N9U2jayJLksAGUYXaFoaPcG+1v1N1Y3YqG7GLC8gAACToQu0yRo9sTtrE/azVg8A1IoAbLJGT+zO2oT9rNUDALUiAJus0ctrZW25rqzVAwC1IgCbrNEjLtOeKJ71egCgVowCbYE0RoG2UtbqAYASjAJNU6MndmdtonjW6gGAWtAFCgDIJVqADdYu3YHtUicANAsB2EDtMim8XeoEgGaiC7SB2mVSeLvUCQDNRAA2ULtMCm+XOgGgmQjABmqXSeHtUicANFPLA9D2Utt32t5g+xHbF1XY5gTbL9q+P/n581bXOR3tMim8XeoEgGZq+UR424skLYqIe23PlbRe0rsi4tGSbU6Q9CcRcVqt+83KRPh2GV3J7ZAA5ER2JsJHxGZJm5PHw7Y3SFos6dFJ/2KbaJdJ4dwOCUDepXoN0PYRko6RdE+Ft99i+wHb37f92ip/f5XtQduDQ0NDTawU1TCiFEC7Si0AbR8o6TuSPhwRL5W9fa+kwyPi9ZL+VtI/VNpHRFwdEQMRMdDf39/cglERI0oBtKtUAtD2LBXD72sRcXP5+xHxUkTsTB5/T9Is24e0uEzUgBGlANpVGqNALek6SRsi4jNVtnlVsp1sv0nFOre1rkrUihGlANpVGkuhrZD0HkkP2b4/ee1SScskKSLWSDpL0h/bHpM0Iuns6JT7NjVJWiMxu7qs5Qvnau3qFYwCBdBWuB9gB2AkJgBUVfVLkJVgOgAjMQGgfgRgB2AkJgDUj9shpaD0et2sni71dFkjo9Wvn011fW98JGZpCNY7EpPVXADkDQHYYpWu133qrKP117c9rqGde/a7dlfL9b3xkZjl29Q6EpNriADyiEEwLTY0vEcrr1q3X2vt46cdpf/xlfVaMr9Pa1evmFimrNr2pdtIM2vB1XoMAGhD2VkLNO+qXa+b1zdr4nHptbtar+/NZG1PriECyCMGwbRYtZVTdoy8PPG49NpdK1ZaYTUXAHlEALZYpZVTPnXW0Vpz15MVr921YqUVVnMBkEdcA0xBo0eBNromRoEC6CBcA8ySitfr5tS5fStqAoAORhcoACCXCEAAQC4RgACAXOIaYIMxmAQA2gMB2EAsKQYA7YMu0AZq5m2JxsYKem7HiJ7ZtkvP7RjR2FhhxvsEgDyjBdhA9S4pVmt36dhYQY9tGdb5X10/0bJcc85xes3Cuerp4d8wADAdfHs2UD1Lio13l668ap1WXHGnVl61To9vGVahsP/CBFt37pkIP6kYqud/db227tzTnF8EAHKAAGygepYUq6e79OW9hYoty7G9dIMCwHTRBdpAXV3W8oVztXb1iim7NevpLp3V3VXxhrc93fz7BQCmi2/QBhtfUmzx/APUP3d21dGf9XSXHnrgbK0557h9WpZrzjlOhx7I0mUAMF0shp2SeqdMjI0VtHXnHo3tLainu0uHHjibATAAMLWqc9AIwBQxaR4Amo67QWRRPXdgICwBoLEIwCYZ77J8eW9Bs2rospws4FhhBgAar+UXkWwvtX2n7Q22H7F9UYVtbPvztp+w/aDtY1td50yMT1x/9xd/qv/4qbv07i/+VI9tGa66estUcwKbucIMAORVGqMoxiRdHBG/Lel4SRfYPqpsm3dIOjL5WSXpC60tcWbqnbg+VcDVu8IMAGBqLQ/AiNgcEfcmj4clbZC0uGyzMyTdEEV3S5pne1GLS522eieuTxVw9UyZAADUJtVx9LaPkHSMpHvK3losaWPJ803aPyRle5XtQduDQ0NDzSqzbuMT10tNNnF9qoCrZ4UZAEBtUpsGYftASf8i6RMRcXPZe/8k6ZMR8ZPk+Y8kXRIR66vtL0vTIOpdvLqWQS6MAgWAacnWPEDbsyTdKumfI+IzFd7/oqS7IuIbyfPHJZ0QEZur7TNLASjVP3GdgAOApsjOPEDblnSdpA2Vwi9xi6QP2v6mpDdLenGy8Muinp4uHTavb+oNE/XMCQQAzFwa8wBXSHqPpIds35+8dqmkZZIUEWskfU/SqZKekLRb0vtTqBMA0MFaHoDJdb1J+/ai2C97QWsqAgDkEaspAwByiQAEAOQSAQgAyCUWwy7BVAQAyA8CMMEdFwAgX+gCTXDHBQDIFwIwwR0XACBfCMAEd1wAgHwhABPccQEA8oVBMImuLmv5wrlau3oFo0ABIAcIwBIsSA0A+UEXKAAglwhAAEAuEYAAgFwiAAEAuUQAAgByiQAEAOQSAQgAyCUCEACQS46ItGtoCNtDkp5Ju44WOETS82kX0aY4d9PDeZs+zt30NercPR8Rp1R6o2MCMC9sD0bEQNp1tCPO3fRw3qaPczd9rTh3dIECAHKJAAQA5BIB2H6uTruANsa5mx7O2/Rx7qav6eeOa4AAgFyiBQgAyCUCsA3YXmr7TtsbbD9i+6K0a2o3trtt32f71rRraSe259m+yfZjyX9/b0m7pnZh+38m/78+bPsbtl+Rdk1ZZfvvbW+1/XDJa6+0/QPbP0/+nN/o4xKA7WFM0sUR8duSjpd0ge2jUq6p3VwkaUPaRbShv5F0W0S8RtLrxTmsie3Fkj4kaSAiXiepW9LZ6VaVaV+WVD5X76OSfhQRR0r6UfK8oQjANhARmyPi3uTxsIpfQovTrap92F4i6Z2Srk27lnZi+yBJvy/pOkmKiNGI2JFuVW2lR1Kf7R5JB0h6LuV6MisifizphbKXz5B0ffL4eknvavRxCcA2Y/sIScdIuifdStrK5yRdIqmQdiFt5tWShiR9Kek+vtb2nLSLagcR8aykT0v6paTNkl6MiNvTrartLIyIzVKxESDp0EYfgABsI7YPlPQdSR+OiJfSrqcd2D5N0taIWJ92LW2oR9Kxkr4QEcdI2qUmdEN1ouR61RmS/oOkwyTNsX1OulWhHAHYJmzPUjH8vhYRN6ddTxtZIel0209L+qakE21/Nd2S2sYmSZsiYry34SYVAxFTO1nSLyJiKCJelnSzpN9NuaZ2s8X2IklK/tza6AMQgG3AtlW8DrMhIj6Tdj3tJCI+FhFLIuIIFQch3BER/Eu8BhHxK0kbbS9PXjpJ0qMpltROftG+9WMAAAFhSURBVCnpeNsHJP//niQGENXrFknvTR6/V9J3G32AnkbvEE2xQtJ7JD1k+/7ktUsj4nsp1oR8uFDS12z3SnpK0vtTrqctRMQ9tm+SdK+Ko7jvE6vCVGX7G5JOkHSI7U2SLpP0V5K+ZfsDKv6D4g8bflxWggEA5BFdoACAXCIAAQC5RAACAHKJAAQA5BIBCADIJQIQAJBLBCAAIJcIQKBD2P4d2+tKnh9r+440awKyjInwQIew3aXiLXcWR8Re23eqeB/Je1MuDcgklkIDOkREFGw/Ium1to+U9EvCD6iOAAQ6y90qrh27WvvfYRtACQIQ6Cx3S/qypL9LbsoKoAquAQIdJOn6/BdJR0bErrTrAbKMUaBAZ7lI0scIP2BqBCDQAWz/hu3HJPVFxPVp1wO0A7pAAQC5RAsQAJBLBCAAIJcIQABALhGAAIBcIgABALlEAAIAcokABADkEgEIAMil/w+EU7WztR3ubAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.ensemble import AdaBoostRegressor\n", "\n", "## Get dummies\n", "X_train = pd.get_dummies(X_train, drop_first = True)\n", "X_test = pd.get_dummies(X_test, drop_first = True)\n", "\n", "## Build model\n", "abr = AdaBoostRegressor(n_estimators = 50)\n", "abr.fit(X_train, y_train)\n", "y_test_hat = abr.predict(X_test)\n", "\n", "## Visualize predictions\n", "fig, ax = plt.subplots(figsize = (7, 5))\n", "sns.scatterplot(y_test, y_test_hat)\n", "ax.set(xlabel = r'$y$', ylabel = r'$\\hat{y}$', title = r'Test Sample $y$ vs. $\\hat{y}$')\n", "sns.despine()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }