{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to scikit-learn\n",
    "\n",
    "## Basic preprocessing and model fitting\n",
    "\n",
    "In this notebook, we present how to build predictive models on tabular\n",
    "datasets.\n",
    "\n",
    "In particular we will highlight:\n",
    "* the importance of scaling numerical variables;\n",
    "* how to train predictive models when you only have numerical variables;\n",
    "* how to evaluate the performance of a model via cross-validation.\n",
    "\n",
    "## Introducing the dataset\n",
    "\n",
    "To this aim, we will use data from the 1994 Census bureau database. The goal\n",
    "with this data is to regress wages from heterogeneous data such as age,\n",
    "employment, education, family information, etc.\n",
    "\n",
    "Let's first load the data located in the `datasets` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\n",
    "    \"https://www.openml.org/data/get_csv/1595261/adult-census.csv\")\n",
    "\n",
    "# Or use the local copy:\n",
    "# df = pd.read_csv('../datasets/adult-census.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a look at the first records of this data frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age   workclass  fnlwgt      education  education-num       marital-status  \\\n",
       "0   25     Private  226802           11th              7        Never-married   \n",
       "1   38     Private   89814        HS-grad              9   Married-civ-spouse   \n",
       "2   28   Local-gov  336951     Assoc-acdm             12   Married-civ-spouse   \n",
       "3   44     Private  160323   Some-college             10   Married-civ-spouse   \n",
       "4   18           ?  103497   Some-college             10        Never-married   \n",
       "\n",
       "           occupation relationship    race      sex  capital-gain  \\\n",
       "0   Machine-op-inspct    Own-child   Black     Male             0   \n",
       "1     Farming-fishing      Husband   White     Male             0   \n",
       "2     Protective-serv      Husband   White     Male             0   \n",
       "3   Machine-op-inspct      Husband   Black     Male          7688   \n",
       "4                   ?    Own-child   White   Female             0   \n",
       "\n",
       "   capital-loss  hours-per-week  native-country   class  \n",
       "0             0              40   United-States   <=50K  \n",
       "1             0              50   United-States   <=50K  \n",
       "2             0              40   United-States    >50K  \n",
       "3             0              40   United-States    >50K  \n",
       "4             0              30   United-States   <=50K  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The target variable in our study will be the \"class\" column while we will use\n",
    "the other columns as input variables for our model. This target column divides\n",
    "the samples (also known as records) into two groups: high income (>50K) vs low\n",
    "income (<=50K). The resulting prediction problem is therefore a binary\n",
    "classification problem.\n",
    "\n",
    "For simplicity, we will ignore the \"fnlwgt\" (final weight) column that was\n",
    "crafted by the creators of the dataset when sampling the dataset to be\n",
    "representative of the full census database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([' <=50K', ' <=50K', ' >50K', ..., ' <=50K', ' <=50K', ' >50K'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_name = \"class\"\n",
    "target = df[target_name].to_numpy()\n",
    "target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age   workclass      education  education-num       marital-status  \\\n",
       "0   25     Private           11th              7        Never-married   \n",
       "1   38     Private        HS-grad              9   Married-civ-spouse   \n",
       "2   28   Local-gov     Assoc-acdm             12   Married-civ-spouse   \n",
       "3   44     Private   Some-college             10   Married-civ-spouse   \n",
       "4   18           ?   Some-college             10        Never-married   \n",
       "\n",
       "           occupation relationship    race      sex  capital-gain  \\\n",
       "0   Machine-op-inspct    Own-child   Black     Male             0   \n",
       "1     Farming-fishing      Husband   White     Male             0   \n",
       "2     Protective-serv      Husband   White     Male             0   \n",
       "3   Machine-op-inspct      Husband   Black     Male          7688   \n",
       "4                   ?    Own-child   White   Female             0   \n",
       "\n",
       "   capital-loss  hours-per-week  native-country  \n",
       "0             0              40   United-States  \n",
       "1             0              50   United-States  \n",
       "2             0              40   United-States  \n",
       "3             0              40   United-States  \n",
       "4             0              30   United-States  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = df.drop(columns=[target_name, \"fnlwgt\"])\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check the number of samples and the number of features available in\n",
    "the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The dataset contains 48842 samples and 13 features\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"The dataset contains {data.shape[0]} samples and {data.shape[1]} \"\n",
    "    \"features\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with numerical data\n",
    "\n",
    "Numerical data is the most natural type of data used in machine learning\n",
    "and can (almost) directly be fed to predictive models. We can quickly have a\n",
    "look at such data by selecting the subset of numerical columns from the\n",
    "original data.\n",
    "\n",
    "We will use this subset of data to fit a linear classification model to\n",
    "predict the income class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['age', 'workclass', 'education', 'education-num', 'marital-status',\n",
       "       'occupation', 'relationship', 'race', 'sex', 'capital-gain',\n",
       "       'capital-loss', 'hours-per-week', 'native-country'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "age                int64\n",
       "workclass         object\n",
       "education         object\n",
       "education-num      int64\n",
       "marital-status    object\n",
       "occupation        object\n",
       "relationship      object\n",
       "race              object\n",
       "sex               object\n",
       "capital-gain       int64\n",
       "capital-loss       int64\n",
       "hours-per-week     int64\n",
       "native-country    object\n",
       "dtype: object"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# \"i\" denotes integer type, \"f\" denotes float type\n",
    "numerical_columns = [\n",
    "    c for c in data.columns if data[c].dtype.kind in [\"i\", \"f\"]]\n",
    "numerical_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>25</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>38</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>28</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>44</td>\n",
       "      <td>10</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>18</td>\n",
       "      <td>10</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  education-num  capital-gain  capital-loss  hours-per-week\n",
       "0   25              7             0             0              40\n",
       "1   38              9             0             0              50\n",
       "2   28             12             0             0              40\n",
       "3   44             10          7688             0              40\n",
       "4   18             10             0             0              30"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_numeric = data[numerical_columns]\n",
    "data_numeric.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When building a machine learning model, it is important to leave out a\n",
    "subset of the data which we can use later to evaluate the trained model.\n",
    "The data used to fit a model a called training data while the one used to\n",
    "assess a model are called testing data.\n",
    "\n",
    "Scikit-learn provides an helper function `train_test_split` which will\n",
    "split the dataset into a training and a testing set. It will ensure that\n",
    "the data are shuffled randomly before splitting the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The training dataset contains 36631 samples and 5 features\n",
      "The testing dataset contains 12211 samples and 5 features\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "data_train, data_test, target_train, target_test = train_test_split(\n",
    "    data_numeric, target, random_state=42)\n",
    "\n",
    "print(\n",
    "    f\"The training dataset contains {data_train.shape[0]} samples and \"\n",
    "    f\"{data_train.shape[1]} features\")\n",
    "print(\n",
    "    f\"The testing dataset contains {data_test.shape[0]} samples and \"\n",
    "    f\"{data_test.shape[1]} features\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "We will build a linear classification model called \"Logistic Regression\". The\n",
    "`fit` method is called to train the model from the input (features) and\n",
    "target data. Only the training data should be given for this purpose.\n",
    "\n",
    "In addition, check the time required to train the model and the number of\n",
    "iterations done by the solver to find a solution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model LogisticRegression was trained in 0.381 seconds for [100] iterations\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lesteve/miniconda3/envs/scikit-learn-tutorial/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:947: ConvergenceWarning: lbfgs failed to converge. Increase the number of iterations.\n",
      "  \"of iterations.\", ConvergenceWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "import time\n",
    "\n",
    "model = LogisticRegression(solver='lbfgs')\n",
    "start = time.time()\n",
    "model.fit(data_train, target_train)\n",
    "elapsed_time = time.time() - start\n",
    "\n",
    "print(f\"The model {model.__class__.__name__} was trained in \"\n",
    "      f\"{elapsed_time:.3f} seconds for {model.n_iter_} iterations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's ignore the convergence warning for now and instead let's try\n",
    "to use our model to make some predictions on the first three records\n",
    "of the held out test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([' <=50K', ' <=50K', ' >50K', ' <=50K', ' <=50K'], dtype=object)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_predicted = model.predict(data_test)\n",
    "target_predicted[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([' <=50K', ' <=50K', ' >50K', ' <=50K', ' <=50K'], dtype=object)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_test[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>predicted-class</th>\n",
       "      <th>expected-class</th>\n",
       "      <th>correct</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>7762</td>\n",
       "      <td>56</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23881</td>\n",
       "      <td>25</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30507</td>\n",
       "      <td>43</td>\n",
       "      <td>13</td>\n",
       "      <td>14344</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28911</td>\n",
       "      <td>32</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19484</td>\n",
       "      <td>39</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       age  education-num  capital-gain  capital-loss  hours-per-week  \\\n",
       "7762    56              9             0             0              40   \n",
       "23881   25              9             0             0              40   \n",
       "30507   43             13         14344             0              40   \n",
       "28911   32              9             0             0              40   \n",
       "19484   39             13             0             0              30   \n",
       "\n",
       "      predicted-class expected-class  correct  \n",
       "7762            <=50K          <=50K     True  \n",
       "23881           <=50K          <=50K     True  \n",
       "30507            >50K           >50K     True  \n",
       "28911           <=50K          <=50K     True  \n",
       "19484           <=50K          <=50K     True  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = data_test.copy()\n",
    "predictions['predicted-class'] = target_predicted\n",
    "predictions['expected-class'] = target_test\n",
    "predictions['correct'] = target_predicted == target_test\n",
    "predictions.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To quantitatively evaluate our model, we can use the method `score`. It will\n",
    "compute the classification accuracy when dealing with a classificiation\n",
    "problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The test accuracy using a LogisticRegression is 0.818\n"
     ]
    }
   ],
   "source": [
    "print(f\"The test accuracy using a {model.__class__.__name__} is \"\n",
    "      f\"{model.score(data_test, target_test):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is mathematically equivalent as computing the average number of time\n",
    "the model makes a correct prediction on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8177053476373761"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(target_test == target_predicted).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 1\n",
    "\n",
    "- What would be the score of a model that always predicts `' >50K'`?\n",
    "- What would be the score of a model that always predicts `' <= 50K'`?\n",
    "- Is 81% or 82% accuracy a good score for this problem?\n",
    "\n",
    "Hint: You can compute the cross-validated of a [DummyClassifier](https://scikit-learn.org/stable/modules/model_evaluation.html#dummy-estimators) the performance of such baselines.\n",
    "\n",
    "Use the dedicated notebook to do this exercise."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now consider the `ConvergenceWarning` message that was raised previously\n",
    "when calling the `fit` method to train our model. This warning informs us that\n",
    "our model stopped learning because it reached the maximum number of\n",
    "iterations allowed by the user. This could potentially be detrimental for the\n",
    "model accuracy. We can follow the (bad) advice given in the warning message\n",
    "and increase the maximum number of iterations allowed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression(solver='lbfgs', max_iter=50000)\n",
    "start = time.time()\n",
    "model.fit(data_train, target_train)\n",
    "elapsed_time = time.time() - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The accuracy using a LogisticRegression is 0.818 with a fitting time of 0.353 seconds in [105] iterations\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"The accuracy using a {model.__class__.__name__} is \"\n",
    "    f\"{model.score(data_test, target_test):.3f} with a fitting time of \"\n",
    "    f\"{elapsed_time:.3f} seconds in {model.n_iter_} iterations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now observe a longer training time but not significant improvement in\n",
    "the predictive performance. Instead of increasing the number of iterations, we\n",
    "can try to help fit the model faster by scaling the data first. A range of\n",
    "preprocessing algorithms in scikit-learn allows us to transform the input data\n",
    "before training a model. We can easily combine these sequential operations\n",
    "with a scikit-learn `Pipeline`, which chain together operations and can be\n",
    "used like any other classifier or regressor. The helper function\n",
    "`make_pipeline` will create a `Pipeline` by giving as arguments the successive\n",
    "transformations to perform followed by the classifier or regressor model.\n",
    "\n",
    "In our case, we will standardize the data and then train a new logistic\n",
    "regression model on that new version of the dataset set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>count</td>\n",
       "      <td>36631.000000</td>\n",
       "      <td>36631.000000</td>\n",
       "      <td>36631.000000</td>\n",
       "      <td>36631.000000</td>\n",
       "      <td>36631.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>mean</td>\n",
       "      <td>38.642352</td>\n",
       "      <td>10.078131</td>\n",
       "      <td>1087.077721</td>\n",
       "      <td>89.665311</td>\n",
       "      <td>40.431247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>std</td>\n",
       "      <td>13.725748</td>\n",
       "      <td>2.570143</td>\n",
       "      <td>7522.692939</td>\n",
       "      <td>407.110175</td>\n",
       "      <td>12.423952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>min</td>\n",
       "      <td>17.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25%</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>9.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>40.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50%</td>\n",
       "      <td>37.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>40.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75%</td>\n",
       "      <td>48.000000</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>45.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>max</td>\n",
       "      <td>90.000000</td>\n",
       "      <td>16.000000</td>\n",
       "      <td>99999.000000</td>\n",
       "      <td>4356.000000</td>\n",
       "      <td>99.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                age  education-num  capital-gain  capital-loss  hours-per-week\n",
       "count  36631.000000   36631.000000  36631.000000  36631.000000    36631.000000\n",
       "mean      38.642352      10.078131   1087.077721     89.665311       40.431247\n",
       "std       13.725748       2.570143   7522.692939    407.110175       12.423952\n",
       "min       17.000000       1.000000      0.000000      0.000000        1.000000\n",
       "25%       28.000000       9.000000      0.000000      0.000000       40.000000\n",
       "50%       37.000000      10.000000      0.000000      0.000000       40.000000\n",
       "75%       48.000000      12.000000      0.000000      0.000000       45.000000\n",
       "max       90.000000      16.000000  99999.000000   4356.000000       99.000000"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.17177061,  0.35868902, -0.14450843,  5.71188483, -2.28845333],\n",
       "       [ 0.02605707,  1.1368665 , -0.14450843, -0.22025127, -0.27618374],\n",
       "       [-0.33822677,  1.1368665 , -0.14450843, -0.22025127,  0.77019645],\n",
       "       ...,\n",
       "       [-0.77536738, -0.03039972, -0.14450843, -0.22025127, -0.03471139],\n",
       "       [ 0.53605445,  0.35868902, -0.14450843, -0.22025127, -0.03471139],\n",
       "       [ 1.48319243,  1.52595523, -0.14450843, -0.22025127, -2.69090725]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "data_train_scaled = scaler.fit_transform(data_train)\n",
    "data_train_scaled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>count</td>\n",
       "      <td>3.663100e+04</td>\n",
       "      <td>3.663100e+04</td>\n",
       "      <td>3.663100e+04</td>\n",
       "      <td>3.663100e+04</td>\n",
       "      <td>3.663100e+04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>mean</td>\n",
       "      <td>-2.273364e-16</td>\n",
       "      <td>1.219606e-16</td>\n",
       "      <td>3.530310e-17</td>\n",
       "      <td>3.840667e-17</td>\n",
       "      <td>1.844684e-16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>std</td>\n",
       "      <td>1.000014e+00</td>\n",
       "      <td>1.000014e+00</td>\n",
       "      <td>1.000014e+00</td>\n",
       "      <td>1.000014e+00</td>\n",
       "      <td>1.000014e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>min</td>\n",
       "      <td>-1.576792e+00</td>\n",
       "      <td>-3.532198e+00</td>\n",
       "      <td>-1.445084e-01</td>\n",
       "      <td>-2.202513e-01</td>\n",
       "      <td>-3.173852e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25%</td>\n",
       "      <td>-7.753674e-01</td>\n",
       "      <td>-4.194885e-01</td>\n",
       "      <td>-1.445084e-01</td>\n",
       "      <td>-2.202513e-01</td>\n",
       "      <td>-3.471139e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50%</td>\n",
       "      <td>-1.196565e-01</td>\n",
       "      <td>-3.039972e-02</td>\n",
       "      <td>-1.445084e-01</td>\n",
       "      <td>-2.202513e-01</td>\n",
       "      <td>-3.471139e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75%</td>\n",
       "      <td>6.817680e-01</td>\n",
       "      <td>7.477778e-01</td>\n",
       "      <td>-1.445084e-01</td>\n",
       "      <td>-2.202513e-01</td>\n",
       "      <td>3.677425e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>max</td>\n",
       "      <td>3.741752e+00</td>\n",
       "      <td>2.304133e+00</td>\n",
       "      <td>1.314865e+01</td>\n",
       "      <td>1.047970e+01</td>\n",
       "      <td>4.714245e+00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                age  education-num  capital-gain  capital-loss  hours-per-week\n",
       "count  3.663100e+04   3.663100e+04  3.663100e+04  3.663100e+04    3.663100e+04\n",
       "mean  -2.273364e-16   1.219606e-16  3.530310e-17  3.840667e-17    1.844684e-16\n",
       "std    1.000014e+00   1.000014e+00  1.000014e+00  1.000014e+00    1.000014e+00\n",
       "min   -1.576792e+00  -3.532198e+00 -1.445084e-01 -2.202513e-01   -3.173852e+00\n",
       "25%   -7.753674e-01  -4.194885e-01 -1.445084e-01 -2.202513e-01   -3.471139e-02\n",
       "50%   -1.196565e-01  -3.039972e-02 -1.445084e-01 -2.202513e-01   -3.471139e-02\n",
       "75%    6.817680e-01   7.477778e-01 -1.445084e-01 -2.202513e-01    3.677425e-01\n",
       "max    3.741752e+00   2.304133e+00  1.314865e+01  1.047970e+01    4.714245e+00"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train_scaled = pd.DataFrame(data_train_scaled,\n",
    "                                 columns=data_train.columns)\n",
    "data_train_scaled.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "model = make_pipeline(StandardScaler(),\n",
    "                      LogisticRegression(solver='lbfgs'))\n",
    "start = time.time()\n",
    "model.fit(data_train, target_train)\n",
    "elapsed_time = time.time() - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The accuracy using a Pipeline is 0.818 with a fitting time of 0.086 seconds in [13] iterations\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"The accuracy using a {model.__class__.__name__} is \"\n",
    "    f\"{model.score(data_test, target_test):.3f} with a fitting time of \"\n",
    "    f\"{elapsed_time:.3f} seconds in {model[-1].n_iter_} iterations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "We can see that the training time and the number of iterations is much shorter\n",
    "while the predictive performance (accuracy) stays the same.\n",
    "\n",
    "In the previous example, we split the original data into a training set and a\n",
    "testing set. This strategy has several issues: in the setting where the amount\n",
    "of data is limited, the subset of data used to train or test will be small;\n",
    "and the splitting was done in a random manner and we have no information\n",
    "regarding the confidence of the results obtained.\n",
    "\n",
    "Instead, we can use cross-validation. Cross-validation consists of\n",
    "repeating this random splitting into training and testing sets and aggregating\n",
    "the model performance. By repeating the experiment, one can get an estimate of\n",
    "the variability of the model performance.\n",
    "\n",
    "The function `cross_val_score` allows for such experimental protocol by giving\n",
    "the model, the data and the target. Since there exists several\n",
    "cross-validation strategies, `cross_val_score` takes a parameter `cv` which\n",
    "defines the splitting strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The different scores obtained are: \n",
      "[0.81216092 0.8096018  0.81337019 0.81326781 0.82207207]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "scores = cross_val_score(model, data_numeric, target, cv=5)\n",
    "print(f\"The different scores obtained are: \\n{scores}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mean cross-validation accuracy is: 0.814 +/- 0.004\n"
     ]
    }
   ],
   "source": [
    "print(f\"The mean cross-validation accuracy is: \"\n",
    "      f\"{scores.mean():.3f} +/- {scores.std():.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that by computing the standard-deviation of the cross-validation scores\n",
    "we can get an idea of the uncertainty of our estimation of the predictive\n",
    "performance of the model: in the above results, only the first 2 decimals seem\n",
    "to be trustworthy. Using a single train / test split would not allow us to\n",
    "know anything about the level of uncertainty of the accuracy of the model.\n",
    "\n",
    "Setting `cv=5` created 5 distinct splits to get 5 variations for the training\n",
    "and testing sets. Each training set is used to fit one model which is then\n",
    "scored on the matching test set. This strategy is called K-fold\n",
    "cross-validation where `K` corresponds to the number of splits.\n",
    "\n",
    "The following matplotlib code helps visualize how the dataset is partitioned\n",
    "into train and test samples at each iteration of the cross-validation\n",
    "procedure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from sklearn.model_selection import KFold\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "cmap_cv = plt.cm.coolwarm\n",
    "\n",
    "\n",
    "def plot_cv_indices(cv, X, y, ax, lw=20):\n",
    "    \"\"\"Create a sample plot for indices of a cross-validation object.\"\"\"\n",
    "    splits = list(cv.split(X=X, y=y))\n",
    "    n_splits = len(splits)\n",
    "\n",
    "    # Generate the training/testing visualizations for each CV split\n",
    "    for ii, (train, test) in enumerate(splits):\n",
    "        # Fill in indices with the training/test groups\n",
    "        indices = np.zeros(shape=X.shape[0], dtype=np.int32)\n",
    "        indices[train] = 1\n",
    "\n",
    "        # Visualize the results\n",
    "        ax.scatter(range(len(indices)), [ii + .5] * len(indices),\n",
    "                   c=indices, marker='_', lw=lw, cmap=cmap_cv,\n",
    "                   vmin=-.2, vmax=1.2)\n",
    "\n",
    "    # Formatting\n",
    "    yticklabels = list(range(n_splits))\n",
    "    ax.set(yticks=np.arange(n_splits + 2) + .5,\n",
    "           yticklabels=yticklabels, xlabel='Sample index',\n",
    "           ylabel=\"CV iteration\", ylim=[n_splits + .2,\n",
    "                                        -.2], xlim=[0, 100])\n",
    "    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Some random data points\n",
    "n_points = 100\n",
    "X = np.random.randn(n_points, 10)\n",
    "y = np.random.randn(n_points)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "cv = KFold(5)\n",
    "_ = plot_cv_indices(cv, X, y, ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "In this notebook we have seen:\n",
    "* how to train predictive models when you only have numerical variables;\n",
    "  the importance of scaling numerical variables through\n",
    "  `sklearn.preprocessing.StandardScaler`;\n",
    "* how to chain multiple steps (e.g. preprocessing with `StandardScaler` and\n",
    "  a `LogisticRegression` model) in a single `scikit-learn` estimator through\n",
    "  `sklearn.compose.Pipeline`\n",
    "* how to evaluate the performance of a model via cross-validation through\n",
    "* `sklearn.model_selection.cross_val_score`."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "encoding": "# -*- coding: utf-8 -*-",
   "formats": "python_scripts//py:percent,notebooks//ipynb"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}