{ "metadata": { "name": "", "signature": "sha256:99d5f098d44bc45277f739cac7a9b172e47815f026da055fb0ad3e4af55a664e" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 1, "metadata": {}, "source": [ "MLlib: Classification with Logistic Regression " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction to Spark with Python, by Jose A. Dianes](https://github.com/jadianes/spark-py-notebooks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will use Spark's machine learning library [MLlib](https://spark.apache.org/docs/latest/mllib-guide.html) to build a **Logistic Regression** classifier for network attack detection. We will use the complete [KDD Cup 1999](http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html) datasets in order to test Spark capabilities with large datasets. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Additionally, we will introduce two ways of performing **model selection**: by using a correlation matrix and by using hypothesis testing. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At the time of processing this notebook, our Spark cluster contains: \n", "\n", "- Eight nodes, with one of them acting as master and the rest as workers. \n", "- Each node contains 8Gb of RAM, with 6Gb being used for each node. \n", "- Each node has a 2.4Ghz Intel dual core processor. \n" ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Getting the data and creating the RDD" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we said, this time we will use the complete dataset provided for the [KDD Cup 1999](http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html), containing nearly half million network interactions. The file is provided as a Gzip file that we will download locally. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "import urllib\n", "f = urllib.urlretrieve (\"http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz\", \"kddcup.data.gz\")" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 1 }, { "cell_type": "code", "collapsed": false, "input": [ "data_file = \"./kddcup.data.gz\"\n", "raw_data = sc.textFile(data_file)\n", "\n", "print \"Train data size is {}\".format(raw_data.count())" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Train data size is 4898431\n" ] } ], "prompt_number": 2 }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [KDD Cup 1999](http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html) also provide test data that we will load in a separate RDD. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "ft = urllib.urlretrieve(\"http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz\", \"corrected.gz\")" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 3 }, { "cell_type": "code", "collapsed": false, "input": [ "test_data_file = \"./corrected.gz\"\n", "test_raw_data = sc.textFile(test_data_file)\n", "\n", "print \"Test data size is {}\".format(test_raw_data.count())" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Test data size is 311029\n" ] } ], "prompt_number": 4 }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Labeled Points" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A labeled point is a local vector associated with a label/response. In [MLlib](https://spark.apache.org/docs/latest/mllib-data-types.html#labeled-point), labeled points are used in supervised learning algorithms and they are stored as doubles. For binary classification, a label should be either 0 (negative) or 1 (positive). " ] }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Preparing the training data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our case, we are interested in detecting network attacks in general. We don't need to detect which type of attack we are dealing with. Therefore we will tag each network interaction as non attack (i.e. 'normal' tag) or attack (i.e. anything else but 'normal'). " ] }, { "cell_type": "code", "collapsed": false, "input": [ "from pyspark.mllib.regression import LabeledPoint\n", "from numpy import array\n", "\n", "def parse_interaction(line):\n", " line_split = line.split(\",\")\n", " # leave_out = [1,2,3,41]\n", " clean_line_split = line_split[0:1]+line_split[4:41]\n", " attack = 1.0\n", " if line_split[41]=='normal.':\n", " attack = 0.0\n", " return LabeledPoint(attack, array([float(x) for x in clean_line_split]))\n", "\n", "training_data = raw_data.map(parse_interaction)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 5 }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Preparing the test data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similarly, we process our test data file. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "test_data = test_raw_data.map(parse_interaction)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 6 }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Detecting network attacks using Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a binary response. Spark implements [two algorithms](https://spark.apache.org/docs/latest/mllib-linear-methods.html#logistic-regression) to solve logistic regression: mini-batch gradient descent and L-BFGS. L-BFGS is recommended over mini-batch gradient descent for faster convergence. " ] }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Training a classifier" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from pyspark.mllib.classification import LogisticRegressionWithLBFGS\n", "from time import time\n", "\n", "# Build the model\n", "t0 = time()\n", "logit_model = LogisticRegressionWithLBFGS.train(training_data)\n", "tt = time() - t0\n", "\n", "print \"Classifier trained in {} seconds\".format(round(tt,3))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Classifier trained in 365.599 seconds\n" ] } ], "prompt_number": 7 }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Evaluating the model on new data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to measure the classification error on our test data, we use `map` on the `test_data` RDD and the model to predict each test point class. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "labels_and_preds = test_data.map(lambda p: (p.label, logit_model.predict(p.features)))" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 8 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Classification results are returned in pars, with the actual test label and the predicted one. This is used to calculate the classification error by using `filter` and `count` as follows." ] }, { "cell_type": "code", "collapsed": false, "input": [ "t0 = time()\n", "test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data.count())\n", "tt = time() - t0\n", "\n", "print \"Prediction made in {} seconds. Test accuracy is {}\".format(round(tt,3), round(test_accuracy,4))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Prediction made in 34.291 seconds. Test accuracy is 0.9164\n" ] } ], "prompt_number": 9 }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's a decent accuracy. We know that there is space for improvement with a better variable selection and also by including categorical variables (e.g. we have excluded 'protocol' and 'service')." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Model selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model or feature selection helps us building more interpretable and efficient models (or a classifier in this case). For illustrative purposes, we will follow two different approaches, correlation matrices and hypothesis testing. " ] }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Using a correlation matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a [previous notebook](https://github.com/jadianes/spark-py-notebooks/blob/master/nb7-mllib-statistics/nb7-mllib-statistics.ipynb) we calculated a correlation matrix in order to find predictors that are highly correlated. There are many possible choices there in order to simplify our model. We can pick different combinations of correlated variables and leave just those that represent them. The reader can try different combinations. Here we will choose the following for illustrative purposes: " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- From the description of the [KDD Cup 99 task](http://kdd.ics.uci.edu/databases/kddcup99/task.html) we know that the variable `dst_host_same_src_port_rate` references the percentage of the last 100 connections to the same port, for the same destination host. In our correlation matrix (and auxiliary dataframes) we find that this one is highly and positively correlated to `src_bytes` and `srv_count`. The former is the number of bytes sent form source to destination. The later is the number of connections to the same service as the current connection in the past 2 seconds. We decide not to include `dst_host_same_src_port_rate` in our model since we include the other two. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Variables `serror_rate` and `srv_error_rate` (% of connections that have *SYN* errors for same host and same service respectively) are highly positively correlated. Moreover, the set of variables that they highly correlate with are pretty much the same. They look like contributing very similarly to our model. We will keep just `serror_rate`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- A similar situation happens with `rerror_rate` and `srv_rerror_rate` (% of connections that have *REJ* errors) so we will keep just `rerror_rate`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Same thing with variables prefixed with `dst_host_` for the previous ones (e.g. `dst_host_srv_serror_rate`). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will stop here, although the reader can keep experimenting removing correlated variables has before (e.g. `same_srv_rate` and `diff_srv_rate` are good candidates. Our list of variables we will drop includes: \n", "- `dst_host_same_src_port_rate`, (column 35). \n", "- `srv_serror_rate` (column 25). \n", "- `srv_rerror_rate` (column 27). \n", "- `dst_host_srv_serror_rate` (column 38). \n", "- `dst_host_srv_rerror_rate` (column 40). " ] }, { "cell_type": "heading", "level": 4, "metadata": {}, "source": [ "Evaluating the new model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's proceed with the evaluation of our reduced model. First we need to provide training and testing datasets containing just the selected variables. For that we will define a new function to parse the raw data that keeps just what we need. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "def parse_interaction_corr(line):\n", " line_split = line.split(\",\")\n", " # leave_out = [1,2,3,25,27,35,38,40,41]\n", " clean_line_split = line_split[0:1]+line_split[4:25]+line_split[26:27]+line_split[28:35]+line_split[36:38]+line_split[39:40]\n", " attack = 1.0\n", " if line_split[41]=='normal.':\n", " attack = 0.0\n", " return LabeledPoint(attack, array([float(x) for x in clean_line_split]))\n", "\n", "corr_reduced_training_data = raw_data.map(parse_interaction_corr)\n", "corr_reduced_test_data = test_raw_data.map(parse_interaction_corr)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 10 }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Note: when selecting elements in the split, a list comprehension with a `leave_out` list for filtering is more Pythonic than slicing and concatenation indeed, but we have found it less efficient. This is very important when dealing with large datasets. The `parse_interaction` functions will be called for every element in the RDD, so we need to make them as efficient as possible.* " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can train the model. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "# Build the model\n", "t0 = time()\n", "logit_model_2 = LogisticRegressionWithLBFGS.train(corr_reduced_training_data)\n", "tt = time() - t0\n", "\n", "print \"Classifier trained in {} seconds\".format(round(tt,3))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Classifier trained in 319.124 seconds\n" ] } ], "prompt_number": 11 }, { "cell_type": "markdown", "metadata": {}, "source": [ "And evaluate its accuracy on the test data. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "labels_and_preds = corr_reduced_test_data.map(lambda p: (p.label, logit_model_2.predict(p.features)))\n", "t0 = time()\n", "test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(corr_reduced_test_data.count())\n", "tt = time() - t0\n", "\n", "print \"Prediction made in {} seconds. Test accuracy is {}\".format(round(tt,3), round(test_accuracy,4))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Prediction made in 34.102 seconds. Test accuracy is 0.8599\n" ] } ], "prompt_number": 12 }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, we have reduced accuracy and also training time. However this doesn't seem a good trade! At least not for logistic regression and considering the predictors we decided to leave out. We have lost quite a lot of accuracy and have not gained a lot of execution time during training. Moreover prediction time didn't improve." ] }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Using hypothesis testing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hypothesis testing is a powerful tool in statistical inference and learning to determine whether a result is statistically significant. MLlib supports Pearson's chi-squared (\u03c72) tests for goodness of fit and independence. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. Moreover, MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests. Again, these methods are part of the [`Statistics`](https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) package. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our case we want to perform some sort of feature selection, so we will provide an RDD of `LabeledPoint`. Internally, MLlib will calculate a contingency matrix and perform the Persons's chi-squared (\u03c72) test. Features need to be categorical. Real-valued features will be treated as categorical in each of its different values. There is a limit of 1000 different values, so we need either to leave out some features or categorise them. In this case, we will consider just features that either take boolean values or just a few different numeric values in our dataset. We could overcome this limitation by defining a more complex `parse_interaction` function that categorises each feature properly. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "feature_names = [\"land\",\"wrong_fragment\",\n", " \"urgent\",\"hot\",\"num_failed_logins\",\"logged_in\",\"num_compromised\",\n", " \"root_shell\",\"su_attempted\",\"num_root\",\"num_file_creations\",\n", " \"num_shells\",\"num_access_files\",\"num_outbound_cmds\",\n", " \"is_hot_login\",\"is_guest_login\",\"count\",\"srv_count\",\"serror_rate\",\n", " \"srv_serror_rate\",\"rerror_rate\",\"srv_rerror_rate\",\"same_srv_rate\",\n", " \"diff_srv_rate\",\"srv_diff_host_rate\",\"dst_host_count\",\"dst_host_srv_count\",\n", " \"dst_host_same_srv_rate\",\"dst_host_diff_srv_rate\",\"dst_host_same_src_port_rate\",\n", " \"dst_host_srv_diff_host_rate\",\"dst_host_serror_rate\",\"dst_host_srv_serror_rate\",\n", " \"dst_host_rerror_rate\",\"dst_host_srv_rerror_rate\"]" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 13 }, { "cell_type": "code", "collapsed": false, "input": [ "def parse_interaction_categorical(line):\n", " line_split = line.split(\",\")\n", " clean_line_split = line_split[6:41]\n", " attack = 1.0\n", " if line_split[41]=='normal.':\n", " attack = 0.0\n", " return LabeledPoint(attack, array([float(x) for x in clean_line_split]))\n", "\n", "training_data_categorical = raw_data.map(parse_interaction_categorical)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 14 }, { "cell_type": "code", "collapsed": false, "input": [ "from pyspark.mllib.stat import Statistics\n", "\n", "chi = Statistics.chiSqTest(training_data_categorical)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 15 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can check the resulting values after putting them into a *Pandas* data frame. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "import pandas as pd\n", "pd.set_option('display.max_colwidth', 30)\n", "\n", "records = [(result.statistic, result.pValue) for result in chi]\n", "\n", "chi_df = pd.DataFrame(data=records, index= feature_names, columns=[\"Statistic\",\"p-value\"])\n", "\n", "chi_df " ], "language": "python", "metadata": {}, "outputs": [ { "html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Statisticp-value
land 0.464984 0.495304
wrong_fragment 306.855508 0.000000
urgent 38.718442 0.000000
hot 19463.314328 0.000000
num_failed_logins 127.769106 0.000000
logged_in 3273098.055702 0.000000
num_compromised 2011.862724 0.000000
root_shell 1044.917853 0.000000
su_attempted 433.999968 0.000000
num_root 22871.675646 0.000000
num_file_creations 9179.739210 0.000000
num_shells 1380.027801 0.000000
num_access_files 18734.770753 0.000000
num_outbound_cmds 0.000000 1.000000
is_hot_login 8.070987 0.004498
is_guest_login 13500.511066 0.000000
count 4546398.359513 0.000000
srv_count 2296059.697747 0.000000
serror_rate 268419.919232 0.000000
srv_serror_rate 302626.967659 0.000000
rerror_rate 9860.453313 0.000000
srv_rerror_rate 32476.385256 0.000000
same_srv_rate 399912.444046 0.000000
diff_srv_rate 390999.770051 0.000000
srv_diff_host_rate 1365458.716670 0.000000
dst_host_count 2520478.944345 0.000000
dst_host_srv_count 1439086.154193 0.000000
dst_host_same_srv_rate 1237931.905888 0.000000
dst_host_diff_srv_rate 1339002.406545 0.000000
dst_host_same_src_port_rate 2915195.296767 0.000000
dst_host_srv_diff_host_rate 2226290.874366 0.000000
dst_host_serror_rate 407454.601492 0.000000
dst_host_srv_serror_rate 455099.001618 0.000000
dst_host_rerror_rate 136478.956325 0.000000
dst_host_srv_rerror_rate 254547.369785 0.000000
\n", "
" ], "metadata": {}, "output_type": "pyout", "prompt_number": 16, "text": [ " Statistic p-value\n", "land 0.464984 0.495304\n", "wrong_fragment 306.855508 0.000000\n", "urgent 38.718442 0.000000\n", "hot 19463.314328 0.000000\n", "num_failed_logins 127.769106 0.000000\n", "logged_in 3273098.055702 0.000000\n", "num_compromised 2011.862724 0.000000\n", "root_shell 1044.917853 0.000000\n", "su_attempted 433.999968 0.000000\n", "num_root 22871.675646 0.000000\n", "num_file_creations 9179.739210 0.000000\n", "num_shells 1380.027801 0.000000\n", "num_access_files 18734.770753 0.000000\n", "num_outbound_cmds 0.000000 1.000000\n", "is_hot_login 8.070987 0.004498\n", "is_guest_login 13500.511066 0.000000\n", "count 4546398.359513 0.000000\n", "srv_count 2296059.697747 0.000000\n", "serror_rate 268419.919232 0.000000\n", "srv_serror_rate 302626.967659 0.000000\n", "rerror_rate 9860.453313 0.000000\n", "srv_rerror_rate 32476.385256 0.000000\n", "same_srv_rate 399912.444046 0.000000\n", "diff_srv_rate 390999.770051 0.000000\n", "srv_diff_host_rate 1365458.716670 0.000000\n", "dst_host_count 2520478.944345 0.000000\n", "dst_host_srv_count 1439086.154193 0.000000\n", "dst_host_same_srv_rate 1237931.905888 0.000000\n", "dst_host_diff_srv_rate 1339002.406545 0.000000\n", "dst_host_same_src_port_rate 2915195.296767 0.000000\n", "dst_host_srv_diff_host_rate 2226290.874366 0.000000\n", "dst_host_serror_rate 407454.601492 0.000000\n", "dst_host_srv_serror_rate 455099.001618 0.000000\n", "dst_host_rerror_rate 136478.956325 0.000000\n", "dst_host_srv_rerror_rate 254547.369785 0.000000" ] } ], "prompt_number": 16 }, { "cell_type": "markdown", "metadata": {}, "source": [ "From that we conclude that predictors `land` and `num_outbound_cmds` could be removed from our model without affecting our accuracy dramatically. Let's try this. " ] }, { "cell_type": "heading", "level": 4, "metadata": {}, "source": [ "Evaluating the new model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So the only modification to our first `parse_interaction` function will be to remove columns 6 and 19, corresponding to the two predictors that we want not to be part of our model. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "def parse_interaction_chi(line):\n", " line_split = line.split(\",\")\n", " # leave_out = [1,2,3,6,19,41]\n", " clean_line_split = line_split[0:1] + line_split[4:6] + line_split[7:19] + line_split[20:41]\n", " attack = 1.0\n", " if line_split[41]=='normal.':\n", " attack = 0.0\n", " return LabeledPoint(attack, array([float(x) for x in clean_line_split]))\n", "\n", "training_data_chi = raw_data.map(parse_interaction_chi)\n", "test_data_chi = test_raw_data.map(parse_interaction_chi)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 17 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we build the logistic regression classifier again. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "# Build the model\n", "t0 = time()\n", "logit_model_chi = LogisticRegressionWithLBFGS.train(training_data_chi)\n", "tt = time() - t0\n", "\n", "print \"Classifier trained in {} seconds\".format(round(tt,3))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Classifier trained in 356.507 seconds\n" ] } ], "prompt_number": 18 }, { "cell_type": "markdown", "metadata": {}, "source": [ "And evaluate in test data. " ] }, { "cell_type": "code", "collapsed": false, "input": [ "labels_and_preds = test_data_chi.map(lambda p: (p.label, logit_model_chi.predict(p.features)))\n", "t0 = time()\n", "test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data_chi.count())\n", "tt = time() - t0\n", "\n", "print \"Prediction made in {} seconds. Test accuracy is {}\".format(round(tt,3), round(test_accuracy,4))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "Prediction made in 34.656 seconds. Test accuracy is 0.9164\n" ] } ], "prompt_number": 19 }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we can see that, by using hypothesis testing, we have been able to remove two predictors without diminishing testing accuracy at all. Training time improved a bit as well. This might now seem like a big model reduction, but it is something when dealing with big data sources. Moreover, we should be able to categorise those five predictors we have left out for different reasons and, either include them in the model or leave them out if they aren't statistically significant. \n", "\n", "Additionally, we could try to remove some of those predictors that are highly correlated, trying not to reduce accuracy too much. In the end, we should end up with a model easier to understand and use. " ] } ], "metadata": {} } ] }