{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3.3 Machine Learning - Complex Features\n",
    "\n",
    "This notebook demonstates how to work with complex data sets that include both numerical and categorical features.\n",
    "\n",
    "We will use the `Adults` dataset which includes numerical and catgorical features of adults and our task is to predict if their income is above $50K. (`data/adult.data.bz2`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n"
     ]
    }
   ],
   "source": [
    "# lets analyze the data set description to obtain the names of the features\n",
    "\n",
    "import re\n",
    "with open('data/adult.names.txt', 'r') as f:\n",
    "    featureNames = [ line.split(':')[0] for line in f.readlines() if re.match(r'^[a-z\\-]+:.*', line)]\n",
    "\n",
    "print(featureNames)\n",
    "columnNames = featureNames + ['income']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: long (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: double (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: double (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: double (nullable = true)\n",
      " |-- capital-loss: double (nullable = true)\n",
      " |-- hours-per-week: double (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# load the date to a DataFrame\n",
    "inputDF = spark.createDataFrame(spark.read.csv('data/adult.data.bz2', inferSchema=True).rdd, columnNames)\n",
    "inputDF.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646.0</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "DataFrame[age: bigint, workclass: string, fnlwgt: double, education: string, education-num: double, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: double, capital-loss: double, hours-per-week: double, native-country: string, income: string]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(inputDF.limit(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646.0</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721.0</td>\n",
       "      <td>11th</td>\n",
       "      <td>7.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>Cuba</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>37</td>\n",
       "      <td>Private</td>\n",
       "      <td>284582.0</td>\n",
       "      <td>Masters</td>\n",
       "      <td>14.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>49</td>\n",
       "      <td>Private</td>\n",
       "      <td>160187.0</td>\n",
       "      <td>9th</td>\n",
       "      <td>5.0</td>\n",
       "      <td>Married-spouse-absent</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>Jamaica</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>52</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>209642.0</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "DataFrame[age: bigint, workclass: string, fnlwgt: double, education: string, education-num: double, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: double, capital-loss: double, hours-per-week: double, native-country: string, income: string, label: double]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pyspark.sql.functions import trim\n",
    "\n",
    "# create the numerical label column ( 1.0 if income > 50K)\n",
    "dataDF = inputDF.withColumn('label',trim(inputDF.income).startswith('>').cast('double'))\n",
    "display(dataDF.limit(8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']\n",
      "root\n",
      " |-- age: long (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: double (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: double (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: double (nullable = true)\n",
      " |-- capital-loss: double (nullable = true)\n",
      " |-- hours-per-week: double (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- label: double (nullable = true)\n",
      " |-- cat_workclass: double (nullable = true)\n",
      " |-- cat_education: double (nullable = true)\n",
      " |-- cat_marital-status: double (nullable = true)\n",
      " |-- cat_occupation: double (nullable = true)\n",
      " |-- cat_relationship: double (nullable = true)\n",
      " |-- cat_race: double (nullable = true)\n",
      " |-- cat_sex: double (nullable = true)\n",
      " |-- cat_native-country: double (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.sql.types import StringType\n",
    "\n",
    "# detect the categorial features from the schema\n",
    "catrgoricalFeatures = [ f.name for f in dataDF.schema.fields if f.dataType == StringType() and f.name!='income']\n",
    "print(catrgoricalFeatures)\n",
    "\n",
    "# create the categorical values indexer\n",
    "indexerPipeline = Pipeline(stages = [ StringIndexer(inputCol=feature, \n",
    "                                            outputCol= \"cat_%s\"% feature, handleInvalid='skip') for feature in catrgoricalFeatures])\n",
    "\n",
    "pipelineModel = indexerPipeline.fit(dataDF)\n",
    "indexed_df = pipelineModel.transform(dataDF)\n",
    "indexed_df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: long (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: double (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: double (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: double (nullable = true)\n",
      " |-- capital-loss: double (nullable = true)\n",
      " |-- hours-per-week: double (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- label: double (nullable = true)\n",
      " |-- cat_workclass: double (nullable = true)\n",
      " |-- cat_education: double (nullable = true)\n",
      " |-- cat_marital-status: double (nullable = true)\n",
      " |-- cat_occupation: double (nullable = true)\n",
      " |-- cat_relationship: double (nullable = true)\n",
      " |-- cat_race: double (nullable = true)\n",
      " |-- cat_sex: double (nullable = true)\n",
      " |-- cat_native-country: double (nullable = true)\n",
      " |-- cat_vector: vector (nullable = true)\n",
      " |-- cat_features: vector (nullable = true)\n",
      "\n",
      "+--------------------+\n",
      "|        cat_features|\n",
      "+--------------------+\n",
      "|[4.0,2.0,1.0,3.0,...|\n",
      "|(8,[0,1,3],[1.0,2...|\n",
      "|(8,[2,3,4],[2.0,9...|\n",
      "|(8,[1,3,5],[5.0,9...|\n",
      "|[0.0,2.0,0.0,0.0,...|\n",
      "+--------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import VectorAssembler, VectorIndexer\n",
    "\n",
    "# assemble and mark the categorical features\n",
    "\n",
    "categorical_assembler = Pipeline(stages = [\n",
    "    VectorAssembler(inputCols = [\"cat_%s\"%c for c in catrgoricalFeatures ], outputCol='cat_vector'),                    \n",
    "    VectorIndexer(inputCol='cat_vector', outputCol='cat_features')\n",
    "])\n",
    "\n",
    "categorical_assembler_model = categorical_assembler.fit(indexed_df)\n",
    "cat_df = categorical_assembler_model.transform(indexed_df)\n",
    "cat_df.printSchema()\n",
    "cat_df.select('cat_features').show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: long (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: double (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: double (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: double (nullable = true)\n",
      " |-- capital-loss: double (nullable = true)\n",
      " |-- hours-per-week: double (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- label: double (nullable = true)\n",
      " |-- cat_workclass: double (nullable = true)\n",
      " |-- cat_education: double (nullable = true)\n",
      " |-- cat_marital-status: double (nullable = true)\n",
      " |-- cat_occupation: double (nullable = true)\n",
      " |-- cat_relationship: double (nullable = true)\n",
      " |-- cat_race: double (nullable = true)\n",
      " |-- cat_sex: double (nullable = true)\n",
      " |-- cat_native-country: double (nullable = true)\n",
      " |-- cat_vector: vector (nullable = true)\n",
      " |-- cat_features: vector (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      "\n",
      "+--------------------+\n",
      "|            features|\n",
      "+--------------------+\n",
      "|[39.0,77516.0,40....|\n",
      "|(11,[0,1,2,3,4,6]...|\n",
      "|(11,[0,1,2,5,6,7]...|\n",
      "|(11,[0,1,2,4,6,8]...|\n",
      "|[28.0,338409.0,40...|\n",
      "+--------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# assemble both numerical and categorical features\n",
    "feature_assembler = VectorAssembler(inputCols = ['age', 'fnlwgt', 'hours-per-week', 'cat_features'] , outputCol='features')\n",
    "feature_df = feature_assembler.transform(cat_df)\n",
    "\n",
    "feature_df.printSchema()\n",
    "feature_df.select('features').show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split data into traning and testing sets with stratified sampline by label\n",
    "trainingDF = dataDF.sampleBy('label', fractions = {0.0: 0.7, 1.0:0.7}).cache()\n",
    "testDF = dataDF.subtract(trainingDF).cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22658\n",
      "5414\n",
      "9890\n",
      "2425\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.functions import col\n",
    "\n",
    "# just check that it worked\n",
    "print(trainingDF.count())\n",
    "print(trainingDF.where(col('label') == 1.0).count())\n",
    "print(testDF.count())\n",
    "print(testDF.where(col('label') == 1.0).count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "\n",
    "# build and train random forest classifier\n",
    "rfClassifier = RandomForestClassifier(featuresCol='features', labelCol='label', maxBins=50)\n",
    "rfPipeline = Pipeline(stages = [indexerPipeline, categorical_assembler, feature_assembler, rfClassifier])\n",
    "rfPipelineModel = rfPipeline.fit(trainingDF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random forest AUC: 0.886826830643\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "# evaluate random forest AUC\n",
    "evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\")\n",
    "print(\"Random forest AUC: %s\" % evaluator.evaluate(rfPipelineModel.transform(testDF)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random forest accuracy: 0.824350288199\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "# evaluate random forest accuracy\n",
    "evaluator = MulticlassClassificationEvaluator(predictionCol=\"prediction\", metricName='accuracy')\n",
    "print(\"Random forest accuracy: %s\" % evaluator.evaluate(rfPipelineModel.transform(testDF)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictionDF = rfPipelineModel.transform(testDF).select('label','prediction').cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7104, 360)\n",
      "(1377, 1048)\n"
     ]
    }
   ],
   "source": [
    "# compute the confusion matrix\n",
    "tpr = predictionDF.where((col('label') == 1.0) & (col('prediction') ==1.0)).count()\n",
    "tnr = predictionDF.where((col('label') == 0.0) & (col('prediction') ==0.0)).count()\n",
    "fpr = predictionDF.where((col('label') == 0.0) & (col('prediction') ==1.0)).count()\n",
    "fnr = predictionDF.where((col('label') == 1.0) & (col('prediction') ==0.0)).count()\n",
    "\n",
    "print(tnr, fpr)\n",
    "print(fnr, tpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use logistic regression we need to encode categorical features differently - using one hot encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: long (nullable = true)\n",
      " |-- workclass: string (nullable = true)\n",
      " |-- fnlwgt: double (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- education-num: double (nullable = true)\n",
      " |-- marital-status: string (nullable = true)\n",
      " |-- occupation: string (nullable = true)\n",
      " |-- relationship: string (nullable = true)\n",
      " |-- race: string (nullable = true)\n",
      " |-- sex: string (nullable = true)\n",
      " |-- capital-gain: double (nullable = true)\n",
      " |-- capital-loss: double (nullable = true)\n",
      " |-- hours-per-week: double (nullable = true)\n",
      " |-- native-country: string (nullable = true)\n",
      " |-- income: string (nullable = true)\n",
      " |-- label: double (nullable = true)\n",
      " |-- cat_workclass: double (nullable = true)\n",
      " |-- cat_education: double (nullable = true)\n",
      " |-- cat_marital-status: double (nullable = true)\n",
      " |-- cat_occupation: double (nullable = true)\n",
      " |-- cat_relationship: double (nullable = true)\n",
      " |-- cat_race: double (nullable = true)\n",
      " |-- cat_sex: double (nullable = true)\n",
      " |-- cat_native-country: double (nullable = true)\n",
      " |-- enc_workclass: vector (nullable = true)\n",
      " |-- enc_education: vector (nullable = true)\n",
      " |-- enc_marital-status: vector (nullable = true)\n",
      " |-- enc_occupation: vector (nullable = true)\n",
      " |-- enc_relationship: vector (nullable = true)\n",
      " |-- enc_race: vector (nullable = true)\n",
      " |-- enc_sex: vector (nullable = true)\n",
      " |-- enc_native-country: vector (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import OneHotEncoder\n",
    "\n",
    "# create and test the OneHotEncoder for categorial features\n",
    "categorical_assembler = Pipeline(stages = [ OneHotEncoder(inputCol=\"cat_%s\" % feature, \n",
    "                                            outputCol= \"enc_%s\"% feature) for feature in catrgoricalFeatures])\n",
    "\n",
    "categorical_assembler_model = categorical_assembler.fit(indexed_df)\n",
    "cat_df = categorical_assembler_model.transform(indexed_df)\n",
    "cat_df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "\n",
    "# train the logistic regression classifier\n",
    "lrClassifier = LogisticRegression(featuresCol='features', labelCol='label')\n",
    "encFetureAssembler = VectorAssembler(inputCols = ['age', 'fnlwgt', 'hours-per-week'] +\n",
    "                               [\"enc_%s\"% feature for feature in catrgoricalFeatures ]      \n",
    "                                     , outputCol='features')\n",
    "lrPipeline = Pipeline(stages = [indexerPipeline, categorical_assembler, encFetureAssembler, lrClassifier])\n",
    "lrPipeline = lrPipeline.fit(trainingDF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logisting regression AUC: 0.891875117402\n"
     ]
    }
   ],
   "source": [
    "evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\")\n",
    "print(\"Logisting regression AUC: %s\" % evaluator.evaluate(lrPipeline.transform(testDF)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PySpark",
   "language": "python",
   "name": "pyspark"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}