{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from pyspark.sql import SparkSession\n",
    "import pandas as pd\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.pysparkutils import *\n",
    "\n",
    "spark = SparkSession.builder.appName(\"titanic\").getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- PassengerId: integer (nullable = true)\n",
      " |-- Survived: integer (nullable = true)\n",
      " |-- Pclass: integer (nullable = true)\n",
      " |-- Name: string (nullable = true)\n",
      " |-- Sex: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- SibSp: integer (nullable = true)\n",
      " |-- Parch: integer (nullable = true)\n",
      " |-- Ticket: string (nullable = true)\n",
      " |-- Fare: double (nullable = true)\n",
      " |-- Cabin: string (nullable = true)\n",
      " |-- Embarked: string (nullable = true)\n",
      "\n",
      "root\n",
      " |-- PassengerId: integer (nullable = true)\n",
      " |-- Pclass: integer (nullable = true)\n",
      " |-- Name: string (nullable = true)\n",
      " |-- Sex: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- SibSp: integer (nullable = true)\n",
      " |-- Parch: integer (nullable = true)\n",
      " |-- Ticket: string (nullable = true)\n",
      " |-- Fare: double (nullable = true)\n",
      " |-- Cabin: string (nullable = true)\n",
      " |-- Embarked: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train = spark.read.csv('./train.csv', header=\"true\", inferSchema=\"true\")\n",
    "test = spark.read.csv('./test.csv', header=\"true\", inferSchema=\"true\")\n",
    "\n",
    "train.printSchema()\n",
    "test.printSchema()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5 21171</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>PC 17599</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101282</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113803</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>373450</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Moran, Mr. James</td>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>330877</td>\n",
       "      <td>8.4583</td>\n",
       "      <td>None</td>\n",
       "      <td>Q</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>McCarthy, Mr. Timothy J</td>\n",
       "      <td>male</td>\n",
       "      <td>54.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>17463</td>\n",
       "      <td>51.8625</td>\n",
       "      <td>E46</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Palsson, Master. Gosta Leonard</td>\n",
       "      <td>male</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>349909</td>\n",
       "      <td>21.0750</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n",
       "      <td>female</td>\n",
       "      <td>27.0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>347742</td>\n",
       "      <td>11.1333</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n",
       "      <td>female</td>\n",
       "      <td>14.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>237736</td>\n",
       "      <td>30.0708</td>\n",
       "      <td>None</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Sandstrom, Miss. Marguerite Rut</td>\n",
       "      <td>female</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>PP 9549</td>\n",
       "      <td>16.7000</td>\n",
       "      <td>G6</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>12</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Bonnell, Miss. Elizabeth</td>\n",
       "      <td>female</td>\n",
       "      <td>58.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>113783</td>\n",
       "      <td>26.5500</td>\n",
       "      <td>C103</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Saundercock, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5. 2151</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>14</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Andersson, Mr. Anders Johan</td>\n",
       "      <td>male</td>\n",
       "      <td>39.0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>347082</td>\n",
       "      <td>31.2750</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Vestrom, Miss. Hulda Amanda Adolfina</td>\n",
       "      <td>female</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>350406</td>\n",
       "      <td>7.8542</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>16</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>Hewlett, Mrs. (Mary D Kingcome)</td>\n",
       "      <td>female</td>\n",
       "      <td>55.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>248706</td>\n",
       "      <td>16.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Rice, Master. Eugene</td>\n",
       "      <td>male</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>382652</td>\n",
       "      <td>29.1250</td>\n",
       "      <td>None</td>\n",
       "      <td>Q</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>18</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>Williams, Mr. Charles Eugene</td>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>244373</td>\n",
       "      <td>13.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>19</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Vander Planke, Mrs. Julius (Emelia Maria Vande...</td>\n",
       "      <td>female</td>\n",
       "      <td>31.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>345763</td>\n",
       "      <td>18.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Masselmani, Mrs. Fatima</td>\n",
       "      <td>female</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2649</td>\n",
       "      <td>7.2250</td>\n",
       "      <td>None</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    PassengerId  Survived  Pclass  \\\n",
       "0             1         0       3   \n",
       "1             2         1       1   \n",
       "2             3         1       3   \n",
       "3             4         1       1   \n",
       "4             5         0       3   \n",
       "5             6         0       3   \n",
       "6             7         0       1   \n",
       "7             8         0       3   \n",
       "8             9         1       3   \n",
       "9            10         1       2   \n",
       "10           11         1       3   \n",
       "11           12         1       1   \n",
       "12           13         0       3   \n",
       "13           14         0       3   \n",
       "14           15         0       3   \n",
       "15           16         1       2   \n",
       "16           17         0       3   \n",
       "17           18         1       2   \n",
       "18           19         0       3   \n",
       "19           20         1       3   \n",
       "\n",
       "                                                 Name     Sex   Age  SibSp  \\\n",
       "0                             Braund, Mr. Owen Harris    male  22.0      1   \n",
       "1   Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   \n",
       "2                              Heikkinen, Miss. Laina  female  26.0      0   \n",
       "3        Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   \n",
       "4                            Allen, Mr. William Henry    male  35.0      0   \n",
       "5                                    Moran, Mr. James    male   NaN      0   \n",
       "6                             McCarthy, Mr. Timothy J    male  54.0      0   \n",
       "7                      Palsson, Master. Gosta Leonard    male   2.0      3   \n",
       "8   Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)  female  27.0      0   \n",
       "9                 Nasser, Mrs. Nicholas (Adele Achem)  female  14.0      1   \n",
       "10                    Sandstrom, Miss. Marguerite Rut  female   4.0      1   \n",
       "11                           Bonnell, Miss. Elizabeth  female  58.0      0   \n",
       "12                     Saundercock, Mr. William Henry    male  20.0      0   \n",
       "13                        Andersson, Mr. Anders Johan    male  39.0      1   \n",
       "14               Vestrom, Miss. Hulda Amanda Adolfina  female  14.0      0   \n",
       "15                   Hewlett, Mrs. (Mary D Kingcome)   female  55.0      0   \n",
       "16                               Rice, Master. Eugene    male   2.0      4   \n",
       "17                       Williams, Mr. Charles Eugene    male   NaN      0   \n",
       "18  Vander Planke, Mrs. Julius (Emelia Maria Vande...  female  31.0      1   \n",
       "19                            Masselmani, Mrs. Fatima  female   NaN      0   \n",
       "\n",
       "    Parch            Ticket     Fare Cabin Embarked  \n",
       "0       0         A/5 21171   7.2500  None        S  \n",
       "1       0          PC 17599  71.2833   C85        C  \n",
       "2       0  STON/O2. 3101282   7.9250  None        S  \n",
       "3       0            113803  53.1000  C123        S  \n",
       "4       0            373450   8.0500  None        S  \n",
       "5       0            330877   8.4583  None        Q  \n",
       "6       0             17463  51.8625   E46        S  \n",
       "7       1            349909  21.0750  None        S  \n",
       "8       2            347742  11.1333  None        S  \n",
       "9       0            237736  30.0708  None        C  \n",
       "10      1           PP 9549  16.7000    G6        S  \n",
       "11      0            113783  26.5500  C103        S  \n",
       "12      0         A/5. 2151   8.0500  None        S  \n",
       "13      5            347082  31.2750  None        S  \n",
       "14      0            350406   7.8542  None        S  \n",
       "15      0            248706  16.0000  None        S  \n",
       "16      1            382652  29.1250  None        Q  \n",
       "17      0            244373  13.0000  None        S  \n",
       "18      0            345763  18.0000  None        S  \n",
       "19      0              2649   7.2250  None        C  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.limit(20).toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section we will explore missing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Age', 0.19865319865319866),\n",
       " ('Cabin', 0.7710437710437711),\n",
       " ('Embarked', 0.002244668911335578)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "findMissingValuesCols(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see almost 80% of Cabin column is missing data. So we will drop the Cabin column.\n",
    "Very few data is missing in Embarked column. We will just drop those rows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import Imputer\n",
    "ageImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- PassengerId: integer (nullable = true)\n",
      " |-- Survived: integer (nullable = true)\n",
      " |-- Pclass: integer (nullable = true)\n",
      " |-- Name: string (nullable = true)\n",
      " |-- Sex: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- SibSp: integer (nullable = true)\n",
      " |-- Parch: integer (nullable = true)\n",
      " |-- Ticket: string (nullable = true)\n",
      " |-- Fare: double (nullable = true)\n",
      " |-- Embarked: string (nullable = true)\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "889"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = train.filter(train.Embarked.isNotNull())\n",
    "train = train.drop('Cabin')\n",
    "train.printSchema()\n",
    "train.count()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exploratory Data Analysis\n",
    "In next few sections, we will explore training data and the relationship between different features and labels.\n",
    "As we already know, most of passengers in Titanic didn't survive. Our training data suggests the same, around one-third of the passengers survived. Same goes for passenger class and sex."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>549</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  count\n",
       "0         1    340\n",
       "1         0    549"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labelCol = 'Survived'\n",
    "train.groupby(labelCol).count().toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived_Sex</th>\n",
       "      <th>female</th>\n",
       "      <th>male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>231</td>\n",
       "      <td>109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>81</td>\n",
       "      <td>468</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Survived_Sex  female  male\n",
       "0            1     231   109\n",
       "1            0      81   468"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.crosstab(labelCol, 'Sex').toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pointwise Mutual Information (PMI) is useful metric for exploring the relationship between two categorical features. PMI gives a scalar value for a pair of values in features. The value denotes the amount of information can be derived about other categori\n",
    "\n",
    "PMI can be normalized between [-1,+1], is called Normalized PMI, resulting in: \n",
    "* -1 (in the limit) for never occurring together,\n",
    "* 0 for independence,\n",
    "* +1 for complete co-occurrence\n",
    "\n",
    "\n",
    "`calcNormalizedPointwiseMutualInformation` function is implemented in `pysparkutils.py` file in `utils` directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sex</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Normalized PMI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>female</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.361721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>female</td>\n",
       "      <td>1</td>\n",
       "      <td>0.490151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>male</td>\n",
       "      <td>0</td>\n",
       "      <td>0.424895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.336078</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      Sex  Survived  Normalized PMI\n",
       "2  female         0       -0.361721\n",
       "1  female         1        0.490151\n",
       "0    male         0        0.424895\n",
       "3    male         1       -0.336078"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pmis = calcNormalizedPointwiseMutualInformation(train, 'Sex', labelCol)\n",
    "toPandasDF(pmis, 'Normalized PMI', 'Sex', labelCol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived_Pclass</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>134</td>\n",
       "      <td>87</td>\n",
       "      <td>119</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>80</td>\n",
       "      <td>97</td>\n",
       "      <td>372</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Survived_Pclass    1   2    3\n",
       "0               1  134  87  119\n",
       "1               0   80  97  372"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.crosstab(labelCol, 'Pclass').toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Normalized PMI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.208445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.260544</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.071421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.091268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.234674</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.226840</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Pclass  Survived  Normalized PMI\n",
       "0       1         0       -0.208445\n",
       "2       1         1        0.260544\n",
       "4       2         0       -0.071421\n",
       "3       2         1        0.091268\n",
       "5       3         0        0.234674\n",
       "1       3         1       -0.226840"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pmis = calcNormalizedPointwiseMutualInformation(train, 'Pclass', labelCol)\n",
    "toPandasDF(pmis, 'Normalized PMI', 'Pclass', labelCol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived_Embarked</th>\n",
       "      <th>C</th>\n",
       "      <th>Q</th>\n",
       "      <th>S</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>93</td>\n",
       "      <td>30</td>\n",
       "      <td>217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>75</td>\n",
       "      <td>47</td>\n",
       "      <td>427</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Survived_Embarked   C   Q    S\n",
       "0                 1  93  30  217\n",
       "1                 0  75  47  427"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.crosstab(labelCol, 'Embarked').toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Embarked</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Normalized PMI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>C</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.131229</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C</td>\n",
       "      <td>1</td>\n",
       "      <td>0.163804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Q</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.003966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Q</td>\n",
       "      <td>1</td>\n",
       "      <td>0.005472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>S</td>\n",
       "      <td>0</td>\n",
       "      <td>0.096935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>S</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.089810</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Embarked  Survived  Normalized PMI\n",
       "5        C         0       -0.131229\n",
       "3        C         1        0.163804\n",
       "4        Q         0       -0.003966\n",
       "0        Q         1        0.005472\n",
       "1        S         0        0.096935\n",
       "2        S         1       -0.089810"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pmis = calcNormalizedPointwiseMutualInformation(train, 'Embarked', labelCol)\n",
    "toPandasDF(pmis, 'Normalized PMI', 'Embarked', labelCol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived_SibSp</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>8</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>208</td>\n",
       "      <td>112</td>\n",
       "      <td>13</td>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>398</td>\n",
       "      <td>97</td>\n",
       "      <td>15</td>\n",
       "      <td>12</td>\n",
       "      <td>15</td>\n",
       "      <td>5</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Survived_SibSp    0    1   2   3   4  5  8\n",
       "0              1  208  112  13   4   3  0  0\n",
       "1              0  398   97  15  12  15  5  7"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.crosstab(labelCol, 'SibSp').toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Normalized PMI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.076614</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.074483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.128928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.162829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.034825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.045891</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.045135</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.078675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0.073413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.145939</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>0.093038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>0.099500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    SibSp  Survived  Normalized PMI\n",
       "4       0         0        0.076614\n",
       "6       0         1       -0.074483\n",
       "0       1         0       -0.128928\n",
       "3       1         1        0.162829\n",
       "7       2         0       -0.034825\n",
       "5       2         1        0.045891\n",
       "10      3         0        0.045135\n",
       "1       3         1       -0.078675\n",
       "2       4         0        0.073413\n",
       "11      4         1       -0.145939\n",
       "8       5         0        0.093038\n",
       "9       8         0        0.099500"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pmis = calcNormalizedPointwiseMutualInformation(train, 'SibSp', labelCol)\n",
    "toPandasDF(pmis, 'Normalized PMI', 'SibSp', labelCol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived_Parch</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>231</td>\n",
       "      <td>65</td>\n",
       "      <td>40</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>445</td>\n",
       "      <td>53</td>\n",
       "      <td>40</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Survived_Parch    0   1   2  3  4  5  6\n",
       "0              1  231  65  40  3  0  1  0\n",
       "1              0  445  53  40  2  4  4  1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.crosstab(labelCol, 'Parch').toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Parch</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Normalized PMI</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.092309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.083569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.112913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.139486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.068086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.086419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>-0.071231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0.079123</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0.089196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>0.047902</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>5</td>\n",
       "      <td>1</td>\n",
       "      <td>-0.095475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>0.070986</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Parch  Survived  Normalized PMI\n",
       "5       0         0        0.092309\n",
       "7       0         1       -0.083569\n",
       "0       1         0       -0.112913\n",
       "4       1         1        0.139486\n",
       "8       2         0       -0.068086\n",
       "6       2         1        0.086419\n",
       "11      3         0       -0.071231\n",
       "1       3         1        0.079123\n",
       "3       4         0        0.089196\n",
       "9       5         0        0.047902\n",
       "10      5         1       -0.095475\n",
       "2       6         0        0.070986"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pmis = calcNormalizedPointwiseMutualInformation(train, 'Parch', labelCol)\n",
    "toPandasDF(pmis, 'Normalized PMI', 'Parch', labelCol)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will calculate the entropy of categorical features, which will give us the variance for categorical features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Feature</th>\n",
       "      <th>Entropy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Sex</td>\n",
       "      <td>0.934919</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Pclass</td>\n",
       "      <td>0.907245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Embarked</td>\n",
       "      <td>0.692048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SibSp</td>\n",
       "      <td>0.477435</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Parch</td>\n",
       "      <td>0.402510</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Feature   Entropy\n",
       "0       Sex  0.934919\n",
       "1    Pclass  0.907245\n",
       "2  Embarked  0.692048\n",
       "3     SibSp  0.477435\n",
       "4     Parch  0.402510"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "columns = ['Sex', 'Pclass', 'Embarked', 'SibSp', 'Parch']\n",
    "entropies = calcNormalizedEntropy(train, *columns)\n",
    "dictToPandasDF(entropies, 'Feature', 'Entropy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Categorical feature independence test via chi square test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pValues: [0.0,0.0,6.02813466444e-06,4.02603175464e-07,4.93843854699e-11,0.0101897422598,0.0315461645121,4.25298058386e-05,0.00150622342036,0.612884928604,0.116808580457]\n",
      "degreesOfFreedom: [2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "statistics: [100.980407261,260.756342249,20.4792462347,25.6818141585,43.2014376768,6.60142083307,4.62299205997,16.755010532,10.0709867693,0.255995235761,2.45959925254]\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.stat import ChiSquareTest\n",
    "from pyspark.ml.feature import Bucketizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler, VectorIndexer\n",
    "\n",
    "edaEmbarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked')\n",
    "edaSexIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex')\n",
    "\n",
    "edaAgeImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')\n",
    "\n",
    "ageSplits = [0, 16, 32, 48, 64, 200]\n",
    "edaAgeBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge')\n",
    "\n",
    "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n",
    "edaFareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare')\n",
    "\n",
    "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n",
    "                                                outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n",
    "inputCols=['Pclass', 'oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge']\n",
    "edaAssembler = VectorAssembler(inputCols=inputCols, outputCol='features')\n",
    "\n",
    "pipeline = Pipeline(stages=[edaEmbarkedIndexer, edaSexIndexer, edaAgeImputer, edaAgeBucketizer, \n",
    "                            edaFareBucketizer, oneHotEncoderEstimator, edaAssembler])\n",
    "chiSqTrain = pipeline.fit(train).transform(train)\n",
    "\n",
    "r = ChiSquareTest.test(chiSqTrain, 'features', 'Survived').head()\n",
    "print(\"pValues: \" + str(r.pValues))\n",
    "print(\"degreesOfFreedom: \" + str(r.degreesOfFreedom))\n",
    "print(\"statistics: \" + str(r.statistics))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "\n",
    "embarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked', handleInvalid='skip')\n",
    "sexFeatureIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex', handleInvalid='skip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import Bucketizer\n",
    "\n",
    "ageSplits = [0, 16, 32, 48, 64, 200]\n",
    "ageBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge', handleInvalid='skip')\n",
    "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n",
    "fareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare', handleInvalid='skip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import OneHotEncoderEstimator, VectorIndexer\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "\n",
    "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n",
    "                                                outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n",
    "assembler = VectorAssembler(inputCols=['Pclass', 'SibSp', 'Parch', 'bucketedAge', \n",
    "                                       'bucketedFare', 'indexedEmbarked', 'indexedSex'], outputCol='features')\n",
    "rf = RandomForestClassifier(labelCol=labelCol, featuresCol='features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n",
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "from pyspark.ml import Pipeline\n",
    "\n",
    "pipeline = Pipeline(stages=[ageImputer, embarkedIndexer, sexFeatureIndexer, ageBucketizer, \n",
    "                            fareBucketizer, oneHotEncoderEstimator, assembler, rf])\n",
    "\n",
    "grid = ParamGridBuilder().addGrid(rf.numTrees, [15, 20, 25, 30])\\\n",
    "                         .addGrid(rf.maxDepth, [5, 8])\\\n",
    "                         .build()\n",
    "\n",
    "cv = CrossValidator(estimator=pipeline, \n",
    "                    estimatorParamMaps=grid, \n",
    "                    evaluator=BinaryClassificationEvaluator(labelCol=labelCol, metricName='areaUnderROC'), \n",
    "                    numFolds=10)\n",
    "\n",
    "model = cv.fit(train)\n",
    "train = model.transform(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9265509482481523"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator = model.getEvaluator()\n",
    "evaluator.evaluate(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>...</th>\n",
       "      <th>bucketedAge</th>\n",
       "      <th>bucketedFare</th>\n",
       "      <th>oneHotSex</th>\n",
       "      <th>oneHotEmbarked</th>\n",
       "      <th>oneHotFare</th>\n",
       "      <th>oneHotAge</th>\n",
       "      <th>features</th>\n",
       "      <th>rawPrediction</th>\n",
       "      <th>probability</th>\n",
       "      <th>prediction</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>892</td>\n",
       "      <td>3</td>\n",
       "      <td>Kelly, Mr. James</td>\n",
       "      <td>male</td>\n",
       "      <td>34.5</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>330911</td>\n",
       "      <td>7.8292</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(0.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>(3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0)</td>\n",
       "      <td>[29.2960574888, 0.703942511236]</td>\n",
       "      <td>[0.976535249625, 0.0234647503745]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>893</td>\n",
       "      <td>3</td>\n",
       "      <td>Wilkes, Mrs. James (Ellen Needs)</td>\n",
       "      <td>female</td>\n",
       "      <td>47.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>363272</td>\n",
       "      <td>7.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>[3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0]</td>\n",
       "      <td>[17.9363522365, 12.0636477635]</td>\n",
       "      <td>[0.597878407882, 0.402121592118]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>894</td>\n",
       "      <td>2</td>\n",
       "      <td>Myles, Mr. Thomas Francis</td>\n",
       "      <td>male</td>\n",
       "      <td>62.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>240276</td>\n",
       "      <td>9.6875</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 0.0, 1.0)</td>\n",
       "      <td>[2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0]</td>\n",
       "      <td>[26.2911066277, 3.70889337228]</td>\n",
       "      <td>[0.876370220924, 0.123629779076]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>895</td>\n",
       "      <td>3</td>\n",
       "      <td>Wirz, Mr. Albert</td>\n",
       "      <td>male</td>\n",
       "      <td>27.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>315154</td>\n",
       "      <td>8.6625</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>(3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[25.9086506078, 4.09134939223]</td>\n",
       "      <td>[0.863621686926, 0.136378313074]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>896</td>\n",
       "      <td>3</td>\n",
       "      <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n",
       "      <td>female</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3101298</td>\n",
       "      <td>12.2875</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0]</td>\n",
       "      <td>[21.2731321988, 8.72686780122]</td>\n",
       "      <td>[0.709104406626, 0.290895593374]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>897</td>\n",
       "      <td>3</td>\n",
       "      <td>Svensson, Mr. Johan Cervin</td>\n",
       "      <td>male</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7538</td>\n",
       "      <td>9.2250</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0, 0.0)</td>\n",
       "      <td>(3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[21.79357844, 8.20642155996]</td>\n",
       "      <td>[0.726452614668, 0.273547385332]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>898</td>\n",
       "      <td>3</td>\n",
       "      <td>Connolly, Miss. Kate</td>\n",
       "      <td>female</td>\n",
       "      <td>30.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>330972</td>\n",
       "      <td>7.6292</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(0.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0]</td>\n",
       "      <td>[6.58546403436, 23.4145359656]</td>\n",
       "      <td>[0.219515467812, 0.780484532188]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>899</td>\n",
       "      <td>2</td>\n",
       "      <td>Caldwell, Mr. Albert Francis</td>\n",
       "      <td>male</td>\n",
       "      <td>26.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>248738</td>\n",
       "      <td>29.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0]</td>\n",
       "      <td>[26.9219114219, 3.07808857809]</td>\n",
       "      <td>[0.897397047397, 0.102602952603]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>900</td>\n",
       "      <td>3</td>\n",
       "      <td>Abrahim, Mrs. Joseph (Sophie Halaut Easu)</td>\n",
       "      <td>female</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2657</td>\n",
       "      <td>7.2292</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(0.0, 1.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0]</td>\n",
       "      <td>[5.95376028838, 24.0462397116]</td>\n",
       "      <td>[0.198458676279, 0.801541323721]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>901</td>\n",
       "      <td>3</td>\n",
       "      <td>Davies, Mr. John Samuel</td>\n",
       "      <td>male</td>\n",
       "      <td>21.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>A/4 48871</td>\n",
       "      <td>24.1500</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0]</td>\n",
       "      <td>[27.683257006, 2.31674299405]</td>\n",
       "      <td>[0.922775233532, 0.0772247664683]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>902</td>\n",
       "      <td>3</td>\n",
       "      <td>Ilieff, Mr. Ylio</td>\n",
       "      <td>male</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>349220</td>\n",
       "      <td>7.8958</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>(3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0)</td>\n",
       "      <td>[27.1796680333, 2.8203319667]</td>\n",
       "      <td>[0.905988934443, 0.0940110655565]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>903</td>\n",
       "      <td>1</td>\n",
       "      <td>Jones, Mr. Charles Cresson</td>\n",
       "      <td>male</td>\n",
       "      <td>46.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>694</td>\n",
       "      <td>26.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0)</td>\n",
       "      <td>[14.2653782322, 15.7346217678]</td>\n",
       "      <td>[0.475512607739, 0.524487392261]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>904</td>\n",
       "      <td>1</td>\n",
       "      <td>Snyder, Mrs. John Pillsbury (Nelle Stevenson)</td>\n",
       "      <td>female</td>\n",
       "      <td>23.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>21228</td>\n",
       "      <td>82.2667</td>\n",
       "      <td>B45</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0]</td>\n",
       "      <td>[0.295454545455, 29.7045454545]</td>\n",
       "      <td>[0.00984848484848, 0.990151515152]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>905</td>\n",
       "      <td>2</td>\n",
       "      <td>Howard, Mr. Benjamin</td>\n",
       "      <td>male</td>\n",
       "      <td>63.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>24065</td>\n",
       "      <td>26.0000</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0)</td>\n",
       "      <td>(0.0, 0.0, 0.0, 1.0)</td>\n",
       "      <td>[2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0]</td>\n",
       "      <td>[27.5151896549, 2.4848103451]</td>\n",
       "      <td>[0.917172988497, 0.0828270115032]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>906</td>\n",
       "      <td>1</td>\n",
       "      <td>Chaffee, Mrs. Herbert Fuller (Carrie Constance...</td>\n",
       "      <td>female</td>\n",
       "      <td>47.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>W.E.P. 5734</td>\n",
       "      <td>61.1750</td>\n",
       "      <td>E31</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>[1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0]</td>\n",
       "      <td>[0.0, 30.0]</td>\n",
       "      <td>[0.0, 1.0]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>907</td>\n",
       "      <td>2</td>\n",
       "      <td>del Carlo, Mrs. Sebastiano (Argenia Genovesi)</td>\n",
       "      <td>female</td>\n",
       "      <td>24.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>SC/PARIS 2167</td>\n",
       "      <td>27.7208</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(0.0, 1.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0]</td>\n",
       "      <td>[2.02801517479, 27.9719848252]</td>\n",
       "      <td>[0.0676005058263, 0.932399494174]</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>908</td>\n",
       "      <td>2</td>\n",
       "      <td>Keane, Mr. Daniel</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>233734</td>\n",
       "      <td>12.3500</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>[2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0]</td>\n",
       "      <td>[26.5861882672, 3.41381173276]</td>\n",
       "      <td>[0.886206275575, 0.113793724425]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>909</td>\n",
       "      <td>3</td>\n",
       "      <td>Assaf, Mr. Gerios</td>\n",
       "      <td>male</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2692</td>\n",
       "      <td>7.2250</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(1.0)</td>\n",
       "      <td>(0.0, 1.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>(3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>[25.8019873912, 4.19801260875]</td>\n",
       "      <td>[0.860066246375, 0.139933753625]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>910</td>\n",
       "      <td>3</td>\n",
       "      <td>Ilmakangas, Miss. Ida Livija</td>\n",
       "      <td>female</td>\n",
       "      <td>27.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101270</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0)</td>\n",
       "      <td>(0.0, 1.0, 0.0, 0.0)</td>\n",
       "      <td>[3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0]</td>\n",
       "      <td>[20.8412096143, 9.15879038571]</td>\n",
       "      <td>[0.694706987143, 0.305293012857]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>911</td>\n",
       "      <td>3</td>\n",
       "      <td>\"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\"</td>\n",
       "      <td>female</td>\n",
       "      <td>45.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2696</td>\n",
       "      <td>7.2250</td>\n",
       "      <td>None</td>\n",
       "      <td>...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>(0.0)</td>\n",
       "      <td>(0.0, 1.0)</td>\n",
       "      <td>(1.0, 0.0, 0.0)</td>\n",
       "      <td>(0.0, 0.0, 1.0, 0.0)</td>\n",
       "      <td>[3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0]</td>\n",
       "      <td>[17.3386478658, 12.6613521342]</td>\n",
       "      <td>[0.577954928861, 0.422045071139]</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>20 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    PassengerId  Pclass                                               Name  \\\n",
       "0           892       3                                   Kelly, Mr. James   \n",
       "1           893       3                   Wilkes, Mrs. James (Ellen Needs)   \n",
       "2           894       2                          Myles, Mr. Thomas Francis   \n",
       "3           895       3                                   Wirz, Mr. Albert   \n",
       "4           896       3       Hirvonen, Mrs. Alexander (Helga E Lindqvist)   \n",
       "5           897       3                         Svensson, Mr. Johan Cervin   \n",
       "6           898       3                               Connolly, Miss. Kate   \n",
       "7           899       2                       Caldwell, Mr. Albert Francis   \n",
       "8           900       3          Abrahim, Mrs. Joseph (Sophie Halaut Easu)   \n",
       "9           901       3                            Davies, Mr. John Samuel   \n",
       "10          902       3                                   Ilieff, Mr. Ylio   \n",
       "11          903       1                         Jones, Mr. Charles Cresson   \n",
       "12          904       1      Snyder, Mrs. John Pillsbury (Nelle Stevenson)   \n",
       "13          905       2                               Howard, Mr. Benjamin   \n",
       "14          906       1  Chaffee, Mrs. Herbert Fuller (Carrie Constance...   \n",
       "15          907       2      del Carlo, Mrs. Sebastiano (Argenia Genovesi)   \n",
       "16          908       2                                  Keane, Mr. Daniel   \n",
       "17          909       3                                  Assaf, Mr. Gerios   \n",
       "18          910       3                       Ilmakangas, Miss. Ida Livija   \n",
       "19          911       3          \"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\"   \n",
       "\n",
       "       Sex   Age  SibSp  Parch            Ticket     Fare Cabin    ...      \\\n",
       "0     male  34.5      0      0            330911   7.8292  None    ...       \n",
       "1   female  47.0      1      0            363272   7.0000  None    ...       \n",
       "2     male  62.0      0      0            240276   9.6875  None    ...       \n",
       "3     male  27.0      0      0            315154   8.6625  None    ...       \n",
       "4   female  22.0      1      1           3101298  12.2875  None    ...       \n",
       "5     male  14.0      0      0              7538   9.2250  None    ...       \n",
       "6   female  30.0      0      0            330972   7.6292  None    ...       \n",
       "7     male  26.0      1      1            248738  29.0000  None    ...       \n",
       "8   female  18.0      0      0              2657   7.2292  None    ...       \n",
       "9     male  21.0      2      0         A/4 48871  24.1500  None    ...       \n",
       "10    male   NaN      0      0            349220   7.8958  None    ...       \n",
       "11    male  46.0      0      0               694  26.0000  None    ...       \n",
       "12  female  23.0      1      0             21228  82.2667   B45    ...       \n",
       "13    male  63.0      1      0             24065  26.0000  None    ...       \n",
       "14  female  47.0      1      0       W.E.P. 5734  61.1750   E31    ...       \n",
       "15  female  24.0      1      0     SC/PARIS 2167  27.7208  None    ...       \n",
       "16    male  35.0      0      0            233734  12.3500  None    ...       \n",
       "17    male  21.0      0      0              2692   7.2250  None    ...       \n",
       "18  female  27.0      1      0  STON/O2. 3101270   7.9250  None    ...       \n",
       "19  female  45.0      0      0              2696   7.2250  None    ...       \n",
       "\n",
       "   bucketedAge  bucketedFare  oneHotSex  oneHotEmbarked       oneHotFare  \\\n",
       "0          2.0           0.0      (1.0)      (0.0, 0.0)  (1.0, 0.0, 0.0)   \n",
       "1          2.0           0.0      (0.0)      (1.0, 0.0)  (1.0, 0.0, 0.0)   \n",
       "2          3.0           1.0      (1.0)      (0.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "3          1.0           1.0      (1.0)      (1.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "4          1.0           1.0      (0.0)      (1.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "5          0.0           1.0      (1.0)      (1.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "6          1.0           0.0      (0.0)      (0.0, 0.0)  (1.0, 0.0, 0.0)   \n",
       "7          1.0           2.0      (1.0)      (1.0, 0.0)  (0.0, 0.0, 1.0)   \n",
       "8          1.0           0.0      (0.0)      (0.0, 1.0)  (1.0, 0.0, 0.0)   \n",
       "9          1.0           2.0      (1.0)      (1.0, 0.0)  (0.0, 0.0, 1.0)   \n",
       "10         1.0           0.0      (1.0)      (1.0, 0.0)  (1.0, 0.0, 0.0)   \n",
       "11         2.0           2.0      (1.0)      (1.0, 0.0)  (0.0, 0.0, 1.0)   \n",
       "12         1.0           3.0      (0.0)      (1.0, 0.0)  (0.0, 0.0, 0.0)   \n",
       "13         3.0           2.0      (1.0)      (1.0, 0.0)  (0.0, 0.0, 1.0)   \n",
       "14         2.0           3.0      (0.0)      (1.0, 0.0)  (0.0, 0.0, 0.0)   \n",
       "15         1.0           2.0      (0.0)      (0.0, 1.0)  (0.0, 0.0, 1.0)   \n",
       "16         2.0           1.0      (1.0)      (0.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "17         1.0           0.0      (1.0)      (0.0, 1.0)  (1.0, 0.0, 0.0)   \n",
       "18         1.0           1.0      (0.0)      (1.0, 0.0)  (0.0, 1.0, 0.0)   \n",
       "19         2.0           0.0      (0.0)      (0.0, 1.0)  (1.0, 0.0, 0.0)   \n",
       "\n",
       "               oneHotAge                             features  \\\n",
       "0   (0.0, 0.0, 1.0, 0.0)  (3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0)   \n",
       "1   (0.0, 0.0, 1.0, 0.0)  [3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0]   \n",
       "2   (0.0, 0.0, 0.0, 1.0)  [2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0]   \n",
       "3   (0.0, 1.0, 0.0, 0.0)  (3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0)   \n",
       "4   (0.0, 1.0, 0.0, 0.0)  [3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0]   \n",
       "5   (1.0, 0.0, 0.0, 0.0)  (3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0)   \n",
       "6   (0.0, 1.0, 0.0, 0.0)  [3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0]   \n",
       "7   (0.0, 1.0, 0.0, 0.0)  [2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0]   \n",
       "8   (0.0, 1.0, 0.0, 0.0)  [3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0]   \n",
       "9   (0.0, 1.0, 0.0, 0.0)  [3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0]   \n",
       "10  (0.0, 1.0, 0.0, 0.0)  (3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0)   \n",
       "11  (0.0, 0.0, 1.0, 0.0)  (1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0)   \n",
       "12  (0.0, 1.0, 0.0, 0.0)  [1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0]   \n",
       "13  (0.0, 0.0, 0.0, 1.0)  [2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0]   \n",
       "14  (0.0, 0.0, 1.0, 0.0)  [1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0]   \n",
       "15  (0.0, 1.0, 0.0, 0.0)  [2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0]   \n",
       "16  (0.0, 0.0, 1.0, 0.0)  [2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0]   \n",
       "17  (0.0, 1.0, 0.0, 0.0)  (3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0)   \n",
       "18  (0.0, 1.0, 0.0, 0.0)  [3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0]   \n",
       "19  (0.0, 0.0, 1.0, 0.0)  [3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0]   \n",
       "\n",
       "                      rawPrediction                         probability  \\\n",
       "0   [29.2960574888, 0.703942511236]   [0.976535249625, 0.0234647503745]   \n",
       "1    [17.9363522365, 12.0636477635]    [0.597878407882, 0.402121592118]   \n",
       "2    [26.2911066277, 3.70889337228]    [0.876370220924, 0.123629779076]   \n",
       "3    [25.9086506078, 4.09134939223]    [0.863621686926, 0.136378313074]   \n",
       "4    [21.2731321988, 8.72686780122]    [0.709104406626, 0.290895593374]   \n",
       "5      [21.79357844, 8.20642155996]    [0.726452614668, 0.273547385332]   \n",
       "6    [6.58546403436, 23.4145359656]    [0.219515467812, 0.780484532188]   \n",
       "7    [26.9219114219, 3.07808857809]    [0.897397047397, 0.102602952603]   \n",
       "8    [5.95376028838, 24.0462397116]    [0.198458676279, 0.801541323721]   \n",
       "9     [27.683257006, 2.31674299405]   [0.922775233532, 0.0772247664683]   \n",
       "10    [27.1796680333, 2.8203319667]   [0.905988934443, 0.0940110655565]   \n",
       "11   [14.2653782322, 15.7346217678]    [0.475512607739, 0.524487392261]   \n",
       "12  [0.295454545455, 29.7045454545]  [0.00984848484848, 0.990151515152]   \n",
       "13    [27.5151896549, 2.4848103451]   [0.917172988497, 0.0828270115032]   \n",
       "14                      [0.0, 30.0]                          [0.0, 1.0]   \n",
       "15   [2.02801517479, 27.9719848252]   [0.0676005058263, 0.932399494174]   \n",
       "16   [26.5861882672, 3.41381173276]    [0.886206275575, 0.113793724425]   \n",
       "17   [25.8019873912, 4.19801260875]    [0.860066246375, 0.139933753625]   \n",
       "18   [20.8412096143, 9.15879038571]    [0.694706987143, 0.305293012857]   \n",
       "19   [17.3386478658, 12.6613521342]    [0.577954928861, 0.422045071139]   \n",
       "\n",
       "   prediction  \n",
       "0         0.0  \n",
       "1         0.0  \n",
       "2         0.0  \n",
       "3         0.0  \n",
       "4         0.0  \n",
       "5         0.0  \n",
       "6         1.0  \n",
       "7         0.0  \n",
       "8         1.0  \n",
       "9         0.0  \n",
       "10        0.0  \n",
       "11        1.0  \n",
       "12        1.0  \n",
       "13        0.0  \n",
       "14        1.0  \n",
       "15        1.0  \n",
       "16        0.0  \n",
       "17        0.0  \n",
       "18        0.0  \n",
       "19        0.0  \n",
       "\n",
       "[20 rows x 24 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.limit(20).toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Write the predictions to CSV file in Kaggle specified format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.types import IntegerType\n",
    "\n",
    "csvPath = 'prediction.csv'\n",
    "test.select('PassengerId', 'prediction')\\\n",
    "    .coalesce(1)\\\n",
    "    .withColumn('Survived', test['prediction'].cast(IntegerType()))\\\n",
    "    .drop('prediction')\\\n",
    "    .write.csv(csvPath, header='true', mode='ignore')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}