{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# spark \n", "from pyspark import SparkConf, SparkContext\n", "from pyspark.sql import SQLContext, SparkSession\n", "\n", "# pipeline \n", "from pyspark.ml import Pipeline\n", "\n", "# model\n", "from pyspark.ml.classification import (RandomForestClassifier,\n", " GBTClassifier,\n", " DecisionTreeClassifier)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# config \n", "conf = SparkConf().setAppName(\"building a TREE MODEL\")\n", "sc = SparkContext(conf=conf)\n", "sqlCtx = SQLContext(sc)\n", "spark = SparkSession.builder.enableHiveSupport().getOrCreate()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "data = spark.read.format('libsvm').load('sample_libsvm_data.txt')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+\n", "|label| features|\n", "+-----+--------------------+\n", "| 0.0|(692,[127,128,129...|\n", "| 1.0|(692,[158,159,160...|\n", "| 1.0|(692,[124,125,126...|\n", "| 1.0|(692,[152,153,154...|\n", "| 1.0|(692,[151,152,153...|\n", "| 0.0|(692,[129,130,131...|\n", "| 1.0|(692,[158,159,160...|\n", "| 1.0|(692,[99,100,101,...|\n", "| 0.0|(692,[154,155,156...|\n", "| 0.0|(692,[127,128,129...|\n", "| 1.0|(692,[154,155,156...|\n", "| 0.0|(692,[153,154,155...|\n", "| 0.0|(692,[151,152,153...|\n", "| 1.0|(692,[129,130,131...|\n", "| 0.0|(692,[154,155,156...|\n", "| 1.0|(692,[150,151,152...|\n", "| 0.0|(692,[124,125,126...|\n", "| 0.0|(692,[152,153,154...|\n", "| 1.0|(692,[97,98,99,12...|\n", "| 1.0|(692,[124,125,126...|\n", "+-----+--------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "data.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_data, test_data = data.randomSplit([0.7, 0.3])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# build the models\n", "dtc = DecisionTreeClassifier()\n", "rfc = RandomForestClassifier()\n", "gbt = GBTClassifier()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# train \n", "dtc_model = dtc.fit(train_data)\n", "rfc_model = rfc.fit(train_data)\n", "gbt_model = gbt.fit(train_data)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# predict \n", "dtc_preds = dtc_model.transform(test_data)\n", "rfc_preds = rfc_model.transform(test_data)\n", "gbt_preds = gbt_model.transform(test_data)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+-------------+-----------+----------+\n", "|label| features|rawPrediction|probability|prediction|\n", "+-----+--------------------+-------------+-----------+----------+\n", "| 0.0|(692,[95,96,97,12...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[100,101,102...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[121,122,123...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[123,124,125...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[123,124,125...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[124,125,126...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[124,125,126...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[125,126,127...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[126,127,128...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[126,127,128...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[127,128,129...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[127,128,129...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[153,154,155...| [28.0,0.0]| [1.0,0.0]| 0.0|\n", "| 0.0|(692,[154,155,156...| [0.0,37.0]| [0.0,1.0]| 1.0|\n", "| 0.0|(692,[234,235,237...| [0.0,1.0]| [0.0,1.0]| 1.0|\n", "| 1.0|(692,[100,101,102...| [0.0,37.0]| [0.0,1.0]| 1.0|\n", "| 1.0|(692,[123,124,125...| [0.0,37.0]| [0.0,1.0]| 1.0|\n", "| 1.0|(692,[123,124,125...| [0.0,1.0]| [0.0,1.0]| 1.0|\n", "| 1.0|(692,[123,124,125...| [0.0,37.0]| [0.0,1.0]| 1.0|\n", "| 1.0|(692,[125,126,153...| [0.0,37.0]| [0.0,1.0]| 1.0|\n", "+-----+--------------------+-------------+-----------+----------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# show the results \n", "dtc_preds.show()\n", "#gbt_preds.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# MulticlassClassificationEvaluator works on the binary class dataset as well \n", "from pyspark.ml.evaluation import MulticlassClassificationEvaluator" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DTC (DECISION TREE) ACCURACY : \n" ] }, { "data": { "text/plain": [ "0.9117647058823529" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print ('DTC (DECISION TREE) ACCURACY : ')\n", "acc_eval.evaluate(dtc_preds)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RFC (RANDOM FOREST) ACCURACY : \n" ] }, { "data": { "text/plain": [ "1.0" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print ('RFC (RANDOM FOREST) ACCURACY : ')\n", "acc_eval.evaluate(rfc_preds)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# FEATURE IMPORTANCE " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SparseVector(692, {100: 0.0036, 185: 0.0029, 272: 0.0416, 292: 0.0026, 295: 0.0026, 300: 0.0107, 317: 0.0393, 322: 0.0027, 325: 0.005, 343: 0.003, 350: 0.045, 351: 0.0443, 355: 0.0023, 359: 0.0026, 374: 0.0696, 377: 0.0792, 379: 0.0471, 401: 0.0299, 403: 0.0027, 406: 0.0477, 411: 0.0039, 415: 0.0084, 426: 0.0057, 428: 0.0447, 434: 0.0618, 455: 0.0475, 456: 0.0107, 457: 0.0113, 462: 0.0471, 463: 0.0467, 464: 0.0034, 490: 0.0464, 491: 0.0061, 510: 0.0162, 511: 0.0393, 512: 0.0033, 517: 0.0452, 526: 0.0121, 540: 0.05, 598: 0.0001, 637: 0.0027, 661: 0.0031})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc_model.featureImportances" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# end of 13.46\n", "# next : 14 " ] } ], "metadata": { "kernelspec": { "display_name": "Environment (conda_pyspark_)", "language": "python", "name": "conda_pyspark_" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.5" } }, "nbformat": 4, "nbformat_minor": 2 }