{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Aerospike Connect for Spark - SparkML Prediction Model Tutorial\n", "## Tested with Java 8, Spark 3.0.0, Python 3.7, and Aerospike Spark Connector 3.0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "Build a linear regression model to predict birth weight using Aerospike Database and Spark.\n", "Here are the features used:\n", "- gestation weeks\n", "- mother’s age\n", "- father’s age\n", "- mother’s weight gain during pregnancy\n", "- [Apgar score](https://en.wikipedia.org/wiki/Apgar_score)\n", "\n", "Aerospike is used to store the Natality dataset that is published by CDC. The table is accessed in Apache Spark using the Aerospike Spark Connector, and Spark ML is used to build and evaluate the model. The model can later be converted to PMML and deployed on your inference server for predictions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prerequisites\n", "\n", "1. Load Aerospike server if not alrady available - docker run -d --name aerospike -p 3000:3000 -p 3001:3001 -p 3002:3002 -p 3003:3003 aerospike\n", "2. Feature key needs to be located in AS_FEATURE_KEY_PATH\n", "3. [Download the connector](https://www.aerospike.com/enterprise/download/connectors/aerospike-spark/3.0.0/)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#IP Address or DNS name for one host in your Aerospike cluster. \n", "#A seed address for the Aerospike database cluster is required\n", "AS_HOST =\"127.0.0.1\"\n", "# Name of one of your namespaces. Type 'show namespaces' at the aql prompt if you are not sure\n", "AS_NAMESPACE = \"test\" \n", "AS_FEATURE_KEY_PATH = \"/etc/aerospike/features.conf\"\n", "AEROSPIKE_SPARK_JAR_VERSION=\"3.0.0\"\n", "\n", "AS_PORT = 3000 # Usually 3000, but change here if not\n", "AS_CONNECTION_STRING = AS_HOST + \":\"+ str(AS_PORT)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "#Locate the Spark installation - this'll use the SPARK_HOME environment variable\n", "\n", "import findspark\n", "findspark.init()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "aerospike-spark-assembly-3.0.0.jar already downloaded\n" ] } ], "source": [ "# Below will help you download the Spark Connector Jar if you haven't done so already.\n", "import urllib\n", "import os\n", "\n", "def aerospike_spark_jar_download_url(version=AEROSPIKE_SPARK_JAR_VERSION):\n", " DOWNLOAD_PREFIX=\"https://www.aerospike.com/enterprise/download/connectors/aerospike-spark/\"\n", " DOWNLOAD_SUFFIX=\"/artifact/jar\"\n", " AEROSPIKE_SPARK_JAR_DOWNLOAD_URL = DOWNLOAD_PREFIX+AEROSPIKE_SPARK_JAR_VERSION+DOWNLOAD_SUFFIX\n", " return AEROSPIKE_SPARK_JAR_DOWNLOAD_URL\n", "\n", "def download_aerospike_spark_jar(version=AEROSPIKE_SPARK_JAR_VERSION):\n", " JAR_NAME=\"aerospike-spark-assembly-\"+AEROSPIKE_SPARK_JAR_VERSION+\".jar\"\n", " if(not(os.path.exists(JAR_NAME))) :\n", " urllib.request.urlretrieve(aerospike_spark_jar_download_url(),JAR_NAME)\n", " else :\n", " print(JAR_NAME+\" already downloaded\")\n", " return os.path.join(os.getcwd(),JAR_NAME)\n", "\n", "AEROSPIKE_JAR_PATH=download_aerospike_spark_jar()\n", "os.environ[\"PYSPARK_SUBMIT_ARGS\"] = '--jars ' + AEROSPIKE_JAR_PATH + ' pyspark-shell'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import pyspark\n", "from pyspark.context import SparkContext\n", "from pyspark.sql.context import SQLContext\n", "from pyspark.sql.session import SparkSession\n", "from pyspark.ml.linalg import Vectors\n", "from pyspark.ml.regression import LinearRegression\n", "from pyspark.sql.types import StringType, StructField, StructType, ArrayType, IntegerType, MapType, LongType, DoubleType" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Spark Verison: 3.0.0\n" ] } ], "source": [ "#Get a spark session object and set required Aerospike configuration properties\n", "sc = SparkContext.getOrCreate()\n", "print(\"Spark Verison:\", sc.version)\n", "\n", "spark = SparkSession(sc)\n", "sqlContext = SQLContext(sc)\n", "\n", "spark.conf.set(\"aerospike.namespace\",AS_NAMESPACE)\n", "spark.conf.set(\"aerospike.seedhost\",AS_CONNECTION_STRING)\n", "spark.conf.set(\"aerospike.keyPath\",AS_FEATURE_KEY_PATH )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Load Data into a DataFrame" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------------------+---------+------------+-------+-------------+---------------+-------------+----------+----------+----------+\n", "|__key| __digest| __expiry|__generation| __ttl| weight_pnd|weight_gain_pnd|gstation_week|apgar_5min|mother_age|father_age|\n", "+-----+--------------------+---------+------------+-------+-------------+---------------+-------------+----------+----------+----------+\n", "| null|[00 E0 68 A0 09 5...|354071840| 1|2367835| 6.9996768185| 99| 36| 99| 13| 15|\n", "| null|[01 B0 1F 4D D6 9...|354071839| 1|2367834| 5.291094288| 18| 40| 9| 14| 99|\n", "| null|[02 C0 93 23 F1 1...|354071837| 1|2367832| 6.8122838958| 24| 39| 9| 42| 36|\n", "| null|[02 B0 C4 C7 3B F...|354071838| 1|2367833|7.67649596284| 99| 39| 99| 14| 99|\n", "| null|[02 70 2A 45 E4 2...|354071843| 1|2367838| 7.8594796403| 40| 39| 8| 13| 99|\n", "+-----+--------------------+---------+------------+-------+-------------+---------------+-------------+----------+----------+----------+\n", "only showing top 5 rows\n", "\n", "Inferred Schema along with Metadata.\n", "root\n", " |-- __key: string (nullable = true)\n", " |-- __digest: binary (nullable = false)\n", " |-- __expiry: integer (nullable = false)\n", " |-- __generation: integer (nullable = false)\n", " |-- __ttl: integer (nullable = false)\n", " |-- weight_pnd: double (nullable = true)\n", " |-- weight_gain_pnd: long (nullable = true)\n", " |-- gstation_week: long (nullable = true)\n", " |-- apgar_5min: long (nullable = true)\n", " |-- mother_age: long (nullable = true)\n", " |-- father_age: long (nullable = true)\n", "\n" ] } ], "source": [ "as_data=spark \\\n", ".read \\\n", ".format(\"aerospike\") \\\n", ".option(\"aerospike.set\", \"natality\").load()\n", "\n", "as_data.show(5)\n", "\n", "print(\"Inferred Schema along with Metadata.\")\n", "as_data.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To speed up the load process at scale, use the [knobs](https://www.aerospike.com/docs/connect/processing/spark/performance.html) available in the Aerospike Spark Connector. \n", "For example, **spark.conf.set(\"aerospike.partition.factor\", 15 )** will map 4096 Aerospike partitions to 32K Spark partitions. (Note: Please configure this carefully based on the available resources (CPU threads) in your system.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Prep data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+---------------+-------------+----------+----------+----------+\n", "| weight_pnd|weight_gain_pnd|gstation_week|apgar_5min|mother_age|father_age|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "| 7.5398093604| 38| 39| 9| 42| 41|\n", "| 7.3634395508| 25| 37| 9| 14| 18|\n", "| 7.06361087448| 26| 39| 9| 42| 28|\n", "|6.1244416383599996| 20| 37| 9| 44| 41|\n", "| 7.06361087448| 49| 38| 9| 14| 18|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "only showing top 5 rows\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
summarycountmeanstddevminmax
weight_pnd78977.1073822788854451.49711566135046530.529109428812.62587374474
weight_gain_pnd789729.84095226035203312.356615508036581189
gstation_week789738.208813473470942.77688266371580641847
apgar_5min78978.8662783335443840.7499400449037321010
mother_age789739.857160947195149.2899707168817991154
father_age789739.9202228694440969.7917638793666361178
\n", "
" ], "text/plain": [ " 0 1 2 3 \\\n", "summary count mean stddev min \n", "weight_pnd 7897 7.107382278885445 1.4971156613504653 0.5291094288 \n", "weight_gain_pnd 7897 29.840952260352033 12.356615508036581 1 \n", "gstation_week 7897 38.20881347347094 2.7768826637158064 18 \n", "apgar_5min 7897 8.866278333544384 0.7499400449037321 0 \n", "mother_age 7897 39.85716094719514 9.289970716881799 11 \n", "father_age 7897 39.920222869444096 9.791763879366636 11 \n", "\n", " 4 \n", "summary max \n", "weight_pnd 12.62587374474 \n", "weight_gain_pnd 89 \n", "gstation_week 47 \n", "apgar_5min 10 \n", "mother_age 54 \n", "father_age 78 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This Spark3.0 setting, if true, will turn on Adaptive Query Execution (AQE), which will make use of the \n", "# runtime statistics to choose the most efficient query execution plan. It will speed up any joins that you\n", "# plan to use for data prep step.\n", "spark.conf.set(\"spark.sql.adaptive.enabled\", 'true')\n", "\n", "# Run a query in Spark SQL to ensure no NULL values exist.\n", "as_data.createOrReplaceTempView(\"natality\")\n", "\n", "sql_query = \"\"\"\n", "SELECT *\n", "from natality\n", "where weight_pnd is not null\n", "and mother_age is not null\n", "and father_age is not null\n", "and father_age < 80\n", "and gstation_week is not null\n", "and weight_gain_pnd < 90\n", "and apgar_5min != \"99\"\n", "and apgar_5min != \"88\"\n", "\"\"\"\n", "clean_data = spark.sql(sql_query)\n", "\n", "#Drop the Aerospike metadata from the dataset because its not required. \n", "#The metadata is added because we are inferring the schema as opposed to providing a strict schema\n", "columns_to_drop = ['__key','__digest','__expiry','__generation','__ttl' ]\n", "clean_data = clean_data.drop(*columns_to_drop)\n", "\n", "# dropping null values\n", "clean_data = clean_data.dropna()\n", "\n", "\n", "clean_data.cache()\n", "clean_data.show(5)\n", "\n", "#Descriptive Analysis of the data\n", "clean_data.describe().toPandas().transpose()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 Visualize Data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Apgar Score: Scores of 7 and above are generally normal; 4 to 6, fairly low; and 3 and below are generally regarded as critically low and cause for immediate resuscitative efforts.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import math\n", "\n", "\n", "pdf = clean_data.toPandas()\n", "\n", "#Histogram - Father Age\n", "pdf[['father_age']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('Fathers Age (years)',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()\n", "\n", "'''\n", "pdf[['mother_age']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('Mothers Age (years)',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()\n", "'''\n", "\n", "pdf[['weight_pnd']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('Babys Weight (Pounds)',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()\n", "\n", "pdf[['gstation_week']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('Gestation (Weeks)',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()\n", "\n", "pdf[['weight_gain_pnd']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('mother’s weight gain during pregnancy',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()\n", "\n", "#Histogram - Apgar Score\n", "print(\"Apgar Score: Scores of 7 and above are generally normal; 4 to 6, fairly low; and 3 and below are generally \\\n", "regarded as critically low and cause for immediate resuscitative efforts.\")\n", "pdf[['apgar_5min']].plot(kind='hist',bins=10,rwidth=0.8)\n", "plt.xlabel('Apgar score',fontsize=12)\n", "plt.legend(loc=None)\n", "plt.style.use('seaborn-whitegrid')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Create Model\n", "\n", "**Steps used for model creation:**\n", "1. Split cleaned data into Training and Test sets\n", "2. Vectorize features on which the model will be trained\n", "3. Create a linear regression model (Choose any ML algorithm that provides the best fit for the given dataset)\n", "4. Train model (Although not shown here, you could use K-fold cross-validation and Grid Search to choose the best hyper-parameters for the model)\n", "5. Evaluate model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Define a function that collects the features of interest\n", "# (mother_age, father_age, and gestation_weeks) into a vector.\n", "# Package the vector in a tuple containing the label (`weight_pounds`) for that\n", "# row.## \n", "\n", "def vector_from_inputs(r):\n", " return (r[\"weight_pnd\"], Vectors.dense(float(r[\"mother_age\"]),\n", " float(r[\"father_age\"]),\n", " float(r[\"gstation_week\"]),\n", " float(r[\"weight_gain_pnd\"]),\n", " float(r[\"apgar_5min\"])))\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+---------------+-------------+----------+----------+----------+\n", "| weight_pnd|weight_gain_pnd|gstation_week|apgar_5min|mother_age|father_age|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "| 4.0565056208| 50| 33| 9| 44| 41|\n", "| 4.68702769012| 70| 36| 9| 44| 40|\n", "| 4.87442061282| 23| 33| 9| 43| 46|\n", "|6.1244416383599996| 20| 37| 9| 44| 41|\n", "|6.2501051276999995| 12| 38| 9| 44| 45|\n", "| 6.56316153974| 40| 38| 9| 47| 45|\n", "| 6.7681914434| 33| 39| 10| 47| 45|\n", "| 6.87621795178| 19| 38| 9| 44| 46|\n", "| 7.06361087448| 26| 39| 9| 42| 28|\n", "| 7.1099079495| 35| 39| 10| 43| 61|\n", "| 7.24879917456| 40| 37| 9| 44| 44|\n", "| 7.5398093604| 38| 39| 9| 42| 41|\n", "| 7.5618555866| 50| 38| 9| 42| 35|\n", "| 7.7492485093| 40| 38| 9| 44| 48|\n", "| 7.87491199864| 59| 41| 9| 43| 46|\n", "| 8.18796841068| 22| 40| 9| 42| 34|\n", "| 9.31232594688| 28| 41| 9| 45| 44|\n", "| 4.5856150496| 23| 36| 9| 42| 43|\n", "| 5.1257475915| 25| 36| 9| 54| 54|\n", "| 5.3131405142| 55| 36| 9| 47| 45|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "only showing top 20 rows\n", "\n", "(5499, 6)\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "| weight_pnd|weight_gain_pnd|gstation_week|apgar_5min|mother_age|father_age|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "| 3.62439958728| 50| 35| 9| 42| 37|\n", "| 5.3351867404| 6| 38| 9| 43| 48|\n", "| 6.8122838958| 24| 39| 9| 42| 36|\n", "| 6.9776305923| 27| 39| 9| 46| 42|\n", "| 7.06361087448| 49| 38| 9| 14| 18|\n", "| 7.3634395508| 25| 37| 9| 14| 18|\n", "| 7.4075320032| 18| 38| 9| 45| 45|\n", "| 7.68751907594| 25| 38| 10| 42| 49|\n", "| 3.09088091324| 42| 32| 9| 43| 46|\n", "| 5.62619692624| 24| 39| 9| 44| 50|\n", "|6.4992274837599995| 20| 39| 9| 42| 47|\n", "|6.5918216337999995| 63| 35| 9| 42| 38|\n", "| 6.686620406459999| 36| 38| 10| 14| 17|\n", "| 6.6910296517| 37| 40| 9| 42| 42|\n", "| 6.8122838958| 13| 35| 9| 14| 15|\n", "| 7.1870697412| 40| 36| 8| 14| 15|\n", "| 7.4075320032| 19| 40| 9| 43| 45|\n", "| 7.4736706818| 41| 37| 9| 43| 53|\n", "| 7.62578964258| 35| 38| 8| 43| 46|\n", "| 7.62578964258| 39| 39| 9| 42| 37|\n", "+------------------+---------------+-------------+----------+----------+----------+\n", "only showing top 20 rows\n", "\n", "(2398, 6)\n" ] } ], "source": [ "#Split that data 70% training and 30% Evaluation data\n", "train, test = clean_data.randomSplit([0.7, 0.3])\n", "\n", "#Check the shape of the data\n", "train.show()\n", "print((train.count(), len(train.columns)))\n", "test.show()\n", "print((test.count(), len(test.columns)))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Coefficients:[0.00858931617782676,0.0008477851947958541,0.27948866120791893,0.009329081045860402,0.18817058385589935]\n", "Intercept:-5.893364345930709\n", "R^2:0.3970187134779115\n", "+--------------------+\n", "| residuals|\n", "+--------------------+\n", "| -1.845934264937739|\n", "| -2.2396120149639067|\n", "| -0.7717836944756593|\n", "| -0.6160804608336026|\n", "| -0.6986641251138215|\n", "| -0.672589930891391|\n", "| -0.8699157049741881|\n", "|-0.13870265354963962|\n", "|-0.26366319351660383|\n", "| -0.5260646593713352|\n", "| 0.3191520988648042|\n", "| 0.08956511232072462|\n", "| 0.28423773834709554|\n", "| 0.5367216316177004|\n", "|-0.34304851596998454|\n", "| 0.613435294490146|\n", "| 1.3680838827256254|\n", "| -1.887922569557201|\n", "| -1.4788456210255978|\n", "| -1.5035698497034602|\n", "+--------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# Create an input DataFrame for Spark ML using the above function.\n", "training_data = train.rdd.map(vector_from_inputs).toDF([\"label\",\n", " \"features\"])\n", " \n", "# Construct a new LinearRegression object and fit the training data.\n", "lr = LinearRegression(maxIter=5, regParam=0.2, solver=\"normal\")\n", "\n", "#Voila! your first model using Spark ML is trained\n", "model = lr.fit(training_data)\n", "\n", "# Print the model summary.\n", "print(\"Coefficients:\" + str(model.coefficients))\n", "print(\"Intercept:\" + str(model.intercept))\n", "print(\"R^2:\" + str(model.summary.r2))\n", "model.summary.residuals.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate Model" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+--------------------+\n", "| label| features|\n", "+------------------+--------------------+\n", "| 3.62439958728|[42.0,37.0,35.0,5...|\n", "| 5.3351867404|[43.0,48.0,38.0,6...|\n", "| 6.8122838958|[42.0,36.0,39.0,2...|\n", "| 6.9776305923|[46.0,42.0,39.0,2...|\n", "| 7.06361087448|[14.0,18.0,38.0,4...|\n", "| 7.3634395508|[14.0,18.0,37.0,2...|\n", "| 7.4075320032|[45.0,45.0,38.0,1...|\n", "| 7.68751907594|[42.0,49.0,38.0,2...|\n", "| 3.09088091324|[43.0,46.0,32.0,4...|\n", "| 5.62619692624|[44.0,50.0,39.0,2...|\n", "|6.4992274837599995|[42.0,47.0,39.0,2...|\n", "|6.5918216337999995|[42.0,38.0,35.0,6...|\n", "| 6.686620406459999|[14.0,17.0,38.0,3...|\n", "| 6.6910296517|[42.0,42.0,40.0,3...|\n", "| 6.8122838958|[14.0,15.0,35.0,1...|\n", "| 7.1870697412|[14.0,15.0,36.0,4...|\n", "| 7.4075320032|[43.0,45.0,40.0,1...|\n", "| 7.4736706818|[43.0,53.0,37.0,4...|\n", "| 7.62578964258|[43.0,46.0,38.0,3...|\n", "| 7.62578964258|[42.0,37.0,39.0,3...|\n", "+------------------+--------------------+\n", "only showing top 20 rows\n", "\n", "MAE: 0.9094828902906563\n", "RMSE: 1.1665322992147173\n", "R-squared value: 0.378390902740944\n" ] } ], "source": [ "eval_data = test.rdd.map(vector_from_inputs).toDF([\"label\",\n", " \"features\"])\n", "\n", "eval_data.show()\n", "\n", "evaluation_summary = model.evaluate(eval_data)\n", "\n", "\n", "print(\"MAE:\", evaluation_summary.meanAbsoluteError)\n", "print(\"RMSE:\", evaluation_summary.rootMeanSquaredError)\n", "print(\"R-squared value:\", evaluation_summary.r2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Batch Prediction" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+--------------------+-----------------+\n", "| label| features| prediction|\n", "+------------------+--------------------+-----------------+\n", "| 3.62439958728|[42.0,37.0,35.0,5...|6.440847435018738|\n", "| 5.3351867404|[43.0,48.0,38.0,6...| 6.88674880594522|\n", "| 6.8122838958|[42.0,36.0,39.0,2...|7.315398187463249|\n", "| 6.9776305923|[46.0,42.0,39.0,2...|7.382829406480911|\n", "| 7.06361087448|[14.0,18.0,38.0,4...|7.013375565916365|\n", "| 7.3634395508|[14.0,18.0,37.0,2...|6.509988959607797|\n", "| 7.4075320032|[45.0,45.0,38.0,1...|7.013333055266812|\n", "| 7.68751907594|[42.0,49.0,38.0,2...|7.244430398689434|\n", "| 3.09088091324|[43.0,46.0,32.0,4...|5.543968185959089|\n", "| 5.62619692624|[44.0,50.0,39.0,2...|7.344445812546044|\n", "|6.4992274837599995|[42.0,47.0,39.0,2...|7.287407500422561|\n", "|6.5918216337999995|[42.0,38.0,35.0,6...| 6.56297327380972|\n", "| 6.686620406459999|[14.0,17.0,38.0,3...|7.079420310981281|\n", "| 6.6910296517|[42.0,42.0,40.0,3...|7.721251613436126|\n", "| 6.8122838958|[14.0,15.0,35.0,1...|5.836519309057246|\n", "| 7.1870697412|[14.0,15.0,36.0,4...|6.179722574647495|\n", "| 7.4075320032|[43.0,45.0,40.0,1...|7.564460826372854|\n", "| 7.4736706818|[43.0,53.0,37.0,4...|6.938016907316393|\n", "| 7.62578964258|[43.0,46.0,38.0,3...| 6.96742600202968|\n", "| 7.62578964258|[42.0,37.0,39.0,3...|7.456182188345951|\n", "+------------------+--------------------+-----------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "#eval_data contains the records (ideally production) that you'd like to use for the prediction\n", "\n", "predictions = model.transform(eval_data)\n", "predictions.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compare the labels and the predictions, they should ideally match up for an accurate model. Label is the actual weight of the baby and prediction is the predicated weight" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving the Predictions to Aerospike for ML Application's consumption" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+--------------------+-----------------+----------+\n", "| label| features| prediction| _id|\n", "+------------------+--------------------+-----------------+----------+\n", "| 3.62439958728|[42.0,37.0,35.0,5...|6.440847435018738| 0|\n", "| 5.3351867404|[43.0,48.0,38.0,6...| 6.88674880594522| 1|\n", "| 6.8122838958|[42.0,36.0,39.0,2...|7.315398187463249| 2|\n", "| 6.9776305923|[46.0,42.0,39.0,2...|7.382829406480911| 3|\n", "| 7.06361087448|[14.0,18.0,38.0,4...|7.013375565916365| 4|\n", "| 7.3634395508|[14.0,18.0,37.0,2...|6.509988959607797| 5|\n", "| 7.4075320032|[45.0,45.0,38.0,1...|7.013333055266812| 6|\n", "| 7.68751907594|[42.0,49.0,38.0,2...|7.244430398689434| 7|\n", "| 3.09088091324|[43.0,46.0,32.0,4...|5.543968185959089|8589934592|\n", "| 5.62619692624|[44.0,50.0,39.0,2...|7.344445812546044|8589934593|\n", "|6.4992274837599995|[42.0,47.0,39.0,2...|7.287407500422561|8589934594|\n", "|6.5918216337999995|[42.0,38.0,35.0,6...| 6.56297327380972|8589934595|\n", "| 6.686620406459999|[14.0,17.0,38.0,3...|7.079420310981281|8589934596|\n", "| 6.6910296517|[42.0,42.0,40.0,3...|7.721251613436126|8589934597|\n", "| 6.8122838958|[14.0,15.0,35.0,1...|5.836519309057246|8589934598|\n", "| 7.1870697412|[14.0,15.0,36.0,4...|6.179722574647495|8589934599|\n", "| 7.4075320032|[43.0,45.0,40.0,1...|7.564460826372854|8589934600|\n", "| 7.4736706818|[43.0,53.0,37.0,4...|6.938016907316393|8589934601|\n", "| 7.62578964258|[43.0,46.0,38.0,3...| 6.96742600202968|8589934602|\n", "| 7.62578964258|[42.0,37.0,39.0,3...|7.456182188345951|8589934603|\n", "+------------------+--------------------+-----------------+----------+\n", "only showing top 20 rows\n", "\n", "#records: 2398\n" ] } ], "source": [ "# Aerospike is a key/value database, hence a key is needed to store the predictions into the database. Hence we need \n", "# to add the _id column to the predictions using SparkSQL\n", "\n", "predictions.createOrReplaceTempView(\"predict_view\")\n", " \n", "sql_query = \"\"\"\n", "SELECT *, monotonically_increasing_id() as _id\n", "from predict_view\n", "\"\"\"\n", "predict_df = spark.sql(sql_query)\n", "predict_df.show()\n", "print(\"#records:\", predict_df.count())" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Now we are good to write the Predictions to Aerospike\n", "predict_df \\\n", ".write \\\n", ".mode('overwrite') \\\n", ".format(\"aerospike\") \\\n", ".option(\"aerospike.writeset\", \"predictions\")\\\n", ".option(\"aerospike.updateByKey\", \"_id\") \\\n", ".save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### You can verify that data is written to Aerospike by using either [AQL](https://www.aerospike.com/docs/tools/aql/data_management.html) or the [Aerospike Data Browser](https://github.com/aerospike/aerospike-data-browser)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 6 - Deploy\n", "### Here are a few options:\n", "1. Save the model to a PMML file by converting it using Jpmml/[pyspark2pmml](https://github.com/jpmml/pyspark2pmml) and load it into your production enviornment for inference.\n", "2. Use Aerospike as an [edge database for high velocity ingestion](https://medium.com/aerospike-developer-blog/add-horsepower-to-ai-ml-pipeline-15ca42a10982) for your inference pipline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }