{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 4 - Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[4.1 An Overview of Classification](#4.1-An-Overview-of-Classification)\n",
    "\n",
    "[4.3 Logistic regression](#4.3-Logistic-regression)\n",
    "\n",
    "[4.4 Linear Discriminant Analysis](#4.4-Linear-Discriminant-Analysis)\n",
    "> [4.4.4 Quadratic Discriminant Analysis](#4.4.4-Quadratic-Discriminant-Analysis)\n",
    "\n",
    "[4.5 A Comparison of Classification Methods](#4.5-A-Comparison-of-Classification-Methods)\n",
    "\n",
    "[4.6 Lab: Logistic Regression, LDA, QDA, and KNN](#4.6-Lab:-Logistic-Regression,-LDA,-QDA,-and-KNN)\n",
    "> [4.6.1 The Stock Market Data](#4.6.1-The-Stock-Market-Data)<br>\n",
    "> [4.6.2 Logistic regression](#4.6.2-Logistic-regression)<br>\n",
    "> [4.6.3 Linear Discriminant Analysis (LDA)](#4.6.3-Linear-Discriminant-Analysis-%28LDA%29)<br>\n",
    "> [4.6.4 Quadratic Discriminant Analysis (QDA)](#4.6.4-Quadratic-Discriminant-Analysis-%28QDA%29)<br>\n",
    "> [4.6.5 K-Nearest Neighbors](#4.6.5-K-Nearest-Neighbors)<br>\n",
    "> [4.6.6 An Application to Caravan Insurance Data](#4.6.6-An-Application-to-Caravan-Insurance-Data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import sklearn.linear_model as skl_lm\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis\n",
    "from sklearn.metrics import confusion_matrix, classification_report, precision_score, roc_curve, auc, log_loss\n",
    "from sklearn import preprocessing\n",
    "from sklearn import neighbors\n",
    "\n",
    "from scipy import stats\n",
    "\n",
    "import scikitplot as skplt\n",
    "\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "\n",
    "from ipywidgets import widgets\n",
    "\n",
    "from classification_helper import print_classification_statistics, plot_ROC, print_OLS_error_table, plot_classification\n",
    "\n",
    "%matplotlib inline\n",
    "plt.style.use('seaborn-white')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.1 An Overview of Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>default</th>\n",
       "      <th>student</th>\n",
       "      <th>balance</th>\n",
       "      <th>income</th>\n",
       "      <th>default2</th>\n",
       "      <th>student2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "      <td>729.526495</td>\n",
       "      <td>44361.625074</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>No</td>\n",
       "      <td>Yes</td>\n",
       "      <td>817.180407</td>\n",
       "      <td>12106.134700</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "      <td>1073.549164</td>\n",
       "      <td>31767.138947</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  default student      balance        income  default2  student2\n",
       "1      No      No   729.526495  44361.625074         0         0\n",
       "2      No     Yes   817.180407  12106.134700         0         1\n",
       "3      No      No  1073.549164  31767.138947         0         0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# In R, I exported the dataset from package 'ISLR' to an Excel file\n",
    "df_default = pd.read_excel('Data/Default.xlsx')\n",
    "\n",
    "# Note: factorize() returns two objects: a label array and an array with the unique values.\n",
    "# We are only interested in the first object. \n",
    "df_default['default2'] = df_default.default.factorize()[0]\n",
    "df_default['student2'] = df_default.student.factorize()[0]\n",
    "df_default.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = plt.figure(figsize=(12,5))\n",
    "gs = mpl.gridspec.GridSpec(1, 4)\n",
    "ax1 = plt.subplot(gs[0,:-2])\n",
    "ax2 = plt.subplot(gs[0,-2])\n",
    "ax3 = plt.subplot(gs[0,-1])\n",
    "\n",
    "# Take a fraction of the samples where target value (default) is 'no'\n",
    "df_no = df_default[df_default.default2 == 0].sample(frac=.08)\n",
    "# Take all samples  where target value is 'yes'\n",
    "df_yes = df_default[df_default.default2 == 1]\n",
    "\n",
    "ax1.scatter(df_yes.balance, df_yes.income, s=40, c='orange', marker='+', linewidths=1)\n",
    "ax1.scatter(df_no.balance, df_no.income, s=40, marker='o', linewidths='1',\n",
    "            edgecolors='lightblue', facecolors='white', alpha=.6)\n",
    "\n",
    "ax1.set_ylim(ymin=0)\n",
    "ax1.set_ylabel('Income')\n",
    "ax1.set_xlim(xmin=-100)\n",
    "ax1.set_xlabel('Balance')\n",
    "\n",
    "c_palette = {'No':'lightblue', 'Yes':'orange'}\n",
    "sns.boxplot('default', 'balance', data=df_default, orient='v', ax=ax2, palette=c_palette)\n",
    "sns.boxplot('default', 'income', data=df_default, orient='v', ax=ax3, palette=c_palette)\n",
    "gs.tight_layout(plt.gcf())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.3 Logistic regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
      "          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n",
      "          penalty='l2', random_state=None, solver='newton-cg', tol=0.0001,\n",
      "          verbose=0, warm_start=False)\n",
      "classes:  [0 1]\n",
      "coefficients:  [-10.651330005794106, 0.0054989165568046445]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X_train = df_default.balance.values.reshape(-1,1) \n",
    "y = df_default.default2\n",
    "\n",
    "# Create array of test data. Calculate the classification probability\n",
    "# and predicted classification.\n",
    "X_plot = np.arange(df_default.balance.min(), df_default.balance.max()).reshape(-1,1)\n",
    "\n",
    "clf = skl_lm.LogisticRegression(solver='newton-cg')\n",
    "clf.fit(X_train,y)\n",
    "prob = clf.predict_proba(X_plot)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = plt.axes()\n",
    "# Right plot\n",
    "ax.scatter(X_train, y, color='orange')\n",
    "ax.plot(X_plot, prob[:,1], color='lightblue')\n",
    "\n",
    "ax.hlines(1, xmin=ax.xaxis.get_data_interval()[0],\n",
    "              xmax=ax.xaxis.get_data_interval()[1], linestyles='dashed', lw=1)\n",
    "ax.hlines(0, xmin=ax.xaxis.get_data_interval()[0],\n",
    "              xmax=ax.xaxis.get_data_interval()[1], linestyles='dashed', lw=1)\n",
    "ax.set_ylabel('Probability of default')\n",
    "ax.set_xlabel('Balance')\n",
    "ax.set_yticks([0, 0.25, 0.5, 0.75, 1.])\n",
    "ax.set_xlim(xmin=-100)\n",
    "print(clf)\n",
    "print('classes: ', clf.classes_)\n",
    "print('coefficients: ', [*clf.intercept_, *clf.coef_.tolist()[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### statsmodels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.079823\n",
      "         Iterations 10\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Coef.</th>\n",
       "      <th>Std.Err.</th>\n",
       "      <th>z</th>\n",
       "      <th>P&gt;|z|</th>\n",
       "      <th>[0.025</th>\n",
       "      <th>0.975]</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>const</th>\n",
       "      <td>-10.651331</td>\n",
       "      <td>0.361169</td>\n",
       "      <td>-29.491287</td>\n",
       "      <td>3.723665e-191</td>\n",
       "      <td>-11.359208</td>\n",
       "      <td>-9.943453</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>balance</th>\n",
       "      <td>0.005499</td>\n",
       "      <td>0.000220</td>\n",
       "      <td>24.952404</td>\n",
       "      <td>2.010855e-137</td>\n",
       "      <td>0.005067</td>\n",
       "      <td>0.005931</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             Coef.  Std.Err.          z          P>|z|     [0.025    0.975]\n",
       "const   -10.651331  0.361169 -29.491287  3.723665e-191 -11.359208 -9.943453\n",
       "balance   0.005499  0.000220  24.952404  2.010855e-137   0.005067  0.005931"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train = sm.add_constant(df_default.balance)\n",
    "y = df_default.default2\n",
    "est = smf.Logit(y.ravel(), X_train).fit()\n",
    "est.summary2().tables[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.145434\n",
      "         Iterations 7\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Coef.</th>\n",
       "      <th>Std.Err.</th>\n",
       "      <th>z</th>\n",
       "      <th>P&gt;|z|</th>\n",
       "      <th>[0.025</th>\n",
       "      <th>0.975]</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>const</th>\n",
       "      <td>-3.504128</td>\n",
       "      <td>0.070713</td>\n",
       "      <td>-49.554094</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>-3.642723</td>\n",
       "      <td>-3.365532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>student2</th>\n",
       "      <td>0.404887</td>\n",
       "      <td>0.115019</td>\n",
       "      <td>3.520177</td>\n",
       "      <td>0.000431</td>\n",
       "      <td>0.179454</td>\n",
       "      <td>0.630320</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             Coef.  Std.Err.          z     P>|z|    [0.025    0.975]\n",
       "const    -3.504128  0.070713 -49.554094  0.000000 -3.642723 -3.365532\n",
       "student2  0.404887  0.115019   3.520177  0.000431  0.179454  0.630320"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train = sm.add_constant(df_default.student2)\n",
    "y = df_default.default2\n",
    "est = smf.Logit(y, X_train).fit()\n",
    "est.summary2().tables[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.078577\n",
      "         Iterations 10\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "        <td>Model:</td>              <td>Logit</td>      <td>Pseudo R-squared:</td>    <td>0.462</td>   \n",
       "</tr>\n",
       "<tr>\n",
       "  <td>Dependent Variable:</td>     <td>default2</td>           <td>AIC:</td>         <td>1579.5448</td> \n",
       "</tr>\n",
       "<tr>\n",
       "         <td>Date:</td>        <td>2018-06-23 13:11</td>       <td>BIC:</td>         <td>1608.3862</td> \n",
       "</tr>\n",
       "<tr>\n",
       "   <td>No. Observations:</td>        <td>10000</td>       <td>Log-Likelihood:</td>    <td>-785.77</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "       <td>Df Model:</td>              <td>3</td>            <td>LL-Null:</td>        <td>-1460.3</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "     <td>Df Residuals:</td>          <td>9996</td>         <td>LLR p-value:</td>    <td>3.2575e-292</td>\n",
       "</tr>\n",
       "<tr>\n",
       "      <td>Converged:</td>           <td>1.0000</td>           <td>Scale:</td>         <td>1.0000</td>   \n",
       "</tr>\n",
       "<tr>\n",
       "    <td>No. Iterations:</td>        <td>10.0000</td>             <td></td>               <td></td>      \n",
       "</tr>\n",
       "</table>\n",
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "      <td></td>       <th>Coef.</th>  <th>Std.Err.</th>     <th>z</th>     <th>P>|z|</th>  <th>[0.025</th>  <th>0.975]</th> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>const</th>    <td>-10.8690</td>  <td>0.4923</td>  <td>-22.0793</td> <td>0.0000</td> <td>-11.8339</td> <td>-9.9042</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>balance</th>   <td>0.0057</td>   <td>0.0002</td>   <td>24.7365</td> <td>0.0000</td>  <td>0.0053</td>  <td>0.0062</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>income</th>    <td>0.0000</td>   <td>0.0000</td>   <td>0.3698</td>  <td>0.7115</td>  <td>-0.0000</td> <td>0.0000</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>student2</th>  <td>-0.6468</td>  <td>0.2363</td>   <td>-2.7376</td> <td>0.0062</td>  <td>-1.1098</td> <td>-0.1837</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.summary2.Summary'>\n",
       "\"\"\"\n",
       "                          Results: Logit\n",
       "==================================================================\n",
       "Model:              Logit            Pseudo R-squared: 0.462      \n",
       "Dependent Variable: default2         AIC:              1579.5448  \n",
       "Date:               2018-06-23 13:11 BIC:              1608.3862  \n",
       "No. Observations:   10000            Log-Likelihood:   -785.77    \n",
       "Df Model:           3                LL-Null:          -1460.3    \n",
       "Df Residuals:       9996             LLR p-value:      3.2575e-292\n",
       "Converged:          1.0000           Scale:            1.0000     \n",
       "No. Iterations:     10.0000                                       \n",
       "-------------------------------------------------------------------\n",
       "             Coef.    Std.Err.     z      P>|z|    [0.025    0.975]\n",
       "-------------------------------------------------------------------\n",
       "const       -10.8690    0.4923  -22.0793  0.0000  -11.8339  -9.9042\n",
       "balance       0.0057    0.0002   24.7365  0.0000    0.0053   0.0062\n",
       "income        0.0000    0.0000    0.3698  0.7115   -0.0000   0.0000\n",
       "student2     -0.6468    0.2363   -2.7376  0.0062   -1.1098  -0.1837\n",
       "==================================================================\n",
       "\n",
       "\"\"\""
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train = sm.add_constant(df_default[['balance', 'income', 'student2']])\n",
    "est = smf.Logit(y, X_train).fit()\n",
    "est.summary2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.4 Linear Discriminant Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_LDA(mean1=-2, mean2=1, mean3=2, sigma1=0.5, sigma2=0.5):\n",
    "    mean1 = mean1*np.array([1, 1])\n",
    "    mean2 = mean2*np.array([1, 1])\n",
    "    mean3 = mean3*np.array([1, -1])\n",
    "    cov = np.array([[sigma1, 0], [0, sigma2]])\n",
    "    N = 500\n",
    "    K = 3\n",
    "    # if you sample from a t-distribution, the LDA results are really bad\n",
    "    def multivariate_t(means, S, N):\n",
    "        df = 1\n",
    "        m = np.asarray(means)\n",
    "        d = len(means)\n",
    "        x = np.random.chisquare(df, N[0])/df\n",
    "        z = np.random.multivariate_normal(np.zeros(d), S, N)\n",
    "        return m + z/np.sqrt(x)[:,None]\n",
    "    sample1 = np.random.multivariate_normal(mean1, cov, (N,)) \n",
    "    sample2 = np.random.multivariate_normal(mean2, cov, (N,)) \n",
    "    sample3 = np.random.multivariate_normal(mean3, cov, (N,))\n",
    "    \n",
    "    maxX = np.max([sample1[:,0], sample2[:,0], sample3[:,0]])\n",
    "    minX = np.min([sample1[:,0], sample2[:,0], sample3[:,0]])\n",
    "    maxY = np.max([sample1[:,1], sample2[:,1], sample3[:,1]])\n",
    "    minY = np.min([sample1[:,1], sample2[:,1], sample3[:,1]])\n",
    "\n",
    "    # priors\n",
    "    pi1 = pi2 = pi3 = N/K\n",
    "\n",
    "    # grid of points to plot the bayes and LDA regions/lines\n",
    "    N_points_grid = 200\n",
    "    xx, yy = np.meshgrid(np.linspace(minX, maxX, N_points_grid), np.linspace(minY, maxY, N_points_grid))\n",
    "    X = np.c_[xx.ravel(), yy.ravel()]\n",
    "\n",
    "    # Bayes regions\n",
    "    inv_cov = np.linalg.inv(cov)\n",
    "    delta1_fun = lambda X: np.dot(X, np.dot(inv_cov, mean1)) - 1/2*np.dot(mean1.T, np.dot(inv_cov, mean1)) + np.log(pi1)\n",
    "    delta2_fun = lambda X: np.dot(X, np.dot(inv_cov, mean2)) - 1/2*np.dot(mean2.T, np.dot(inv_cov, mean2)) + np.log(pi2)\n",
    "    delta3_fun = lambda X: np.dot(X, np.dot(inv_cov, mean3)) - 1/2*np.dot(mean3.T, np.dot(inv_cov, mean3)) + np.log(pi3)\n",
    "    region1 = np.logical_and(delta1_fun(X) > delta2_fun(X), delta1_fun(X) > delta3_fun(X))\n",
    "    region2 = np.logical_and(delta2_fun(X) > delta1_fun(X), delta2_fun(X) > delta3_fun(X))\n",
    "    region3 = np.logical_and(delta3_fun(X) > delta1_fun(X), delta3_fun(X) > delta2_fun(X))\n",
    "\n",
    "    # LDA prediction\n",
    "    est_mean1 = 1/N*np.sum(sample1, axis=0)\n",
    "    est_mean2 = 1/N*np.sum(sample2, axis=0)\n",
    "    est_mean3 = 1/N*np.sum(sample3, axis=0)\n",
    "    est_cov = (np.cov(sample1, rowvar=False) + np.cov(sample2, rowvar=False) + np.cov(sample3, rowvar=False))/K\n",
    "    inv_est_cov = np.linalg.inv(est_cov)\n",
    "    est_delta1_fun = lambda X: np.dot(X, np.dot(inv_est_cov, est_mean1)) - 1/2*np.dot(est_mean1.T, np.dot(inv_est_cov, est_mean1)) + np.log(pi1)\n",
    "    est_delta2_fun = lambda X: np.dot(X, np.dot(inv_est_cov, est_mean2)) - 1/2*np.dot(est_mean2.T, np.dot(inv_est_cov, est_mean2)) + np.log(pi2)\n",
    "    est_delta3_fun = lambda X: np.dot(X, np.dot(inv_est_cov, est_mean3)) - 1/2*np.dot(est_mean3.T, np.dot(inv_est_cov, est_mean3)) + np.log(pi3)\n",
    "    est_region1 = np.logical_and(est_delta1_fun(X) > est_delta2_fun(X), est_delta1_fun(X) > est_delta3_fun(X))\n",
    "    est_region3 = np.logical_and(est_delta3_fun(X) > est_delta2_fun(X), est_delta3_fun(X) > est_delta1_fun(X))\n",
    "\n",
    "    ### Plot Bayes regions and LDA lines\n",
    "    fig = plt.figure(figsize=(8,8))\n",
    "    ax = plt.subplot(1,1,1)\n",
    "\n",
    "    # Bayes regions\n",
    "    plt.contourf(xx, yy, region1.reshape(xx.shape), alpha=0.5, colors='g', levels=[0.5, 1.0])\n",
    "    plt.contourf(xx, yy, region2.reshape(xx.shape), alpha=0.5, colors='orange', levels=[0.5, 1.0])\n",
    "    plt.contourf(xx, yy, region3.reshape(xx.shape), alpha=0.5, colors='b', levels=[0.5, 1.0])\n",
    "\n",
    "    # Samples\n",
    "    ax.scatter(sample1[:,0], sample1[:,1], s=20, c='green', marker='o', label='1')\n",
    "    ax.scatter(sample2[:,0], sample2[:,1], s=20, c='orange', marker='o', label='2')\n",
    "    ax.scatter(sample3[:,0], sample3[:,1], s=20, c='blue', marker='o', label='3')\n",
    "    ax.set_xlabel('X1');\n",
    "    ax.set_ylabel('X2');\n",
    "\n",
    "    # LDA lines\n",
    "    plt.contour(xx, yy, est_region1.reshape(xx.shape), alpha=0.5, colors='k')\n",
    "    plt.contour(xx, yy, est_region3.reshape(xx.shape), alpha=0.5, colors='k');\n",
    "    \n",
    "    # statistics\n",
    "    pred_green = sum([np.logical_and(delta1_fun(x) > delta2_fun(x), delta1_fun(x) > delta3_fun(x)) for x in sample1])/N*100\n",
    "    pred_orange = sum([np.logical_and(delta2_fun(x) > delta1_fun(x), delta2_fun(x) > delta3_fun(x)) for x in sample2])/N*100\n",
    "    pred_blue = sum([np.logical_and(delta3_fun(x) > delta2_fun(x), delta3_fun(x) > delta1_fun(x)) for x in sample3])/N*100\n",
    "    print('Bayes accuracy: ', np.round(pred_green, 1), np.round(pred_orange, 1), np.round(pred_blue, 1))\n",
    "    est_pred_green = sum([np.logical_and(est_delta1_fun(x) > est_delta2_fun(x), est_delta1_fun(x) > est_delta3_fun(x)) for x in sample1])/N*100\n",
    "    est_pred_orange = sum([np.logical_and(est_delta2_fun(x) > est_delta1_fun(x), est_delta2_fun(x) > est_delta3_fun(x)) for x in sample2])/N*100\n",
    "    est_pred_blue = sum([np.logical_and(est_delta3_fun(x) > est_delta2_fun(x), est_delta3_fun(x) > est_delta1_fun(x)) for x in sample3])/N*100\n",
    "    print('LDA accuracy: ', np.round(est_pred_green, 1), np.round(est_pred_orange, 1), np.round(est_pred_blue, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a68581678eee484bbd7bff34f6b73a4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=-2.0, description='mean1', max=2.0, min=-2.0, step=0.5), FloatSlider(v…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "interactive_plot = widgets.interactive(generate_LDA, \n",
    "                 mean1=(-2,2,0.5), mean2=(-2,2,0.5), mean3=(-2,2,0.5),\n",
    "                 sigma1=(0.1,5,0.1), sigma2=(0.1,5,0.1),\n",
    "                 continuous_update=False);\n",
    "output = interactive_plot.children[-1]\n",
    "output.layout.height = '15cm'\n",
    "interactive_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Default dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,\n",
       "              solver='svd', store_covariance=False, tol=0.0001)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = df_default[['balance', 'income', 'student2']].values\n",
    "y = df_default.default2.values\n",
    "\n",
    "lda = LinearDiscriminantAnalysis(solver='svd')\n",
    "lda.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.974     0.998     0.986      9667\n",
      "         Up      0.782     0.237     0.364       333\n",
      "\n",
      "avg / total      0.968     0.972     0.965     10000\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.997724  0.002276\n",
      "     False  0.762763  0.237237\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_classification_statistics(lda, X, y, labels=['Down', 'Up'])\n",
    "plot_ROC(lda, X, y, label='LDA Classification')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.4.4 Quadratic Discriminant Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_QDA(mean1=-2, mean2=1, sigma1=1, sigma2=0.5):\n",
    "    mean1 = mean1*np.array([1, 1])\n",
    "    mean2 = mean2*np.array([1, 1])\n",
    "    cov1 = np.array([[sigma1, 0], [0, sigma1]])\n",
    "    cov2 = np.array([[sigma2, 0], [0, sigma2]])\n",
    "    inv_cov1 = np.linalg.inv(cov1)\n",
    "    inv_cov2 = np.linalg.inv(cov2)\n",
    "\n",
    "    N = 500\n",
    "    K = 2\n",
    "    sample1 = np.random.multivariate_normal(mean1, cov1, (N,)) \n",
    "    sample2 = np.random.multivariate_normal(mean2, cov2, (N,)) \n",
    "    maxX = np.max([sample1[:,0], sample2[:,0]])\n",
    "    minX = np.min([sample1[:,0], sample2[:,0]])\n",
    "    maxY = np.max([sample1[:,1], sample2[:,1]])\n",
    "    minY = np.min([sample1[:,1], sample2[:,1]])\n",
    "\n",
    "    pi1 = pi2 = N/K\n",
    "\n",
    "    # grid of points to plot the bayes and LDA regions/lines\n",
    "    N_points_grid = 150\n",
    "    xx, yy = np.meshgrid(np.linspace(minX, maxX, N_points_grid), np.linspace(minY, maxY, N_points_grid))\n",
    "    X = np.c_[xx.ravel(), yy.ravel()]\n",
    "\n",
    "    delta1_fun = lambda Xin: -1/2*((Xin-mean1).dot(inv_cov1)*(Xin-mean1)).sum(axis=1) + np.log(pi1)\n",
    "    delta2_fun = lambda Xin: -1/2*((Xin-mean2).dot(inv_cov2)*(Xin-mean2)).sum(axis=1) + np.log(pi2)\n",
    "    region1 = delta1_fun(X) > delta2_fun(X)\n",
    "    region2 = delta2_fun(X) > delta1_fun(X)\n",
    "\n",
    "    # prediction\n",
    "    est_mean1 = 1/N*np.sum(sample1, axis=0)\n",
    "    est_mean2 = 1/N*np.sum(sample2, axis=0)\n",
    "    est_cov1 = np.cov(sample1, rowvar=False) \n",
    "    est_cov2 = np.cov(sample2, rowvar=False)\n",
    "    inv_est_cov1 = np.linalg.inv(est_cov1)\n",
    "    inv_est_cov2 = np.linalg.inv(est_cov2)\n",
    "    est_delta1_fun = lambda Xin: -1/2*((Xin-est_mean1).dot(inv_est_cov1)*(Xin-est_mean1)).sum(axis=1) + np.log(pi1)\n",
    "    est_delta2_fun = lambda Xin: -1/2*((Xin-est_mean2).dot(inv_est_cov2)*(Xin-est_mean2)).sum(axis=1) + np.log(pi2)\n",
    "    est_region1 = est_delta1_fun(X) > est_delta2_fun(X)\n",
    "\n",
    "    fig = plt.figure(figsize=(8,8))\n",
    "    ax = plt.subplot(1,1,1)\n",
    "\n",
    "    # Bayes regions\n",
    "    plt.contourf(xx, yy, region1.reshape(xx.shape), alpha=0.5, colors='g', levels=[0.5, 1.0])\n",
    "    plt.contourf(xx, yy, region2.reshape(xx.shape), alpha=0.5, colors='orange', levels=[0.5, 1.0])\n",
    "\n",
    "    # Samples\n",
    "    ax.scatter(sample1[:,0], sample1[:,1], s=20, c='green', marker='o',)\n",
    "    ax.scatter(sample2[:,0], sample2[:,1], s=20, c='orange', marker='o',)\n",
    "    ax.set_xlabel('X1');\n",
    "    ax.set_ylabel('X2');\n",
    "\n",
    "    # LDA lines\n",
    "    plt.contour(xx, yy, est_region1.reshape(xx.shape), alpha=0.5, colors='k');\n",
    "    \n",
    "    pred_green = sum(delta1_fun(sample1) > delta2_fun(sample1))/N*100\n",
    "    pred_orange = sum(delta2_fun(sample2) > delta1_fun(sample2))/N*100\n",
    "    print('Bayes accuracy: ', np.round(pred_green, 1), np.round(pred_orange, 1))\n",
    "    est_pred_green = sum(est_delta1_fun(sample1) > est_delta2_fun(sample1))/N*100\n",
    "    est_pred_orange = sum(est_delta2_fun(sample2) > est_delta1_fun(sample2))/N*100\n",
    "    print('LDA accuracy: ', np.round(est_pred_green, 1), np.round(est_pred_orange, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e3ffe5ce9b04013a5b8426437b0f6a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=-2.0, description='mean1', max=2.0, min=-2.0, step=0.5), FloatSlider(v…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "interactive_plot = widgets.interactive(plot_QDA, \n",
    "                                 mean1=(-2,2,0.5), mean2=(-2,2,0.5),\n",
    "                                 sigma1=(0.1,5,0.1), sigma2=(0.1,5,0.1),\n",
    "                                 continuous_update=False);\n",
    "output = interactive_plot.children[-1]\n",
    "output.layout.height = '15cm'\n",
    "interactive_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.5 A Comparison of Classification Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare Logistic regression, LDA, QDA and KNN under different conditions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.6 Lab: Logistic Regression, LDA, QDA, and KNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.6.1 The Stock Market Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Lag1</th>\n",
       "      <th>Lag2</th>\n",
       "      <th>Lag3</th>\n",
       "      <th>Lag4</th>\n",
       "      <th>Lag5</th>\n",
       "      <th>Volume</th>\n",
       "      <th>Today</th>\n",
       "      <th>Direction</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Year</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2001-01-01</th>\n",
       "      <td>0.381</td>\n",
       "      <td>-0.192</td>\n",
       "      <td>-2.624</td>\n",
       "      <td>-1.055</td>\n",
       "      <td>5.010</td>\n",
       "      <td>1.1913</td>\n",
       "      <td>0.959</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001-01-01</th>\n",
       "      <td>0.959</td>\n",
       "      <td>0.381</td>\n",
       "      <td>-0.192</td>\n",
       "      <td>-2.624</td>\n",
       "      <td>-1.055</td>\n",
       "      <td>1.2965</td>\n",
       "      <td>1.032</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001-01-01</th>\n",
       "      <td>1.032</td>\n",
       "      <td>0.959</td>\n",
       "      <td>0.381</td>\n",
       "      <td>-0.192</td>\n",
       "      <td>-2.624</td>\n",
       "      <td>1.4112</td>\n",
       "      <td>-0.623</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001-01-01</th>\n",
       "      <td>-0.623</td>\n",
       "      <td>1.032</td>\n",
       "      <td>0.959</td>\n",
       "      <td>0.381</td>\n",
       "      <td>-0.192</td>\n",
       "      <td>1.2760</td>\n",
       "      <td>0.614</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001-01-01</th>\n",
       "      <td>0.614</td>\n",
       "      <td>-0.623</td>\n",
       "      <td>1.032</td>\n",
       "      <td>0.959</td>\n",
       "      <td>0.381</td>\n",
       "      <td>1.2057</td>\n",
       "      <td>0.213</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             Lag1   Lag2   Lag3   Lag4   Lag5  Volume  Today  Direction\n",
       "Year                                                                   \n",
       "2001-01-01  0.381 -0.192 -2.624 -1.055  5.010  1.1913  0.959          1\n",
       "2001-01-01  0.959  0.381 -0.192 -2.624 -1.055  1.2965  1.032          1\n",
       "2001-01-01  1.032  0.959  0.381 -0.192 -2.624  1.4112 -0.623          0\n",
       "2001-01-01 -0.623  1.032  0.959  0.381 -0.192  1.2760  0.614          1\n",
       "2001-01-01  0.614 -0.623  1.032  0.959  0.381  1.2057  0.213          1"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_stock = pd.read_csv('Data/Smarket.csv', usecols=range(1,10), index_col=0, parse_dates=True)\n",
    "# convert direction to binary. Up is 1, Down is 0\n",
    "df_stock.replace({'Up': 1, 'Down': 0}, inplace=True)\n",
    "df_stock.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Lag1</th>\n",
       "      <th>Lag2</th>\n",
       "      <th>Lag3</th>\n",
       "      <th>Lag4</th>\n",
       "      <th>Lag5</th>\n",
       "      <th>Volume</th>\n",
       "      <th>Today</th>\n",
       "      <th>Direction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.00000</td>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.000000</td>\n",
       "      <td>1250.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.003834</td>\n",
       "      <td>0.003919</td>\n",
       "      <td>0.001716</td>\n",
       "      <td>0.001636</td>\n",
       "      <td>0.00561</td>\n",
       "      <td>1.478305</td>\n",
       "      <td>0.003138</td>\n",
       "      <td>0.518400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>1.136299</td>\n",
       "      <td>1.136280</td>\n",
       "      <td>1.138703</td>\n",
       "      <td>1.138774</td>\n",
       "      <td>1.14755</td>\n",
       "      <td>0.360357</td>\n",
       "      <td>1.136334</td>\n",
       "      <td>0.499861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>-4.922000</td>\n",
       "      <td>-4.922000</td>\n",
       "      <td>-4.922000</td>\n",
       "      <td>-4.922000</td>\n",
       "      <td>-4.92200</td>\n",
       "      <td>0.356070</td>\n",
       "      <td>-4.922000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>-0.639500</td>\n",
       "      <td>-0.639500</td>\n",
       "      <td>-0.640000</td>\n",
       "      <td>-0.640000</td>\n",
       "      <td>-0.64000</td>\n",
       "      <td>1.257400</td>\n",
       "      <td>-0.639500</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.039000</td>\n",
       "      <td>0.039000</td>\n",
       "      <td>0.038500</td>\n",
       "      <td>0.038500</td>\n",
       "      <td>0.03850</td>\n",
       "      <td>1.422950</td>\n",
       "      <td>0.038500</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.596750</td>\n",
       "      <td>0.596750</td>\n",
       "      <td>0.596750</td>\n",
       "      <td>0.596750</td>\n",
       "      <td>0.59700</td>\n",
       "      <td>1.641675</td>\n",
       "      <td>0.596750</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>5.733000</td>\n",
       "      <td>5.733000</td>\n",
       "      <td>5.733000</td>\n",
       "      <td>5.733000</td>\n",
       "      <td>5.73300</td>\n",
       "      <td>3.152470</td>\n",
       "      <td>5.733000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Lag1         Lag2         Lag3         Lag4        Lag5  \\\n",
       "count  1250.000000  1250.000000  1250.000000  1250.000000  1250.00000   \n",
       "mean      0.003834     0.003919     0.001716     0.001636     0.00561   \n",
       "std       1.136299     1.136280     1.138703     1.138774     1.14755   \n",
       "min      -4.922000    -4.922000    -4.922000    -4.922000    -4.92200   \n",
       "25%      -0.639500    -0.639500    -0.640000    -0.640000    -0.64000   \n",
       "50%       0.039000     0.039000     0.038500     0.038500     0.03850   \n",
       "75%       0.596750     0.596750     0.596750     0.596750     0.59700   \n",
       "max       5.733000     5.733000     5.733000     5.733000     5.73300   \n",
       "\n",
       "            Volume        Today    Direction  \n",
       "count  1250.000000  1250.000000  1250.000000  \n",
       "mean      1.478305     0.003138     0.518400  \n",
       "std       0.360357     1.136334     0.499861  \n",
       "min       0.356070    -4.922000     0.000000  \n",
       "25%       1.257400    -0.639500     0.000000  \n",
       "50%       1.422950     0.038500     1.000000  \n",
       "75%       1.641675     0.596750     1.000000  \n",
       "max       3.152470     5.733000     1.000000  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_stock.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Lag1</th>\n",
       "      <th>Lag2</th>\n",
       "      <th>Lag3</th>\n",
       "      <th>Lag4</th>\n",
       "      <th>Lag5</th>\n",
       "      <th>Volume</th>\n",
       "      <th>Today</th>\n",
       "      <th>Direction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Lag1</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.026294</td>\n",
       "      <td>-0.010803</td>\n",
       "      <td>-0.002986</td>\n",
       "      <td>-0.005675</td>\n",
       "      <td>0.040910</td>\n",
       "      <td>-0.026155</td>\n",
       "      <td>-0.039757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Lag2</th>\n",
       "      <td>-0.026294</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.025897</td>\n",
       "      <td>-0.010854</td>\n",
       "      <td>-0.003558</td>\n",
       "      <td>-0.043383</td>\n",
       "      <td>-0.010250</td>\n",
       "      <td>-0.024081</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Lag3</th>\n",
       "      <td>-0.010803</td>\n",
       "      <td>-0.025897</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.024051</td>\n",
       "      <td>-0.018808</td>\n",
       "      <td>-0.041824</td>\n",
       "      <td>-0.002448</td>\n",
       "      <td>0.006132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Lag4</th>\n",
       "      <td>-0.002986</td>\n",
       "      <td>-0.010854</td>\n",
       "      <td>-0.024051</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.027084</td>\n",
       "      <td>-0.048414</td>\n",
       "      <td>-0.006900</td>\n",
       "      <td>0.004215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Lag5</th>\n",
       "      <td>-0.005675</td>\n",
       "      <td>-0.003558</td>\n",
       "      <td>-0.018808</td>\n",
       "      <td>-0.027084</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-0.022002</td>\n",
       "      <td>-0.034860</td>\n",
       "      <td>0.005423</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Volume</th>\n",
       "      <td>0.040910</td>\n",
       "      <td>-0.043383</td>\n",
       "      <td>-0.041824</td>\n",
       "      <td>-0.048414</td>\n",
       "      <td>-0.022002</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.014592</td>\n",
       "      <td>0.022951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Today</th>\n",
       "      <td>-0.026155</td>\n",
       "      <td>-0.010250</td>\n",
       "      <td>-0.002448</td>\n",
       "      <td>-0.006900</td>\n",
       "      <td>-0.034860</td>\n",
       "      <td>0.014592</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.730563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Direction</th>\n",
       "      <td>-0.039757</td>\n",
       "      <td>-0.024081</td>\n",
       "      <td>0.006132</td>\n",
       "      <td>0.004215</td>\n",
       "      <td>0.005423</td>\n",
       "      <td>0.022951</td>\n",
       "      <td>0.730563</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               Lag1      Lag2      Lag3      Lag4      Lag5    Volume  \\\n",
       "Lag1       1.000000 -0.026294 -0.010803 -0.002986 -0.005675  0.040910   \n",
       "Lag2      -0.026294  1.000000 -0.025897 -0.010854 -0.003558 -0.043383   \n",
       "Lag3      -0.010803 -0.025897  1.000000 -0.024051 -0.018808 -0.041824   \n",
       "Lag4      -0.002986 -0.010854 -0.024051  1.000000 -0.027084 -0.048414   \n",
       "Lag5      -0.005675 -0.003558 -0.018808 -0.027084  1.000000 -0.022002   \n",
       "Volume     0.040910 -0.043383 -0.041824 -0.048414 -0.022002  1.000000   \n",
       "Today     -0.026155 -0.010250 -0.002448 -0.006900 -0.034860  0.014592   \n",
       "Direction -0.039757 -0.024081  0.006132  0.004215  0.005423  0.022951   \n",
       "\n",
       "              Today  Direction  \n",
       "Lag1      -0.026155  -0.039757  \n",
       "Lag2      -0.010250  -0.024081  \n",
       "Lag3      -0.002448   0.006132  \n",
       "Lag4      -0.006900   0.004215  \n",
       "Lag5      -0.034860   0.005423  \n",
       "Volume     0.014592   0.022951  \n",
       "Today      1.000000   0.730563  \n",
       "Direction  0.730563   1.000000  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_stock.corr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.04090991,  0.04338321,  0.04182369,  0.04841425,  0.03486008,\n",
       "        0.02295096,  0.7305629 ,  0.        ])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# very small correlations (today and direction are obiously correlated)\n",
    "corr = df_stock.corr().values\n",
    "np.max(np.abs(np.triu(corr, k=1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# volume increases with year\n",
    "plot = sns.boxplot(df_stock.index, df_stock['Volume'],)\n",
    "plot.set_xticklabels([str(date.year) for date in df_stock.index.unique()]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# volumne by year and direction\n",
    "ax = sns.boxplot(df_stock.index, df_stock['Volume'], hue=df_stock['Direction'])\n",
    "ax.set_xticklabels([str(date.year) for date in df_stock.index.unique()])\n",
    "handles, _ = ax.get_legend_handles_labels()\n",
    "ax.legend(handles, [\"Down\", \"Up\"]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df_stock[df_stock.columns.difference(['Today', 'Direction'])]\n",
    "y = df_stock['Direction']\n",
    "X_train = X[:'2004']\n",
    "y_train = y[:'2004']\n",
    "X_test = X['2005':]\n",
    "y_test = y['2005':]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.6.2 Logistic regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logistic regression, not test/train split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No. Observations: 1250\n",
      "Df Residuals: 1243\n",
      "Df Model: 6\n",
      "Log-Likelihood: -863.79\n",
      "AIC: 1741.58\n",
      "           Coefficients  Standard Errors  t values  p values\n",
      "Intercept       -0.1259            0.241    -0.523     0.601\n",
      "Lag1            -0.0731            0.050    -1.457     0.145\n",
      "Lag2            -0.0423            0.050    -0.845     0.399\n",
      "Lag3             0.0111            0.050     0.222     0.824\n",
      "Lag4             0.0094            0.050     0.187     0.851\n",
      "Lag5             0.0103            0.050     0.208     0.835\n",
      "Volume           0.1354            0.158     0.855     0.393\n",
      "\n",
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.507     0.241     0.327       602\n",
      "         Up      0.526     0.782     0.629       648\n",
      "\n",
      "avg / total      0.517     0.522     0.483      1250\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.240864  0.759136\n",
      "     False  0.217593  0.782407\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "logistic = skl_lm.LogisticRegression(C=1e10)\n",
    "logistic.fit(X, y)\n",
    "print_OLS_error_table(logistic, X, y)\n",
    "print_classification_statistics(logistic, X, y, labels=['Down', 'Up'])\n",
    "plot_ROC(logistic, X, y, label='Logistic Classification')\n",
    "\n",
    "# same results as statsmodels\n",
    "#smLogistic = sm.Logit(y, sm.add_constant(X)).fit()\n",
    "#print(smLogistic.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logistic regression, with test/train split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No. Observations: 998\n",
      "Df Residuals: 991\n",
      "Df Model: 6\n",
      "Log-Likelihood: -690.55\n",
      "AIC: 1395.11\n",
      "           Coefficients  Standard Errors  t values  p values\n",
      "Intercept        0.1912            0.334     0.573     0.567\n",
      "Lag1            -0.0542            0.052    -1.046     0.296\n",
      "Lag2            -0.0458            0.052    -0.884     0.377\n",
      "Lag3             0.0072            0.052     0.139     0.889\n",
      "Lag4             0.0064            0.052     0.125     0.901\n",
      "Lag5            -0.0042            0.051    -0.083     0.934\n",
      "Volume          -0.1162            0.240    -0.485     0.628\n",
      "\n",
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.443     0.694     0.540       111\n",
      "         Up      0.564     0.312     0.402       141\n",
      "\n",
      "avg / total      0.511     0.480     0.463       252\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.693694  0.306306\n",
      "     False  0.687943  0.312057\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "logistic_test = skl_lm.LogisticRegression(C=1e10)\n",
    "logistic_test.fit(X_train, y_train)\n",
    "print_OLS_error_table(logistic_test, X_train, y_train)\n",
    "print_classification_statistics(logistic_test, X_test, y_test, labels=['Down', 'Up'])\n",
    "plot_ROC(logistic_test, X_test, y_test, label='Logistic Classification Train/Test')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.6.3 Linear Discriminant Analysis (LDA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use only Lag1 and Lag2\n",
    "X_train2 = X_train[['Lag1','Lag2']]\n",
    "X_test2 = X_test[['Lag1','Lag2']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prior probabilities of groups: \n",
      "      Down        Up\n",
      "  0.491984  0.508016\n",
      "\n",
      "Group means: \n",
      "          Lag1      Lag2\n",
      "Down  0.042790  0.033894\n",
      "Up   -0.039546 -0.031325\n",
      "\n",
      "Coefficients of linear discriminant: \n",
      "           LDA\n",
      "Lag1 -0.642019\n",
      "Lag2 -0.513529\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lda = LinearDiscriminantAnalysis()\n",
    "lda.fit(X_train2, y_train)\n",
    "print('Prior probabilities of groups: ')\n",
    "print(pd.DataFrame(data=lda.priors_.reshape((1,2)), columns=['Down', 'Up'], index=['']))\n",
    "print()\n",
    "print('Group means: ')\n",
    "print(pd.DataFrame(data=lda.means_, columns=X_train2.columns, index=['Down', 'Up']))\n",
    "print()\n",
    "print('Coefficients of linear discriminant: ')\n",
    "print(pd.DataFrame(data=lda.scalings_, columns=['LDA'], index=X_train2.columns))\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.500     0.315     0.387       111\n",
      "         Up      0.582     0.752     0.656       141\n",
      "\n",
      "avg / total      0.546     0.560     0.538       252\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.315315  0.684685\n",
      "     False  0.248227  0.751773\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_classification_statistics(lda, X_test2, y_test, labels=['Down', 'Up'])\n",
    "plot_ROC(lda, X_test2, y_test, label='LDA Train/Test, only Lag1 and Lag2')\n",
    "plot_classification(lda, X_test2, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.6.4 Quadratic Discriminant Analysis (QDA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prior probabilities of groups: \n",
      "      Down        Up\n",
      "  0.491984  0.508016\n",
      "\n",
      "Group means: \n",
      "          Lag1      Lag2\n",
      "Down  0.042790  0.033894\n",
      "Up   -0.039546 -0.031325\n",
      "\n"
     ]
    }
   ],
   "source": [
    "qda = QuadraticDiscriminantAnalysis()\n",
    "qda.fit(X_train2, y_train)\n",
    "print('Prior probabilities of groups: ')\n",
    "print(pd.DataFrame(data=qda.priors_.reshape((1,2)), columns=['Down', 'Up'], index=['']))\n",
    "print()\n",
    "print('Group means: ')\n",
    "print(pd.DataFrame(data=qda.means_, columns=X_train2.columns, index=['Down', 'Up']))\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.600     0.270     0.373       111\n",
      "         Up      0.599     0.858     0.706       141\n",
      "\n",
      "avg / total      0.599     0.599     0.559       252\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.270270  0.729730\n",
      "     False  0.141844  0.858156\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_classification_statistics(qda, X_test2, y_test, labels=['Down', 'Up'])\n",
    "plot_ROC(qda, X_test2, y_test, label='QDA Train/Test, only Lag1 and Lag2')\n",
    "plot_classification(qda, X_test2, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.6.5 K-Nearest Neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=3, p=2,\n",
       "           weights='uniform')"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_neighbors = 3\n",
    "knn = neighbors.KNeighborsClassifier(n_neighbors)\n",
    "knn.fit(X_train2, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "       Down      0.466     0.432     0.449       111\n",
      "         Up      0.577     0.610     0.593       141\n",
      "\n",
      "avg / total      0.528     0.532     0.529       252\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.432432  0.567568\n",
      "     False  0.390071  0.609929\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print_classification_statistics(knn, X_test2, y_test, labels=['Down', 'Up'])\n",
    "plot_ROC(knn, X_test2, y_test, label='KNN Train/Test, only Lag1 and Lag2')\n",
    "plot_classification(knn, X_test2, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.6.6 An Application to Caravan Insurance Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>MOSTYPE</th>\n",
       "      <th>MAANTHUI</th>\n",
       "      <th>MGEMOMV</th>\n",
       "      <th>MGEMLEEF</th>\n",
       "      <th>MOSHOOFD</th>\n",
       "      <th>MGODRK</th>\n",
       "      <th>MGODPR</th>\n",
       "      <th>MGODOV</th>\n",
       "      <th>MGODGE</th>\n",
       "      <th>...</th>\n",
       "      <th>APERSONG</th>\n",
       "      <th>AGEZONG</th>\n",
       "      <th>AWAOREG</th>\n",
       "      <th>ABRAND</th>\n",
       "      <th>AZEILPL</th>\n",
       "      <th>APLEZIER</th>\n",
       "      <th>AFIETS</th>\n",
       "      <th>AINBOED</th>\n",
       "      <th>ABYSTAND</th>\n",
       "      <th>Purchase</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>33</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>37</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>37</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>40</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 87 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  MOSTYPE  MAANTHUI  MGEMOMV  MGEMLEEF  MOSHOOFD  MGODRK  MGODPR  \\\n",
       "0           1       33         1        3         2         8       0       5   \n",
       "1           2       37         1        2         2         8       1       4   \n",
       "2           3       37         1        2         2         8       0       4   \n",
       "3           4        9         1        3         3         3       2       3   \n",
       "4           5       40         1        4         2        10       1       4   \n",
       "\n",
       "   MGODOV  MGODGE    ...     APERSONG  AGEZONG  AWAOREG  ABRAND  AZEILPL  \\\n",
       "0       1       3    ...            0        0        0       1        0   \n",
       "1       1       4    ...            0        0        0       1        0   \n",
       "2       2       4    ...            0        0        0       1        0   \n",
       "3       2       4    ...            0        0        0       1        0   \n",
       "4       1       4    ...            0        0        0       1        0   \n",
       "\n",
       "   APLEZIER  AFIETS  AINBOED  ABYSTAND  Purchase  \n",
       "0         0       0        0         0        No  \n",
       "1         0       0        0         0        No  \n",
       "2         0       0        0         0        No  \n",
       "3         0       0        0         0        No  \n",
       "4         0       0        0         0        No  \n",
       "\n",
       "[5 rows x 87 columns]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_caravan = pd.read_csv('Data/Caravan.csv')\n",
    "df_caravan['Purchase'] = df_caravan['Purchase'].astype('category')\n",
    "df_caravan.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "No     5474\n",
       "Yes     348\n",
       "Name: Purchase, dtype: int64"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_caravan['Purchase'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = df_caravan.Purchase\n",
    "X = df_caravan.drop('Purchase', axis=1).astype('float64')\n",
    "X_scaled = preprocessing.scale(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = X_scaled[:1000]\n",
    "y_test = y[:1000]\n",
    "X_train = X_scaled[1000:]\n",
    "y_train = y[1000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using 1 neighbors\n",
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "         No      0.948     0.937     0.943       941\n",
      "        Yes      0.157     0.186     0.171        59\n",
      "\n",
      "avg / total      0.902     0.893     0.897      1000\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.937301  0.062699\n",
      "     False  0.813559  0.186441\n",
      "\n",
      "Using 3 neighbors\n",
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "         No      0.946     0.979     0.962       941\n",
      "        Yes      0.231     0.102     0.141        59\n",
      "\n",
      "avg / total      0.903     0.927     0.913      1000\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.978746  0.021254\n",
      "     False  0.898305  0.101695\n",
      "\n",
      "Using 5 neighbors\n",
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "         No      0.944     0.993     0.968       941\n",
      "        Yes      0.364     0.068     0.114        59\n",
      "\n",
      "avg / total      0.910     0.938     0.918      1000\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.992561  0.007439\n",
      "     False  0.932203  0.067797\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in [1, 3, 5]:\n",
    "    print(f'Using {i} neighbors')\n",
    "    knn = neighbors.KNeighborsClassifier(n_neighbors=i)\n",
    "    knn.fit(X_train, y_train)\n",
    "    print_classification_statistics(knn, X_test, y_test, labels=['No', 'Yes'])\n",
    "    #plot_ROC(knn, X_test, y_test, label='KNN')\n",
    "    #skplt.metrics.plot_confusion_matrix(y_test, knn.predict(X_test), normalize=False)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Report:\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "         No      0.941     0.994     0.966       941\n",
      "        Yes      0.000     0.000     0.000        59\n",
      "\n",
      "avg / total      0.885     0.935     0.909      1000\n",
      "\n",
      "Confusion Matrix:\n",
      "           Predicted          \n",
      "                True     False\n",
      "Real True   0.993624  0.006376\n",
      "     False  1.000000  0.000000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "logistic = skl_lm.LogisticRegression(C=1e10)\n",
    "logistic.fit(X_train, y_train)\n",
    "print_classification_statistics(logistic, X_test, y_test, labels=['No', 'Yes'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "         No       0.95      0.98      0.96       941\n",
      "        Yes       0.34      0.19      0.24        59\n",
      "\n",
      "avg / total       0.91      0.93      0.92      1000\n",
      "\n",
      "Pred   No  Yes\n",
      "True          \n",
      "No    920   21\n",
      "Yes    48   11\n"
     ]
    }
   ],
   "source": [
    "# using 25% changes of buying instead of 50%\n",
    "pred_p = logistic.predict_proba(X_test)\n",
    "cm_df = pd.DataFrame({'True': y_test, 'Pred': pred_p[:,1] > .25})\n",
    "cm_df.Pred.replace(to_replace={True:'Yes', False:'No'}, inplace=True)\n",
    "print(classification_report(y_test, cm_df.Pred))\n",
    "print(cm_df.groupby(['True', 'Pred']).size().unstack('True').T)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "013d5d7b1aa446d190005ec1533089c0": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "04d85ef6d6bf40b5bc00c8405e54764b": {
      "model_module": "@jupyter-widgets/output",
      "model_module_version": "1.0.0",
      "model_name": "OutputModel",
      "state": {
       "layout": "IPY_MODEL_2c8bb972c2e5454bbe1cc7b674889caa",
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": "Bayes accuracy:  93.4 82.2\nLDA accuracy:  94.0 81.6\n"
        },
        {
         "data": {
          "image/png": "\n",
          "text/plain": "<Figure size 576x576 with 1 Axes>"
         },
         "metadata": {},
         "output_type": "display_data"
        }
       ]
      }
     },
     "060dd90dce484d8488b21ac24a8689a2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "127f59fb10e6405f9c032090244584ac": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "130a9ffcdb2f4e2a8af156c2b1e5d19d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "VBoxModel",
      "state": {
       "_dom_classes": [
        "widget-interact"
       ],
       "children": [
        "IPY_MODEL_57cd4b3d181f471fab0b9e136ea4beaa",
        "IPY_MODEL_3092abfa779b4e01862e66f0e28b7064",
        "IPY_MODEL_734cdfcb36a74bf9947f2e04a7b69d2f",
        "IPY_MODEL_9185d46ef72f460ba90d7fad794836c6",
        "IPY_MODEL_1505547a028a4959aa8f8225b611d507",
        "IPY_MODEL_36bba20749084534bec436b74418a25f"
       ],
       "layout": "IPY_MODEL_867eb8d642324162ab76f0ec96bc4462"
      }
     },
     "1505547a028a4959aa8f8225b611d507": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "sigma2",
       "layout": "IPY_MODEL_45f98f7f180b47728ee09dc6b66eadee",
       "max": 5,
       "min": 0.1,
       "step": 0.1,
       "style": "IPY_MODEL_5267ded6b1884918af7d6043cc3b9560",
       "value": 0.5
      }
     },
     "270c25a8c2764db39c5c18e0874bbdaa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "2c8bb972c2e5454bbe1cc7b674889caa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {
       "height": "15cm"
      }
     },
     "2d488c83ab604ab681ce6414d1a89628": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {
       "height": "15cm"
      }
     },
     "3092abfa779b4e01862e66f0e28b7064": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "mean2",
       "layout": "IPY_MODEL_e9e7ffa7d08b4bbda69a30e02348813e",
       "max": 2,
       "min": -2,
       "step": 0.5,
       "style": "IPY_MODEL_f50eebfa501248babd0ec54efb2931a5",
       "value": 1
      }
     },
     "36bba20749084534bec436b74418a25f": {
      "model_module": "@jupyter-widgets/output",
      "model_module_version": "1.0.0",
      "model_name": "OutputModel",
      "state": {
       "layout": "IPY_MODEL_2d488c83ab604ab681ce6414d1a89628",
       "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": "Bayes accuracy:  87.8 96.8 89.4\nLDA accuracy:  87.6 96.8 89.8\n"
        },
        {
         "data": {
          "image/png": "\n",
          "text/plain": "<Figure size 576x576 with 1 Axes>"
         },
         "metadata": {},
         "output_type": "display_data"
        }
       ]
      }
     },
     "3bd6103280d2445891fb4c6f003303ca": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "45f98f7f180b47728ee09dc6b66eadee": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "489f535e20da4b1da4108d8a552f05dd": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "5267ded6b1884918af7d6043cc3b9560": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "57cd4b3d181f471fab0b9e136ea4beaa": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "mean1",
       "layout": "IPY_MODEL_7bbb0964c5c84a778ad66c5b2824b0e6",
       "max": 2,
       "min": -2,
       "step": 0.5,
       "style": "IPY_MODEL_b659be192f754c87b171c849edac40ea",
       "value": -2
      }
     },
     "69ea6971801a422a90211f4a0df9751f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "70776e09c27d4f9f8c18c6259b80ca33": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "sigma2",
       "layout": "IPY_MODEL_69ea6971801a422a90211f4a0df9751f",
       "max": 5,
       "min": 0.1,
       "step": 0.1,
       "style": "IPY_MODEL_729dbd1b05f84ff082e5c80bfcebab20",
       "value": 0.9
      }
     },
     "729dbd1b05f84ff082e5c80bfcebab20": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "734cdfcb36a74bf9947f2e04a7b69d2f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "mean3",
       "layout": "IPY_MODEL_127f59fb10e6405f9c032090244584ac",
       "max": 2,
       "min": -2,
       "step": 0.5,
       "style": "IPY_MODEL_060dd90dce484d8488b21ac24a8689a2",
       "value": 2
      }
     },
     "7bbb0964c5c84a778ad66c5b2824b0e6": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "8017aecdb1124b2ba4934d9668b5ece2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "867eb8d642324162ab76f0ec96bc4462": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "9185d46ef72f460ba90d7fad794836c6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "sigma1",
       "layout": "IPY_MODEL_013d5d7b1aa446d190005ec1533089c0",
       "max": 5,
       "min": 0.1,
       "step": 0.1,
       "style": "IPY_MODEL_3bd6103280d2445891fb4c6f003303ca",
       "value": 2.3
      }
     },
     "935faf5295db42a69a040fdddb5b8cbd": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "9576e45393bd4da986afcaadb9ea01b3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "962ce5ec503f483884c8b1401be529a0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "VBoxModel",
      "state": {
       "_dom_classes": [
        "widget-interact"
       ],
       "children": [
        "IPY_MODEL_d24d66952038418b93b6c54715023112",
        "IPY_MODEL_e7d07c9ceb3c479bb156d2b5f1ea8a67",
        "IPY_MODEL_cbea21d7a40a40be986b9d0e3c66e8b8",
        "IPY_MODEL_70776e09c27d4f9f8c18c6259b80ca33",
        "IPY_MODEL_04d85ef6d6bf40b5bc00c8405e54764b"
       ],
       "layout": "IPY_MODEL_c419dea5bd784b7fa242a308b465b6f4"
      }
     },
     "9be199aed5b040c5867c67e4a560a69f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "b659be192f754c87b171c849edac40ea": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "c419dea5bd784b7fa242a308b465b6f4": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "cbea21d7a40a40be986b9d0e3c66e8b8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "sigma1",
       "layout": "IPY_MODEL_935faf5295db42a69a040fdddb5b8cbd",
       "max": 5,
       "min": 0.1,
       "step": 0.1,
       "style": "IPY_MODEL_9be199aed5b040c5867c67e4a560a69f",
       "value": 3.3
      }
     },
     "d24d66952038418b93b6c54715023112": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "mean1",
       "layout": "IPY_MODEL_489f535e20da4b1da4108d8a552f05dd",
       "max": 2,
       "min": -2,
       "step": 0.5,
       "style": "IPY_MODEL_9576e45393bd4da986afcaadb9ea01b3",
       "value": -2
      }
     },
     "e7d07c9ceb3c479bb156d2b5f1ea8a67": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "FloatSliderModel",
      "state": {
       "description": "mean2",
       "layout": "IPY_MODEL_270c25a8c2764db39c5c18e0874bbdaa",
       "max": 2,
       "min": -2,
       "step": 0.5,
       "style": "IPY_MODEL_8017aecdb1124b2ba4934d9668b5ece2",
       "value": 0.5
      }
     },
     "e9e7ffa7d08b4bbda69a30e02348813e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "f50eebfa501248babd0ec54efb2931a5": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.2.0",
      "model_name": "SliderStyleModel",
      "state": {
       "description_width": ""
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}