{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "\n", "import sys\n", "sys.path.append('..')\n", "from utils.pysparkutils import *\n", "\n", "spark = SparkSession.builder.appName('income').getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: integer (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: integer (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: integer (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: integer (nullable = true)\n", " |-- capital-loss: integer (nullable = true)\n", " |-- hours-per-week: integer (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- class: string (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.sql.types import *\n", "\n", "# schema = StructType([\n", "# StructField(\"age\", IntegerType(), True), \n", "# StructField(\"workclass\", StringType(), True),\n", "# StructField(\"fnlwgt\", FloatType(), True),\n", "# StructField(\"education\", StringType(), True),\n", "# StructField(\"education-num\", FloatType(), True),\n", "# StructField(\"marital-status\", StringType(), True),\n", "# StructField(\"occupation\", StringType(), True),\n", "# StructField(\"relationship\", StringType(), True),\n", "# StructField(\"race\", StringType(), True),\n", "# StructField(\"sex\", StringType(), True),\n", "# StructField(\"capital-gain\", FloatType(), True),\n", "# StructField(\"capital-loss\", FloatType(), True),\n", "# StructField(\"hours-per-week\", FloatType(), True),\n", "# StructField(\"native-country\", StringType(), True),\n", "# StructField(\"class\", StringType(), True)]\n", "# )\n", "\n", "# train = spark.read.csv('./adult.data.txt', schema=schema, inferSchema='true')\n", "\n", "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n", " \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n", " \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n", " \"class\"]\n", "\n", "train = spark.read.csv('./adult.data.txt',\n", " inferSchema='true', \n", " ignoreLeadingWhiteSpace='true',\n", " ignoreTrailingWhiteSpace='true').toDF(*headers)\n", "train.printSchema()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.functions import udf, monotonically_increasing_id\n", "\n", "train = train.withColumn('id', monotonically_increasing_id())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "32561" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labelCol = 'class'\n", "train.count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exploratory Data Analysis" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K</td>\n", " <td>24720</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>>50K</td>\n", " <td>7841</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class count\n", "0 <=50K 24720\n", "1 >50K 7841" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.groupby(labelCol).count().toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see there is a class imbalance problem in our training set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "findMissingValuesCols(train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no missing values in our training data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---+\n", "|age|\n", "+---+\n", "| 31|\n", "| 85|\n", "| 65|\n", "| 53|\n", "| 78|\n", "| 34|\n", "| 81|\n", "| 28|\n", "| 76|\n", "| 27|\n", "| 26|\n", "| 44|\n", "| 22|\n", "| 47|\n", "| 52|\n", "| 86|\n", "| 40|\n", "| 20|\n", "| 57|\n", "| 54|\n", "+---+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "train.select('age').distinct().show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>race</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Amer-Indian-Eskimo</td>\n", " <td><=50K</td>\n", " <td>275</td>\n", " <td>88.42</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Amer-Indian-Eskimo</td>\n", " <td>>50K</td>\n", " <td>36</td>\n", " <td>11.58</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Asian-Pac-Islander</td>\n", " <td><=50K</td>\n", " <td>763</td>\n", " <td>73.44</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Asian-Pac-Islander</td>\n", " <td>>50K</td>\n", " <td>276</td>\n", " <td>26.56</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Black</td>\n", " <td><=50K</td>\n", " <td>2737</td>\n", " <td>87.61</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>Black</td>\n", " <td>>50K</td>\n", " <td>387</td>\n", " <td>12.39</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>Other</td>\n", " <td><=50K</td>\n", " <td>246</td>\n", " <td>90.77</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>Other</td>\n", " <td>>50K</td>\n", " <td>25</td>\n", " <td>9.23</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>White</td>\n", " <td><=50K</td>\n", " <td>20699</td>\n", " <td>74.41</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>White</td>\n", " <td>>50K</td>\n", " <td>7117</td>\n", " <td>25.59</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " race class count percentage\n", "0 Amer-Indian-Eskimo <=50K 275 88.42\n", "1 Amer-Indian-Eskimo >50K 36 11.58\n", "2 Asian-Pac-Islander <=50K 763 73.44\n", "3 Asian-Pac-Islander >50K 276 26.56\n", "4 Black <=50K 2737 87.61\n", "5 Black >50K 387 12.39\n", "6 Other <=50K 246 90.77\n", "7 Other >50K 25 9.23\n", "8 White <=50K 20699 74.41\n", "9 White >50K 7117 25.59" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "percentageCol = 'percentage'\n", "df = crosstabPercentage(train, 'race', labelCol)\n", "df = df.withColumn(percentageCol, format_number(df[percentageCol], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>17</td>\n", " <td><=50K</td>\n", " <td>395</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>18</td>\n", " <td><=50K</td>\n", " <td>550</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>19</td>\n", " <td><=50K</td>\n", " <td>710</td>\n", " <td>99.72</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>19</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>0.28</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>20</td>\n", " <td><=50K</td>\n", " <td>753</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>21</td>\n", " <td><=50K</td>\n", " <td>717</td>\n", " <td>99.58</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>21</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>0.42</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>22</td>\n", " <td><=50K</td>\n", " <td>752</td>\n", " <td>98.30</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>22</td>\n", " <td>>50K</td>\n", " <td>13</td>\n", " <td>1.70</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>23</td>\n", " <td><=50K</td>\n", " <td>865</td>\n", " <td>98.63</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>23</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>1.37</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>24</td>\n", " <td><=50K</td>\n", " <td>767</td>\n", " <td>96.12</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>24</td>\n", " <td>>50K</td>\n", " <td>31</td>\n", " <td>3.88</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>25</td>\n", " <td><=50K</td>\n", " <td>788</td>\n", " <td>93.70</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>25</td>\n", " <td>>50K</td>\n", " <td>53</td>\n", " <td>6.30</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>26</td>\n", " <td><=50K</td>\n", " <td>722</td>\n", " <td>91.97</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>26</td>\n", " <td>>50K</td>\n", " <td>63</td>\n", " <td>8.03</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>27</td>\n", " <td><=50K</td>\n", " <td>754</td>\n", " <td>90.30</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>27</td>\n", " <td>>50K</td>\n", " <td>81</td>\n", " <td>9.70</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>28</td>\n", " <td><=50K</td>\n", " <td>748</td>\n", " <td>86.27</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>28</td>\n", " <td>>50K</td>\n", " <td>119</td>\n", " <td>13.73</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>29</td>\n", " <td><=50K</td>\n", " <td>679</td>\n", " <td>83.52</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>29</td>\n", " <td>>50K</td>\n", " <td>134</td>\n", " <td>16.48</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>30</td>\n", " <td><=50K</td>\n", " <td>690</td>\n", " <td>80.14</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>30</td>\n", " <td>>50K</td>\n", " <td>171</td>\n", " <td>19.86</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>31</td>\n", " <td><=50K</td>\n", " <td>705</td>\n", " <td>79.39</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>31</td>\n", " <td>>50K</td>\n", " <td>183</td>\n", " <td>20.61</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>32</td>\n", " <td><=50K</td>\n", " <td>639</td>\n", " <td>77.17</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>32</td>\n", " <td>>50K</td>\n", " <td>189</td>\n", " <td>22.83</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>33</td>\n", " <td><=50K</td>\n", " <td>684</td>\n", " <td>78.17</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>108</th>\n", " <td>72</td>\n", " <td>>50K</td>\n", " <td>9</td>\n", " <td>13.43</td>\n", " </tr>\n", " <tr>\n", " <th>109</th>\n", " <td>73</td>\n", " <td><=50K</td>\n", " <td>54</td>\n", " <td>84.38</td>\n", " </tr>\n", " <tr>\n", " <th>110</th>\n", " <td>73</td>\n", " <td>>50K</td>\n", " <td>10</td>\n", " <td>15.62</td>\n", " </tr>\n", " <tr>\n", " <th>111</th>\n", " <td>74</td>\n", " <td><=50K</td>\n", " <td>39</td>\n", " <td>76.47</td>\n", " </tr>\n", " <tr>\n", " <th>112</th>\n", " <td>74</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>23.53</td>\n", " </tr>\n", " <tr>\n", " <th>113</th>\n", " <td>75</td>\n", " <td><=50K</td>\n", " <td>38</td>\n", " <td>84.44</td>\n", " </tr>\n", " <tr>\n", " <th>114</th>\n", " <td>75</td>\n", " <td>>50K</td>\n", " <td>7</td>\n", " <td>15.56</td>\n", " </tr>\n", " <tr>\n", " <th>115</th>\n", " <td>76</td>\n", " <td><=50K</td>\n", " <td>41</td>\n", " <td>89.13</td>\n", " </tr>\n", " <tr>\n", " <th>116</th>\n", " <td>76</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>10.87</td>\n", " </tr>\n", " <tr>\n", " <th>117</th>\n", " <td>77</td>\n", " <td><=50K</td>\n", " <td>24</td>\n", " <td>82.76</td>\n", " </tr>\n", " <tr>\n", " <th>118</th>\n", " <td>77</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>17.24</td>\n", " </tr>\n", " <tr>\n", " <th>119</th>\n", " <td>78</td>\n", " <td><=50K</td>\n", " <td>18</td>\n", " <td>78.26</td>\n", " </tr>\n", " <tr>\n", " <th>120</th>\n", " <td>78</td>\n", " <td>>50K</td>\n", " <td>5</td>\n", " <td>21.74</td>\n", " </tr>\n", " <tr>\n", " <th>121</th>\n", " <td>79</td>\n", " <td><=50K</td>\n", " <td>13</td>\n", " <td>59.09</td>\n", " </tr>\n", " <tr>\n", " <th>122</th>\n", " <td>79</td>\n", " <td>>50K</td>\n", " <td>9</td>\n", " <td>40.91</td>\n", " </tr>\n", " <tr>\n", " <th>123</th>\n", " <td>80</td>\n", " <td><=50K</td>\n", " <td>20</td>\n", " <td>90.91</td>\n", " </tr>\n", " <tr>\n", " <th>124</th>\n", " <td>80</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>9.09</td>\n", " </tr>\n", " <tr>\n", " <th>125</th>\n", " <td>81</td>\n", " <td><=50K</td>\n", " <td>17</td>\n", " <td>85.00</td>\n", " </tr>\n", " <tr>\n", " <th>126</th>\n", " <td>81</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>15.00</td>\n", " </tr>\n", " <tr>\n", " <th>127</th>\n", " <td>82</td>\n", " <td><=50K</td>\n", " <td>12</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>128</th>\n", " <td>83</td>\n", " <td><=50K</td>\n", " <td>4</td>\n", " <td>66.67</td>\n", " </tr>\n", " <tr>\n", " <th>129</th>\n", " <td>83</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>33.33</td>\n", " </tr>\n", " <tr>\n", " <th>130</th>\n", " <td>84</td>\n", " <td><=50K</td>\n", " <td>9</td>\n", " <td>90.00</td>\n", " </tr>\n", " <tr>\n", " <th>131</th>\n", " <td>84</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>10.00</td>\n", " </tr>\n", " <tr>\n", " <th>132</th>\n", " <td>85</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>133</th>\n", " <td>86</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>134</th>\n", " <td>87</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>135</th>\n", " <td>88</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>136</th>\n", " <td>90</td>\n", " <td><=50K</td>\n", " <td>35</td>\n", " <td>81.40</td>\n", " </tr>\n", " <tr>\n", " <th>137</th>\n", " <td>90</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>18.60</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>138 rows × 4 columns</p>\n", "</div>" ], "text/plain": [ " age class count percentage\n", "0 17 <=50K 395 100.00\n", "1 18 <=50K 550 100.00\n", "2 19 <=50K 710 99.72\n", "3 19 >50K 2 0.28\n", "4 20 <=50K 753 100.00\n", "5 21 <=50K 717 99.58\n", "6 21 >50K 3 0.42\n", "7 22 <=50K 752 98.30\n", "8 22 >50K 13 1.70\n", "9 23 <=50K 865 98.63\n", "10 23 >50K 12 1.37\n", "11 24 <=50K 767 96.12\n", "12 24 >50K 31 3.88\n", "13 25 <=50K 788 93.70\n", "14 25 >50K 53 6.30\n", "15 26 <=50K 722 91.97\n", "16 26 >50K 63 8.03\n", "17 27 <=50K 754 90.30\n", "18 27 >50K 81 9.70\n", "19 28 <=50K 748 86.27\n", "20 28 >50K 119 13.73\n", "21 29 <=50K 679 83.52\n", "22 29 >50K 134 16.48\n", "23 30 <=50K 690 80.14\n", "24 30 >50K 171 19.86\n", "25 31 <=50K 705 79.39\n", "26 31 >50K 183 20.61\n", "27 32 <=50K 639 77.17\n", "28 32 >50K 189 22.83\n", "29 33 <=50K 684 78.17\n", ".. ... ... ... ...\n", "108 72 >50K 9 13.43\n", "109 73 <=50K 54 84.38\n", "110 73 >50K 10 15.62\n", "111 74 <=50K 39 76.47\n", "112 74 >50K 12 23.53\n", "113 75 <=50K 38 84.44\n", "114 75 >50K 7 15.56\n", "115 76 <=50K 41 89.13\n", "116 76 >50K 5 10.87\n", "117 77 <=50K 24 82.76\n", "118 77 >50K 5 17.24\n", "119 78 <=50K 18 78.26\n", "120 78 >50K 5 21.74\n", "121 79 <=50K 13 59.09\n", "122 79 >50K 9 40.91\n", "123 80 <=50K 20 90.91\n", "124 80 >50K 2 9.09\n", "125 81 <=50K 17 85.00\n", "126 81 >50K 3 15.00\n", "127 82 <=50K 12 100.00\n", "128 83 <=50K 4 66.67\n", "129 83 >50K 2 33.33\n", "130 84 <=50K 9 90.00\n", "131 84 >50K 1 10.00\n", "132 85 <=50K 3 100.00\n", "133 86 <=50K 1 100.00\n", "134 87 <=50K 1 100.00\n", "135 88 <=50K 3 100.00\n", "136 90 <=50K 35 81.40\n", "137 90 >50K 8 18.60\n", "\n", "[138 rows x 4 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'age', labelCol)\n", "df = df.orderBy('age').withColumn('percentage', \n", " format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sex</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Female</td>\n", " <td><=50K</td>\n", " <td>9592</td>\n", " <td>89.05</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Female</td>\n", " <td>>50K</td>\n", " <td>1179</td>\n", " <td>10.95</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Male</td>\n", " <td><=50K</td>\n", " <td>15128</td>\n", " <td>69.43</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Male</td>\n", " <td>>50K</td>\n", " <td>6662</td>\n", " <td>30.57</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " sex class count percentage\n", "0 Female <=50K 9592 89.05\n", "1 Female >50K 1179 10.95\n", "2 Male <=50K 15128 69.43\n", "3 Male >50K 6662 30.57" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'sex', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`crosstabPercentage` is a simple way to explore the relation between a particular categorical feature and the label. For instance, the above shows the usefulness of `sex` feature in predicting the salary. It is obvious that more men earn >50K salary than women. So if a person is male, then he is more likely to earn >50K salary." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>education</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>10th</td>\n", " <td><=50K</td>\n", " <td>871</td>\n", " <td>93.35</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>10th</td>\n", " <td>>50K</td>\n", " <td>62</td>\n", " <td>6.65</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>11th</td>\n", " <td><=50K</td>\n", " <td>1115</td>\n", " <td>94.89</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>11th</td>\n", " <td>>50K</td>\n", " <td>60</td>\n", " <td>5.11</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>12th</td>\n", " <td><=50K</td>\n", " <td>400</td>\n", " <td>92.38</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>12th</td>\n", " <td>>50K</td>\n", " <td>33</td>\n", " <td>7.62</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>1st-4th</td>\n", " <td><=50K</td>\n", " <td>162</td>\n", " <td>96.43</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>1st-4th</td>\n", " <td>>50K</td>\n", " <td>6</td>\n", " <td>3.57</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>5th-6th</td>\n", " <td><=50K</td>\n", " <td>317</td>\n", " <td>95.20</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>5th-6th</td>\n", " <td>>50K</td>\n", " <td>16</td>\n", " <td>4.80</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>7th-8th</td>\n", " <td><=50K</td>\n", " <td>606</td>\n", " <td>93.81</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>7th-8th</td>\n", " <td>>50K</td>\n", " <td>40</td>\n", " <td>6.19</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>9th</td>\n", " <td><=50K</td>\n", " <td>487</td>\n", " <td>94.75</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>9th</td>\n", " <td>>50K</td>\n", " <td>27</td>\n", " <td>5.25</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>Assoc-acdm</td>\n", " <td><=50K</td>\n", " <td>802</td>\n", " <td>75.16</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>Assoc-acdm</td>\n", " <td>>50K</td>\n", " <td>265</td>\n", " <td>24.84</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>Assoc-voc</td>\n", " <td><=50K</td>\n", " <td>1021</td>\n", " <td>73.88</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>Assoc-voc</td>\n", " <td>>50K</td>\n", " <td>361</td>\n", " <td>26.12</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>Bachelors</td>\n", " <td><=50K</td>\n", " <td>3134</td>\n", " <td>58.52</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>Bachelors</td>\n", " <td>>50K</td>\n", " <td>2221</td>\n", " <td>41.48</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>Doctorate</td>\n", " <td><=50K</td>\n", " <td>107</td>\n", " <td>25.91</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>Doctorate</td>\n", " <td>>50K</td>\n", " <td>306</td>\n", " <td>74.09</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>HS-grad</td>\n", " <td><=50K</td>\n", " <td>8826</td>\n", " <td>84.05</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>HS-grad</td>\n", " <td>>50K</td>\n", " <td>1675</td>\n", " <td>15.95</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>Masters</td>\n", " <td><=50K</td>\n", " <td>764</td>\n", " <td>44.34</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>Masters</td>\n", " <td>>50K</td>\n", " <td>959</td>\n", " <td>55.66</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>Preschool</td>\n", " <td><=50K</td>\n", " <td>51</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>Prof-school</td>\n", " <td><=50K</td>\n", " <td>153</td>\n", " <td>26.56</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>Prof-school</td>\n", " <td>>50K</td>\n", " <td>423</td>\n", " <td>73.44</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>Some-college</td>\n", " <td><=50K</td>\n", " <td>5904</td>\n", " <td>80.98</td>\n", " </tr>\n", " <tr>\n", " <th>30</th>\n", " <td>Some-college</td>\n", " <td>>50K</td>\n", " <td>1387</td>\n", " <td>19.02</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " education class count percentage\n", "0 10th <=50K 871 93.35\n", "1 10th >50K 62 6.65\n", "2 11th <=50K 1115 94.89\n", "3 11th >50K 60 5.11\n", "4 12th <=50K 400 92.38\n", "5 12th >50K 33 7.62\n", "6 1st-4th <=50K 162 96.43\n", "7 1st-4th >50K 6 3.57\n", "8 5th-6th <=50K 317 95.20\n", "9 5th-6th >50K 16 4.80\n", "10 7th-8th <=50K 606 93.81\n", "11 7th-8th >50K 40 6.19\n", "12 9th <=50K 487 94.75\n", "13 9th >50K 27 5.25\n", "14 Assoc-acdm <=50K 802 75.16\n", "15 Assoc-acdm >50K 265 24.84\n", "16 Assoc-voc <=50K 1021 73.88\n", "17 Assoc-voc >50K 361 26.12\n", "18 Bachelors <=50K 3134 58.52\n", "19 Bachelors >50K 2221 41.48\n", "20 Doctorate <=50K 107 25.91\n", "21 Doctorate >50K 306 74.09\n", "22 HS-grad <=50K 8826 84.05\n", "23 HS-grad >50K 1675 15.95\n", "24 Masters <=50K 764 44.34\n", "25 Masters >50K 959 55.66\n", "26 Preschool <=50K 51 100.00\n", "27 Prof-school <=50K 153 26.56\n", "28 Prof-school >50K 423 73.44\n", "29 Some-college <=50K 5904 80.98\n", "30 Some-college >50K 1387 19.02" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'education', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "%%script false\n", "educationNumClass = crosstabPercentage(train, 'education-num', labelCol)\n", "educationNumClass = educationNumClass.withColumn('percentage', \n", " format_number(educationNumClass['percentage-'], 2))\n", "educationNumClass = educationNumClass.withColumn('education-numClassF', educationNumClass['education-numClass'].cast(DoubleType()))\\\n", " .orderBy('education-numClassF').drop('education-numClass')\n", "cols = educationNumClass.columns\n", "cols.remove('education-numClassF')\n", "cols.insert(0, 'education-numClassF')\n", "educationNumClass = educationNumClass.select(cols)\n", "educationNumClass.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see above that this is a sparse matrix, it's hard to find the non-zero values. So we will only focus on non-zero values to find out whether there is any relationship between these features and one of them is redundant." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "%%script false\n", "\n", "from pyspark.sql.functions import coalesce, lit, when\n", "\n", "iterator = df.toLocalIterator()\n", "d = {}\n", "for row in iterator:\n", " rowDict = row.asDict()\n", " educationNum = rowDict['education-num_education']\n", " for k, v in rowDict.items():\n", " if k != 'education-num_education' and v != 0:\n", " d[educationNum+'_'+k] = v\n", "\n", "import json\n", "s = json.dumps(d, indent=4)\n", "print(s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see it's obvious that these features are redundant. Only one of them should suffice for our classification task.\n", "\n", "Let's try more rigorous chi square test instead of something hand-wavy.\n", "\n", "First we will define an utility method that'll index the catgorical string columns, encodes them into one-hot-encoded vectors, and finally assemble all the feature vectos into once vector for later downstream analysis." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>workclass</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>?</td>\n", " <td><=50K</td>\n", " <td>1645</td>\n", " <td>89.60</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>?</td>\n", " <td>>50K</td>\n", " <td>191</td>\n", " <td>10.40</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Federal-gov</td>\n", " <td><=50K</td>\n", " <td>589</td>\n", " <td>61.35</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Federal-gov</td>\n", " <td>>50K</td>\n", " <td>371</td>\n", " <td>38.65</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Local-gov</td>\n", " <td><=50K</td>\n", " <td>1476</td>\n", " <td>70.52</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>Local-gov</td>\n", " <td>>50K</td>\n", " <td>617</td>\n", " <td>29.48</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>Never-worked</td>\n", " <td><=50K</td>\n", " <td>7</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>Private</td>\n", " <td><=50K</td>\n", " <td>17733</td>\n", " <td>78.13</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>Private</td>\n", " <td>>50K</td>\n", " <td>4963</td>\n", " <td>21.87</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>Self-emp-inc</td>\n", " <td><=50K</td>\n", " <td>494</td>\n", " <td>44.27</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>Self-emp-inc</td>\n", " <td>>50K</td>\n", " <td>622</td>\n", " <td>55.73</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>Self-emp-not-inc</td>\n", " <td><=50K</td>\n", " <td>1817</td>\n", " <td>71.51</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>Self-emp-not-inc</td>\n", " <td>>50K</td>\n", " <td>724</td>\n", " <td>28.49</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>State-gov</td>\n", " <td><=50K</td>\n", " <td>945</td>\n", " <td>72.80</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>State-gov</td>\n", " <td>>50K</td>\n", " <td>353</td>\n", " <td>27.20</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>Without-pay</td>\n", " <td><=50K</td>\n", " <td>14</td>\n", " <td>100.00</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " workclass class count percentage\n", "0 ? <=50K 1645 89.60\n", "1 ? >50K 191 10.40\n", "2 Federal-gov <=50K 589 61.35\n", "3 Federal-gov >50K 371 38.65\n", "4 Local-gov <=50K 1476 70.52\n", "5 Local-gov >50K 617 29.48\n", "6 Never-worked <=50K 7 100.00\n", "7 Private <=50K 17733 78.13\n", "8 Private >50K 4963 21.87\n", "9 Self-emp-inc <=50K 494 44.27\n", "10 Self-emp-inc >50K 622 55.73\n", "11 Self-emp-not-inc <=50K 1817 71.51\n", "12 Self-emp-not-inc >50K 724 28.49\n", "13 State-gov <=50K 945 72.80\n", "14 State-gov >50K 353 27.20\n", "15 Without-pay <=50K 14 100.00" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'workclass', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>hours-per-week</th>\n", " <th>class</th>\n", " <th>count</th>\n", " <th>percentage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td><=50K</td>\n", " <td>18</td>\n", " <td>90.00</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>10.00</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td><=50K</td>\n", " <td>24</td>\n", " <td>75.00</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>2</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>25.00</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>3</td>\n", " <td><=50K</td>\n", " <td>38</td>\n", " <td>97.44</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>3</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>2.56</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>4</td>\n", " <td><=50K</td>\n", " <td>51</td>\n", " <td>94.44</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>4</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>5.56</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>5</td>\n", " <td><=50K</td>\n", " <td>53</td>\n", " <td>88.33</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>5</td>\n", " <td>>50K</td>\n", " <td>7</td>\n", " <td>11.67</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>6</td>\n", " <td><=50K</td>\n", " <td>56</td>\n", " <td>87.50</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>6</td>\n", " <td>>50K</td>\n", " <td>8</td>\n", " <td>12.50</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>7</td>\n", " <td><=50K</td>\n", " <td>22</td>\n", " <td>84.62</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>7</td>\n", " <td>>50K</td>\n", " <td>4</td>\n", " <td>15.38</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>8</td>\n", " <td><=50K</td>\n", " <td>134</td>\n", " <td>92.41</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>8</td>\n", " <td>>50K</td>\n", " <td>11</td>\n", " <td>7.59</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>9</td>\n", " <td><=50K</td>\n", " <td>17</td>\n", " <td>94.44</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>9</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>5.56</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>10</td>\n", " <td><=50K</td>\n", " <td>258</td>\n", " <td>92.81</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>10</td>\n", " <td>>50K</td>\n", " <td>20</td>\n", " <td>7.19</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>11</td>\n", " <td><=50K</td>\n", " <td>11</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>12</td>\n", " <td><=50K</td>\n", " <td>161</td>\n", " <td>93.06</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>12</td>\n", " <td>>50K</td>\n", " <td>12</td>\n", " <td>6.94</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>13</td>\n", " <td><=50K</td>\n", " <td>21</td>\n", " <td>91.30</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>13</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>8.70</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>14</td>\n", " <td><=50K</td>\n", " <td>32</td>\n", " <td>94.12</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>14</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>5.88</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>15</td>\n", " <td><=50K</td>\n", " <td>389</td>\n", " <td>96.29</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>15</td>\n", " <td>>50K</td>\n", " <td>15</td>\n", " <td>3.71</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>16</td>\n", " <td><=50K</td>\n", " <td>192</td>\n", " <td>93.66</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>143</th>\n", " <td>78</td>\n", " <td><=50K</td>\n", " <td>6</td>\n", " <td>75.00</td>\n", " </tr>\n", " <tr>\n", " <th>144</th>\n", " <td>78</td>\n", " <td>>50K</td>\n", " <td>2</td>\n", " <td>25.00</td>\n", " </tr>\n", " <tr>\n", " <th>145</th>\n", " <td>80</td>\n", " <td><=50K</td>\n", " <td>76</td>\n", " <td>57.14</td>\n", " </tr>\n", " <tr>\n", " <th>146</th>\n", " <td>80</td>\n", " <td>>50K</td>\n", " <td>57</td>\n", " <td>42.86</td>\n", " </tr>\n", " <tr>\n", " <th>147</th>\n", " <td>81</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>148</th>\n", " <td>82</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>149</th>\n", " <td>84</td>\n", " <td><=50K</td>\n", " <td>28</td>\n", " <td>62.22</td>\n", " </tr>\n", " <tr>\n", " <th>150</th>\n", " <td>84</td>\n", " <td>>50K</td>\n", " <td>17</td>\n", " <td>37.78</td>\n", " </tr>\n", " <tr>\n", " <th>151</th>\n", " <td>85</td>\n", " <td><=50K</td>\n", " <td>9</td>\n", " <td>69.23</td>\n", " </tr>\n", " <tr>\n", " <th>152</th>\n", " <td>85</td>\n", " <td>>50K</td>\n", " <td>4</td>\n", " <td>30.77</td>\n", " </tr>\n", " <tr>\n", " <th>153</th>\n", " <td>86</td>\n", " <td><=50K</td>\n", " <td>2</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>154</th>\n", " <td>87</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>155</th>\n", " <td>88</td>\n", " <td><=50K</td>\n", " <td>2</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>156</th>\n", " <td>89</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>157</th>\n", " <td>89</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>158</th>\n", " <td>90</td>\n", " <td><=50K</td>\n", " <td>19</td>\n", " <td>65.52</td>\n", " </tr>\n", " <tr>\n", " <th>159</th>\n", " <td>90</td>\n", " <td>>50K</td>\n", " <td>10</td>\n", " <td>34.48</td>\n", " </tr>\n", " <tr>\n", " <th>160</th>\n", " <td>91</td>\n", " <td><=50K</td>\n", " <td>3</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>161</th>\n", " <td>92</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>162</th>\n", " <td>94</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>100.00</td>\n", " </tr>\n", " <tr>\n", " <th>163</th>\n", " <td>95</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>164</th>\n", " <td>95</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>165</th>\n", " <td>96</td>\n", " <td><=50K</td>\n", " <td>4</td>\n", " <td>80.00</td>\n", " </tr>\n", " <tr>\n", " <th>166</th>\n", " <td>96</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>20.00</td>\n", " </tr>\n", " <tr>\n", " <th>167</th>\n", " <td>97</td>\n", " <td><=50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>168</th>\n", " <td>97</td>\n", " <td>>50K</td>\n", " <td>1</td>\n", " <td>50.00</td>\n", " </tr>\n", " <tr>\n", " <th>169</th>\n", " <td>98</td>\n", " <td><=50K</td>\n", " <td>8</td>\n", " <td>72.73</td>\n", " </tr>\n", " <tr>\n", " <th>170</th>\n", " <td>98</td>\n", " <td>>50K</td>\n", " <td>3</td>\n", " <td>27.27</td>\n", " </tr>\n", " <tr>\n", " <th>171</th>\n", " <td>99</td>\n", " <td><=50K</td>\n", " <td>60</td>\n", " <td>70.59</td>\n", " </tr>\n", " <tr>\n", " <th>172</th>\n", " <td>99</td>\n", " <td>>50K</td>\n", " <td>25</td>\n", " <td>29.41</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>173 rows × 4 columns</p>\n", "</div>" ], "text/plain": [ " hours-per-week class count percentage\n", "0 1 <=50K 18 90.00\n", "1 1 >50K 2 10.00\n", "2 2 <=50K 24 75.00\n", "3 2 >50K 8 25.00\n", "4 3 <=50K 38 97.44\n", "5 3 >50K 1 2.56\n", "6 4 <=50K 51 94.44\n", "7 4 >50K 3 5.56\n", "8 5 <=50K 53 88.33\n", "9 5 >50K 7 11.67\n", "10 6 <=50K 56 87.50\n", "11 6 >50K 8 12.50\n", "12 7 <=50K 22 84.62\n", "13 7 >50K 4 15.38\n", "14 8 <=50K 134 92.41\n", "15 8 >50K 11 7.59\n", "16 9 <=50K 17 94.44\n", "17 9 >50K 1 5.56\n", "18 10 <=50K 258 92.81\n", "19 10 >50K 20 7.19\n", "20 11 <=50K 11 100.00\n", "21 12 <=50K 161 93.06\n", "22 12 >50K 12 6.94\n", "23 13 <=50K 21 91.30\n", "24 13 >50K 2 8.70\n", "25 14 <=50K 32 94.12\n", "26 14 >50K 2 5.88\n", "27 15 <=50K 389 96.29\n", "28 15 >50K 15 3.71\n", "29 16 <=50K 192 93.66\n", ".. ... ... ... ...\n", "143 78 <=50K 6 75.00\n", "144 78 >50K 2 25.00\n", "145 80 <=50K 76 57.14\n", "146 80 >50K 57 42.86\n", "147 81 <=50K 3 100.00\n", "148 82 <=50K 1 100.00\n", "149 84 <=50K 28 62.22\n", "150 84 >50K 17 37.78\n", "151 85 <=50K 9 69.23\n", "152 85 >50K 4 30.77\n", "153 86 <=50K 2 100.00\n", "154 87 <=50K 1 100.00\n", "155 88 <=50K 2 100.00\n", "156 89 <=50K 1 50.00\n", "157 89 >50K 1 50.00\n", "158 90 <=50K 19 65.52\n", "159 90 >50K 10 34.48\n", "160 91 <=50K 3 100.00\n", "161 92 <=50K 1 100.00\n", "162 94 <=50K 1 100.00\n", "163 95 <=50K 1 50.00\n", "164 95 >50K 1 50.00\n", "165 96 <=50K 4 80.00\n", "166 96 >50K 1 20.00\n", "167 97 <=50K 1 50.00\n", "168 97 >50K 1 50.00\n", "169 98 <=50K 8 72.73\n", "170 98 >50K 3 27.27\n", "171 99 <=50K 60 70.59\n", "172 99 >50K 25 29.41\n", "\n", "[173 rows x 4 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = crosstabPercentage(train, 'hours-per-week', labelCol)\n", "df = df.withColumn('percentage', format_number(df['percentage'], 2))\n", "df.toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Outlier Detection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use Attribute-Value Frequency (AVF) outlier detection in categorical features. The beauty of this algorithm is that it's very simple, highly parallelizable, and fit well with distributed programming paradigm. `attributeValueFrequency` function is implemented in `pysparkutils.py` file in `utils` directory." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>avfScore</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>64873</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>84761</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>85760</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>103252</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>52051</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>49136</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>91948</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>92741</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>99489</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>89041</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>95526</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>78598</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>84745</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>76448</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>80545</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>97699</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>86445</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>95149</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>101519</td>\n", " <td>5</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>99688</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>96113</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>86132</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>22</th>\n", " <td>74783</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>23</th>\n", " <td>88291</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>24</th>\n", " <td>75232</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>81293</td>\n", " <td>5</td>\n", " </tr>\n", " <tr>\n", " <th>26</th>\n", " <td>66091</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>27</th>\n", " <td>73190</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>28</th>\n", " <td>107959</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>29</th>\n", " <td>77034</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>8159</th>\n", " <td>121083</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8160</th>\n", " <td>90753</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8161</th>\n", " <td>83486</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8162</th>\n", " <td>88325</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8163</th>\n", " <td>96995</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8164</th>\n", " <td>100990</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8165</th>\n", " <td>90597</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8166</th>\n", " <td>100460</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8167</th>\n", " <td>92453</td>\n", " <td>14</td>\n", " </tr>\n", " <tr>\n", " <th>8168</th>\n", " <td>95690</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8169</th>\n", " <td>104486</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8170</th>\n", " <td>119995</td>\n", " <td>13</td>\n", " </tr>\n", " <tr>\n", " <th>8171</th>\n", " <td>131889</td>\n", " <td>9</td>\n", " </tr>\n", " <tr>\n", " <th>8172</th>\n", " <td>110660</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8173</th>\n", " <td>102790</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8174</th>\n", " <td>122394</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8175</th>\n", " <td>124053</td>\n", " <td>76</td>\n", " </tr>\n", " <tr>\n", " <th>8176</th>\n", " <td>134386</td>\n", " <td>45</td>\n", " </tr>\n", " <tr>\n", " <th>8177</th>\n", " <td>113783</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8178</th>\n", " <td>94066</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8179</th>\n", " <td>73337</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8180</th>\n", " <td>91838</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>8181</th>\n", " <td>90865</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8182</th>\n", " <td>100947</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>8183</th>\n", " <td>71321</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8184</th>\n", " <td>110738</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>8185</th>\n", " <td>99204</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8186</th>\n", " <td>38322</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8187</th>\n", " <td>89328</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8188</th>\n", " <td>65479</td>\n", " <td>1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>8189 rows × 2 columns</p>\n", "</div>" ], "text/plain": [ " avfScore count\n", "0 64873 1\n", "1 84761 1\n", "2 85760 1\n", "3 103252 1\n", "4 52051 1\n", "5 49136 1\n", "6 91948 1\n", "7 92741 2\n", "8 99489 1\n", "9 89041 1\n", "10 95526 1\n", "11 78598 1\n", "12 84745 1\n", "13 76448 1\n", "14 80545 1\n", "15 97699 1\n", "16 86445 1\n", "17 95149 3\n", "18 101519 5\n", "19 99688 4\n", "20 96113 1\n", "21 86132 1\n", "22 74783 1\n", "23 88291 2\n", "24 75232 1\n", "25 81293 5\n", "26 66091 1\n", "27 73190 1\n", "28 107959 1\n", "29 77034 1\n", "... ... ...\n", "8159 121083 1\n", "8160 90753 1\n", "8161 83486 1\n", "8162 88325 3\n", "8163 96995 1\n", "8164 100990 4\n", "8165 90597 1\n", "8166 100460 1\n", "8167 92453 14\n", "8168 95690 4\n", "8169 104486 2\n", "8170 119995 13\n", "8171 131889 9\n", "8172 110660 1\n", "8173 102790 1\n", "8174 122394 2\n", "8175 124053 76\n", "8176 134386 45\n", "8177 113783 3\n", "8178 94066 1\n", "8179 73337 1\n", "8180 91838 2\n", "8181 90865 1\n", "8182 100947 3\n", "8183 71321 1\n", "8184 110738 4\n", "8185 99204 1\n", "8186 38322 1\n", "8187 89328 1\n", "8188 65479 1\n", "\n", "[8189 rows x 2 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import seaborn as sns\n", "\n", "avfScoreCol = 'avfScore'\n", "categoricalCols = ['workclass', 'education', 'marital-status',\n", " 'occupation', 'relationship', 'race', 'sex',\n", " 'native-country']\n", "avfScore = attributeValueFrequency(train, categoricalCols)\n", "pdf = avfScore.groupby(avfScoreCol).count().toPandas()\n", "pdf" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7ff049874978>]], dtype=object)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# sns.barplot(x=\"avfScore\", y=\"count\", data=pdf)\n", "# pdf['count'].hist(by=pdf['avfScore'])\n", "# pdf.plot(x='avfScore', y='count')\n", "# pdf.plot(x='avfScore', y='count', kind='bar')\n", "\n", "# Ideally, we want to use Spark for aggregation, and just plot the data by converting to \n", "# Pandas. Unfortuantely I couldn't figure out a way yet, visualization is not my strongest skill.\n", "# This approach is NOT recommended for large datasets which are residing over multiple machines\n", "# Since this will bring whole dataframe to the driver node and the driver node might run out of\n", "# memory.\n", "avfScore.select(avfScoreCol).toPandas().hist(avfScoreCol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In AVF, the lower the score of a datapoint, the more likely that datapoint is an outlier. We can safely remove the rows whose score is below 70000." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: integer (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: integer (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: integer (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: integer (nullable = true)\n", " |-- capital-loss: integer (nullable = true)\n", " |-- hours-per-week: integer (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- class: string (nullable = true)\n", " |-- id: long (nullable = false)\n", "\n" ] } ], "source": [ "train = avfScore.filter(col(avfScoreCol) > 70000).drop(avfScoreCol)\n", "train.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Feature Selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chi Sqaure based categorical feature selection" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pValues: [1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]\n" ] } ], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler\n", "from pyspark.ml.stat import ChiSquareTest\n", "from pyspark.ml import Pipeline\n", "\n", "indexed = train.select('education-num', 'education')\n", "\n", "indexer = StringIndexer(inputCol='education', outputCol='educationIndexed')\n", "indexed = indexer.fit(indexed).transform(indexed)\n", "ohe = OneHotEncoderEstimator(inputCols=['education-num',], outputCols=['education-numOHE',])\n", "indexed = ohe.fit(indexed).transform(indexed)\n", "\n", "# The null hypothesis is that the occurrence of the outcomes is statistically independent.\n", "# In general, small p-values (1% to 5%) would cause you to reject the null hypothesis. \n", "# This very large p-value (92.65%) means that the null hypothesis should not be rejected.\n", "testResult = ChiSquareTest.test(indexed, 'education-numOHE', 'educationIndexed')\n", "r = testResult.head()\n", "print(\"pValues: \" + str(r.pValues))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can accept the hypothesis that features are dependent. We will drop the 'education' feature since the info. is covered" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "train = train.drop('education')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<pyspark.ml.clustering.KMeansSummary at 0x7ff049711390>" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.clustering import KMeans\n", "\n", "_, _, indexedDf = autoIndexer(train, labelCol)\n", "\n", "kmeans = KMeans(k=2, featuresCol='assembled')\n", "model = kmeans.fit(indexedDf)\n", "indexedDf = model.transform(indexedDf)\n", "model.summary" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15349812641524918" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "from sklearn.metrics import adjusted_mutual_info_score\n", "indexer = StringIndexer(inputCol=labelCol, outputCol=labelCol+'Indexed')\n", "indexedDf = indexer.fit(indexedDf).transform(indexedDf)\n", "classIndexed = [row[0] for row in indexedDf.select('classIndexed').collect()]\n", "prediction = [row[0] for row in indexedDf.select('prediction').collect()]\n", "\n", "adjusted_mutual_info_score(classIndexed, prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer\n", "from pyspark.ml import Pipeline\n", "\n", "stringTypes = [dtype[0] for dtype in train.dtypes if dtype[1] == 'string']\n", "indexedTypes = [stringType+'Indexed' for stringType in stringTypes]\n", "\n", "indexers = [StringIndexer(inputCol=stringType, outputCol=stringType+'Indexed', handleInvalid='skip') \\\n", " for stringType in stringTypes]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataFrame[age: int, workclass: string, fnlwgt: int, education-num: int, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: int, capital-loss: int, hours-per-week: int, native-country: string, class: string, id: bigint, workclassIndexed: double, marital-statusIndexed: double, occupationIndexed: double, relationshipIndexed: double, raceIndexed: double, sexIndexed: double, native-countryIndexed: double, classIndexed: double, workclassIndexedOneHotEncoded: vector, raceIndexedOneHotEncoded: vector, occupationIndexedOneHotEncoded: vector, relationshipIndexedOneHotEncoded: vector, native-countryIndexedOneHotEncoded: vector, marital-statusIndexedOneHotEncoded: vector, sexIndexedOneHotEncoded: vector, classIndexedOneHotEncoded: vector, assembled: vector, rawPrediction: vector, probability: vector, prediction: double]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, VectorAssembler\n", "from pyspark.ml.classification import GBTClassifier\n", "\n", "oheTypes = [indexedType+'OneHotEncoded' for indexedType in indexedTypes]\n", "ohe = OneHotEncoderEstimator(inputCols=indexedTypes, outputCols=oheTypes)\n", "\n", "# Fix columns\n", "oheTypes.remove('classIndexedOneHotEncoded')\n", "cols = train.columns[:]\n", "for oheType in oheTypes:\n", " cols.append(oheType)\n", "for stringType in stringTypes:\n", " cols.remove(stringType)\n", "\n", "cols.remove('id')\n", "\n", "assembler = VectorAssembler(inputCols=cols, outputCol='assembled')\n", "classifier = GBTClassifier(featuresCol='assembled', labelCol='classIndexed')\n", "pipeline = Pipeline(stages=[*indexers, ohe, assembler, classifier])\n", "model = pipeline.fit(train)\n", "train = model.transform(train)\n", "train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we have class imbalance problem, that's why we will use area under ROC curve as metric." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9179985970287455" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "evaluator = BinaryClassificationEvaluator(labelCol='classIndexed', metricName='areaUnderROC')\n", "metric = evaluator.evaluate(train)\n", "metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>>50K.</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><=50K.</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><=50K.</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class\n", "0 <=50K.\n", "1 <=50K.\n", "2 >50K.\n", "3 >50K.\n", "4 <=50K.\n", "5 <=50K.\n", "6 <=50K.\n", "7 >50K.\n", "8 <=50K.\n", "9 <=50K." ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "headers = [\"age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\",\n", " \"marital-status\", \"occupation\", \"relationship\", \"race\", \"sex\",\n", " \"capital-gain\", \"capital-loss\", \"hours-per-week\", \"native-country\",\n", " \"class\"]\n", "\n", "test = spark.read.csv('./adult.test.txt',\n", " inferSchema='true', \n", " ignoreLeadingWhiteSpace='true',\n", " ignoreTrailingWhiteSpace='true').toDF(*headers)\n", "test.select('class').limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the class labels in the test dataset are different than in train - '>50K' and '>50K.'. So we have to remove the extrac dot from the class label, before evaluating." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>class</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>>50K</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td><=50K</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " class\n", "0 <=50K\n", "1 <=50K\n", "2 >50K\n", "3 >50K\n", "4 <=50K\n", "5 <=50K\n", "6 <=50K\n", "7 >50K\n", "8 <=50K\n", "9 <=50K" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.sql.types import StringType\n", "stripDot = udf(lambda s: s[:-1], StringType())\n", "\n", "test = test.withColumn('classTrailed', stripDot(test['class'])).drop('class').withColumnRenamed('classTrailed', 'class')\n", "test.select('class').limit(10).toPandas()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9096987015789428" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test = model.transform(test)\n", "metric = evaluator.evaluate(test)\n", "metric\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }