{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Online prediction with BigQuery ML\n", "\n", "ML.Predict in BigQuery ML is primarily meant for batch predictions. What if you want to build a web application to provide online predictions? Here, I show the basic Python code to do online prediction. You can wrap this code in AppEngine or other web framework/toolkit to provide scalable, fast, online predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!pip install google-cloud # Reset Session after installing" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "PROJECT = 'cloud-training-demos' # change as needed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a model\n", "\n", "Let's start by creating a simple prediction model to [predict arrival delays of aircraft](https://towardsdatascience.com/how-to-train-and-predict-regression-and-classification-ml-models-using-only-sql-using-bigquery-ml-f219b180b947). I'll use this to illustrate the process.\n", "\n", "First, if necessary, create the BigQuery dataset that will store the output of the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%bash\n", "bq mk -d flights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, do a \"CREATE MODEL\". This will take about 5 minutes." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "Done" ], "text/plain": [ "QueryResultsTable job_1EM_XyCZV-9AGbXua8W72EfKMZ1g" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%bq query\n", "\n", "CREATE OR REPLACE MODEL flights.arrdelay\n", "OPTIONS\n", " (model_type='linear_reg', input_label_cols=['arr_delay']) AS\n", "SELECT\n", " arr_delay,\n", " carrier,\n", " origin,\n", " dest,\n", " dep_delay,\n", " taxi_out,\n", " distance\n", "FROM\n", " `cloud-training-demos.flights.tzcorr`\n", "WHERE\n", " arr_delay IS NOT NULL" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch prediction with model\n", "\n", "Once you have a trained model, batch prediction can be done within BigQuery itself.\n", "\n", "For example, to find the predicted arrival delays for a flight from DFW to LAX for a range of departure delays" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
predicted_arr_delaycarrierorigindestdep_delaytaxi_outdistance
-8.31085723761AADFWLAX-3181235
-7.32016373789AADFWLAX-2181235
-6.32947023816AADFWLAX-1181235
-5.33877673844AADFWLAX0181235
-4.34808323871AADFWLAX1181235
-3.35738973899AADFWLAX2181235
-2.36669623926AADFWLAX3181235
-1.37600273954AADFWLAX4181235
-0.385309239811AADFWLAX5181235
0.605384259914AADFWLAX6181235
1.59607775964AADFWLAX7181235
2.58677125936AADFWLAX8181235
3.57746475909AADFWLAX9181235
4.56815825881AADFWLAX10181235
\n", "
(rows: 14, time: 1.2s, 16KB processed, job: job_WVW0Xk2F-GX7NphFVwFZXwenWOIM)
\n", " \n", " \n", " " ], "text/plain": [ "QueryResultsTable job_WVW0Xk2F-GX7NphFVwFZXwenWOIM" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%bq query\n", "SELECT * FROM ml.PREDICT(MODEL flights.arrdelay, (\n", "SELECT \n", " 'AA' as carrier,\n", " 'DFW' as origin,\n", " 'LAX' as dest,\n", " dep_delay,\n", " 18 as taxi_out,\n", " 1235 as distance\n", "FROM\n", " UNNEST(GENERATE_ARRAY(-3, 10)) as dep_delay\n", "))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Online prediction in Python\n", "\n", "The batch prediction technique above can not be used for online prediction though. Typical BigQuery queries have a latency of 1-2 seconds and that is too high for a web application.\n", "\n", "For online prediction, it is better to grab the weights and do the computation yourself." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
inputinput_weight
carrier 
origin 
dest 
dep_delay36.5569545237
taxi_out8.15557957221
distance-1.88324519311
__INTERCEPT__1.09017737502
\n", "
(rows: 7, time: 1.3s, 16KB processed, job: job_5YCmyG8UY5qfA8aNa9hy_RQBgzI1)
\n", " \n", " \n", " " ], "text/plain": [ "QueryResultsTable job_5YCmyG8UY5qfA8aNa9hy_RQBgzI1" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%bq query\n", "SELECT\n", " processed_input AS input,\n", " model.weight AS input_weight\n", "FROM\n", " ml.WEIGHTS(MODEL flights.arrdelay) AS model" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
inputinput_weightcategory_namecategory_weight
carrier B62.41360895323
carrier EV2.27889474393
carrier AA-0.0548843604154
carrier US1.82939081237
carrier DL-1.28516116856
carrier F95.76556618145
carrier WN2.18336049823
carrier HA6.17381375535
carrier VX5.05422774942
carrier AS2.022529139
carrier UA-1.9492217572
carrier OO1.11110679975
carrier NK7.09579009395
carrier MQ1.19620456586
origin WRG8.99721250183
origin MSY4.96015143256
origin AMA6.20849983281
origin SAF4.19530797642
origin GUC3.60929742614
origin MKG2.23621423776
origin CHO3.63391263959
origin ESC1.01325898826
origin MCI4.16187598831
origin GSP3.06866659933
origin GJT3.6366332815
\n", "
(rows: 658, time: 1.1s, 16KB processed, job: job_h3LTu2B6REiDJWXHslEMZ9MyIjH_)
\n", " \n", " \n", " " ], "text/plain": [ "QueryResultsTable job_h3LTu2B6REiDJWXHslEMZ9MyIjH_" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%bq query\n", "SELECT\n", " processed_input AS input,\n", " model.weight AS input_weight,\n", " category.category AS category_name,\n", " category.weight AS category_weight\n", "FROM\n", " ml.WEIGHTS(MODEL flights.arrdelay) AS model,\n", " UNNEST(category_weights) AS category" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's how to do that in Python. \n", " \n", "p.s. I'm assuming that you are in an environment where you are already authenticated with Google Cloud. If not, see [this article on how to access BigQuery using private keys](https://towardsdatascience.com/how-to-enable-pandas-to-access-bigquery-from-a-service-account-205a216f0f68)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def query_to_dataframe(query):\n", " import pandas as pd\n", " import pkgutil\n", " privatekey = None # pkgutil.get_data(KEYDIR, 'privatekey.json')\n", " return pd.read_gbq(query,\n", " project_id=PROJECT,\n", " dialect='standard',\n", " private_key=privatekey)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You need to pull 3 pieces of information:\n", "* The weights for each of your numeric columns\n", "* The scaling for each of your numeric columns\n", "* The vocabulary and weights for each of your categorical columns\n", "\n", "I pull them using three separate BigQuery queries below" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requesting query... ok.\n", "Job ID: 4f89a94d-11a1-4cfa-bbb2-f046ffe63df1\n", "Query running...\n", "Query done.\n", "Cache hit.\n", "\n", "Retrieving results...\n", "Got 7 rows.\n", "\n", "Total time taken 1.02 s.\n", "Finished at 2018-08-12 04:20:46.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
inputinput_weight
3dep_delay36.556955
4taxi_out8.155580
5distance-1.883245
6__INTERCEPT__1.090177
\n", "
" ], "text/plain": [ " input input_weight\n", "3 dep_delay 36.556955\n", "4 taxi_out 8.155580\n", "5 distance -1.883245\n", "6 __INTERCEPT__ 1.090177" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numeric_query = \"\"\"\n", "SELECT\n", " processed_input AS input,\n", " model.weight AS input_weight\n", "FROM\n", " ml.WEIGHTS(MODEL flights.arrdelay) AS model\n", "\"\"\"\n", "numeric_weights = query_to_dataframe(numeric_query).dropna()\n", "numeric_weights" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requesting query... ok.\n", "Job ID: a3d9e2fb-f940-48ba-a5aa-2e7645d82db1\n", "Query running...\n", "Query done.\n", "Processed: 0.0 B Billed: 0.0 B\n", "Standard price: $0.00 USD\n", "\n", "Retrieving results...\n", "Got 6 rows.\n", "\n", "Total time taken 2.63 s.\n", "Finished at 2018-08-12 05:00:15.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
inputminmaxmeanstddev
3dep_delay-82.01988.09.16909536.900368
4taxi_out1.0225.016.0998788.901454
5distance31.04983.0825.795621608.756947
\n", "
" ], "text/plain": [ " input min max mean stddev\n", "3 dep_delay -82.0 1988.0 9.169095 36.900368\n", "4 taxi_out 1.0 225.0 16.099878 8.901454\n", "5 distance 31.0 4983.0 825.795621 608.756947" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scaling_query = \"\"\"\n", "SELECT\n", " input, min, max, mean, stddev\n", "FROM\n", " ml.FEATURE_INFO(MODEL flights.arrdelay) AS model\n", "\"\"\"\n", "scaling_df = query_to_dataframe(scaling_query).dropna()\n", "scaling_df" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requesting query... ok.\n", "Job ID: b662c5fd-32af-4624-88db-d7aeed97300f\n", "Query running...\n", "Query done.\n", "Cache hit.\n", "\n", "Retrieving results...\n", "Got 658 rows.\n", "\n", "Total time taken 1.03 s.\n", "Finished at 2018-08-12 04:17:36.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
inputinput_weightcategory_namecategory_weight
0carrierNaNB62.413609
1carrierNaNEV2.278895
2carrierNaNAA-0.054884
3carrierNaNUS1.829391
4carrierNaNDL-1.285161
\n", "
" ], "text/plain": [ " input input_weight category_name category_weight\n", "0 carrier NaN B6 2.413609\n", "1 carrier NaN EV 2.278895\n", "2 carrier NaN AA -0.054884\n", "3 carrier NaN US 1.829391\n", "4 carrier NaN DL -1.285161" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_query = \"\"\"\n", "SELECT\n", " processed_input AS input,\n", " model.weight AS input_weight,\n", " category.category AS category_name,\n", " category.weight AS category_weight\n", "FROM\n", " ml.WEIGHTS(MODEL flights.arrdelay) AS model,\n", " UNNEST(category_weights) AS category\n", "\"\"\"\n", "categorical_weights = query_to_dataframe(categorical_query)\n", "categorical_weights.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With the three pieces of information in-place, you can simply do the math for linear regression:" ] }, { "cell_type": "code", "execution_count": 87, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_prediction(rowdict, numeric_weights, scaling_df, categorical_weights):\n", " input_values = rowdict\n", " # numeric inputs\n", " pred = 0\n", " for column_name in numeric_weights['input'].unique():\n", " wt = numeric_weights[ numeric_weights['input'] == column_name ]['input_weight'].values[0]\n", " if column_name != '__INTERCEPT__':\n", " #minv = scaling_df[ scaling_df['input'] == column_name ]['min'].values[0]\n", " #maxv = scaling_df[ scaling_df['input'] == column_name ]['max'].values[0]\n", " #scaled_value = (input_values[column_name] - minv)/(maxv - minv)\n", " meanv = scaling_df[ scaling_df['input'] == column_name ]['mean'].values[0]\n", " stddev = scaling_df[ scaling_df['input'] == column_name ]['stddev'].values[0]\n", " scaled_value = (input_values[column_name] - meanv)/stddev\n", " else:\n", " scaled_value = 1.0\n", " contrib = wt * scaled_value\n", " print('col={} wt={} scaled_value={} contrib={}'.format(column_name, wt, scaled_value, contrib))\n", " pred = pred + contrib\n", " # categorical inputs\n", " for column_name in categorical_weights['input'].unique():\n", " category_weights = categorical_weights[ categorical_weights['input'] == column_name ]\n", " wt = category_weights[ category_weights['category_name'] == input_values[column_name] ]['category_weight'].values[0]\n", " print('col={} wt={} value={} contrib={}'.format(column_name, wt, input_values[column_name], wt))\n", " pred = pred + wt\n", " return pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example of the prediction code in action:" ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "col=dep_delay wt=36.5569545237 scaled_value=-0.329782492822 contrib=-12.0558435928\n", "col=taxi_out wt=8.15557957221 scaled_value=0.213461991601 contrib=1.74090625815\n", "col=distance wt=-1.88324519311 scaled_value=0.672196648431 contrib=-1.26591110699\n", "col=__INTERCEPT__ wt=1.09017737502 scaled_value=1.0 contrib=1.09017737502\n", "col=carrier wt=-0.0548843604154 value=AA contrib=-0.0548843604154\n", "col=origin wt=0.966535564037 value=DFW contrib=0.966535564037\n", "col=dest wt=1.26816262538 value=LAX contrib=1.26816262538\n", "-8.310857237611458\n" ] } ], "source": [ "rowdict = {\n", " 'carrier' : 'AA',\n", " 'origin': 'DFW',\n", " 'dest': 'LAX',\n", " 'dep_delay': -3,\n", " 'taxi_out': 18,\n", " 'distance': 1235\n", "}\n", "print(compute_prediction(rowdict, numeric_weights, scaling_df, categorical_weights))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, this matches the batch prediction value, telling us that we got the computation correct.\n", "\n", "Now, all that we have to do is to wrap up the code into a web application to get online prediction.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2017 Google Inc. Licensed under the Apache License, Version 2.0 (the \\\"License\\\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \\\"AS IS\\\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" } }, "nbformat": 4, "nbformat_minor": 2 }