{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.pysparkutils import *\n",
    "\n",
    "spark = SparkSession.builder.appName('income').getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- class: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.types import *\n",
    "\n",
    "# schema = StructType([\n",
    "#     StructField(\"age\", IntegerType(), True), \n",
    "#     StructField(\"workclass\", StringType(), True),\n",
    "#     StructField(\"fnlwgt\", FloatType(), True),\n",
    "#     StructField(\"education\", StringType(), True),\n",
    "#     StructField(\"education-num\", FloatType(), True),\n",
    "#     StructField(\"marital-status\", StringType(), True),\n",
    "#     StructField(\"occupation\", StringType(), True),\n",
    "#     StructField(\"relationship\", StringType(), True),\n",
    "#     StructField(\"race\", StringType(), True),\n",
    "#     StructField(\"sex\", StringType(), True),\n",
    "#     StructField(\"capital-gain\", FloatType(), True),\n",
    "#     StructField(\"capital-loss\", FloatType(), True),\n",
    "#     StructField(\"hours-per-week\", FloatType(), True),\n",
    "#     StructField(\"native-country\", StringType(), True),\n",
    "#     StructField(\"class\", StringType(), True)]\n",
    "# )\n",
    "\n",
    "# train = spark.read.csv('./adult.data.txt', schema=schema, inferSchema='true')\n",
    "\n",
    "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n",
    "           \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n",
    "           \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n",
    "           \"class\"]\n",
    "\n",
    "train = spark.read.csv('./adult.data.txt',\n",
    "                       inferSchema='true', \n",
    "                       ignoreLeadingWhiteSpace='true',\n",
    "                       ignoreTrailingWhiteSpace='true').toDF(*headers)\n",
    "train.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import udf, monotonically_increasing_id\n",
    "\n",
    "train = train.withColumn('id', monotonically_increasing_id())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32561"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labelCol = 'class'\n",
    "train.count()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exploratory Data Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>24720</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>7841</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   class  count\n",
       "0  <=50K  24720\n",
       "1   >50K   7841"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.groupby(labelCol).count().toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see there is a class imbalance problem in our training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "findMissingValuesCols(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is no missing values in our training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+\n",
      "|age|\n",
      "+---+\n",
      "| 31|\n",
      "| 85|\n",
      "| 65|\n",
      "| 53|\n",
      "| 78|\n",
      "| 34|\n",
      "| 81|\n",
      "| 28|\n",
      "| 76|\n",
      "| 27|\n",
      "| 26|\n",
      "| 44|\n",
      "| 22|\n",
      "| 47|\n",
      "| 52|\n",
      "| 86|\n",
      "| 40|\n",
      "| 20|\n",
      "| 57|\n",
      "| 54|\n",
      "+---+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train.select('age').distinct().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Amer-Indian-Eskimo</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>275</td>\n",
       "      <td>88.42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Amer-Indian-Eskimo</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>36</td>\n",
       "      <td>11.58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Asian-Pac-Islander</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>763</td>\n",
       "      <td>73.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Asian-Pac-Islander</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>276</td>\n",
       "      <td>26.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Black</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>2737</td>\n",
       "      <td>87.61</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Black</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>387</td>\n",
       "      <td>12.39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Other</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>246</td>\n",
       "      <td>90.77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Other</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>25</td>\n",
       "      <td>9.23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>White</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>20699</td>\n",
       "      <td>74.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>White</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>7117</td>\n",
       "      <td>25.59</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 race  class  count percentage\n",
       "0  Amer-Indian-Eskimo  <=50K    275      88.42\n",
       "1  Amer-Indian-Eskimo   >50K     36      11.58\n",
       "2  Asian-Pac-Islander  <=50K    763      73.44\n",
       "3  Asian-Pac-Islander   >50K    276      26.56\n",
       "4               Black  <=50K   2737      87.61\n",
       "5               Black   >50K    387      12.39\n",
       "6               Other  <=50K    246      90.77\n",
       "7               Other   >50K     25       9.23\n",
       "8               White  <=50K  20699      74.41\n",
       "9               White   >50K   7117      25.59"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "percentageCol = 'percentage'\n",
    "df = crosstabPercentage(train, 'race', labelCol)\n",
    "df = df.withColumn(percentageCol, format_number(df[percentageCol], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>17</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>395</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>18</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>550</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>19</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>710</td>\n",
       "      <td>99.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>19</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>0.28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>20</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>753</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>21</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>717</td>\n",
       "      <td>99.58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>21</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>3</td>\n",
       "      <td>0.42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>22</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>752</td>\n",
       "      <td>98.30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>22</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>13</td>\n",
       "      <td>1.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>23</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>865</td>\n",
       "      <td>98.63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>23</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>12</td>\n",
       "      <td>1.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>24</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>767</td>\n",
       "      <td>96.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>24</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>31</td>\n",
       "      <td>3.88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>25</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>788</td>\n",
       "      <td>93.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>25</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>53</td>\n",
       "      <td>6.30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>26</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>722</td>\n",
       "      <td>91.97</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>26</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>63</td>\n",
       "      <td>8.03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>27</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>754</td>\n",
       "      <td>90.30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>27</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>81</td>\n",
       "      <td>9.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>28</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>748</td>\n",
       "      <td>86.27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>28</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>119</td>\n",
       "      <td>13.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>29</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>679</td>\n",
       "      <td>83.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>29</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>134</td>\n",
       "      <td>16.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>30</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>690</td>\n",
       "      <td>80.14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>30</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>171</td>\n",
       "      <td>19.86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>31</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>705</td>\n",
       "      <td>79.39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>31</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>183</td>\n",
       "      <td>20.61</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>32</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>639</td>\n",
       "      <td>77.17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>32</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>189</td>\n",
       "      <td>22.83</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>33</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>684</td>\n",
       "      <td>78.17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108</th>\n",
       "      <td>72</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>9</td>\n",
       "      <td>13.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>109</th>\n",
       "      <td>73</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>54</td>\n",
       "      <td>84.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>110</th>\n",
       "      <td>73</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>10</td>\n",
       "      <td>15.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>111</th>\n",
       "      <td>74</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>39</td>\n",
       "      <td>76.47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>74</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>12</td>\n",
       "      <td>23.53</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>113</th>\n",
       "      <td>75</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>38</td>\n",
       "      <td>84.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114</th>\n",
       "      <td>75</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>7</td>\n",
       "      <td>15.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>76</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>41</td>\n",
       "      <td>89.13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>116</th>\n",
       "      <td>76</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>5</td>\n",
       "      <td>10.87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>77</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>24</td>\n",
       "      <td>82.76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>77</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>5</td>\n",
       "      <td>17.24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>78</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>18</td>\n",
       "      <td>78.26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>120</th>\n",
       "      <td>78</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>5</td>\n",
       "      <td>21.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>121</th>\n",
       "      <td>79</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>13</td>\n",
       "      <td>59.09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>122</th>\n",
       "      <td>79</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>9</td>\n",
       "      <td>40.91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>123</th>\n",
       "      <td>80</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>20</td>\n",
       "      <td>90.91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>124</th>\n",
       "      <td>80</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>9.09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>81</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>17</td>\n",
       "      <td>85.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>81</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>3</td>\n",
       "      <td>15.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>82</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>12</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>83</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>4</td>\n",
       "      <td>66.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>83</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>33.33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>84</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>9</td>\n",
       "      <td>90.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>84</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>10.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>85</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>3</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>86</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>87</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>88</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>3</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>90</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>35</td>\n",
       "      <td>81.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>90</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>8</td>\n",
       "      <td>18.60</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>138 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     age  class  count percentage\n",
       "0     17  <=50K    395     100.00\n",
       "1     18  <=50K    550     100.00\n",
       "2     19  <=50K    710      99.72\n",
       "3     19   >50K      2       0.28\n",
       "4     20  <=50K    753     100.00\n",
       "5     21  <=50K    717      99.58\n",
       "6     21   >50K      3       0.42\n",
       "7     22  <=50K    752      98.30\n",
       "8     22   >50K     13       1.70\n",
       "9     23  <=50K    865      98.63\n",
       "10    23   >50K     12       1.37\n",
       "11    24  <=50K    767      96.12\n",
       "12    24   >50K     31       3.88\n",
       "13    25  <=50K    788      93.70\n",
       "14    25   >50K     53       6.30\n",
       "15    26  <=50K    722      91.97\n",
       "16    26   >50K     63       8.03\n",
       "17    27  <=50K    754      90.30\n",
       "18    27   >50K     81       9.70\n",
       "19    28  <=50K    748      86.27\n",
       "20    28   >50K    119      13.73\n",
       "21    29  <=50K    679      83.52\n",
       "22    29   >50K    134      16.48\n",
       "23    30  <=50K    690      80.14\n",
       "24    30   >50K    171      19.86\n",
       "25    31  <=50K    705      79.39\n",
       "26    31   >50K    183      20.61\n",
       "27    32  <=50K    639      77.17\n",
       "28    32   >50K    189      22.83\n",
       "29    33  <=50K    684      78.17\n",
       "..   ...    ...    ...        ...\n",
       "108   72   >50K      9      13.43\n",
       "109   73  <=50K     54      84.38\n",
       "110   73   >50K     10      15.62\n",
       "111   74  <=50K     39      76.47\n",
       "112   74   >50K     12      23.53\n",
       "113   75  <=50K     38      84.44\n",
       "114   75   >50K      7      15.56\n",
       "115   76  <=50K     41      89.13\n",
       "116   76   >50K      5      10.87\n",
       "117   77  <=50K     24      82.76\n",
       "118   77   >50K      5      17.24\n",
       "119   78  <=50K     18      78.26\n",
       "120   78   >50K      5      21.74\n",
       "121   79  <=50K     13      59.09\n",
       "122   79   >50K      9      40.91\n",
       "123   80  <=50K     20      90.91\n",
       "124   80   >50K      2       9.09\n",
       "125   81  <=50K     17      85.00\n",
       "126   81   >50K      3      15.00\n",
       "127   82  <=50K     12     100.00\n",
       "128   83  <=50K      4      66.67\n",
       "129   83   >50K      2      33.33\n",
       "130   84  <=50K      9      90.00\n",
       "131   84   >50K      1      10.00\n",
       "132   85  <=50K      3     100.00\n",
       "133   86  <=50K      1     100.00\n",
       "134   87  <=50K      1     100.00\n",
       "135   88  <=50K      3     100.00\n",
       "136   90  <=50K     35      81.40\n",
       "137   90   >50K      8      18.60\n",
       "\n",
       "[138 rows x 4 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = crosstabPercentage(train, 'age', labelCol)\n",
    "df = df.orderBy('age').withColumn('percentage', \n",
    "                                    format_number(df['percentage'], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Female</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>9592</td>\n",
       "      <td>89.05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Female</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1179</td>\n",
       "      <td>10.95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Male</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>15128</td>\n",
       "      <td>69.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Male</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>6662</td>\n",
       "      <td>30.57</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      sex  class  count percentage\n",
       "0  Female  <=50K   9592      89.05\n",
       "1  Female   >50K   1179      10.95\n",
       "2    Male  <=50K  15128      69.43\n",
       "3    Male   >50K   6662      30.57"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = crosstabPercentage(train, 'sex', labelCol)\n",
    "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`crosstabPercentage` is a simple way to explore the relation between a particular categorical feature and the label. For instance, the above shows the usefulness of `sex` feature in predicting the salary. It is obvious that more men earn >50K salary than women. So if a person is male, then he is more likely to earn >50K salary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>education</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>10th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>871</td>\n",
       "      <td>93.35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>10th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>62</td>\n",
       "      <td>6.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>11th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1115</td>\n",
       "      <td>94.89</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>60</td>\n",
       "      <td>5.11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>12th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>400</td>\n",
       "      <td>92.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>12th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>33</td>\n",
       "      <td>7.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1st-4th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>162</td>\n",
       "      <td>96.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1st-4th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>6</td>\n",
       "      <td>3.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5th-6th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>317</td>\n",
       "      <td>95.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5th-6th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>16</td>\n",
       "      <td>4.80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>7th-8th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>606</td>\n",
       "      <td>93.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>7th-8th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>40</td>\n",
       "      <td>6.19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>9th</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>487</td>\n",
       "      <td>94.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>9th</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>27</td>\n",
       "      <td>5.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>802</td>\n",
       "      <td>75.16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>265</td>\n",
       "      <td>24.84</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>Assoc-voc</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1021</td>\n",
       "      <td>73.88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>Assoc-voc</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>361</td>\n",
       "      <td>26.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>Bachelors</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>3134</td>\n",
       "      <td>58.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>Bachelors</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2221</td>\n",
       "      <td>41.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>Doctorate</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>107</td>\n",
       "      <td>25.91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>Doctorate</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>306</td>\n",
       "      <td>74.09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>HS-grad</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>8826</td>\n",
       "      <td>84.05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>HS-grad</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1675</td>\n",
       "      <td>15.95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>Masters</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>764</td>\n",
       "      <td>44.34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>Masters</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>959</td>\n",
       "      <td>55.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>Preschool</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>51</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>Prof-school</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>153</td>\n",
       "      <td>26.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>Prof-school</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>423</td>\n",
       "      <td>73.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>Some-college</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>5904</td>\n",
       "      <td>80.98</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>Some-college</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1387</td>\n",
       "      <td>19.02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       education  class  count percentage\n",
       "0           10th  <=50K    871      93.35\n",
       "1           10th   >50K     62       6.65\n",
       "2           11th  <=50K   1115      94.89\n",
       "3           11th   >50K     60       5.11\n",
       "4           12th  <=50K    400      92.38\n",
       "5           12th   >50K     33       7.62\n",
       "6        1st-4th  <=50K    162      96.43\n",
       "7        1st-4th   >50K      6       3.57\n",
       "8        5th-6th  <=50K    317      95.20\n",
       "9        5th-6th   >50K     16       4.80\n",
       "10       7th-8th  <=50K    606      93.81\n",
       "11       7th-8th   >50K     40       6.19\n",
       "12           9th  <=50K    487      94.75\n",
       "13           9th   >50K     27       5.25\n",
       "14    Assoc-acdm  <=50K    802      75.16\n",
       "15    Assoc-acdm   >50K    265      24.84\n",
       "16     Assoc-voc  <=50K   1021      73.88\n",
       "17     Assoc-voc   >50K    361      26.12\n",
       "18     Bachelors  <=50K   3134      58.52\n",
       "19     Bachelors   >50K   2221      41.48\n",
       "20     Doctorate  <=50K    107      25.91\n",
       "21     Doctorate   >50K    306      74.09\n",
       "22       HS-grad  <=50K   8826      84.05\n",
       "23       HS-grad   >50K   1675      15.95\n",
       "24       Masters  <=50K    764      44.34\n",
       "25       Masters   >50K    959      55.66\n",
       "26     Preschool  <=50K     51     100.00\n",
       "27   Prof-school  <=50K    153      26.56\n",
       "28   Prof-school   >50K    423      73.44\n",
       "29  Some-college  <=50K   5904      80.98\n",
       "30  Some-college   >50K   1387      19.02"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = crosstabPercentage(train, 'education', labelCol)\n",
    "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%script false\n",
    "educationNumClass = crosstabPercentage(train, 'education-num', labelCol)\n",
    "educationNumClass = educationNumClass.withColumn('percentage', \n",
    "                                    format_number(educationNumClass['percentage-'], 2))\n",
    "educationNumClass = educationNumClass.withColumn('education-numClassF', educationNumClass['education-numClass'].cast(DoubleType()))\\\n",
    "                                     .orderBy('education-numClassF').drop('education-numClass')\n",
    "cols = educationNumClass.columns\n",
    "cols.remove('education-numClassF')\n",
    "cols.insert(0, 'education-numClassF')\n",
    "educationNumClass = educationNumClass.select(cols)\n",
    "educationNumClass.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see above that this is a sparse matrix, it's hard to find the non-zero values. So we will only focus on non-zero values to find out whether there is any relationship between these features and one of them is redundant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%script false\n",
    "\n",
    "from pyspark.sql.functions import coalesce, lit, when\n",
    "\n",
    "iterator = df.toLocalIterator()\n",
    "d = {}\n",
    "for row in iterator:\n",
    "    rowDict = row.asDict()\n",
    "    educationNum = rowDict['education-num_education']\n",
    "    for k, v in rowDict.items():\n",
    "        if k != 'education-num_education' and v != 0:\n",
    "            d[educationNum+'_'+k] = v\n",
    "\n",
    "import json\n",
    "s = json.dumps(d, indent=4)\n",
    "print(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see it's obvious that these features are redundant. Only one of them should suffice for our classification task.\n",
    "\n",
    "Let's try more rigorous chi square test instead of something hand-wavy.\n",
    "\n",
    "First we will define an utility method that'll index the catgorical string columns, encodes them into one-hot-encoded vectors, and finally assemble all the feature vectos into once vector for later downstream analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>workclass</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>?</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1645</td>\n",
       "      <td>89.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>?</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>191</td>\n",
       "      <td>10.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Federal-gov</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>589</td>\n",
       "      <td>61.35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Federal-gov</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>371</td>\n",
       "      <td>38.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Local-gov</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1476</td>\n",
       "      <td>70.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Local-gov</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>617</td>\n",
       "      <td>29.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Never-worked</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>7</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Private</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>17733</td>\n",
       "      <td>78.13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Private</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>4963</td>\n",
       "      <td>21.87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Self-emp-inc</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>494</td>\n",
       "      <td>44.27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Self-emp-inc</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>622</td>\n",
       "      <td>55.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1817</td>\n",
       "      <td>71.51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>724</td>\n",
       "      <td>28.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>State-gov</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>945</td>\n",
       "      <td>72.80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>State-gov</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>353</td>\n",
       "      <td>27.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>Without-pay</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>14</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           workclass  class  count percentage\n",
       "0                  ?  <=50K   1645      89.60\n",
       "1                  ?   >50K    191      10.40\n",
       "2        Federal-gov  <=50K    589      61.35\n",
       "3        Federal-gov   >50K    371      38.65\n",
       "4          Local-gov  <=50K   1476      70.52\n",
       "5          Local-gov   >50K    617      29.48\n",
       "6       Never-worked  <=50K      7     100.00\n",
       "7            Private  <=50K  17733      78.13\n",
       "8            Private   >50K   4963      21.87\n",
       "9       Self-emp-inc  <=50K    494      44.27\n",
       "10      Self-emp-inc   >50K    622      55.73\n",
       "11  Self-emp-not-inc  <=50K   1817      71.51\n",
       "12  Self-emp-not-inc   >50K    724      28.49\n",
       "13         State-gov  <=50K    945      72.80\n",
       "14         State-gov   >50K    353      27.20\n",
       "15       Without-pay  <=50K     14     100.00"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = crosstabPercentage(train, 'workclass', labelCol)\n",
    "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>class</th>\n",
       "      <th>count</th>\n",
       "      <th>percentage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>18</td>\n",
       "      <td>90.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>10.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>24</td>\n",
       "      <td>75.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>8</td>\n",
       "      <td>25.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>38</td>\n",
       "      <td>97.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>2.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>51</td>\n",
       "      <td>94.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>3</td>\n",
       "      <td>5.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>53</td>\n",
       "      <td>88.33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>7</td>\n",
       "      <td>11.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>6</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>56</td>\n",
       "      <td>87.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>6</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>8</td>\n",
       "      <td>12.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>7</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>22</td>\n",
       "      <td>84.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>7</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>4</td>\n",
       "      <td>15.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>8</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>134</td>\n",
       "      <td>92.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>8</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>11</td>\n",
       "      <td>7.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>9</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>17</td>\n",
       "      <td>94.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>9</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>5.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>10</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>258</td>\n",
       "      <td>92.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>10</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>20</td>\n",
       "      <td>7.19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>11</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>11</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>12</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>161</td>\n",
       "      <td>93.06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>12</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>12</td>\n",
       "      <td>6.94</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>13</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>21</td>\n",
       "      <td>91.30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>13</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>8.70</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>14</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>32</td>\n",
       "      <td>94.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>14</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>5.88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>15</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>389</td>\n",
       "      <td>96.29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>15</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>15</td>\n",
       "      <td>3.71</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>16</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>192</td>\n",
       "      <td>93.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>143</th>\n",
       "      <td>78</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>6</td>\n",
       "      <td>75.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>144</th>\n",
       "      <td>78</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>2</td>\n",
       "      <td>25.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>80</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>76</td>\n",
       "      <td>57.14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>80</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>57</td>\n",
       "      <td>42.86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>81</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>3</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>82</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>84</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>28</td>\n",
       "      <td>62.22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>150</th>\n",
       "      <td>84</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>17</td>\n",
       "      <td>37.78</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>151</th>\n",
       "      <td>85</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>9</td>\n",
       "      <td>69.23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>152</th>\n",
       "      <td>85</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>4</td>\n",
       "      <td>30.77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>153</th>\n",
       "      <td>86</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>2</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>154</th>\n",
       "      <td>87</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>155</th>\n",
       "      <td>88</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>2</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>156</th>\n",
       "      <td>89</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>157</th>\n",
       "      <td>89</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>158</th>\n",
       "      <td>90</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>19</td>\n",
       "      <td>65.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>159</th>\n",
       "      <td>90</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>10</td>\n",
       "      <td>34.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>160</th>\n",
       "      <td>91</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>3</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>161</th>\n",
       "      <td>92</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162</th>\n",
       "      <td>94</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>163</th>\n",
       "      <td>95</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>164</th>\n",
       "      <td>95</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>165</th>\n",
       "      <td>96</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>4</td>\n",
       "      <td>80.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>166</th>\n",
       "      <td>96</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>20.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>167</th>\n",
       "      <td>97</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>168</th>\n",
       "      <td>97</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "      <td>50.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>169</th>\n",
       "      <td>98</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>8</td>\n",
       "      <td>72.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>170</th>\n",
       "      <td>98</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>3</td>\n",
       "      <td>27.27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>171</th>\n",
       "      <td>99</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>60</td>\n",
       "      <td>70.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>172</th>\n",
       "      <td>99</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>25</td>\n",
       "      <td>29.41</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>173 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     hours-per-week  class  count percentage\n",
       "0                 1  <=50K     18      90.00\n",
       "1                 1   >50K      2      10.00\n",
       "2                 2  <=50K     24      75.00\n",
       "3                 2   >50K      8      25.00\n",
       "4                 3  <=50K     38      97.44\n",
       "5                 3   >50K      1       2.56\n",
       "6                 4  <=50K     51      94.44\n",
       "7                 4   >50K      3       5.56\n",
       "8                 5  <=50K     53      88.33\n",
       "9                 5   >50K      7      11.67\n",
       "10                6  <=50K     56      87.50\n",
       "11                6   >50K      8      12.50\n",
       "12                7  <=50K     22      84.62\n",
       "13                7   >50K      4      15.38\n",
       "14                8  <=50K    134      92.41\n",
       "15                8   >50K     11       7.59\n",
       "16                9  <=50K     17      94.44\n",
       "17                9   >50K      1       5.56\n",
       "18               10  <=50K    258      92.81\n",
       "19               10   >50K     20       7.19\n",
       "20               11  <=50K     11     100.00\n",
       "21               12  <=50K    161      93.06\n",
       "22               12   >50K     12       6.94\n",
       "23               13  <=50K     21      91.30\n",
       "24               13   >50K      2       8.70\n",
       "25               14  <=50K     32      94.12\n",
       "26               14   >50K      2       5.88\n",
       "27               15  <=50K    389      96.29\n",
       "28               15   >50K     15       3.71\n",
       "29               16  <=50K    192      93.66\n",
       "..              ...    ...    ...        ...\n",
       "143              78  <=50K      6      75.00\n",
       "144              78   >50K      2      25.00\n",
       "145              80  <=50K     76      57.14\n",
       "146              80   >50K     57      42.86\n",
       "147              81  <=50K      3     100.00\n",
       "148              82  <=50K      1     100.00\n",
       "149              84  <=50K     28      62.22\n",
       "150              84   >50K     17      37.78\n",
       "151              85  <=50K      9      69.23\n",
       "152              85   >50K      4      30.77\n",
       "153              86  <=50K      2     100.00\n",
       "154              87  <=50K      1     100.00\n",
       "155              88  <=50K      2     100.00\n",
       "156              89  <=50K      1      50.00\n",
       "157              89   >50K      1      50.00\n",
       "158              90  <=50K     19      65.52\n",
       "159              90   >50K     10      34.48\n",
       "160              91  <=50K      3     100.00\n",
       "161              92  <=50K      1     100.00\n",
       "162              94  <=50K      1     100.00\n",
       "163              95  <=50K      1      50.00\n",
       "164              95   >50K      1      50.00\n",
       "165              96  <=50K      4      80.00\n",
       "166              96   >50K      1      20.00\n",
       "167              97  <=50K      1      50.00\n",
       "168              97   >50K      1      50.00\n",
       "169              98  <=50K      8      72.73\n",
       "170              98   >50K      3      27.27\n",
       "171              99  <=50K     60      70.59\n",
       "172              99   >50K     25      29.41\n",
       "\n",
       "[173 rows x 4 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = crosstabPercentage(train, 'hours-per-week', labelCol)\n",
    "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n",
    "df.toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Outlier Detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use Attribute-Value Frequency (AVF) outlier detection in categorical features. The beauty of this algorithm is that it's very simple, highly parallelizable, and fit well with distributed programming paradigm. `attributeValueFrequency` function is implemented in `pysparkutils.py` file in `utils` directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>avfScore</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>64873</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>84761</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>85760</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>103252</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>52051</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>49136</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>91948</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>92741</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>99489</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>89041</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>95526</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>78598</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>84745</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>76448</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>80545</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>97699</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>86445</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>95149</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>101519</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>99688</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>96113</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>86132</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>74783</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>88291</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>75232</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>81293</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>66091</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>73190</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>107959</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>77034</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8159</th>\n",
       "      <td>121083</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8160</th>\n",
       "      <td>90753</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8161</th>\n",
       "      <td>83486</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8162</th>\n",
       "      <td>88325</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8163</th>\n",
       "      <td>96995</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8164</th>\n",
       "      <td>100990</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8165</th>\n",
       "      <td>90597</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8166</th>\n",
       "      <td>100460</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8167</th>\n",
       "      <td>92453</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8168</th>\n",
       "      <td>95690</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8169</th>\n",
       "      <td>104486</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8170</th>\n",
       "      <td>119995</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8171</th>\n",
       "      <td>131889</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8172</th>\n",
       "      <td>110660</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8173</th>\n",
       "      <td>102790</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8174</th>\n",
       "      <td>122394</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8175</th>\n",
       "      <td>124053</td>\n",
       "      <td>76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8176</th>\n",
       "      <td>134386</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8177</th>\n",
       "      <td>113783</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8178</th>\n",
       "      <td>94066</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8179</th>\n",
       "      <td>73337</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8180</th>\n",
       "      <td>91838</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8181</th>\n",
       "      <td>90865</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8182</th>\n",
       "      <td>100947</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8183</th>\n",
       "      <td>71321</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8184</th>\n",
       "      <td>110738</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8185</th>\n",
       "      <td>99204</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8186</th>\n",
       "      <td>38322</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8187</th>\n",
       "      <td>89328</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8188</th>\n",
       "      <td>65479</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>8189 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      avfScore  count\n",
       "0        64873      1\n",
       "1        84761      1\n",
       "2        85760      1\n",
       "3       103252      1\n",
       "4        52051      1\n",
       "5        49136      1\n",
       "6        91948      1\n",
       "7        92741      2\n",
       "8        99489      1\n",
       "9        89041      1\n",
       "10       95526      1\n",
       "11       78598      1\n",
       "12       84745      1\n",
       "13       76448      1\n",
       "14       80545      1\n",
       "15       97699      1\n",
       "16       86445      1\n",
       "17       95149      3\n",
       "18      101519      5\n",
       "19       99688      4\n",
       "20       96113      1\n",
       "21       86132      1\n",
       "22       74783      1\n",
       "23       88291      2\n",
       "24       75232      1\n",
       "25       81293      5\n",
       "26       66091      1\n",
       "27       73190      1\n",
       "28      107959      1\n",
       "29       77034      1\n",
       "...        ...    ...\n",
       "8159    121083      1\n",
       "8160     90753      1\n",
       "8161     83486      1\n",
       "8162     88325      3\n",
       "8163     96995      1\n",
       "8164    100990      4\n",
       "8165     90597      1\n",
       "8166    100460      1\n",
       "8167     92453     14\n",
       "8168     95690      4\n",
       "8169    104486      2\n",
       "8170    119995     13\n",
       "8171    131889      9\n",
       "8172    110660      1\n",
       "8173    102790      1\n",
       "8174    122394      2\n",
       "8175    124053     76\n",
       "8176    134386     45\n",
       "8177    113783      3\n",
       "8178     94066      1\n",
       "8179     73337      1\n",
       "8180     91838      2\n",
       "8181     90865      1\n",
       "8182    100947      3\n",
       "8183     71321      1\n",
       "8184    110738      4\n",
       "8185     99204      1\n",
       "8186     38322      1\n",
       "8187     89328      1\n",
       "8188     65479      1\n",
       "\n",
       "[8189 rows x 2 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "avfScoreCol = 'avfScore'\n",
    "categoricalCols = ['workclass', 'education', 'marital-status',\n",
    "                  'occupation', 'relationship', 'race', 'sex',\n",
    "                  'native-country']\n",
    "avfScore = attributeValueFrequency(train, categoricalCols)\n",
    "pdf = avfScore.groupby(avfScoreCol).count().toPandas()\n",
    "pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7ff049874978>]], dtype=object)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# sns.barplot(x=\"avfScore\", y=\"count\", data=pdf)\n",
    "# pdf['count'].hist(by=pdf['avfScore'])\n",
    "# pdf.plot(x='avfScore', y='count')\n",
    "# pdf.plot(x='avfScore', y='count', kind='bar')\n",
    "\n",
    "# Ideally, we want to use Spark for aggregation, and just plot the data by converting to \n",
    "# Pandas. Unfortuantely I couldn't figure out a way yet, visualization is not my strongest skill.\n",
    "# This approach is NOT recommended for large datasets which are residing over multiple machines\n",
    "# Since this will bring whole dataframe to the driver node and the driver node might run out of\n",
    "# memory.\n",
    "avfScore.select(avfScoreCol).toPandas().hist(avfScoreCol)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In AVF, the lower the score of a datapoint, the more likely that datapoint is an outlier. We can safely remove the rows whose score is below 70000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: integer (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: integer (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: integer (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: integer (nullable = true)\n",
      " |-- capital-loss: integer (nullable = true)\n",
      " |-- hours-per-week: integer (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- class: string (nullable = true)\n",
      " |-- id: long (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train = avfScore.filter(col(avfScoreCol) > 70000).drop(avfScoreCol)\n",
    "train.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature Selection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chi Sqaure based categorical feature selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pValues: [1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler\n",
    "from pyspark.ml.stat import ChiSquareTest\n",
    "from pyspark.ml import Pipeline\n",
    "\n",
    "indexed = train.select('education-num', 'education')\n",
    "\n",
    "indexer = StringIndexer(inputCol='education', outputCol='educationIndexed')\n",
    "indexed = indexer.fit(indexed).transform(indexed)\n",
    "ohe = OneHotEncoderEstimator(inputCols=['education-num',], outputCols=['education-numOHE',])\n",
    "indexed = ohe.fit(indexed).transform(indexed)\n",
    "\n",
    "# The null hypothesis is that the occurrence of the outcomes is statistically independent.\n",
    "# In general, small p-values (1% to 5%) would cause you to reject the null hypothesis. \n",
    "# This very large p-value (92.65%) means that the null hypothesis should not be rejected.\n",
    "testResult = ChiSquareTest.test(indexed, 'education-numOHE', 'educationIndexed')\n",
    "r = testResult.head()\n",
    "print(\"pValues: \" + str(r.pValues))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can accept the hypothesis that features are dependent. We will drop the 'education' feature since the info. is covered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = train.drop('education')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pyspark.ml.clustering.KMeansSummary at 0x7ff049711390>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.clustering import KMeans\n",
    "\n",
    "_, _, indexedDf =  autoIndexer(train, labelCol)\n",
    "\n",
    "kmeans = KMeans(k=2, featuresCol='assembled')\n",
    "model = kmeans.fit(indexedDf)\n",
    "indexedDf = model.transform(indexedDf)\n",
    "model.summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.15349812641524918"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import adjusted_mutual_info_score\n",
    "indexer = StringIndexer(inputCol=labelCol, outputCol=labelCol+'Indexed')\n",
    "indexedDf = indexer.fit(indexedDf).transform(indexedDf)\n",
    "classIndexed = [row[0] for row in indexedDf.select('classIndexed').collect()]\n",
    "prediction = [row[0] for row in indexedDf.select('prediction').collect()]\n",
    "\n",
    "adjusted_mutual_info_score(classIndexed, prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "from pyspark.ml import Pipeline\n",
    "\n",
    "stringTypes = [dtype[0] for dtype in train.dtypes if dtype[1] == 'string']\n",
    "indexedTypes = [stringType+'Indexed' for stringType in stringTypes]\n",
    "\n",
    "indexers = [StringIndexer(inputCol=stringType, outputCol=stringType+'Indexed', handleInvalid='skip') \\\n",
    "            for stringType in stringTypes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[age: int, workclass: string, fnlwgt: int, education-num: int, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: int, capital-loss: int, hours-per-week: int, native-country: string, class: string, id: bigint, workclassIndexed: double, marital-statusIndexed: double, occupationIndexed: double, relationshipIndexed: double, raceIndexed: double, sexIndexed: double, native-countryIndexed: double, classIndexed: double, workclassIndexedOneHotEncoded: vector, raceIndexedOneHotEncoded: vector, occupationIndexedOneHotEncoded: vector, relationshipIndexedOneHotEncoded: vector, native-countryIndexedOneHotEncoded: vector, marital-statusIndexedOneHotEncoded: vector, sexIndexedOneHotEncoded: vector, classIndexedOneHotEncoded: vector, assembled: vector, rawPrediction: vector, probability: vector, prediction: double]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.feature import OneHotEncoderEstimator, VectorAssembler\n",
    "from pyspark.ml.classification import GBTClassifier\n",
    "\n",
    "oheTypes = [indexedType+'OneHotEncoded' for indexedType in indexedTypes]\n",
    "ohe = OneHotEncoderEstimator(inputCols=indexedTypes, outputCols=oheTypes)\n",
    "\n",
    "# Fix columns\n",
    "oheTypes.remove('classIndexedOneHotEncoded')\n",
    "cols = train.columns[:]\n",
    "for oheType in oheTypes:\n",
    "    cols.append(oheType)\n",
    "for stringType in stringTypes:\n",
    "    cols.remove(stringType)\n",
    "\n",
    "cols.remove('id')\n",
    "\n",
    "assembler = VectorAssembler(inputCols=cols, outputCol='assembled')\n",
    "classifier = GBTClassifier(featuresCol='assembled', labelCol='classIndexed')\n",
    "pipeline = Pipeline(stages=[*indexers, ohe, assembler, classifier])\n",
    "model = pipeline.fit(train)\n",
    "train = model.transform(train)\n",
    "train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we have class imbalance problem, that's why we will use area under ROC curve as metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9179985970287455"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "evaluator = BinaryClassificationEvaluator(labelCol='classIndexed', metricName='areaUnderROC')\n",
    "metric = evaluator.evaluate(train)\n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>&gt;50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>&gt;50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>&gt;50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>&lt;=50K.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    class\n",
       "0  <=50K.\n",
       "1  <=50K.\n",
       "2   >50K.\n",
       "3   >50K.\n",
       "4  <=50K.\n",
       "5  <=50K.\n",
       "6  <=50K.\n",
       "7   >50K.\n",
       "8  <=50K.\n",
       "9  <=50K."
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n",
    "           \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n",
    "           \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n",
    "           \"class\"]\n",
    "\n",
    "test = spark.read.csv('./adult.test.txt',\n",
    "                      inferSchema='true', \n",
    "                      ignoreLeadingWhiteSpace='true',\n",
    "                      ignoreTrailingWhiteSpace='true').toDF(*headers)\n",
    "test.select('class').limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the class labels in the test dataset are different than in train - '>50K' and '>50K.'. So we have to remove the extrac dot from the class label, before evaluating."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   class\n",
       "0  <=50K\n",
       "1  <=50K\n",
       "2   >50K\n",
       "3   >50K\n",
       "4  <=50K\n",
       "5  <=50K\n",
       "6  <=50K\n",
       "7   >50K\n",
       "8  <=50K\n",
       "9  <=50K"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.sql.types import StringType\n",
    "stripDot = udf(lambda s: s[:-1], StringType())\n",
    "\n",
    "test = test.withColumn('classTrailed', stripDot(test['class'])).drop('class').withColumnRenamed('classTrailed', 'class')\n",
    "test.select('class').limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9096987015789428"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = model.transform(test)\n",
    "metric = evaluator.evaluate(test)\n",
    "metric\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}