{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.3 Machine Learning - Complex Features\n", "\n", "This notebook demonstates how to work with complex data sets that include both numerical and categorical features.\n", "\n", "We will use the `Adults` dataset which includes numerical and catgorical features of adults and our task is to predict if their income is above $50K. (`data/adult.data.bz2`)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n" ] } ], "source": [ "# lets analyze the data set description to obtain the names of the features\n", "\n", "import re\n", "with open('data/adult.names.txt', 'r') as f:\n", " featureNames = [ line.split(':')[0] for line in f.readlines() if re.match(r'^[a-z\\-]+:.*', line)]\n", "\n", "print(featureNames)\n", "columnNames = featureNames + ['income']" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: long (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: double (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: double (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: double (nullable = true)\n", " |-- capital-loss: double (nullable = true)\n", " |-- hours-per-week: double (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- income: string (nullable = true)\n", "\n" ] } ], "source": [ "# load the date to a DataFrame\n", "inputDF = spark.createDataFrame(spark.read.csv('data/adult.data.bz2', inferSchema=True).rdd, columnNames)\n", "inputDF.printSchema()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>workclass</th>\n", " <th>fnlwgt</th>\n", " <th>education</th>\n", " <th>education-num</th>\n", " <th>marital-status</th>\n", " <th>occupation</th>\n", " <th>relationship</th>\n", " <th>race</th>\n", " <th>sex</th>\n", " <th>capital-gain</th>\n", " <th>capital-loss</th>\n", " <th>hours-per-week</th>\n", " <th>native-country</th>\n", " <th>income</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>39</td>\n", " <td>State-gov</td>\n", " <td>77516.0</td>\n", " <td>Bachelors</td>\n", " <td>13.0</td>\n", " <td>Never-married</td>\n", " <td>Adm-clerical</td>\n", " <td>Not-in-family</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>2174.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>50</td>\n", " <td>Self-emp-not-inc</td>\n", " <td>83311.0</td>\n", " <td>Bachelors</td>\n", " <td>13.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Exec-managerial</td>\n", " <td>Husband</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>38</td>\n", " <td>Private</td>\n", " <td>215646.0</td>\n", " <td>HS-grad</td>\n", " <td>9.0</td>\n", " <td>Divorced</td>\n", " <td>Handlers-cleaners</td>\n", " <td>Not-in-family</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ "DataFrame[age: bigint, workclass: string, fnlwgt: double, education: string, education-num: double, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: double, capital-loss: double, hours-per-week: double, native-country: string, income: string]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(inputDF.limit(3))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>workclass</th>\n", " <th>fnlwgt</th>\n", " <th>education</th>\n", " <th>education-num</th>\n", " <th>marital-status</th>\n", " <th>occupation</th>\n", " <th>relationship</th>\n", " <th>race</th>\n", " <th>sex</th>\n", " <th>capital-gain</th>\n", " <th>capital-loss</th>\n", " <th>hours-per-week</th>\n", " <th>native-country</th>\n", " <th>income</th>\n", " <th>label</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>39</td>\n", " <td>State-gov</td>\n", " <td>77516.0</td>\n", " <td>Bachelors</td>\n", " <td>13.0</td>\n", " <td>Never-married</td>\n", " <td>Adm-clerical</td>\n", " <td>Not-in-family</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>2174.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>50</td>\n", " <td>Self-emp-not-inc</td>\n", " <td>83311.0</td>\n", " <td>Bachelors</td>\n", " <td>13.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Exec-managerial</td>\n", " <td>Husband</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>38</td>\n", " <td>Private</td>\n", " <td>215646.0</td>\n", " <td>HS-grad</td>\n", " <td>9.0</td>\n", " <td>Divorced</td>\n", " <td>Handlers-cleaners</td>\n", " <td>Not-in-family</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>53</td>\n", " <td>Private</td>\n", " <td>234721.0</td>\n", " <td>11th</td>\n", " <td>7.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Handlers-cleaners</td>\n", " <td>Husband</td>\n", " <td>Black</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>28</td>\n", " <td>Private</td>\n", " <td>338409.0</td>\n", " <td>Bachelors</td>\n", " <td>13.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Prof-specialty</td>\n", " <td>Wife</td>\n", " <td>Black</td>\n", " <td>Female</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>Cuba</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>37</td>\n", " <td>Private</td>\n", " <td>284582.0</td>\n", " <td>Masters</td>\n", " <td>14.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Exec-managerial</td>\n", " <td>Wife</td>\n", " <td>White</td>\n", " <td>Female</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>40.0</td>\n", " <td>United-States</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>49</td>\n", " <td>Private</td>\n", " <td>160187.0</td>\n", " <td>9th</td>\n", " <td>5.0</td>\n", " <td>Married-spouse-absent</td>\n", " <td>Other-service</td>\n", " <td>Not-in-family</td>\n", " <td>Black</td>\n", " <td>Female</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>16.0</td>\n", " <td>Jamaica</td>\n", " <td><=50K</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>52</td>\n", " <td>Self-emp-not-inc</td>\n", " <td>209642.0</td>\n", " <td>HS-grad</td>\n", " <td>9.0</td>\n", " <td>Married-civ-spouse</td>\n", " <td>Exec-managerial</td>\n", " <td>Husband</td>\n", " <td>White</td>\n", " <td>Male</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>45.0</td>\n", " <td>United-States</td>\n", " <td>>50K</td>\n", " <td>1.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ "DataFrame[age: bigint, workclass: string, fnlwgt: double, education: string, education-num: double, marital-status: string, occupation: string, relationship: string, race: string, sex: string, capital-gain: double, capital-loss: double, hours-per-week: double, native-country: string, income: string, label: double]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from pyspark.sql.functions import trim\n", "\n", "# create the numerical label column ( 1.0 if income > 50K)\n", "dataDF = inputDF.withColumn('label',trim(inputDF.income).startswith('>').cast('double'))\n", "display(dataDF.limit(8))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']\n", "root\n", " |-- age: long (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: double (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: double (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: double (nullable = true)\n", " |-- capital-loss: double (nullable = true)\n", " |-- hours-per-week: double (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- income: string (nullable = true)\n", " |-- label: double (nullable = true)\n", " |-- cat_workclass: double (nullable = true)\n", " |-- cat_education: double (nullable = true)\n", " |-- cat_marital-status: double (nullable = true)\n", " |-- cat_occupation: double (nullable = true)\n", " |-- cat_relationship: double (nullable = true)\n", " |-- cat_race: double (nullable = true)\n", " |-- cat_sex: double (nullable = true)\n", " |-- cat_native-country: double (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.ml.feature import StringIndexer\n", "from pyspark.ml import Pipeline\n", "from pyspark.sql.types import StringType\n", "\n", "# detect the categorial features from the schema\n", "catrgoricalFeatures = [ f.name for f in dataDF.schema.fields if f.dataType == StringType() and f.name!='income']\n", "print(catrgoricalFeatures)\n", "\n", "# create the categorical values indexer\n", "indexerPipeline = Pipeline(stages = [ StringIndexer(inputCol=feature, \n", " outputCol= \"cat_%s\"% feature, handleInvalid='skip') for feature in catrgoricalFeatures])\n", "\n", "pipelineModel = indexerPipeline.fit(dataDF)\n", "indexed_df = pipelineModel.transform(dataDF)\n", "indexed_df.printSchema()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: long (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: double (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: double (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: double (nullable = true)\n", " |-- capital-loss: double (nullable = true)\n", " |-- hours-per-week: double (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- income: string (nullable = true)\n", " |-- label: double (nullable = true)\n", " |-- cat_workclass: double (nullable = true)\n", " |-- cat_education: double (nullable = true)\n", " |-- cat_marital-status: double (nullable = true)\n", " |-- cat_occupation: double (nullable = true)\n", " |-- cat_relationship: double (nullable = true)\n", " |-- cat_race: double (nullable = true)\n", " |-- cat_sex: double (nullable = true)\n", " |-- cat_native-country: double (nullable = true)\n", " |-- cat_vector: vector (nullable = true)\n", " |-- cat_features: vector (nullable = true)\n", "\n", "+--------------------+\n", "| cat_features|\n", "+--------------------+\n", "|[4.0,2.0,1.0,3.0,...|\n", "|(8,[0,1,3],[1.0,2...|\n", "|(8,[2,3,4],[2.0,9...|\n", "|(8,[1,3,5],[5.0,9...|\n", "|[0.0,2.0,0.0,0.0,...|\n", "+--------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "from pyspark.ml.feature import VectorAssembler, VectorIndexer\n", "\n", "# assemble and mark the categorical features\n", "\n", "categorical_assembler = Pipeline(stages = [\n", " VectorAssembler(inputCols = [\"cat_%s\"%c for c in catrgoricalFeatures ], outputCol='cat_vector'), \n", " VectorIndexer(inputCol='cat_vector', outputCol='cat_features')\n", "])\n", "\n", "categorical_assembler_model = categorical_assembler.fit(indexed_df)\n", "cat_df = categorical_assembler_model.transform(indexed_df)\n", "cat_df.printSchema()\n", "cat_df.select('cat_features').show(5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: long (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: double (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: double (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: double (nullable = true)\n", " |-- capital-loss: double (nullable = true)\n", " |-- hours-per-week: double (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- income: string (nullable = true)\n", " |-- label: double (nullable = true)\n", " |-- cat_workclass: double (nullable = true)\n", " |-- cat_education: double (nullable = true)\n", " |-- cat_marital-status: double (nullable = true)\n", " |-- cat_occupation: double (nullable = true)\n", " |-- cat_relationship: double (nullable = true)\n", " |-- cat_race: double (nullable = true)\n", " |-- cat_sex: double (nullable = true)\n", " |-- cat_native-country: double (nullable = true)\n", " |-- cat_vector: vector (nullable = true)\n", " |-- cat_features: vector (nullable = true)\n", " |-- features: vector (nullable = true)\n", "\n", "+--------------------+\n", "| features|\n", "+--------------------+\n", "|[39.0,77516.0,40....|\n", "|(11,[0,1,2,3,4,6]...|\n", "|(11,[0,1,2,5,6,7]...|\n", "|(11,[0,1,2,4,6,8]...|\n", "|[28.0,338409.0,40...|\n", "+--------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "# assemble both numerical and categorical features\n", "feature_assembler = VectorAssembler(inputCols = ['age', 'fnlwgt', 'hours-per-week', 'cat_features'] , outputCol='features')\n", "feature_df = feature_assembler.transform(cat_df)\n", "\n", "feature_df.printSchema()\n", "feature_df.select('features').show(5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# split data into traning and testing sets with stratified sampline by label\n", "trainingDF = dataDF.sampleBy('label', fractions = {0.0: 0.7, 1.0:0.7}).cache()\n", "testDF = dataDF.subtract(trainingDF).cache()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "22658\n", "5414\n", "9890\n", "2425\n" ] } ], "source": [ "from pyspark.sql.functions import col\n", "\n", "# just check that it worked\n", "print(trainingDF.count())\n", "print(trainingDF.where(col('label') == 1.0).count())\n", "print(testDF.count())\n", "print(testDF.where(col('label') == 1.0).count())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from pyspark.ml.classification import RandomForestClassifier\n", "\n", "# build and train random forest classifier\n", "rfClassifier = RandomForestClassifier(featuresCol='features', labelCol='label', maxBins=50)\n", "rfPipeline = Pipeline(stages = [indexerPipeline, categorical_assembler, feature_assembler, rfClassifier])\n", "rfPipelineModel = rfPipeline.fit(trainingDF)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random forest AUC: 0.886826830643\n" ] } ], "source": [ "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "\n", "# evaluate random forest AUC\n", "evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\")\n", "print(\"Random forest AUC: %s\" % evaluator.evaluate(rfPipelineModel.transform(testDF)))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random forest accuracy: 0.824350288199\n" ] } ], "source": [ "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", "# evaluate random forest accuracy\n", "evaluator = MulticlassClassificationEvaluator(predictionCol=\"prediction\", metricName='accuracy')\n", "print(\"Random forest accuracy: %s\" % evaluator.evaluate(rfPipelineModel.transform(testDF)))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "predictionDF = rfPipelineModel.transform(testDF).select('label','prediction').cache()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(7104, 360)\n", "(1377, 1048)\n" ] } ], "source": [ "# compute the confusion matrix\n", "tpr = predictionDF.where((col('label') == 1.0) & (col('prediction') ==1.0)).count()\n", "tnr = predictionDF.where((col('label') == 0.0) & (col('prediction') ==0.0)).count()\n", "fpr = predictionDF.where((col('label') == 0.0) & (col('prediction') ==1.0)).count()\n", "fnr = predictionDF.where((col('label') == 1.0) & (col('prediction') ==0.0)).count()\n", "\n", "print(tnr, fpr)\n", "print(fnr, tpr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use logistic regression we need to encode categorical features differently - using one hot encoding" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- age: long (nullable = true)\n", " |-- workclass: string (nullable = true)\n", " |-- fnlwgt: double (nullable = true)\n", " |-- education: string (nullable = true)\n", " |-- education-num: double (nullable = true)\n", " |-- marital-status: string (nullable = true)\n", " |-- occupation: string (nullable = true)\n", " |-- relationship: string (nullable = true)\n", " |-- race: string (nullable = true)\n", " |-- sex: string (nullable = true)\n", " |-- capital-gain: double (nullable = true)\n", " |-- capital-loss: double (nullable = true)\n", " |-- hours-per-week: double (nullable = true)\n", " |-- native-country: string (nullable = true)\n", " |-- income: string (nullable = true)\n", " |-- label: double (nullable = true)\n", " |-- cat_workclass: double (nullable = true)\n", " |-- cat_education: double (nullable = true)\n", " |-- cat_marital-status: double (nullable = true)\n", " |-- cat_occupation: double (nullable = true)\n", " |-- cat_relationship: double (nullable = true)\n", " |-- cat_race: double (nullable = true)\n", " |-- cat_sex: double (nullable = true)\n", " |-- cat_native-country: double (nullable = true)\n", " |-- enc_workclass: vector (nullable = true)\n", " |-- enc_education: vector (nullable = true)\n", " |-- enc_marital-status: vector (nullable = true)\n", " |-- enc_occupation: vector (nullable = true)\n", " |-- enc_relationship: vector (nullable = true)\n", " |-- enc_race: vector (nullable = true)\n", " |-- enc_sex: vector (nullable = true)\n", " |-- enc_native-country: vector (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.ml.feature import OneHotEncoder\n", "\n", "# create and test the OneHotEncoder for categorial features\n", "categorical_assembler = Pipeline(stages = [ OneHotEncoder(inputCol=\"cat_%s\" % feature, \n", " outputCol= \"enc_%s\"% feature) for feature in catrgoricalFeatures])\n", "\n", "categorical_assembler_model = categorical_assembler.fit(indexed_df)\n", "cat_df = categorical_assembler_model.transform(indexed_df)\n", "cat_df.printSchema()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from pyspark.ml.classification import LogisticRegression\n", "\n", "\n", "# train the logistic regression classifier\n", "lrClassifier = LogisticRegression(featuresCol='features', labelCol='label')\n", "encFetureAssembler = VectorAssembler(inputCols = ['age', 'fnlwgt', 'hours-per-week'] +\n", " [\"enc_%s\"% feature for feature in catrgoricalFeatures ] \n", " , outputCol='features')\n", "lrPipeline = Pipeline(stages = [indexerPipeline, categorical_assembler, encFetureAssembler, lrClassifier])\n", "lrPipeline = lrPipeline.fit(trainingDF)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Logisting regression AUC: 0.891875117402\n" ] } ], "source": [ "evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\")\n", "print(\"Logisting regression AUC: %s\" % evaluator.evaluate(lrPipeline.transform(testDF)))" ] } ], "metadata": { "kernelspec": { "display_name": "PySpark", "language": "python", "name": "pyspark" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 1 }