{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "\n", "import sys\n", "sys.path.append('..')\n", "from utils.pysparkutils import *\n", "\n", "spark = SparkSession.builder.appName('income').getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: integer (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: integer (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: integer (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: integer (nullable = true)\n", " |-- capital-loss: integer (nullable = true)\n", " |-- hours-per-week: integer (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- class: string (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.sql.types import *\n", "\n", "# schema = StructType([\n", "# StructField(\"age\", IntegerType(), True), \n", "# StructField(\"workclass\", StringType(), True),\n", "# StructField(\"fnlwgt\", FloatType(), True),\n", "# StructField(\"education\", StringType(), True),\n", "# StructField(\"education-num\", FloatType(), True),\n", "# StructField(\"marital-status\", StringType(), True),\n", "# StructField(\"occupation\", StringType(), True),\n", "# StructField(\"relationship\", StringType(), True),\n", "# StructField(\"race\", StringType(), True),\n", "# StructField(\"sex\", StringType(), True),\n", "# StructField(\"capital-gain\", FloatType(), True),\n", "# StructField(\"capital-loss\", FloatType(), True),\n", "# StructField(\"hours-per-week\", FloatType(), True),\n", "# StructField(\"native-country\", StringType(), True),\n", "# StructField(\"class\", StringType(), True)]\n", "# )\n", "\n", "# train = spark.read.csv('./adult.data.txt', schema=schema, inferSchema='true')\n", "\n", "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n", " \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n", " \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n", " \"class\"]\n", "\n", "train = spark.read.csv('./adult.data.txt',\n", " inferSchema='true', \n", " ignoreLeadingWhiteSpace='true',\n", " ignoreTrailingWhiteSpace='true').toDF(*headers)\n", "train.printSchema()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import udf, monotonically_increasing_id\n", "\n", "train = train.withColumn('id', monotonically_increasing_id())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "32561" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labelCol = 'class'\n", "train.count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exploratory Data Analysis" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K</td>\n", " <td>24720</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>>50K</td>\n", " <td>7841</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class count\n", "0 <=50K 24720\n", "1 >50K 7841" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.groupby(labelCol).count().toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see there is a class imbalance problem in our training set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "findMissingValuesCols(train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no missing values in our training data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---+\n", "|age|\n", "+---+\n", "| 31|\n", "| 85|\n", "| 65|\n", "| 53|\n", "| 78|\n", "| 34|\n", "| 81|\n", "| 28|\n", "| 76|\n", "| 27|\n", "| 26|\n", "| 44|\n", "| 22|\n", "| 47|\n", "| 52|\n", "| 86|\n", "| 40|\n", "| 20|\n", "| 57|\n", "| 54|\n", "+---+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "train.select('age').distinct().show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>race</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Amer-Indian-Eskimo</td>\n", " <td><=50K</td>\n", " <td>275</td>\n", " <td>88.42</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Amer-Indian-Eskimo</td>\n", " <td>>50K</td>\n", " <td>36</td>\n", " <td>11.58</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Asian-Pac-Islander</td>\n", " <td><=50K</td>\n", " <td>763</td>\n", " <td>73.44</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Asian-Pac-Islander</td>\n", " <td>>50K</td>\n", " <td>276</td>\n", " <td>26.56</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Black</td>\n", " <td><=50K</td>\n", " <td>2737</td>\n", " <td>87.61</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>Black</td>\n", " <td>>50K</td>\n", " <td>387</td>\n", " <td>12.39</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>Other</td>\n", " <td><=50K</td>\n", " <td>246</td>\n", " <td>90.77</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>Other</td>\n", " <td>>50K</td>\n", " <td>25</td>\n", " <td>9.23</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>White</td>\n", " <td><=50K</td>\n", " <td>20699</td>\n", " <td>74.41</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>White</td>\n", " <td>>50K</td>\n", " <td>7117</td>\n", " <td>25.59</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " race class count percentage\n", "0 Amer-Indian-Eskimo <=50K 275 88.42\n", "1 Amer-Indian-Eskimo >50K 36 11.58\n", "2 Asian-Pac-Islander <=50K 763 73.44\n", "3 Asian-Pac-Islander >50K 276 26.56\n", "4 Black <=50K 2737 87.61\n", "5 Black >50K 387 12.39\n", "6 Other <=50K 246 90.77\n", "7 Other >50K 25 9.23\n", "8 White <=50K 20699 74.41\n", "9 White >50K 7117 25.59" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "percentageCol = 'percentage'\n", "df = crosstabPercentage(train, 'race', labelCol)\n", "df = df.withColumn(percentageCol, format_number(df[percentageCol], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>17</td>\n", " <td><=50K</td>\n", " <td>395</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>18</td>\n", " <td><=50K</td>\n", " <td>550</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>19</td>\n", " <td><=50K</td>\n", " <td>710</td>\n", " <td>99.72</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>19</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>0.28</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>20</td>\n", " <td><=50K</td>\n", " <td>753</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>21</td>\n", " <td><=50K</td>\n", " <td>717</td>\n", " <td>99.58</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>21</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>0.42</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>22</td>\n", " <td><=50K</td>\n", " <td>752</td>\n", " <td>98.30</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>22</td>\n", " <td>>50K</td>\n", " <td>13</td>\n", " <td>1.70</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>23</td>\n", " <td><=50K</td>\n", " <td>865</td>\n", " <td>98.63</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>23</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>1.37</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>24</td>\n", " <td><=50K</td>\n", " <td>767</td>\n", " <td>96.12</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>24</td>\n", " <td>>50K</td>\n", " <td>31</td>\n", " <td>3.88</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>25</td>\n", " <td><=50K</td>\n", " <td>788</td>\n", " <td>93.70</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>25</td>\n", " <td>>50K</td>\n", " <td>53</td>\n", " <td>6.30</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>26</td>\n", " <td><=50K</td>\n", " <td>722</td>\n", " <td>91.97</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>26</td>\n", " <td>>50K</td>\n", " <td>63</td>\n", " <td>8.03</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>27</td>\n", " <td><=50K</td>\n", " <td>754</td>\n", " <td>90.30</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>27</td>\n", " <td>>50K</td>\n", " <td>81</td>\n", " <td>9.70</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>28</td>\n", " <td><=50K</td>\n", " <td>748</td>\n", " <td>86.27</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>28</td>\n", " <td>>50K</td>\n", " <td>119</td>\n", " <td>13.73</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>29</td>\n", " <td><=50K</td>\n", " <td>679</td>\n", " <td>83.52</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>29</td>\n", " <td>>50K</td>\n", " <td>134</td>\n", " <td>16.48</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>30</td>\n", " <td><=50K</td>\n", " <td>690</td>\n", " <td>80.14</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>30</td>\n", " <td>>50K</td>\n", " <td>171</td>\n", " <td>19.86</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>31</td>\n", " <td><=50K</td>\n", " <td>705</td>\n", " <td>79.39</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>31</td>\n", " <td>>50K</td>\n", " <td>183</td>\n", " <td>20.61</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>32</td>\n", " <td><=50K</td>\n", " <td>639</td>\n", " <td>77.17</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>32</td>\n", " <td>>50K</td>\n", " <td>189</td>\n", " <td>22.83</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>33</td>\n", " <td><=50K</td>\n", " <td>684</td>\n", " <td>78.17</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>108</th>\n", " <td>72</td>\n", " <td>>50K</td>\n", " <td>9</td>\n", " <td>13.43</td>\n", " </tr>\n", " <tr>\n", " <th>109</th>\n", " <td>73</td>\n", " <td><=50K</td>\n", " <td>54</td>\n", " <td>84.38</td>\n", " </tr>\n", " <tr>\n", " <th>110</th>\n", " <td>73</td>\n", " <td>>50K</td>\n", " <td>10</td>\n", " <td>15.62</td>\n", " </tr>\n", " <tr>\n", " <th>111</th>\n", " <td>74</td>\n", " <td><=50K</td>\n", " <td>39</td>\n", " <td>76.47</td>\n", " </tr>\n", " <tr>\n", " <th>112</th>\n", " <td>74</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>23.53</td>\n", " </tr>\n", " <tr>\n", " <th>113</th>\n", " <td>75</td>\n", " <td><=50K</td>\n", " <td>38</td>\n", " <td>84.44</td>\n", " </tr>\n", " <tr>\n", " <th>114</th>\n", " <td>75</td>\n", " <td>>50K</td>\n", " <td>7</td>\n", " <td>15.56</td>\n", " </tr>\n", " <tr>\n", " <th>115</th>\n", " <td>76</td>\n", " <td><=50K</td>\n", " <td>41</td>\n", " <td>89.13</td>\n", " </tr>\n", " <tr>\n", " <th>116</th>\n", " <td>76</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>10.87</td>\n", " </tr>\n", " <tr>\n", " <th>117</th>\n", " <td>77</td>\n", " <td><=50K</td>\n", " <td>24</td>\n", " <td>82.76</td>\n", " </tr>\n", " <tr>\n", " <th>118</th>\n", " <td>77</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>17.24</td>\n", " </tr>\n", " <tr>\n", " <th>119</th>\n", " <td>78</td>\n", " <td><=50K</td>\n", " <td>18</td>\n", " <td>78.26</td>\n", " </tr>\n", " <tr>\n", " <th>120</th>\n", " <td>78</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>21.74</td>\n", " </tr>\n", " <tr>\n", " <th>121</th>\n", " <td>79</td>\n", " <td><=50K</td>\n", " <td>13</td>\n", " <td>59.09</td>\n", " </tr>\n", " <tr>\n", " <th>122</th>\n", " <td>79</td>\n", " <td>>50K</td>\n", " <td>9</td>\n", " <td>40.91</td>\n", " </tr>\n", " <tr>\n", " <th>123</th>\n", " <td>80</td>\n", " <td><=50K</td>\n", " <td>20</td>\n", " <td>90.91</td>\n", " </tr>\n", " <tr>\n", " <th>124</th>\n", " <td>80</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>9.09</td>\n", " </tr>\n", " <tr>\n", " <th>125</th>\n", " <td>81</td>\n", " <td><=50K</td>\n", " <td>17</td>\n", " <td>85.00</td>\n", " </tr>\n", " <tr>\n", " <th>126</th>\n", " <td>81</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>15.00</td>\n", " </tr>\n", " <tr>\n", " <th>127</th>\n", " <td>82</td>\n", " <td><=50K</td>\n", " <td>12</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>128</th>\n", " <td>83</td>\n", " <td><=50K</td>\n", " <td>4</td>\n", " <td>66.67</td>\n", " </tr>\n", " <tr>\n", " <th>129</th>\n", " <td>83</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>33.33</td>\n", " </tr>\n", " <tr>\n", " <th>130</th>\n", " <td>84</td>\n", " <td><=50K</td>\n", " <td>9</td>\n", " <td>90.00</td>\n", " </tr>\n", " <tr>\n", " <th>131</th>\n", " <td>84</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>10.00</td>\n", " </tr>\n", " <tr>\n", " <th>132</th>\n", " <td>85</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>133</th>\n", " <td>86</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>134</th>\n", " <td>87</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>135</th>\n", " <td>88</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>136</th>\n", " <td>90</td>\n", " <td><=50K</td>\n", " <td>35</td>\n", " <td>81.40</td>\n", " </tr>\n", " <tr>\n", " <th>137</th>\n", " <td>90</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>18.60</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>138 rows × 4 columns</p>\n", "</div>" ], "text/plain": [ " age class count percentage\n", "0 17 <=50K 395 100.00\n", "1 18 <=50K 550 100.00\n", "2 19 <=50K 710 99.72\n", "3 19 >50K 2 0.28\n", "4 20 <=50K 753 100.00\n", "5 21 <=50K 717 99.58\n", "6 21 >50K 3 0.42\n", "7 22 <=50K 752 98.30\n", "8 22 >50K 13 1.70\n", "9 23 <=50K 865 98.63\n", "10 23 >50K 12 1.37\n", "11 24 <=50K 767 96.12\n", "12 24 >50K 31 3.88\n", "13 25 <=50K 788 93.70\n", "14 25 >50K 53 6.30\n", "15 26 <=50K 722 91.97\n", "16 26 >50K 63 8.03\n", "17 27 <=50K 754 90.30\n", "18 27 >50K 81 9.70\n", "19 28 <=50K 748 86.27\n", "20 28 >50K 119 13.73\n", "21 29 <=50K 679 83.52\n", "22 29 >50K 134 16.48\n", "23 30 <=50K 690 80.14\n", "24 30 >50K 171 19.86\n", "25 31 <=50K 705 79.39\n", "26 31 >50K 183 20.61\n", "27 32 <=50K 639 77.17\n", "28 32 >50K 189 22.83\n", "29 33 <=50K 684 78.17\n", ".. ... ... ... ...\n", "108 72 >50K 9 13.43\n", "109 73 <=50K 54 84.38\n", "110 73 >50K 10 15.62\n", "111 74 <=50K 39 76.47\n", "112 74 >50K 12 23.53\n", "113 75 <=50K 38 84.44\n", "114 75 >50K 7 15.56\n", "115 76 <=50K 41 89.13\n", "116 76 >50K 5 10.87\n", "117 77 <=50K 24 82.76\n", "118 77 >50K 5 17.24\n", "119 78 <=50K 18 78.26\n", "120 78 >50K 5 21.74\n", "121 79 <=50K 13 59.09\n", "122 79 >50K 9 40.91\n", "123 80 <=50K 20 90.91\n", "124 80 >50K 2 9.09\n", "125 81 <=50K 17 85.00\n", "126 81 >50K 3 15.00\n", "127 82 <=50K 12 100.00\n", "128 83 <=50K 4 66.67\n", "129 83 >50K 2 33.33\n", "130 84 <=50K 9 90.00\n", "131 84 >50K 1 10.00\n", "132 85 <=50K 3 100.00\n", "133 86 <=50K 1 100.00\n", "134 87 <=50K 1 100.00\n", "135 88 <=50K 3 100.00\n", "136 90 <=50K 35 81.40\n", "137 90 >50K 8 18.60\n", "\n", "[138 rows x 4 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'age', labelCol)\n", "df = df.orderBy('age').withColumn('percentage', \n", " format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sex</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Female</td>\n", " <td><=50K</td>\n", " <td>9592</td>\n", " <td>89.05</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Female</td>\n", " <td>>50K</td>\n", " <td>1179</td>\n", " <td>10.95</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Male</td>\n", " <td><=50K</td>\n", " <td>15128</td>\n", " <td>69.43</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Male</td>\n", " <td>>50K</td>\n", " <td>6662</td>\n", " <td>30.57</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " sex class count percentage\n", "0 Female <=50K 9592 89.05\n", "1 Female >50K 1179 10.95\n", "2 Male <=50K 15128 69.43\n", "3 Male >50K 6662 30.57" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'sex', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`crosstabPercentage` is a simple way to explore the relation between a particular categorical feature and the label. For instance, the above shows the usefulness of `sex` feature in predicting the salary. It is obvious that more men earn >50K salary than women. So if a person is male, then he is more likely to earn >50K salary." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>education</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>10th</td>\n", " <td><=50K</td>\n", " <td>871</td>\n", " <td>93.35</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>10th</td>\n", " <td>>50K</td>\n", " <td>62</td>\n", " <td>6.65</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>11th</td>\n", " <td><=50K</td>\n", " <td>1115</td>\n", " <td>94.89</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>11th</td>\n", " <td>>50K</td>\n", " <td>60</td>\n", " <td>5.11</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>12th</td>\n", " <td><=50K</td>\n", " <td>400</td>\n", " <td>92.38</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>12th</td>\n", " <td>>50K</td>\n", " <td>33</td>\n", " <td>7.62</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>1st-4th</td>\n", " <td><=50K</td>\n", " <td>162</td>\n", " <td>96.43</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>1st-4th</td>\n", " <td>>50K</td>\n", " <td>6</td>\n", " <td>3.57</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>5th-6th</td>\n", " <td><=50K</td>\n", " <td>317</td>\n", " <td>95.20</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>5th-6th</td>\n", " <td>>50K</td>\n", " <td>16</td>\n", " <td>4.80</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>7th-8th</td>\n", " <td><=50K</td>\n", " <td>606</td>\n", " <td>93.81</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>7th-8th</td>\n", " <td>>50K</td>\n", " <td>40</td>\n", " <td>6.19</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>9th</td>\n", " <td><=50K</td>\n", " <td>487</td>\n", " <td>94.75</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>9th</td>\n", " <td>>50K</td>\n", " <td>27</td>\n", " <td>5.25</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>Assoc-acdm</td>\n", " <td><=50K</td>\n", " <td>802</td>\n", " <td>75.16</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>Assoc-acdm</td>\n", " <td>>50K</td>\n", " <td>265</td>\n", " <td>24.84</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>Assoc-voc</td>\n", " <td><=50K</td>\n", " <td>1021</td>\n", " <td>73.88</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>Assoc-voc</td>\n", " <td>>50K</td>\n", " <td>361</td>\n", " <td>26.12</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>Bachelors</td>\n", " <td><=50K</td>\n", " <td>3134</td>\n", " <td>58.52</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>Bachelors</td>\n", " <td>>50K</td>\n", " <td>2221</td>\n", " <td>41.48</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>Doctorate</td>\n", " <td><=50K</td>\n", " <td>107</td>\n", " <td>25.91</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>Doctorate</td>\n", " <td>>50K</td>\n", " <td>306</td>\n", " <td>74.09</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>HS-grad</td>\n", " <td><=50K</td>\n", " <td>8826</td>\n", " <td>84.05</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>HS-grad</td>\n", " <td>>50K</td>\n", " <td>1675</td>\n", " <td>15.95</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>Masters</td>\n", " <td><=50K</td>\n", " <td>764</td>\n", " <td>44.34</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>Masters</td>\n", " <td>>50K</td>\n", " <td>959</td>\n", " <td>55.66</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>Preschool</td>\n", " <td><=50K</td>\n", " <td>51</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>Prof-school</td>\n", " <td><=50K</td>\n", " <td>153</td>\n", " <td>26.56</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>Prof-school</td>\n", " <td>>50K</td>\n", " <td>423</td>\n", " <td>73.44</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>Some-college</td>\n", " <td><=50K</td>\n", " <td>5904</td>\n", " <td>80.98</td>\n", " </tr>\n", " <tr>\n", " <th>30</th>\n", " <td>Some-college</td>\n", " <td>>50K</td>\n", " <td>1387</td>\n", " <td>19.02</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " education class count percentage\n", "0 10th <=50K 871 93.35\n", "1 10th >50K 62 6.65\n", "2 11th <=50K 1115 94.89\n", "3 11th >50K 60 5.11\n", "4 12th <=50K 400 92.38\n", "5 12th >50K 33 7.62\n", "6 1st-4th <=50K 162 96.43\n", "7 1st-4th >50K 6 3.57\n", "8 5th-6th <=50K 317 95.20\n", "9 5th-6th >50K 16 4.80\n", "10 7th-8th <=50K 606 93.81\n", "11 7th-8th >50K 40 6.19\n", "12 9th <=50K 487 94.75\n", "13 9th >50K 27 5.25\n", "14 Assoc-acdm <=50K 802 75.16\n", "15 Assoc-acdm >50K 265 24.84\n", "16 Assoc-voc <=50K 1021 73.88\n", "17 Assoc-voc >50K 361 26.12\n", "18 Bachelors <=50K 3134 58.52\n", "19 Bachelors >50K 2221 41.48\n", "20 Doctorate <=50K 107 25.91\n", "21 Doctorate >50K 306 74.09\n", "22 HS-grad <=50K 8826 84.05\n", "23 HS-grad >50K 1675 15.95\n", "24 Masters <=50K 764 44.34\n", "25 Masters >50K 959 55.66\n", "26 Preschool <=50K 51 100.00\n", "27 Prof-school <=50K 153 26.56\n", "28 Prof-school >50K 423 73.44\n", "29 Some-college <=50K 5904 80.98\n", "30 Some-college >50K 1387 19.02" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'education', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "%%script false\n", "educationNumClass = crosstabPercentage(train, 'education-num', labelCol)\n", "educationNumClass = educationNumClass.withColumn('percentage', \n", " format_number(educationNumClass['percentage-'], 2))\n", "educationNumClass = educationNumClass.withColumn('education-numClassF', educationNumClass['education-numClass'].cast(DoubleType()))\\\n", " .orderBy('education-numClassF').drop('education-numClass')\n", "cols = educationNumClass.columns\n", "cols.remove('education-numClassF')\n", "cols.insert(0, 'education-numClassF')\n", "educationNumClass = educationNumClass.select(cols)\n", "educationNumClass.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see above that this is a sparse matrix, it's hard to find the non-zero values. So we will only focus on non-zero values to find out whether there is any relationship between these features and one of them is redundant." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "%%script false\n", "\n", "from pyspark.sql.functions import coalesce, lit, when\n", "\n", "iterator = df.toLocalIterator()\n", "d = {}\n", "for row in iterator:\n", " rowDict = row.asDict()\n", " educationNum = rowDict['education-num_education']\n", " for k, v in rowDict.items():\n", " if k != 'education-num_education' and v != 0:\n", " d[educationNum+'_'+k] = v\n", "\n", "import json\n", "s = json.dumps(d, indent=4)\n", "print(s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see it's obvious that these features are redundant. Only one of them should suffice for our classification task.\n", "\n", "Let's try more rigorous chi square test instead of something hand-wavy.\n", "\n", "First we will define an utility method that'll index the catgorical string columns, encodes them into one-hot-encoded vectors, and finally assemble all the feature vectos into once vector for later downstream analysis." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>workclass</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>?</td>\n", " <td><=50K</td>\n", " <td>1645</td>\n", " <td>89.60</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>?</td>\n", " <td>>50K</td>\n", " <td>191</td>\n", " <td>10.40</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Federal-gov</td>\n", " <td><=50K</td>\n", " <td>589</td>\n", " <td>61.35</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Federal-gov</td>\n", " <td>>50K</td>\n", " <td>371</td>\n", " <td>38.65</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Local-gov</td>\n", " <td><=50K</td>\n", " <td>1476</td>\n", " <td>70.52</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>Local-gov</td>\n", " <td>>50K</td>\n", " <td>617</td>\n", " <td>29.48</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>Never-worked</td>\n", " <td><=50K</td>\n", " <td>7</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>Private</td>\n", " <td><=50K</td>\n", " <td>17733</td>\n", " <td>78.13</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>Private</td>\n", " <td>>50K</td>\n", " <td>4963</td>\n", " <td>21.87</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>Self-emp-inc</td>\n", " <td><=50K</td>\n", " <td>494</td>\n", " <td>44.27</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>Self-emp-inc</td>\n", " <td>>50K</td>\n", " <td>622</td>\n", " <td>55.73</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>Self-emp-not-inc</td>\n", " <td><=50K</td>\n", " <td>1817</td>\n", " <td>71.51</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>Self-emp-not-inc</td>\n", " <td>>50K</td>\n", " <td>724</td>\n", " <td>28.49</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>State-gov</td>\n", " <td><=50K</td>\n", " <td>945</td>\n", " <td>72.80</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>State-gov</td>\n", " <td>>50K</td>\n", " <td>353</td>\n", " <td>27.20</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>Without-pay</td>\n", " <td><=50K</td>\n", " <td>14</td>\n", " <td>100.00</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " workclass class count percentage\n", "0 ? <=50K 1645 89.60\n", "1 ? >50K 191 10.40\n", "2 Federal-gov <=50K 589 61.35\n", "3 Federal-gov >50K 371 38.65\n", "4 Local-gov <=50K 1476 70.52\n", "5 Local-gov >50K 617 29.48\n", "6 Never-worked <=50K 7 100.00\n", "7 Private <=50K 17733 78.13\n", "8 Private >50K 4963 21.87\n", "9 Self-emp-inc <=50K 494 44.27\n", "10 Self-emp-inc >50K 622 55.73\n", "11 Self-emp-not-inc <=50K 1817 71.51\n", "12 Self-emp-not-inc >50K 724 28.49\n", "13 State-gov <=50K 945 72.80\n", "14 State-gov >50K 353 27.20\n", "15 Without-pay <=50K 14 100.00" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'workclass', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>hours-per-week</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td><=50K</td>\n", " <td>18</td>\n", " <td>90.00</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>10.00</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td><=50K</td>\n", " <td>24</td>\n", " <td>75.00</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>2</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>25.00</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>3</td>\n", " <td><=50K</td>\n", " <td>38</td>\n", " <td>97.44</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>3</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>2.56</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>4</td>\n", " <td><=50K</td>\n", " <td>51</td>\n", " <td>94.44</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>4</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>5.56</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>5</td>\n", " <td><=50K</td>\n", " <td>53</td>\n", " <td>88.33</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>5</td>\n", " <td>>50K</td>\n", " <td>7</td>\n", " <td>11.67</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>6</td>\n", " <td><=50K</td>\n", " <td>56</td>\n", " <td>87.50</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>6</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>12.50</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>7</td>\n", " <td><=50K</td>\n", " <td>22</td>\n", " <td>84.62</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>7</td>\n", " <td>>50K</td>\n", " <td>4</td>\n", " <td>15.38</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>8</td>\n", " <td><=50K</td>\n", " <td>134</td>\n", " <td>92.41</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>8</td>\n", " <td>>50K</td>\n", " <td>11</td>\n", " <td>7.59</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>9</td>\n", " <td><=50K</td>\n", " <td>17</td>\n", " <td>94.44</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>9</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>5.56</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>10</td>\n", " <td><=50K</td>\n", " <td>258</td>\n", " <td>92.81</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>10</td>\n", " <td>>50K</td>\n", " <td>20</td>\n", " <td>7.19</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>11</td>\n", " <td><=50K</td>\n", " <td>11</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>12</td>\n", " <td><=50K</td>\n", " <td>161</td>\n", " <td>93.06</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>12</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>6.94</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>13</td>\n", " <td><=50K</td>\n", " <td>21</td>\n", " <td>91.30</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>13</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>8.70</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>14</td>\n", " <td><=50K</td>\n", " <td>32</td>\n", " <td>94.12</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>14</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>5.88</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>15</td>\n", " <td><=50K</td>\n", " <td>389</td>\n", " <td>96.29</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>15</td>\n", " <td>>50K</td>\n", " <td>15</td>\n", " <td>3.71</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>16</td>\n", " <td><=50K</td>\n", " <td>192</td>\n", " <td>93.66</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>143</th>\n", " <td>78</td>\n", " <td><=50K</td>\n", " <td>6</td>\n", " <td>75.00</td>\n", " </tr>\n", " <tr>\n", " <th>144</th>\n", " <td>78</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>25.00</td>\n", " </tr>\n", " <tr>\n", " <th>145</th>\n", " <td>80</td>\n", " <td><=50K</td>\n", " <td>76</td>\n", " <td>57.14</td>\n", " </tr>\n", " <tr>\n", " <th>146</th>\n", " <td>80</td>\n", " <td>>50K</td>\n", " <td>57</td>\n", " <td>42.86</td>\n", " </tr>\n", " <tr>\n", " <th>147</th>\n", " <td>81</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>148</th>\n", " <td>82</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>149</th>\n", " <td>84</td>\n", " <td><=50K</td>\n", " <td>28</td>\n", " <td>62.22</td>\n", " </tr>\n", " <tr>\n", " <th>150</th>\n", " <td>84</td>\n", " <td>>50K</td>\n", " <td>17</td>\n", " <td>37.78</td>\n", " </tr>\n", " <tr>\n", " <th>151</th>\n", " <td>85</td>\n", " <td><=50K</td>\n", " <td>9</td>\n", " <td>69.23</td>\n", " </tr>\n", " <tr>\n", " <th>152</th>\n", " <td>85</td>\n", " <td>>50K</td>\n", " <td>4</td>\n", " <td>30.77</td>\n", " </tr>\n", " <tr>\n", " <th>153</th>\n", " <td>86</td>\n", " <td><=50K</td>\n", " <td>2</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>154</th>\n", " <td>87</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>155</th>\n", " <td>88</td>\n", " <td><=50K</td>\n", " <td>2</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>156</th>\n", " <td>89</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>157</th>\n", " <td>89</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>158</th>\n", " <td>90</td>\n", " <td><=50K</td>\n", " <td>19</td>\n", " <td>65.52</td>\n", " </tr>\n", " <tr>\n", " <th>159</th>\n", " <td>90</td>\n", " <td>>50K</td>\n", " <td>10</td>\n", " <td>34.48</td>\n", " </tr>\n", " <tr>\n", " <th>160</th>\n", " <td>91</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>161</th>\n", " <td>92</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>162</th>\n", " <td>94</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>163</th>\n", " <td>95</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>164</th>\n", " <td>95</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>165</th>\n", " <td>96</td>\n", " <td><=50K</td>\n", " <td>4</td>\n", " <td>80.00</td>\n", " </tr>\n", " <tr>\n", " <th>166</th>\n", " <td>96</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>20.00</td>\n", " </tr>\n", " <tr>\n", " <th>167</th>\n", " <td>97</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>168</th>\n", " <td>97</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>169</th>\n", " <td>98</td>\n", " <td><=50K</td>\n", " <td>8</td>\n", " <td>72.73</td>\n", " </tr>\n", " <tr>\n", " <th>170</th>\n", " <td>98</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>27.27</td>\n", " </tr>\n", " <tr>\n", " <th>171</th>\n", " <td>99</td>\n", " <td><=50K</td>\n", " <td>60</td>\n", " <td>70.59</td>\n", " </tr>\n", " <tr>\n", " <th>172</th>\n", " <td>99</td>\n", " <td>>50K</td>\n", " <td>25</td>\n", " <td>29.41</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>173 rows × 4 columns</p>\n", "</div>" ], "text/plain": [ " hours-per-week class count percentage\n", "0 1 <=50K 18 90.00\n", "1 1 >50K 2 10.00\n", "2 2 <=50K 24 75.00\n", "3 2 >50K 8 25.00\n", "4 3 <=50K 38 97.44\n", "5 3 >50K 1 2.56\n", "6 4 <=50K 51 94.44\n", "7 4 >50K 3 5.56\n", "8 5 <=50K 53 88.33\n", "9 5 >50K 7 11.67\n", "10 6 <=50K 56 87.50\n", "11 6 >50K 8 12.50\n", "12 7 <=50K 22 84.62\n", "13 7 >50K 4 15.38\n", "14 8 <=50K 134 92.41\n", "15 8 >50K 11 7.59\n", "16 9 <=50K 17 94.44\n", "17 9 >50K 1 5.56\n", "18 10 <=50K 258 92.81\n", "19 10 >50K 20 7.19\n", "20 11 <=50K 11 100.00\n", "21 12 <=50K 161 93.06\n", "22 12 >50K 12 6.94\n", "23 13 <=50K 21 91.30\n", "24 13 >50K 2 8.70\n", "25 14 <=50K 32 94.12\n", "26 14 >50K 2 5.88\n", "27 15 <=50K 389 96.29\n", "28 15 >50K 15 3.71\n", "29 16 <=50K 192 93.66\n", ".. ... ... ... ...\n", "143 78 <=50K 6 75.00\n", "144 78 >50K 2 25.00\n", "145 80 <=50K 76 57.14\n", "146 80 >50K 57 42.86\n", "147 81 <=50K 3 100.00\n", "148 82 <=50K 1 100.00\n", "149 84 <=50K 28 62.22\n", "150 84 >50K 17 37.78\n", "151 85 <=50K 9 69.23\n", "152 85 >50K 4 30.77\n", "153 86 <=50K 2 100.00\n", "154 87 <=50K 1 100.00\n", "155 88 <=50K 2 100.00\n", "156 89 <=50K 1 50.00\n", "157 89 >50K 1 50.00\n", "158 90 <=50K 19 65.52\n", "159 90 >50K 10 34.48\n", "160 91 <=50K 3 100.00\n", "161 92 <=50K 1 100.00\n", "162 94 <=50K 1 100.00\n", "163 95 <=50K 1 50.00\n", "164 95 >50K 1 50.00\n", "165 96 <=50K 4 80.00\n", "166 96 >50K 1 20.00\n", "167 97 <=50K 1 50.00\n", "168 97 >50K 1 50.00\n", "169 98 <=50K 8 72.73\n", "170 98 >50K 3 27.27\n", "171 99 <=50K 60 70.59\n", "172 99 >50K 25 29.41\n", "\n", "[173 rows x 4 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'hours-per-week', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Outlier Detection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use Attribute-Value Frequency (AVF) outlier detection in categorical features. The beauty of this algorithm is that it's very simple, highly parallelizable, and fit well with distributed programming paradigm. `attributeValueFrequency` function is implemented in `pysparkutils.py` file in `utils` directory." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>avfScore</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>64873</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>84761</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>85760</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>103252</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>52051</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>49136</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>91948</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>92741</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>99489</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>89041</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>95526</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>78598</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>84745</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>76448</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>80545</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>97699</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>86445</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>95149</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>101519</td>\n", " <td>5</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>99688</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>96113</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>86132</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>74783</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>88291</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>75232</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>81293</td>\n", " <td>5</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>66091</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>73190</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>107959</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>77034</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>8159</th>\n", " <td>121083</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8160</th>\n", " <td>90753</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8161</th>\n", " <td>83486</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8162</th>\n", " <td>88325</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8163</th>\n", " <td>96995</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8164</th>\n", " <td>100990</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8165</th>\n", " <td>90597</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8166</th>\n", " <td>100460</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8167</th>\n", " <td>92453</td>\n", " <td>14</td>\n", " </tr>\n", " <tr>\n", " <th>8168</th>\n", " <td>95690</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8169</th>\n", " <td>104486</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8170</th>\n", " <td>119995</td>\n", " <td>13</td>\n", " </tr>\n", " <tr>\n", " <th>8171</th>\n", " <td>131889</td>\n", " <td>9</td>\n", " </tr>\n", " <tr>\n", " <th>8172</th>\n", " <td>110660</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8173</th>\n", " <td>102790</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8174</th>\n", " <td>122394</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8175</th>\n", " <td>124053</td>\n", " <td>76</td>\n", " </tr>\n", " <tr>\n", " <th>8176</th>\n", " <td>134386</td>\n", " <td>45</td>\n", " </tr>\n", " <tr>\n", " <th>8177</th>\n", " <td>113783</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8178</th>\n", " <td>94066</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8179</th>\n", " <td>73337</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8180</th>\n", " <td>91838</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8181</th>\n", " <td>90865</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8182</th>\n", " <td>100947</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8183</th>\n", " <td>71321</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8184</th>\n", " <td>110738</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8185</th>\n", " <td>99204</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8186</th>\n", " <td>38322</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8187</th>\n", " <td>89328</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8188</th>\n", " <td>65479</td>\n", " <td>1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>8189 rows × 2 columns</p>\n", "</div>" ], "text/plain": [ " avfScore count\n", "0 64873 1\n", "1 84761 1\n", "2 85760 1\n", "3 103252 1\n", "4 52051 1\n", "5 49136 1\n", "6 91948 1\n", "7 92741 2\n", "8 99489 1\n", "9 89041 1\n", "10 95526 1\n", "11 78598 1\n", "12 84745 1\n", "13 76448 1\n", "14 80545 1\n", "15 97699 1\n", "16 86445 1\n", "17 95149 3\n", "18 101519 5\n", "19 99688 4\n", "20 96113 1\n", "21 86132 1\n", "22 74783 1\n", "23 88291 2\n", "24 75232 1\n", "25 81293 5\n", "26 66091 1\n", "27 73190 1\n", "28 107959 1\n", "29 77034 1\n", "... ... ...\n", "8159 121083 1\n", "8160 90753 1\n", "8161 83486 1\n", "8162 88325 3\n", "8163 96995 1\n", "8164 100990 4\n", "8165 90597 1\n", "8166 100460 1\n", "8167 92453 14\n", "8168 95690 4\n", "8169 104486 2\n", "8170 119995 13\n", "8171 131889 9\n", "8172 110660 1\n", "8173 102790 1\n", "8174 122394 2\n", "8175 124053 76\n", "8176 134386 45\n", "8177 113783 3\n", "8178 94066 1\n", "8179 73337 1\n", "8180 91838 2\n", "8181 90865 1\n", "8182 100947 3\n", "8183 71321 1\n", "8184 110738 4\n", "8185 99204 1\n", "8186 38322 1\n", "8187 89328 1\n", "8188 65479 1\n", "\n", "[8189 rows x 2 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import seaborn as sns\n", "\n", "avfScoreCol = 'avfScore'\n", "categoricalCols = ['workclass', 'education', 'marital-status',\n", " 'occupation', 'relationship', 'race', 'sex',\n", " 'native-country']\n", "avfScore = attributeValueFrequency(train, categoricalCols)\n", "pdf = avfScore.groupby(avfScoreCol).count().toPandas()\n", "pdf" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7ff049874978>]], dtype=object)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEICAYAAABWJCMKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHSxJREFUeJzt3XuUHOWZ3/Hvz4iLQDaSuExkSbGE0doLVoxhDhd747SQVwhMEJtj1mJ1jMBylLOLHXtXObZY4ijmsge8YDBnd7EVo1gQjJAxBB3Axoqgs/GeRYC4CQFaDSDDIEUy6IIHMNmxn/xR70BpmJ7pmanpnlb9Puf06aqn3qp63q6Zfrou3aWIwMzMyud9zU7AzMyawwXAzKykXADMzErKBcDMrKRcAMzMSsoFwMyspFwAzHIk/amkHZK6JB3R7HzMRpILgFki6UDgO8CciBgXEa9JmifpCUmvS3pV0jpJ05qbqVkxxjQ7AbNRpA04BNgEIOlY4Gbg3wEPAOOAOcDvilqhJAGKiMKWaVYv7wHYfknSUknPS/q1pGck/ZGkgyXtkfSxXLujJL0l6WRgcwrvkfQAcALwYkSsi8yvI+InEfFSmvcASX+ZW88GSVPTtE9KekTS3vT8ydw6q5KulPQPwJvAMZIOl3STpO2SXpF0haQDGvV6WTm5ANj+6nngXwOHA98C/gcwEbgTOD/X7o+B/x0RDwPHp9j4iDgdeAz4qKTrJM2SNK7XOv4iLess4APAF4E3JU0E7gVuAI4gO6x0b69zCl8AFgPvB34JrAS6gWOBT5DtaXxpuC+CWX9cAGy/FBE/johtEfG7iLgd2AKcDPyIfQvAn6RYX8t4AagAk4HVwKuSfpgrBF8C/nNEbE57CE9GxGvAZ4EtEXFLRHRHxG3Ac8C/zS3+hxGxKSK6yQrTmcDXIuKNiNgJXAfML+TFMKvBBcD2S5IuSCdv90jaA3wMOJLsWP5YSadI+hDZYZ67ai0nIh6KiD+OiKPI9ig+DVyaJk8l29Po7YNkn+rzfklWSHq8nBv+EHAgsD2X7/eBo+vsrtmQ+CSw7XfSG/t/A2YD/xgRv5X0BOlkq6TVZHsBO4B7IuLX9Sw3Ih6RdCdZMYHsTfzDwNO9mm4je1PP+5fAz/KLyw2/DLwNHJn2CMwawnsAtj86jOwN9lcAki7i3TdtyA75fB5YQI3DP2m+P5D07yUdncY/CpwDPJSa/AC4XNIMZf5VOs5/H/B7kv5E0hhJnweOA+7paz0RsR34OXCtpA9Iep+kD0v6N0N+Bczq4AJg+52IeAa4FvhHsk/5M4F/yE1fD7xBdqjmp/0sag/ZG/5GSV1kn+DvAr6dpn+H7NzAz4HXgZuAsek8wNnAEuA14OvA2RHxaj/rugA4CHgG2A3cAUyqu9NmQyDfEMbMrJy8B2BmVlIuAGZmJeUCYGZWUi4AZmYlNaq/B3DkkUfGtGnTClnWG2+8wWGHHVbIspql1fvg/Juv1fvQ6vlDY/qwYcOGV9OXF/s1qgvAtGnTePTRRwtZVrVapVKpFLKsZmn1Pjj/5mv1PrR6/tCYPkjq/U30PvkQkJlZSbkAmJmVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSbkAmJmVlAuAmVlJjepvApvZ6DJt6b1NWe/Wqz7blPXu77wHYGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlJ1FQBJfy5pk6SnJd0m6RBJ0yWtl7RF0u2SDkptD07jHWn6tNxyLknxzZLOGJkumZlZPQYsAJImA/8RaI+IjwEHAPOBq4HrImIGsBtYlGZZBOyOiGOB61I7JB2X5jsemAv8naQDiu2OmZnVq95DQGOAsZLGAIcC24HTgTvS9JXAuWl4XhonTZ8tSSm+KiLejogXgQ7g5OF3wczMhmLAAhARrwDXAC+RvfHvBTYAeyKiOzXrBCan4cnAy2ne7tT+iHy8j3nMzKzBBvwmsKQJZJ/epwN7gB8DZ/bRNHpmqTGtVrz3+hYDiwHa2tqoVqsDpViXrq6uwpbVLK3eB+fffMPtw5KZ3QM3GgE9OXsbFKuen4L4DPBiRPwKQNKdwCeB8ZLGpE/5U4BtqX0nMBXoTIeMDgd25eI98vO8IyKWA8sB2tvbo6ibJ/tm0s3n/JtvuH24sFk/BbGgAngbFK2ecwAvAadKOjQdy58NPAM8CHwutVkI3J2G16Rx0vQHIiJSfH66Smg6MAN4uJhumJnZYA24BxAR6yXdATwGdAOPk31CvxdYJemKFLspzXITcIukDrJP/vPTcjZJWk1WPLqBiyPitwX3x8zM6lTXr4FGxDJgWa/wC/RxFU9E/AY4r8ZyrgSuHGSOZmY2AvxNYDOzkvL9AMzMahiJ+x8smdld18n0RtwDwXsAZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSAxYASR+R9ETu8bqkr0maKGmtpC3peUJqL0k3SOqQ9JSkE3PLWpjab5G0sPZazcxspA1YACJic0ScEBEnACcBbwJ3AUuBdRExA1iXxgHOJLvf7wxgMXAjgKSJZHcVO4XsTmLLeoqGmZk13mAPAc0Gno+IXwLzgJUpvhI4Nw3PA26OzEPAeEmTgDOAtRGxKyJ2A2uBucPugZmZDYkiov7G0grgsYj4G0l7ImJ8btruiJgg6R7gqoj4RYqvA74BVIBDIuKKFP8m8FZEXNNrHYvJ9hxoa2s7adWqVcPqYI+uri7GjRtXyLKapdX74Pybb7h92PjK3gKzqd/MyYcDjd8GI9HftrGw462B2/X0eShmzZq1ISLaB2pX9y0hJR0EnANcMlDTPmLRT3zfQMRyYDlAe3t7VCqVelPsV7VapahlNUur98H5N99w+1DPrQxHwtYFFaDx22Ak+rtkZjfXbhz4rbenzyNpMIeAziT79L8jje9Ih3ZIzztTvBOYmptvCrCtn7iZmTXBYArA+cBtufE1QM+VPAuBu3PxC9LVQKcCeyNiO3A/MEfShHTyd06KmZlZE9R1CEjSocAfAv8hF74KWC1pEfAScF6K3wecBXSQXTF0EUBE7JJ0OfBIandZROwadg/MzGxI6ioAEfEmcESv2GtkVwX1bhvAxTWWswJYMfg0zcysaP4msJlZSbkAmJmVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSdX9W0BmZs0yLf0mz5KZ3U37PaL9kfcAzMxKygXAzKykXADMzErKBcDMrKRcAMzMSsoFwMyspFwAzMxKygXAzKyk6ioAksZLukPSc5KelXSapImS1krakp4npLaSdIOkDklPSToxt5yFqf0WSQtrr9HMzEZavXsA3wV+FhEfBT4OPAssBdZFxAxgXRqH7ObxM9JjMXAjgKSJwDLgFOBkYFlP0TAzs8YbsABI+gDwaeAmgIj4fxGxB5gHrEzNVgLnpuF5wM2ReQgYL2kScAawNiJ2RcRuYC0wt9DemJlZ3ZTdwrefBtIJwHLgGbJP/xuArwKvRMT4XLvdETFB0j3AVRHxixRfB3wDqACHRMQVKf5N4K2IuKbX+haT7TnQ1tZ20qpVq4roJ11dXYwbN66QZTVLq/fB+TffcPuw8ZW9BWYzeG1jYcdbTU1h2Ortw8zJhw95HbNmzdoQEe0Dtavnx+DGACcCX4mI9ZK+y7uHe/qiPmLRT3zfQMRysoJDe3t7VCqVOlIcWLVapahlNUur98H5N99w+9DsH2JbMrObaze29m9Y1tuHrQsqI55LPecAOoHOiFifxu8gKwg70qEd0vPOXPupufmnANv6iZuZWRMMWAAi4v8CL0v6SArNJjsctAbouZJnIXB3Gl4DXJCuBjoV2BsR24H7gTmSJqSTv3NSzMzMmqDefamvALdKOgh4AbiIrHislrQIeAk4L7W9DzgL6ADeTG2JiF2SLgceSe0ui4hdhfTCzMwGra4CEBFPAH2dUJjdR9sALq6xnBXAisEkaGZmI8PfBDYzKykXADOzknIBMDMrKRcAM7OScgEwMyspFwAzs5JyATAzKykXADOzknIBMDMrKRcAM7OScgEwMyspFwAzs5JyATAzKykXADOzknIBMDMrqboKgKStkjZKekLSoyk2UdJaSVvS84QUl6QbJHVIekrSibnlLEztt0haWGt9ZmY28gazBzArIk7I3Wl+KbAuImYA63j3RvFnAjPSYzFwI2QFA1gGnAKcDCzrKRpmZtZ4wzkENA9YmYZXAufm4jdH5iFgfLpp/BnA2ojYFRG7gbXA3GGs38zMhqHeAhDAzyVtkLQ4xdrSzd5Jz0en+GTg5dy8nSlWK25mZk1Q703hPxUR2yQdDayV9Fw/bdVHLPqJ7ztzVmAWA7S1tVGtVutMsX9dXV2FLatZWr0Pzr/5htuHJTO7i0tmCNrGNj+H4aq3D434W6v3pvDb0vNOSXeRHcPfIWlSRGxPh3h2puadwNTc7FOAbSle6RWv9rGu5cBygPb29qhUKr2bDEm1WqWoZTVLq/fB+TffcPtw4dJ7i0tmCJbM7ObajfV+bh2d6u3D1gWVEc9lwENAkg6T9P6eYWAO8DSwBui5kmchcHcaXgNckK4GOhXYmw4R3Q/MkTQhnfydk2JmZtYE9ZTSNuAuST3tfxQRP5P0CLBa0iLgJeC81P4+4CygA3gTuAggInZJuhx4JLW7LCJ2FdYTMzMblAELQES8AHy8j/hrwOw+4gFcXGNZK4AVg0/TzMyK5m8Cm5mVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSbkAmJmVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSbkAmJmVlAuAmVlJuQCYmZWUC4CZWUnVXQAkHSDpcUn3pPHpktZL2iLpdkkHpfjBabwjTZ+WW8YlKb5Z0hlFd8bMzOo3mLsrfxV4FvhAGr8auC4iVkn6HrAIuDE9746IYyXNT+0+L+k4YD5wPPBB4H9J+r2I+G1BfTErhWnDuDH7kpndTb+xu40ede0BSJoCfBb4QRoXcDpwR2qyEjg3Dc9L46Tps1P7ecCqiHg7Il4ku2fwyUV0wszMBq/ePYDrga8D70/jRwB7IqI7jXcCk9PwZOBlgIjolrQ3tZ8MPJRbZn6ed0haDCwGaGtro1qt1tuXfnV1dRW2rGZp9T44/2Ismdk9cKMa2sYOb/5ma/X8of4+NOJvbcACIOlsYGdEbJBU6Qn30TQGmNbfPO8GIpYDywHa29ujUqn0bjIk1WqVopbVLK3eB+dfjOEcwlkys5trNw7myO/o0ur5Q/192LqgMuK51PNKfgo4R9JZwCFk5wCuB8ZLGpP2AqYA21L7TmAq0ClpDHA4sCsX75Gfx8zMGmzAcwARcUlETImIaWQncR+IiAXAg8DnUrOFwN1peE0aJ01/ICIixeenq4SmAzOAhwvriZmZDcpw9qW+AaySdAXwOHBTit8E3CKpg+yT/3yAiNgkaTXwDNANXOwrgMzMmmdQBSAiqkA1Db9AH1fxRMRvgPNqzH8lcOVgkzQzs+L5m8BmZiXlAmBmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUq39u6pmTTScO3OZjQbeAzAzKykXADOzknIBMDMrKRcAM7OScgEwMyupAQuApEMkPSzpSUmbJH0rxadLWi9pi6TbJR2U4gen8Y40fVpuWZek+GZJZ4xUp8zMbGD17AG8DZweER8HTgDmSjoVuBq4LiJmALuBRan9ImB3RBwLXJfaIek4sttDHg/MBf5O0gFFdsbMzOpXz03hIyK60uiB6RHA6cAdKb4SODcNz0vjpOmzJSnFV0XE2xHxItBBH7eUNDOzxqjri2Dpk/oG4Fjgb4HngT0R0Z2adAKT0/Bk4GWAiOiWtBc4IsUfyi02P09+XYuBxQBtbW1Uq9XB9aiGrq6uwpbVLK3eh/0t/yUzu2s3HqXaxrZm3j1aPX+ovw+N+F+pqwBExG+BEySNB+4Cfr+vZulZNabVivde13JgOUB7e3tUKpV6UhxQtVqlqGU1S6v3YX/L/8IW/CbwkpndXLuxdX8AoNXzh/r7sHVBZcRzGdRVQBGxB6gCpwLjJfX0YgqwLQ13AlMB0vTDgV35eB/zmJlZg9VzFdBR6ZM/ksYCnwGeBR4EPpeaLQTuTsNr0jhp+gMRESk+P10lNB2YATxcVEfMzGxw6tmXmgSsTOcB3gesjoh7JD0DrJJ0BfA4cFNqfxNwi6QOsk/+8wEiYpOk1cAzQDdwcTq0ZGZmTTBgAYiIp4BP9BF/gT6u4omI3wDn1VjWlcCVg0/TzMyK5m8Cm5mVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSbkAmJmVlAuAmVlJuQCYmZWUC4CZWUm5AJiZlZQLgJlZSbkAmJmVlAuAmVlJuQCYmZVUPXcEmyrpQUnPStok6aspPlHSWklb0vOEFJekGyR1SHpK0om5ZS1M7bdIWlhrnWZmNvLq2QPoBpZExO+T3Qv4YknHAUuBdRExA1iXxgHOJLvd4wxgMXAjZAUDWAacQnYjmWU9RcPMzBpvwAIQEdsj4rE0/Guy+wFPBuYBK1OzlcC5aXgecHNkHiK7efwk4AxgbUTsiojdwFpgbqG9MTOzuim7X3udjaVpwN8DHwNeiojxuWm7I2KCpHuAqyLiFym+DvgGUAEOiYgrUvybwFsRcU2vdSwm23Ogra3tpFWrVg25c3ldXV2MGzeukGU1S6v3YX/Lf+Mre5uYzdC0jYUdbzU7i6Fr9fyh/j7MnHz4kNcxa9asDRHRPlC7em4KD4CkccBPgK9FxOuSajbtIxb9xPcNRCwHlgO0t7dHpVKpN8V+VatVilpWs7R6H/a3/C9cem/zkhmiJTO7uXZj3f/2o06r5w/192HrgsqI51LXVUCSDiR78781Iu5M4R3p0A7peWeKdwJTc7NPAbb1Ezczsyao5yogATcBz0bEd3KT1gA9V/IsBO7OxS9IVwOdCuyNiO3A/cAcSRPSyd85KWZmZk1Qz77Up4AvABslPZFifwlcBayWtAh4CTgvTbsPOAvoAN4ELgKIiF2SLgceSe0ui4hdhfTCzMwGbcACkE7m1jrgP7uP9gFcXGNZK4AVg0nQzMxGhr8JbGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUq19ax0rvWkNvCvXkpndLXkXMLNavAdgZlZSLgBmZiVVzy0hV0jaKenpXGyipLWStqTnCSkuSTdI6pD0lKQTc/MsTO23SFrY17rMzKxx6tkD+CEwt1dsKbAuImYA69I4wJnAjPRYDNwIWcEAlgGnACcDy3qKhpmZNceABSAi/h7ofe/eecDKNLwSODcXvzkyDwHjJU0CzgDWRsSuiNgNrOW9RcXMzBpoqFcBtUXEdoCI2C7p6BSfDLyca9eZYrXi7yFpMdneA21tbVSr1SGmuK+urq7CltUsrd6Hkch/yczuQpfXn7axjV3fSGj1PrR6/lB/Hxrxv170ZaB93Tw++om/NxixHFgO0N7eHpVKpZDEqtUqRS2rWVq9DyORfyMvy1wys5trN7b2ldOt3odWzx/q78PWBZURz2WoVwHtSId2SM87U7wTmJprNwXY1k/czMyaZKildA2wELgqPd+di39Z0iqyE7570yGi+4G/yp34nQNcMvS0bbSp5wtZ/iKV2egyYAGQdBtQAY6U1El2Nc9VwGpJi4CXgPNS8/uAs4AO4E3gIoCI2CXpcuCR1O6yiOh9YtnMzBpowAIQEefXmDS7j7YBXFxjOSuAFYPKzszMRoy/CWxmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSLgBmZiXlAmBmVlIuAGZmJeUCYGZWUi4AZmYl5QJgZlZSLgBmZiXV2ndWsH3U85PMZmY9vAdgZlZSLgBmZiXlAmBmVlINPwcgaS7wXeAA4AcRcVWjcxhpI3Us3rdUNLMiNXQPQNIBwN8CZwLHAedLOq6ROZiZWabRewAnAx0R8QJAunn8POCZkVhZ/pO4Pz2bme1L2W18G7Qy6XPA3Ij4Uhr/AnBKRHw512YxsDiNfgTYXNDqjwReLWhZzdLqfXD+zdfqfWj1/KExffhQRBw1UKNG7wGoj9g+FSgilgPLC1+x9GhEtBe93EZq9T44/+Zr9T60ev4wuvrQ6KuAOoGpufEpwLYG52BmZjS+ADwCzJA0XdJBwHxgTYNzMDMzGnwIKCK6JX0ZuJ/sMtAVEbGpQasv/LBSE7R6H5x/87V6H1o9fxhFfWjoSWAzMxs9/E1gM7OScgEwMyuplioAkqZKelDSs5I2Sfpqik+UtFbSlvQ8IcUl6QZJHZKeknRiblkLU/stkhbm4idJ2pjmuUFSX5euDrcfB0h6XNI9aXy6pPUpl9vTCXIkHZzGO9L0abllXJLimyWdkYvPTbEOSUuLzj2tY7ykOyQ9l7bFaa20DST9efr7eVrSbZIOGe3bQNIKSTslPZ2LjfhrXmsdBeX/1+lv6ClJd0kan5s2qNd2KNuviD7kpv0nSSHpyDQ+6rZBnyKiZR7AJODENPx+4J/IflLi28DSFF8KXJ2GzwJ+Svb9g1OB9Sk+EXghPU9IwxPStIeB09I8PwXOHIF+/AXwI+CeNL4amJ+Gvwf8aRr+M+B7aXg+cHsaPg54EjgYmA48T3ZS/YA0fAxwUGpz3AjkvxL4Uho+CBjfKtsAmAy8CIzNvfYXjvZtAHwaOBF4Ohcb8de81joKyn8OMCYNX53Lf9Cv7WC3X1F9SPGpZBe2/BI4crRugz77VNSCmvEA7gb+kOzbwpNSbBKwOQ1/Hzg/135zmn4+8P1c/PspNgl4Lhffp11BOU8B1gGnA/ekjf1q7h/hNOD+NHw/cFoaHpPaCbgEuCS3zPvTfO/Mm+L7tCso/w+QvYGqV7wltgFZAXg5/QOOSdvgjFbYBsA09n0DHfHXvNY6isi/17Q/Am7t6zUb6LUdyv9QkX0A7gA+Dmzl3QIwKrdB70dLHQLKS7tynwDWA20RsR0gPR+dmvX8s/foTLH+4p19xIt0PfB14Hdp/AhgT0R097HOd/JM0/em9oPtV5GOAX4F/Hdlh7F+IOkwWmQbRMQrwDXAS8B2std0A621DXo04jWvtY6ifZHsUy8D5NlXfCj/Q4WQdA7wSkQ82WtSS2yDliwAksYBPwG+FhGv99e0j1gMIV4ISWcDOyNiQz7czzpHVf7JGLLd4Bsj4hPAG2S7pbWMqj6k46fzyA4tfBA4jOzXaWutc1TlX6eWylnSpUA3cGtPqEY+Q8l/xPom6VDgUuC/9DW5xnpH1TZouQIg6UCyN/9bI+LOFN4haVKaPgnYmeK1fnqiv/iUPuJF+RRwjqStwCqyw0DXA+Ml9XwpL7/Od/JM0w8Hdg2Q/0j/1EYn0BkR69P4HWQFoVW2wWeAFyPiVxHxz8CdwCdprW3QoxGvea11FCKdBD0bWBDpGMcQ8n+VwW+/InyY7IPEk+l/egrwmKR/MYQ+NGcbFHUsqREPsip5M3B9r/hfs+9Jkm+n4c+y74mYh1N8Itlx7Anp8SIwMU17JLXtORFz1gj1pcK7J4F/zL4nsP4sDV/MviewVqfh49n3JNkLZCfIxqTh6bx7kuz4Ecj9/wAfScP/Nb3+LbENgFOATcChafkrga+0wjbgvecARvw1r7WOgvKfS/ZT8Ef1ajfo13aw26+oPvSatpV3zwGMym3wnpyLWlAjHsAfkO0WPQU8kR5nkR3TWwdsSc89L6jIbkDzPLARaM8t64tAR3pclIu3A0+nef6GYZwwGqAvFd4tAMeQXQHQkf6QD07xQ9J4R5p+TG7+S1OOm8ldJZNej39K0y4dodxPAB5N2+F/pj/kltkGwLeA59I6biF7oxnV2wC4jeycxT+TfVpc1IjXvNY6Csq/g+x4eM//8veG+toOZfsV0Yde07fybgEYddugr4d/CsLMrKRa7hyAmZkVwwXAzKykXADMzErKBcDMrKRcAMzMSsoFwMyspFwAzMxK6v8DGIPNHOq9sRAAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# sns.barplot(x=\"avfScore\", y=\"count\", data=pdf)\n", "# pdf['count'].hist(by=pdf['avfScore'])\n", "# pdf.plot(x='avfScore', y='count')\n", "# pdf.plot(x='avfScore', y='count', kind='bar')\n", "\n", "# Ideally, we want to use Spark for aggregation, and just plot the data by converting to \n", "# Pandas. Unfortuantely I couldn't figure out a way yet, visualization is not my strongest skill.\n", "# This approach is NOT recommended for large datasets which are residing over multiple machines\n", "# Since this will bring whole dataframe to the driver node and the driver node might run out of\n", "# memory.\n", "avfScore.select(avfScoreCol).toPandas().hist(avfScoreCol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In AVF, the lower the score of a datapoint, the more likely that datapoint is an outlier. We can safely remove the rows whose score is below 70000." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: integer (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: integer (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: integer (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: integer (nullable = true)\n", " |-- capital-loss: integer (nullable = true)\n", " |-- hours-per-week: integer (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- class: string (nullable = true)\n", " |-- id: long (nullable = false)\n", "\n" ] } ], "source": [ "train = avfScore.filter(col(avfScoreCol) > 70000).drop(avfScoreCol)\n", "train.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chi Sqaure based categorical feature selection" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pValues: [1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]\n" ] } ], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler\n", "from pyspark.ml.stat import ChiSquareTest\n", "from pyspark.ml import Pipeline\n", "\n", "indexed = train.select('education-num', 'education')\n", "\n", "indexer = StringIndexer(inputCol='education', outputCol='educationIndexed')\n", "indexed = indexer.fit(indexed).transform(indexed)\n", "ohe = OneHotEncoderEstimator(inputCols=['education-num',], outputCols=['education-numOHE',])\n", "indexed = ohe.fit(indexed).transform(indexed)\n", "\n", "# The null hypothesis is that the occurrence of the outcomes is statistically independent.\n", "# In general, small p-values (1% to 5%) would cause you to reject the null hypothesis. \n", "# This very large p-value (92.65%) means that the null hypothesis should not be rejected.\n", "testResult = ChiSquareTest.test(indexed, 'education-numOHE', 'educationIndexed')\n", "r = testResult.head()\n", "print(\"pValues: \" + str(r.pValues))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can accept the hypothesis that features are dependent. We will drop the 'education' feature since the info. is covered" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "train = train.drop('education')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<pyspark.ml.clustering.KMeansSummary at 0x7ff049711390>" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.clustering import KMeans\n", "\n", "_, _, indexedDf = autoIndexer(train, labelCol)\n", "\n", "kmeans = KMeans(k=2, featuresCol='assembled')\n", "model = kmeans.fit(indexedDf)\n", "indexedDf = model.transform(indexedDf)\n", "model.summary" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15349812641524918" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "from sklearn.metrics import adjusted_mutual_info_score\n", "indexer = StringIndexer(inputCol=labelCol, outputCol=labelCol+'Indexed')\n", "indexedDf = indexer.fit(indexedDf).transform(indexedDf)\n", "classIndexed = [row[0] for row in indexedDf.select('classIndexed').collect()]\n", "prediction = [row[0] for row in indexedDf.select('prediction').collect()]\n", "\n", "adjusted_mutual_info_score(classIndexed, prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer\n", "from pyspark.ml import Pipeline\n", "\n", "stringTypes = [dtype[0] for dtype in train.dtypes if dtype[1] == 'string']\n", "indexedTypes = [stringType+'Indexed' for stringType in stringTypes]\n", "\n", "indexers = [StringIndexer(inputCol=stringType, outputCol=stringType+'Indexed', handleInvalid='skip') \\\n", " for stringType in stringTypes]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataFrame[age: int, workclass: string, fnlwgt: int, education-num: int, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: int, capital-loss: int, hours-per-week: int, native-country: string, class: string, id: bigint, workclassIndexed: double, marital-statusIndexed: double, occupationIndexed: double, relationshipIndexed: double, raceIndexed: double, sexIndexed: double, native-countryIndexed: double, classIndexed: double, workclassIndexedOneHotEncoded: vector, raceIndexedOneHotEncoded: vector, occupationIndexedOneHotEncoded: vector, relationshipIndexedOneHotEncoded: vector, native-countryIndexedOneHotEncoded: vector, marital-statusIndexedOneHotEncoded: vector, sexIndexedOneHotEncoded: vector, classIndexedOneHotEncoded: vector, assembled: vector, rawPrediction: vector, probability: vector, prediction: double]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, VectorAssembler\n", "from pyspark.ml.classification import GBTClassifier\n", "\n", "oheTypes = [indexedType+'OneHotEncoded' for indexedType in indexedTypes]\n", "ohe = OneHotEncoderEstimator(inputCols=indexedTypes, outputCols=oheTypes)\n", "\n", "# Fix columns\n", "oheTypes.remove('classIndexedOneHotEncoded')\n", "cols = train.columns[:]\n", "for oheType in oheTypes:\n", " cols.append(oheType)\n", "for stringType in stringTypes:\n", " cols.remove(stringType)\n", "\n", "cols.remove('id')\n", "\n", "assembler = VectorAssembler(inputCols=cols, outputCol='assembled')\n", "classifier = GBTClassifier(featuresCol='assembled', labelCol='classIndexed')\n", "pipeline = Pipeline(stages=[*indexers, ohe, assembler, classifier])\n", "model = pipeline.fit(train)\n", "train = model.transform(train)\n", "train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we have class imbalance problem, that's why we will use area under ROC curve as metric." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9179985970287455" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "evaluator = BinaryClassificationEvaluator(labelCol='classIndexed', metricName='areaUnderROC')\n", "metric = evaluator.evaluate(train)\n", "metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><=50K.</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class\n", "0 <=50K.\n", "1 <=50K.\n", "2 >50K.\n", "3 >50K.\n", "4 <=50K.\n", "5 <=50K.\n", "6 <=50K.\n", "7 >50K.\n", "8 <=50K.\n", "9 <=50K." ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n", " \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n", " \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n", " \"class\"]\n", "\n", "test = spark.read.csv('./adult.test.txt',\n", " inferSchema='true', \n", " ignoreLeadingWhiteSpace='true',\n", " ignoreTrailingWhiteSpace='true').toDF(*headers)\n", "test.select('class').limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the class labels in the test dataset are different than in train - '>50K' and '>50K.'. So we have to remove the extrac dot from the class label, before evaluating." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><=50K</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class\n", "0 <=50K\n", "1 <=50K\n", "2 >50K\n", "3 >50K\n", "4 <=50K\n", "5 <=50K\n", "6 <=50K\n", "7 >50K\n", "8 <=50K\n", "9 <=50K" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.sql.types import StringType\n", "stripDot = udf(lambda s: s[:-1], StringType())\n", "\n", "test = test.withColumn('classTrailed', stripDot(test['class'])).drop('class').withColumnRenamed('classTrailed', 'class')\n", "test.select('class').limit(10).toPandas()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9096987015789428" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test = model.transform(test)\n", "metric = evaluator.evaluate(test)\n", "metric\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }