{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression\n", "> A Summary of lecture \"Supervised Learning with scikit-learn\", via datacamp\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Datacamp, Machine_Learning]\n", "- image: images/ridge_cv.png" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction to Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Importing data for supervised learning\n", "In this chapter, you will work with Gapminder data that we have consolidated into one CSV file available in the workspace as ```'gapminder.csv'```. Specifically, your goal will be to use this data to predict the life expectancy in a given country based on features such as the country's GDP, fertility rate, and population. As in Chapter 1, the dataset has been preprocessed.\n", "\n", "Since the target variable here is quantitative, this is a regression problem. To begin, you will fit a linear regression with just one feature: ```'fertility'```, which is the average number of children a woman in a given country gives birth to. In later exercises, you will use all the features to build regression models.\n", "\n", "Before that, however, you need to import the data and get it into the form needed by scikit-learn. This involves creating feature and target variable arrays. Furthermore, since you are going to use only one feature to begin with, you need to do some reshaping using NumPy's ```.reshape()``` method. Don't worry too much about this reshaping right now, but it is something you will have to do occasionally when working with scikit-learn so it is useful to practice." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dimensions of y before reshaping: (139,)\n", "Dimensions of X before reshaping: (139,)\n", "Dimensions of y after reshaping: (139, 1)\n", "Dimensions of X after reshaping: (139, 1)\n" ] } ], "source": [ "# Read the CSV file into a DataFrame: df\n", "df = pd.read_csv('./dataset/gm_2008_region.csv')\n", "df.drop(labels=['Region'], axis='columns', inplace=True)\n", "\n", "# Create arrays for features and target variable\n", "y = df['life'].values\n", "X = df['fertility'].values\n", "\n", "# Print the dimensions of X and y before reshaping\n", "print(\"Dimensions of y before reshaping: {}\".format(y.shape))\n", "print(\"Dimensions of X before reshaping: {}\".format(X.shape))\n", "\n", "# Reshape X and y\n", "y = y.reshape(-1, 1)\n", "X = X.reshape(-1, 1)\n", "\n", "# Print the dimensions of X and y after reshaping\n", "print(\"Dimensions of y after reshaping: {}\".format(y.shape))\n", "print(\"Dimensions of X after reshaping: {}\".format(X.shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exploring the Gapminder data\n", "As always, it is important to explore your data before building models. On the right, we have constructed a heatmap showing the correlation between the different features of the Gapminder dataset. Cells that are in green show positive correlation, while cells that are in red show negative correlation. Take a moment to explore this: Which features are positively correlated with life, and which ones are negatively correlated? Does this match your intuition?" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.heatmap(df.corr(), square=True, cmap='RdYlGn')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
populationfertilityHIVCO2BMI_maleGDPBMI_femalelifechild_mortality
count1.390000e+02139.000000139.000000139.000000139.000000139.000000139.000000139.000000139.000000
mean3.549977e+073.0051081.9156124.45987424.62305416638.784173126.70191469.60287845.097122
std1.095121e+081.6153544.4089746.2683492.20936819207.2990834.4719979.12218945.724667
min2.773150e+051.2800000.0600000.00861820.397420588.000000117.37550045.2000002.700000
25%3.752776e+061.8100000.1000000.49619022.4481352899.000000123.23220062.2000008.100000
50%9.705130e+062.4100000.4000002.22379625.1569909938.000000126.51960072.00000024.000000
75%2.791973e+074.0950001.3000006.58915626.49757523278.500000130.27590076.85000074.200000
max1.197070e+097.59000025.90000048.70206228.456980126076.000000135.49200082.600000192.000000
\n", "
" ], "text/plain": [ " population fertility HIV CO2 BMI_male \\\n", "count 1.390000e+02 139.000000 139.000000 139.000000 139.000000 \n", "mean 3.549977e+07 3.005108 1.915612 4.459874 24.623054 \n", "std 1.095121e+08 1.615354 4.408974 6.268349 2.209368 \n", "min 2.773150e+05 1.280000 0.060000 0.008618 20.397420 \n", "25% 3.752776e+06 1.810000 0.100000 0.496190 22.448135 \n", "50% 9.705130e+06 2.410000 0.400000 2.223796 25.156990 \n", "75% 2.791973e+07 4.095000 1.300000 6.589156 26.497575 \n", "max 1.197070e+09 7.590000 25.900000 48.702062 28.456980 \n", "\n", " GDP BMI_female life child_mortality \n", "count 139.000000 139.000000 139.000000 139.000000 \n", "mean 16638.784173 126.701914 69.602878 45.097122 \n", "std 19207.299083 4.471997 9.122189 45.724667 \n", "min 588.000000 117.375500 45.200000 2.700000 \n", "25% 2899.000000 123.232200 62.200000 8.100000 \n", "50% 9938.000000 126.519600 72.000000 24.000000 \n", "75% 23278.500000 130.275900 76.850000 74.200000 \n", "max 126076.000000 135.492000 82.600000 192.000000 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 139 entries, 0 to 138\n", "Data columns (total 9 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 population 139 non-null float64\n", " 1 fertility 139 non-null float64\n", " 2 HIV 139 non-null float64\n", " 3 CO2 139 non-null float64\n", " 4 BMI_male 139 non-null float64\n", " 5 GDP 139 non-null float64\n", " 6 BMI_female 139 non-null float64\n", " 7 life 139 non-null float64\n", " 8 child_mortality 139 non-null float64\n", "dtypes: float64(9)\n", "memory usage: 9.9 KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The basics of linear regression\n", "- Regression mechanics\n", " - $y = ax + b$\n", " - $y$ = target\n", " - $x$ = single feature\n", " - $a, b$ = parameters of model\n", " - Define an error functions for any given line\n", " - Choose the line that minimizes the error function\n", "- The loss function\n", " - Ordinary least squares (OLD) : Minimize sum of squares of residuals\n", "- Linear regression in higher dimensions\n", "$$ y = a_1 x_1 + a_2 x_2 + b $$\n", " - To fit a linear regression model here:\n", " - Need to specify 3 variables\n", " - In higher dimensions:\n", " - Must specify coefficient for each feature and the variable b\n", " $ y = a_1x_1 + a_2x_2 + a_3x_3 + \\dots + a_nx_n + b $\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit & predict for regression\n", "Now, you will fit a linear regression and predict life expectancy using just one feature. You saw Andy do this earlier using the 'RM' feature of the Boston housing dataset. In this exercise, you will use the 'fertility' feature of the Gapminder dataset. Since the goal is to predict life expectancy, the target variable here is 'life'." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "X_fertility = df['fertility'].values.reshape(-1, 1)\n", "y = df['life'].values.reshape(-1, 1)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de5jU9ZXn8fepvtHdgN1ig0ZgMI4hwzgogjrKPhmNqyYZZhziNY54S4KMuZmJiZmZ+MRdZvYJatZ1NksATbxEzeh4WbNO1uhkZHcWLwkIIQ4JcUyMgAitdkMDTd/q7B9V1VZ3V3XX7VdVv/p9Xs/D03R1Xb7VDef37fM93/M1d0dERKIjVukBiIhIeSnwi4hEjAK/iEjEKPCLiESMAr+ISMTUV3oAuTjqqKN8zpw5lR6GiEiobNq06W137xh9eygC/5w5c9i4cWOlhyEiEipm9ttMtyvVIyISMQr8IiIRo8AvIhIxCvwiIhGjwC8iEjGhqOophXjceedgP/2DQzTW1zGttZFYzCo9LBGRsotE4I/Hne17evj0/RvZ2dXLzPZm7rpyEXNnTFHwF5HIiUSq552D/cNBH2BnVy+fvn8j7xzsr/DIRETKLxIz/v7BoeGgn7Kzq5f+waEKjShB6ScRqYRIBP7G+jpmtjePCP4z25tprK+rWPBV+klEKiUSqZ5prY3cdeUiZrY3AwwH2fbmBrbv6WHp6g0sXvUcS1dvYPueHuLx4E8lU/pJRColEjP+WMyYO2MKT1y/eMTMPlvwfeL6xXRMaSr6dcf7baJa008iUvsiEfghEfyntTby9sE+egcG2b1viLh7YMF3olTOeOknEZEgRSLVA+8F4o+vfp4vfH8Lv9pzALDh9E9KKYJvPO68tf8wB/sGuXnJPBbMahuTysmWfprW2ljUa4uITKSmZ/zpqRYz445nt9MxuYkbz5/LTY9tpWNyE7ddNJ8vP7p1xKy8mOCbaab/Py4/hZjBm/sOE4/Hgezpp2pZ2FXFkUjtqtnAnykAr7pwPu7OTY8lAv3Orl5ufXo7Ky84keOnT6a5YWSASw9+DfUx6mNGb//4gTDTusFnHnqZe64+lcc27WBmezMdUyYRixmxmJVkLaHUVHEkUttqNtWTKQDf9NhWjj5i0oi8+uYd3Vxz70+pM+iY0jQi6KdX/Hx89fNsf6uHzz60edzqn2yLtvt6B7jqzOO4859/VfWVO6o4EqltNRv4swXgulhuef1Mwe/Lj25lxVnHs7Orlzue3c5b+w+zq+sQnT19wxeB1KLt6Od/52A/Nz22lQsXzsq6eByPO509fWOes9xUcSRS22o28GcLwK2N9ay5YuGIRdU1VyykvblhxH2zBb+25gYWzGrjqjOP45K1L4yp/8+0aLvqwvmsWf8aO7t6mdbaOHyRSQ/07x7sq9iegtGyfe9UcSRSGwIN/Gb2RTP7NzN7xcy+b2aTzOw4M3vJzF41s4fNLJAylmxVM4lxwT1Xn8oT15/JzUvm8fc//hVdvQMjHp8t+HX3DrDirOOH1wlgZCoktWj7yHVn8OiKM7h5yTxu/9F2Nu/oZmZ7M9OnNDGttXFMKulnO/ZVTXpFFUcitS2wxV0zOxb4PDDP3XvN7BHgMuBjwB3u/g9mtgb4JPDtUr/+6KqZhvoYBw4P8vFvPz9isXfN+tfYvKObr//JyDRGKvilL3DedtF8bn16O3/zx783biokFjOOnjqJfb0D3PDwluHHr122kPcd0UwsZnT29I0I9C2NdVWTXqn2iiMRKU7QVT31QLOZDQAtwG7gw8Dlya/fB9xCAIEfGFE109nTx5Xf/cmYxd6bl8xj5VPbxqQxMl046mPGty5fgJmNu/kqVQ10ZEsDj1x3Bu4+4c7d7t6BqtrQVa0VRyJSvMBSPe6+C7gdeINEwN8HbAK63X0webedwLGZHm9my81so5lt7OzsLHo8mXL2HZOb+MCMyTzwydNxfEw+PRX8jm1vYfqUSRzZmvj70VMnjUmFrF22EMPZ2XWIHV2HeK3zACseeJlL1r5A32Acx9m9r3d40XZ0KmnN+te47aL5Sq+ISODMPZjFQzNrBx4DLgW6gX9Mfv51d//d5H1mAT909z8Y77kWLVrkGzduLGo87x7s42c79tHSWEd37wA/3raHpaccO2bzVq616uk1/nGHQ/0DvHtwYMTzffPik3hs086Mr3NCx2Re7TwwIpV0/7WnMXlSPQODcaVXRKRoZrbJ3ReNuT3AwH8x8BF3/2Ty8yuBM4CLgaPdfdDMzgBucffzx3uuYgN/pg1Jq//8FL71L6/yzLa9w/eb2d6cd4O2zp4+Xtm1D4Cbn3xlTKrmnqtP5Zp7fzrm9ieuXzzcKE55dBEJQrbAH2RVzxvAH5pZi5kZcA6wDXgOuCh5n6uAJwMcA5C5Jv/6B1/mwoWzRtyvkMXU/sEhWhrrsi7O1sUs66JteiopffOYiEiQgszxvwQ8CrwM/Dz5WuuAm4C/NLN/B6YB3wlqDCnZavJH588LWUxtrK/jUP8Qh/qHMpZ/pn8s5nVEREol0Dp+d/+6u3/Q3U9092Xu3ufuv3b309z9d939YnfvC3IMkLkm/7x50zn6iEk8uuIM1i5byHnzphe0mDqttZHfmdbCka0NYxZn77jkJGIxWLdsoRZtRaRqBJbjL6VS5/jPmzedz5/zAVY8sGlEjf3c6VOor8//WhiPO929/QwMxjk0MMRQ3Hlr32G++cyv6DzQx2MrziAWiymXLyJllS3HX7PdOdONrsk3My5Z+8KInP9139tU8MlbsZhxZGsTb7x7kLNv/z9jvt435Mw+4r3nTbVqKORCoHbJIlKsSAR+GLkhaVfXoUB2yU5qyHyq1qSG936LKKblsdoli0gp1GyTtvFk68MDFNUV86jWpow9bo5qfW+2X0zLY7VLFpFSiMyMP122PjyffWgznQf6Ct5IlUuPm2JaHqtdsoiUQiQDf3qA7h0Y4rW9B7j16UQHzQWz2tiz/zBXfrewHb0T9bgp5pB1HdAuIqUQyVQPvBeg6wyuufenbN7RDcCKs44fbq8ApU+nFNPyON/HVsvBLiJSXSI54083ehbd1txQcDoll4qbYloe5/NYLQSLSDaRnfGnjJ5FZ9uBO1E6ZfTBKuOdoFVMq4ZcHhuPO2/tP8zBvkFuXjKPBbPatBAsIsMiH/jTZ9Ebbjqbk2YdUVAqploqblIXoEvWvsBFa15g5VPbuPH8ucPBXwvBIhLJVE887rx9sI/DA0PUmdHcODJl0tbcmHcqJpeKm3Jsvsp0ARrvwBkRiZ7IBf5Mue/bLprPjKmTmDOtlVjMCjp9aqKKm3Ll3MdrSKceQSICEUz1ZJoRf/nRrfz2nUNFpWUmqrjJ9Lp3PLudt/YfLmnVTbbNae9ra9bCrogAEZzxZ5sRtzTWFZX/nqjiZvTrLpjVxlVnHjfcM6hUvwFk2px215WLOHrqJAV9EQEiGPizpWQO9Q8Vnf8eL0U0+nVXnHU8Nz02dr9AtkZxua4PFFMuKiLRELlUT6aUzG0Xzed3prUUnf8eb8PU6Ned1tqYdTF49PMMDsbHlIr+Yvd+3j2YOT2kk71EZDyRmPGPni2f0DGZx68/k8MDceoMmhvraGsufFac6se/u/sw16X1+E9P3WRqDZ3pN4+G+tiYReCHPnX6mPWB6x7YxMoLTuToIybllR4aHIyz90AfA0NxGupiTJ/cVNAZBCISXoH9jzezuWa2Je3PfjO7wcxuMbNdabd/LKgxQOaNVa92HuCo1iZmH9nCse0tHNla+Kw49fw/27FvOOhD5jr+9Jn40VMnjfnNY+2yhTTW2Zggv7enL+u6xKfv38hb+w/ntDA8OBjnl8ka/z+6bT2XrH2BX+7pYXAwXtB7F5FwCvLM3e3ufrK7nwwsBA4BTyS/fEfqa+7+w6DGAMFvrEo9f7bD1rMtGMdixgkdk3noU6fz6IozuHnJPO7851+xq+swHZNH5vjfOdifsVKnu3eAnV29vNndm3WXcLq9B/qGTx1LjW/FA5vYeyDw0y9FpIqUK9VzDvCau//WrLz55ng8zs1L5tHW3EB37wBr1r/G5h3dJdvBmqrW6e4dyLtzZlfvAJff/dKIx2zb3cPKC07kmnt/OnzbY5t2sHbZQq773ntppFUXzuf2H21nZnsz7xzs54aHt0x4gtjAUDzjxWlwSDN+kSgpV+C/DPh+2uefNbMrgY3Al9y9a/QDzGw5sBxg9uzZBb1oYoduPyuf2jYiYN73/G9KtoM1Va2zZv1rrLpw/nClTrZWD+nrDUPuGQPxcUe1Dl9EZrY388Vz5w6vSxzqG+I3bx/k9h9tp/NA3/AFIJd2DA11sYwXp/q6kSeE6WhHkdoW+GHrZtYIvAn8vrvvMbMZwNuAAyuBY9z92vGeo9DD1jt7+li6esOYQPfQp05nZntLSQJa+o7cjslNfP6cEzjuqFZamuo4atTawejdu/dcfSo3P/nKmPE9fv2ZGDYcfNubG+jqHaB/cIiG+hg4vPFuYsNZ6jeYme3NE874Uzn+9EPm11yxkA/OSBwyr46eIrWlkoetfxR42d33AKQ+Jgd1F/BUUC+cbbNWXbLKBoqf4eZTNz96veHvf/wqt100f7j/f/pRjenjGxOMly3iyNZGbnh4y7i/XYxWXx/jgzOm8Mh1ZzA4FKd+VFVPtvWQQg+hF5HqVI7A/wnS0jxmdoy7705+uhR4JagXLrZ/Tj6bplKBcbzHjL4Qbd7Rza1Pb+fh5X84PN7Rr/H2wb6xwfh7G3n8L84suKd/Q10Md6ehLlayYyFFJDwCDfxm1gKcC1yXdvOtZnYyiVTP66O+VlKZ2hesXbaQ9uYGYPwZ7rTWxrzTHhNdSDJdiDoP9NFYX5d1Rn14IHMwPjwYZ/aRLXl9PwoZn452FKk9ge7ccfdD7j7N3fel3bbM3f/A3ee7+5+mzf5LLlvJ5KudB4jHfdwZbiFloBM9ppBjF+uSG73SzWxvpq6AlHspxqfjHEXCr+Z37qZKJjsmN7HirOP55H94P2/tO8yMqU3jznALSXtM9JhC+ug0N9aNWQe47aL5tDbV0dnTV9IzAyYanxZ/RWpDzQf+/sEhOiY3ceP5c0eUWq69YiHHtk9i7RULx7RZmNbaOLxpKp+0Ry6pknx7/bc1NzJj6iRWXnAiLY11w0dD7t7Xl3cALnZ8WvwVqQ0136Slsb6Oz59zwphOmNc9sInNb+zja//zFVZecCLrbzyLx68/czh4FpKWKeQxE4nFjDnTWjnx2COY2d7MicceQVNDXUG7kYsdnxZ/RWpDzc/4p7U2ctxRrVl73Wze0c019/50uA4+NWMuJC0TVEvk0bPwXV2HCgrAxY5Pi78itaHmZ/yxmNHSlPlUqu7egeHPMwXOQtobl6MlcrZTtnIJwMWML4jfaESk/Go+8AMc1dqUsQf/mvWvDd8nTDPXSgXg9N8YNtx0Nk9cv1gLuyIhVPOpHngvYKV63by1/zCTGmJ0JrtSnjdvOl/743n0Dw7R2dNXVHqmHL1uKnnKViEH0YtIdYlE4IdEwDKMK76TKO380nkf4N5rTqWhLsbAUHy4S+bM9mbuv/Y0Jk+qZ2AwPm5QHR3k25sbeLXzQFnKHRWARaRQkQn8kL2087aL5tMxuYmdXb10TG5iz/7DXPndreMG70w17WuXLeTOf/6Vyh1FpKpFIsefkq2088uPbmXFWccDiUPQU5ulUl/PVCqZqab9uu9t4sKFs0bcT+WOIlJtIhX4xyvtbEv272lrbsipVDJbTfvoBdYwLRqLSDREKtWTXto5uhb9UH8isKd2xk5Uq54qqUy1gmhrbuBQ/xDHtk0acYiKyh1FpNoEfhBLKRR6EEsm2frNzJjaRG//EM2NdezZP3E7hHjcef2dg+zZf3hkP/1li5hxROK5dIKViFRStoNYIhf4YeKSy1xLMvf2HObjq58f89uBFnNz/x7qqEeR4FTyBK6qM1EpZK6lkgODmQ8vD/tibrHBONcunpXu9qmLjkRVpBZ3Sy3X1glh6mGfCsZLV29g8arnWLp6A9v39OQ15lzPMijkzINSKcX7FAkrBf4i5HpwSZgCTCmCca5dPCvZ7bOSFx2RSotkqqdUv+Ln0johbD3s04PxglltwxVL/YNDxOOe0/cp1y6elez2qRbTEmWBzfjNbK6ZbUn7s9/MbjCzI83sWTN7NfmxPagxZFLqGfhE3S7DFmBSwXjBrDZuPH8uK5/axqXrXuTSdS/m/H3KtYlcJbt9FtPhVCTsylLVY2Z1wC7gdOAzwLvu/g0z+yrQ7u43jff4Ulb1dPb0sXT1hrJV4pT79YqVujC+te8wNz/5SsHjrvaqnkovLIuUQ6Wres4BXnP335rZBcBZydvvA9YD4wb+Uso0A++Y3ET/4BC7ug6VPPikZrWjA0y1bupKpa9am+qK+k0l18qoSjWbq2SHU5FKK1fgvwz4fvLvM9x9N4C77zaz6ZkeYGbLgeUAs2fPLtlARueVF8xq4ysfmcul614MZOYXdIAJYsYcixnNDfVF5d+DGFepn1MdTiWqAk/1mFkj8Cbw++6+x8y63b0t7etd7j5unj/Inbv3XH1qUSmNSgoyXVHMcwcxLqVmRPJXyVTPR4GX3X1P8vM9ZnZMcrZ/DLC3DGMYNnoGPuQeqsVXeG/m2zswyFv7Dg+3lC5lxVAxv6kEUckUtuookWpWjsD/Cd5L8wD8ALgK+Eby45NlGMMI6b/id/b0heoA8Uwz31UXzuf2H21n847ukl60Ck2FBFHJFLbqKJFqFugGLjNrAc4FHk+7+RvAuWb2avJr3whyDBMJ2wHi6TPfBbPauHnJPJrqY9x60XwWzGoryUWr2J3GQZRKqvxSpHQi2aRttDD1bNnVdYjFq54brrMffZLYjKmTmDOttaK5dOX4RaqDunMSrgCfSTzu7Ow6xOV3v8TNS+ax8qltY1JUj19/JtOnTCr4NUq17yAMVT0ita7SdfwVVwszxncO9vO3/7SNVRfOp6k+ljHnPTAYL+o1SpVLD6JUUuWXIqURmSZttdCUq39wiGe27eX2H23niOaGQHLeyqWL1L7IBP5aqApJBeXNO7r5yqNbWXXh/JIvSodtsVtE8heZVM94nSDDkjtOb/+weUc39z3/Gx761OnUxYyG+hj1MWP3vt7AO46KSLhFZnE3W47/hI7JvNp5IDS5/0wXKSD06xdRFJYJh4SXqnrI/B/tnYP9oeqemUnYOoBKbRQbSPXLFvgjk+OHzL3zayH3XwvvIWpqodhAwitSgT+TWqhiqYX3EDW6WEslRT7w10IVSy28h6jRxVoqKVI5/mxqYZGtFt5DlCjHL+UQ+Z2749GO0NzpAlMaKpuVSlLgrwHlmj1qllpamnBIpUQ+x18LylUhokoUkdqgGX8NyLdCpNB0jSpRRGqDAn8NGK8dxWjFpGvyeR0RqV5K9dSAfMo5i0nXqGxUpDYEOuM3szbgbuBEwIFrgfOBTwOdybv9tbv/MMhx1Lp8KkSKSdeoEkWkNgSd6rkTeNrdLzKzRqCFROC/w91vD/i1IyXXCpFi0zWqRBEJv8BSPWY2FfgQ8B0Ad+939+6gXk9yo3SNiAQ5438/iXTOPWZ2ErAJ+ELya581syuBjcCX3L1r9IPNbDmwHGD27NkBDrO2ZargUbpGJNoCa9lgZouAF4HF7v6Smd0J7Ae+BbxNIue/EjjG3a8d77mCbtlQq7ThqrZpF7VMpKi2zGb2ATP7sZm9kvx8vpl9bYKH7QR2uvtLyc8fBU5x9z3uPuTuceAu4LTc34bkI6wbruJxp7Onj11dh+js6SMer/5+UuWWuqgvXb2BxaueY+nqDWzf06PvleQk1xz/XcBfAQMA7r4VuGy8B7j7W8AOM5ubvOkcYJuZHZN2t6XAK3mNWHIWxg1XCmi5CetFXapDroG/xd1/Muq2wRwe9zngQTPbCpwM/BfgVjP7efK2s4Ev5jxayUsYW/8qoOUmjBd1qR65Lu6+bWbHk8jLY2YXAbsnepC7bwFG55eW5TVCKVj64ezpOf5qruBRQHvPeDl87aKWYuQa+D8DrAM+aGa7gN8Afx7YqKQkwrjhKkoBbbzAPtHCfBgv6lI9xq3qMbMvuPudZrbY3TeYWSsQc/ee8g1RVT0pUajiiEol0kTvs7Onj6WrN4y5AD5x/eLhDXQT/XuIwr8XGV+hB7FcQ2L37X8nUZFzMIjBycRKGRCrOSCE8beUQmRby0gF9lxSXuPtoo7KBVQKM9Hi7i/M7HVgrpltTfuTWpyVMinVomcYqmZSAe3Y9hY6pjSVLVCVs4x0osBe7MK8FsllPOMGfnf/BPCHwL8Df5L2Z0nyo5RJqRY9FRAyK/cFcaLAnq21RntzQ04XJy2Sy3gmXNxN1uOfVIaxyDhKteipgJDZRKmXUptocTZTyqu9uYFXOw/klL6J0iK55G/cwG9mj7j7JWb2c5KlnKkvAe7u8wMdnQwrVRWHAkJm5b4g5rKWMTqH39nTxx3PbufmJfNoa26gu3eAO57dzt8tnT/m4lSuqp9qXi+S7Caa8aeaqi0JeiAyvlIteqoMMLNKXBDzbXEdj8e56szjuOmxrcM/u1UXzicej2d87qAXybWAHF6BNWkrJZVzlpZmaWOFIYi92d3LJWtfGHNxeuS6M3hfW/M4jwxGLiWnUlkFlXOaWQ8jUzzDXyKR6plaovFJGVXbYSrVcCEKQxmpu2dMR1Vq8qb1ovAaN/C7+5RyDUSiqZpm2uW8IBZysau29ZlqG4/kToety7BKtEOOYnlpoaWj1XZ6WrWNR3IX9Jm7EhKVmnkXmi6ohvRQoQotHZ0oHVXu70kY0mOSmQJ/mVVrwCp3HXtKIemCakoPFaKY3Hi2dFSlvifVtl4kuVGqp4yquV1CpRbqCkkXhD09FMQ5CWH/nkh5KfCXUTX/56zUoS3p6YINN53NE9cvnnCWGvZqkiBy42H/nkh5KdVTRtX8n7OSG7vyTReEvZqk1LnxeNwxs1B/T6S8FPjLqJoDVpgW6mph93GpcuOp9OEdz25n1YXzR+zqDdv3RMon0J27ZtYG3A2cSGIj2LXAduBhYA7wOnCJu3eN9zy1snM37IuS1aRaF8mLle/7St89u2BWGyvOOp5prY28r62Zo6dOqonviRQu287doAP/fcC/uvvdZtYItAB/Dbzr7t8ws68C7e5+03jPUyuBH2o3YEnxCpkY7Oo6xOJVz425fcNNZ3Nse0ver69/m9WhVD+LQk/gKpiZTQU+BFwN4O79QL+ZXQCclbzbfcB6YNzAX0tU/ibZFFJSW6r0oX4brR7l+FkEWdXzfqATuMfMNpvZ3ckze2e4+26A5MfpmR5sZsvNbKOZbezs7AxwmCLVoZDF/1JVCJWr4qwSu8PDphw/iyAXd+uBU4DPuftLZnYn8NVcH+zu64B1kEj1BDNEkepRyOy9VIvy5ag4028VuSnHzyLIGf9OYKe7v5T8/FESF4I9ZnYMQPLj3gDHIBIahc7eS3FGcTn2cVTzPpZqUo6fRWCBP3lk4w4zm5u86RxgG/AD4KrkbVcBTwY1BpEwKWQzW6mUo+FaNe9jqSbl+FkEXcf/OeDBZEXPr4FrSFxsHjGzTwJvABcHPAaR0Mh18b/UFTjl2MdRzftYqkk5fhaBBn533wKMKSUiMfsXkQIElSsPuuKsFjbelUvQPwsdvSiRE/Z69UKOPKyW91yqcVTL+6l2Za/jF6lGtVBZkm+uvFrecymDfjW8nzBTd06JlFqoLMm36qMa3nMpW5JXw/sJOwV+iZRaqCzJt+qjGt5zKYN1NbyfsFOqRyIln8qSas0j51v1UQ3VNKUM1qPfz4JZbXz+nBMY8sSu4Gr5OVUzzfglUnKdLVfzaWmQ36atajgUvdhNSemtHhzn/mtPY2Z7MwtmtfGVj8zl5idf4UO3rq+6n1O1UlWPRE4uM/lCKmeqWaV/eylmQTbbY2dMbaK3f4hL171YMz+nUlNVj0hSLjXStZZHrnRX2GI2JY3XtTT1ebow/5zKRYFfJINqyIuXQqVn+ukKvfiMdxGuxM+pmr6nhVKOXySDasiL5yNTu+NqX6fI1XjrA+X+OdXK91Q5fpEswjKzy5YDnza5kY+vfj70+e+J1gfK+XMK29qPcvwieap0XjxX2XLgD3369JrIf0+0PlDOn1OtrP0o8ItUmXxmsPG40z84xDcvPonu3gHWrH+NzTu62dnVS51ZTaxTQPVchGtl7Uc5fpEqkk8OOXXfS9e9yKXrXmTlU9u48fy5LJjVxsz2Zpob60K1TlEp+RwHGba1n2yU4xepIvnkkLPdd+UFJ3L0EZOYO2MKQCjWKSqlkP0FYVn7gew5fs34RapIPjnkbPc9fvrk4cBVimMZa1khPYRq4XuqwC9SRfJpbZDtvs0NdaEMRpVQK4u1+Qo08JvZ62b2czPbYmYbk7fdYma7krdtMbOPBTmGKMknVynVKZ8ccqnzzVH891OOg82rUaA5fjN7HVjk7m+n3XYLcMDdb8/1eZTjn5gOp6gd+Vb16HCTwtX6+1Ydf40br59JNZTBSe7yKV0sVZljlP/9NNXHWHnBibQ01nGof4im+trPgAcd+B14xswcWOvu65K3f9bMrgQ2Al9y967RDzSz5cBygNmzZwc8zPCLaq5SSiOq/37eOdjPld/9SWh24pZK0Je2xe5+CvBR4DNm9iHg28DxwMnAbuCbmR7o7uvcfZG7L+ro6Ah4mOEX1VyllEZU//1E9YIXaOB39zeTH/cCTwCnufsedx9y9zhwF3BakGOIilrZWCKVEdV/P1G94AW2uGtmrUDM3XuSf38W+M/Az9x9d/I+XwROd/fLxnsuLe7mJkwbS6T6RPHfjxZ3S28G8ISZpV7nIXd/2sy+Z2Ynk8j/vw5cF+AYIqVa+plIOEXx308xB8SEWWCB391/DZyU4fZlQb2miEi+InnBq/QARESkvBT4RUQiRoFfRCRiFPhFRCJGgV9EJGIU+EVEIkaBX0QkYtSdU0RKJoq7f8NIgV9ESqLW2x/UEqV6RKQkCjm/VipDM34RKcjotE5UWxyHkQK/SA0KOteeKa3z0KdOZ2Z7825KpcoAAAihSURBVJhDTWq9xXEYKdUjUmNSQXnp6g0sXvUcS1dvYPuenpIenp4prfO3/7SNtcsWRq6nfxhpxi9SY8pxfm6mtM4z2/ay8oITI9fiOIwU+EVqTDly7amTq0andWKxWORaHIeRUj0iNaYcxwlG9ajGWhHY0YulpKMXRXJXrnp6bdaqfpU4ehEzex3oAYaAQXdfZGZHAg8Dc0gcvXiJu3cFOQ6RKCnXcYJRPLmqVpQj1XO2u5+cdtX5KvBjdz8B+HHycxEpoVRQPra9hY4pTZqJywiVyPFfANyX/Pt9wJ9VYAwiIpEVdOB34Bkz22Rmy5O3zXD33QDJj9MzPdDMlpvZRjPb2NnZGfAwRUSiI+hyzsXu/qaZTQeeNbNf5vpAd18HrIPE4m5QAxQRiZpAZ/zu/mby417gCeA0YI+ZHQOQ/Lg3yDGIiMhIgQV+M2s1sympvwPnAa8APwCuSt7tKuDJoMYg0RGPO509fezqOkRnT19J2xOI1JogUz0zgCfMLPU6D7n702b2U+ARM/sk8AZwcYBjkAhQH3iR/AQW+N3918BJGW5/BzgnqNeV6ClHbxqRWqKWDRJ66gMvkh8Ffgm9cvSmEaklCvwSemoYJpIftWWW0CtXbxqRWqHALzVBDcNEcqdUj4hIxCjwi4hEjAK/iEjEKPCLiESMAr+ISMQo8IuIRIwCv4hIxKiOX0IjHnfeOdivTVoiRVLgl1BQ62WR0lGqR0IhW+vldw72V3hkIuGjwC+hoNbLIqWjwC+hoNbLIqWjwC9lUeyZuGq9LFI6gS/umlkdsBHY5e5LzOxe4I+Afcm7XO3uW4Ieh1ROKRZm1XpZpHTKMeP/AvCLUbd92d1PTv5R0K9xpVqYTbVePra9hY4pTQr6IgUKNPCb2Uzgj4G7g3wdqW5amBWpLkHP+P8b8BUgPur2vzOzrWZ2h5llPD3DzJab2UYz29jZ2RnwMCVIWpgVqS6BBX4zWwLsdfdNo770V8AHgVOBI4GbMj3e3de5+yJ3X9TR0RHUMKUMtDArUl2CXNxdDPypmX0MmARMNbMH3P2K5Nf7zOwe4MYAxyBVQAuzItUlsBm/u/+Vu8909znAZcC/uPsVZnYMgJkZ8GfAK0GNQaqHFmZFqkclevU8aGYdgAFbgBUVGIOISGSVJfC7+3pgffLvHy7Ha4qISGbauSsiEjEK/CIiEaPALyISMeaeX7OsSjCzTuC3wFHA2xUeTjE0/soK+/gh/O9B4y+v33H3MRuhQhH4U8xso7svqvQ4CqXxV1bYxw/hfw8af3VQqkdEJGIU+EVEIiZsgX9dpQdQJI2/ssI+fgj/e9D4q0CocvwiIlK8sM34RUSkSAr8IiIRE4rAb2bfNbO9ZhbKTp5mNsvMnjOzX5jZv5nZFyo9pnyY2SQz+4mZ/Sw5/v9U6TEVwszqzGyzmT1V6bHky8xeN7Ofm9kWM9tY6fHky8zazOxRM/tl8v/BGZUeU67MbG7y+576s9/Mbqj0uIoRihy/mX0IOADc7+4nVno8+Uq2oj7G3V82synAJuDP3H1bhYeWk2QL7VZ3P2BmDcD/A77g7i9WeGh5MbO/BBYBU919SaXHkw8zex1Y5O5h2jw0zMzuA/7V3e82s0agxd27Kz2ufJlZHbALON3df1vp8RQqFDN+d/+/wLuVHkeh3H23u7+c/HsPicPnj63sqHLnCQeSnzYk/1T/jCGNzn+uHDObCnwI+A6Au/eHMegnnQO8FuagDyEJ/LXEzOYAC4CXKjuS/CTTJFuAvcCz7h6q8ZP9/OewcOAZM9tkZssrPZg8vR/oBO5JptruNrPWSg+qQJcB36/0IIqlwF9GZjYZeAy4wd33V3o8+XD3IXc/GZgJnGZmoUm5jXP+c5gsdvdTgI8Cn0mmP8OiHjgF+La7LwAOAl+t7JDyl0xR/Snwj5UeS7EU+MskmRt/DHjQ3R+v9HgKlfwVfT3wkQoPJR+p859fB/4B+LCZPVDZIeXH3d9MftwLPAGcVtkR5WUnsDPtt8RHSVwIwuajwMvuvqfSAymWAn8ZJBdHvwP8wt3/a6XHky8z6zCztuTfm4H/CPyysqPKXbbznys8rJyZWWuyKIBkiuQ8QnRWtbu/Bewws7nJm84BQlHYMMonqIE0D1TmzN28mdn3gbOAo8xsJ/B1d/9OZUeVl8XAMuDnyTw5wF+7+w8rOKZ8HAPcl6xoiAGPuHvoSiJDbAbwRGL+QD3wkLs/Xdkh5e1zJM7bbgR+DVxT4fHkxcxagHOB6yo9llIIRTmniIiUjlI9IiIRo8AvIhIxCvwiIhGjwC8iEjEK/CIiEaPAL5FmZp9Pdot8MMf7zzGzy9M+X2Rmf5/8+9Vm9q3k31eY2ZVpt78viPGLFCIUdfwiAboe+Ki7/2aiO5pZPTAHuBx4CMDdNwJj2iS7+5q0T68mseHqzeKHK1I8BX6JLDNbQ6KB2A/M7B+A44E/IPH/4hZ3f9LMribR1XMS0Aq0AL+X3Ih3H7AZuHF0m2czu4VEK/HXSbSCftDMeoG/AT7l7kuT9zsX+At3/3iw71bkPUr1SGS5+woSs/CzSQT1f3H3U5Of35bWQfIM4Cp3/zCJ5mL/6u4nu/sdObzGoyR+I/jzZJO7H5K4cHQk73INcE8p35fIRBT4RRLOA76anMmvJzHDn5382rPuXpLzIDyxVf57wBXJ/kdnAP+7FM8tkiulekQSDLjQ3bePuNHsdBJthEvpHuB/AYeBf3T3wRI/v8i4NOMXSfgR8LlkJ1XMbEGW+/UAU/J87hGPSbZYfhP4GnBv3iMVKZICv0jCShJHSm41s1eSn2eyFRhMHjz/xRyf+15gTfKg7ubkbQ8CO8Jy7rLUFnXnFKmAZL3/5pC1F5caocAvUmZmtonEusG57t5X6fFI9Cjwi4hEjHL8IiIRo8AvIhIxCvwiIhGjwC8iEjEK/CIiEfP/AYBIziIy92T5AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.scatterplot(x='fertility', y='life', data=df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there is a strongly negative correlation, so a linear regression should be able to capture this trend. Your job is to fit a linear regression and then predict the life expectancy, overlaying these predicted values on the plot to generate a regression line. You will also compute and print the $R^2$ score using sckit-learn's ```.score()``` method." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6192442167740037\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "\n", "# Create the regressor: reg\n", "reg = LinearRegression()\n", "\n", "# Create th prediction space\n", "prediction_space = np.linspace(min(X_fertility), max(X_fertility)).reshape(-1, 1)\n", "\n", "# Fit the model to the data\n", "reg.fit(X_fertility, y)\n", "\n", "# compute predictions over the prediction space: y_pred\n", "y_pred = reg.predict(prediction_space)\n", "\n", "# Print $R^2$\n", "print(reg.score(X_fertility, y))\n", "\n", "# Plot regression line on scatter plot\n", "sns.scatterplot(x='fertility', y='life', data=df)\n", "plt.plot(prediction_space, y_pred, color='black', linewidth=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train/test split for regression\n", "Train and test sets are vital to ensure that your supervised learning model is able to generalize well to new data. This was true for classification models, and is equally true for linear regression models.\n", "\n", "In this exercise, you will split the Gapminder dataset into training and testing sets, and then fit and predict a linear regression over all features. In addition to computing the $R^2$ score, you will also compute the Root Mean Squared Error (RMSE), which is another commonly used metric to evaluate regression models." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R^2: 0.7298987360907494\n", "Root Mean Squared Error: 4.194027914110243\n" ] } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Create training and test sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n", "\n", "# Create the regressor: reg_all\n", "reg_all = LinearRegression()\n", "\n", "# Fit the regressor to the training data\n", "reg_all.fit(X_train, y_train)\n", "\n", "# Predict on the test data: y_pred\n", "y_pred = reg_all.predict(X_test)\n", "\n", "# compute and print R^2 and RMSE\n", "print(\"R^2: {}\".format(reg_all.score(X_test, y_test)))\n", "rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n", "print(\"Root Mean Squared Error: {}\".format(rmse))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross-validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Cross-validation motivation\n", " - Model performance is dependent on way the data is split\n", " - Not representative of the model's ability to generalize\n", " - Solution : **Cross-validation**!\n", "- k-fold Cross-validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5-fold cross-validation\n", "Cross-validation is a vital step in evaluating a model. It maximizes the amount of data that is used to train the model, as during the course of training, the model is not only trained, but also tested on all of the available data.\n", "\n", "In this exercise, you will practice 5-fold cross validation on the Gapminder data. By default, scikit-learn's ```cross_val_score()``` function uses $R^2$ as the metric of choice for regression. Since you are performing 5-fold cross-validation, the function will return 5 scores. Your job is to compute these 5 scores and then take their average." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.71001079 0.75007717 0.55271526 0.547501 0.52410561]\n", "Average 5-Fold CV Score: 0.6168819644425119\n" ] } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "\n", "# Create a linear regression object: reg\n", "reg = LinearRegression()\n", "\n", "# Compute 5-fold cross-validation scores: cv_scores\n", "cv_scores = cross_val_score(reg, X, y, cv=5)\n", "\n", "# Print the 5-fold cross-validation scores\n", "print(cv_scores)\n", "\n", "print(\"Average 5-Fold CV Score: {}\".format(np.mean(cv_scores)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-Fold CV comparison\n", "Cross validation is essential but do not forget that the more folds you use, the more computationally expensive cross-validation becomes. In this exercise, you will explore this for yourself. Your job is to perform 3-fold cross-validation and then 10-fold cross-validation on the Gapminder dataset.\n", "\n", "In the IPython Shell, you can use ```%timeit``` to see how long each 3-fold CV takes compared to 10-fold CV by executing the following ```cv=3``` and ```cv=10```:\n", "```python\n", "%timeit cross_val_score(reg, X, y, cv = ____)\n", "```" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.98 ms ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "0.6294715754653507\n", "6.27 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "0.5883937741571185\n" ] } ], "source": [ "# Create a linear regression object: reg\n", "reg = LinearRegression()\n", "\n", "# Perform 3-fold CV\n", "%timeit cross_val_score(reg, X, y, cv=3)\n", "cvscores_3 = cross_val_score(reg, X, y, cv=3)\n", "print(np.mean(cvscores_3))\n", "\n", "# Perform 10-fold CV\n", "%timeit cross_val_score(reg, X, y, cv=10)\n", "cvscores_10 = cross_val_score(reg, X, y, cv=10)\n", "print(np.mean(cvscores_10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regularized regression\n", "- Why regularize?\n", " - Recall: Linear regression minimizes a loss function\n", " - It chooses a coefficient for each feature variable\n", " - Large coefficient can lead to overfitting\n", " - Penalizing large coefficients : **Regularization**\n", "- Ridge regression\n", " - Loss function = $ \\text{OLS loss function} + \\alpha \\sum^{n}_{i=1}a_i^2 $\n", " - Alpha : Parameter we need to choose (Hyperparameter or $\\lambda$)\n", " - Picking alpha is similar to picking k in k-NN\n", " - Alpha controls model complexity\n", " - Alpha = 0: get back OLS (Can lead to overfitting)\n", " - Very high alpha: Can lead to underfitting\n", "- Lasso regression\n", " - Loss function = $ \\text{OLS loss function} + \\alpha \\sum^{n}_{i=1}|a_i| $\n", " - Can be used to select import features of a dataset\n", " - Shrinks the coefficients of less important features to exactly 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Regularization I: Lasso\n", "In the video, you saw how Lasso selected out the ```'RM'``` feature as being the most important for predicting Boston house prices, while shrinking the coefficients of certain other features to 0. Its ability to perform feature selection in this way becomes even more useful when you are dealing with data involving thousands of features.\n", "\n", "In this exercise, you will fit a lasso regression to the Gapminder data you have been working with and plot the coefficients. Just as with the Boston data, you will find that the coefficients of some features are shrunk to 0, with only the most important ones remaining." ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "X = df.drop('life', axis='columns').values\n", "y = df['life'].values" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0. -0. -0. 0. 0. 0.\n", " -0. -0.07087587]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Import Lasso\n", "from sklearn.linear_model import Lasso\n", "\n", "# Instantiate a lasso regressor: lasso\n", "lasso = Lasso(alpha=0.4, normalize=True)\n", "\n", "# Fit the regressor to the data\n", "lasso.fit(X, y)\n", "\n", "# Compute and print the coefficients\n", "lasso_coef = lasso.coef_\n", "print(lasso_coef)\n", "\n", "# Plot the coefficients\n", "df_columns = df.columns[:-1]\n", "plt.plot(range(len(df_columns)), lasso_coef)\n", "plt.xticks(range(len(df_columns)), df_columns.values, rotation=60)\n", "plt.margins(0.02)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Regularization II: Ridge\n", "Lasso is great for feature selection, but when building regression models, Ridge regression should be your first choice.\n", "\n", "Recall that lasso performs regularization by adding to the loss function a penalty term of the absolute value of each coefficient multiplied by some alpha. This is also known as L1 regularization because the regularization term is the L1 norm of the coefficients. This is not the only way to regularize, however.\n", "\n", "If instead you took the sum of the squared values of the coefficients multiplied by some alpha - like in Ridge regression - you would be computing the L2 norm. In this exercise, you will practice fitting ridge regression models over a range of different alphas, and plot cross-validated R2 scores for each." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "def display_plot(cv_scores, cv_scores_std):\n", " fig = plt.figure()\n", " ax = fig.add_subplot(1,1,1)\n", " ax.plot(alpha_space, cv_scores)\n", "\n", " std_error = cv_scores_std / np.sqrt(10)\n", "\n", " ax.fill_between(alpha_space, cv_scores + std_error, cv_scores - std_error, alpha=0.2)\n", " ax.set_ylabel('CV Score +/- Std Error')\n", " ax.set_xlabel('Alpha')\n", " ax.axhline(np.max(cv_scores), linestyle='--', color='.5')\n", " ax.set_xlim([alpha_space[0], alpha_space[-1]])\n", " ax.set_xscale('log')" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.linear_model import Ridge\n", "\n", "# Setup the array of alphas and lists to store scores\n", "alpha_space = np.logspace(-4, 0, 50)\n", "ridge_scores = []\n", "ridge_scores_std = []\n", "\n", "# Create a ridge regressor: ridge\n", "ridge = Ridge(normalize=True)\n", "\n", "# Compute scores over range of alphas\n", "for alpha in alpha_space:\n", " \n", " # Specify the alpha value to use: ridge.alhpa\n", " ridge.alpha = alpha\n", " \n", " # Perform 10-fold CV: ridge_cv_scores\n", " ridge_cv_scores = cross_val_score(ridge, X, y, cv=10)\n", " \n", " # Append the mean of ridge_cv_scores to ridge_scores\n", " ridge_scores.append(np.mean(ridge_cv_scores))\n", " \n", " # Append the std of ridge_cv_scores to ridge_scores_std\n", " ridge_scores_std.append(np.std(ridge_cv_scores))\n", " \n", "# Display the plot\n", "display_plot(ridge_scores, ridge_scores_std)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }