{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/pandas/__init__.py:7: DeprecationWarning: bad escape \\s\n",
" from pandas import hashtable, tslib, lib\n"
]
}
],
"source": [
"import h2o\n",
"import pandas\n",
"import pprint\n",
"import operator\n",
"import matplotlib\n",
"from h2o.estimators.glm import H2OGeneralizedLinearEstimator\n",
"from h2o.estimators.gbm import H2OGradientBoostingEstimator\n",
"from h2o.estimators.random_forest import H2ORandomForestEstimator\n",
"from h2o.estimators.deeplearning import H2ODeepLearningEstimator\n",
"from tabulate import tabulate"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"
H2O cluster uptime: | \n",
"11 seconds 120 milliseconds |
\n",
"H2O cluster version: | \n",
"3.7.0.99999 |
\n",
"H2O cluster name: | \n",
"spIdea |
\n",
"H2O cluster total nodes: | \n",
"1 |
\n",
"H2O cluster total free memory: | \n",
"12.44 GB |
\n",
"H2O cluster total cores: | \n",
"8 |
\n",
"H2O cluster allowed cores: | \n",
"8 |
\n",
"H2O cluster healthy: | \n",
"True |
\n",
"H2O Connection ip: | \n",
"127.0.0.1 |
\n",
"H2O Connection port: | \n",
"54321 |
\n",
"H2O Connection proxy: | \n",
"None |
\n",
"Python Version: | \n",
"3.5.0 |
"
],
"text/plain": [
"------------------------------ ---------------------------\n",
"H2O cluster uptime: 11 seconds 120 milliseconds\n",
"H2O cluster version: 3.7.0.99999\n",
"H2O cluster name: spIdea\n",
"H2O cluster total nodes: 1\n",
"H2O cluster total free memory: 12.44 GB\n",
"H2O cluster total cores: 8\n",
"H2O cluster allowed cores: 8\n",
"H2O cluster healthy: True\n",
"H2O Connection ip: 127.0.0.1\n",
"H2O Connection port: 54321\n",
"H2O Connection proxy:\n",
"Python Version: 3.5.0\n",
"------------------------------ ---------------------------"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Connect to a cluster\n",
"h2o.init()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# set this to True if interactive (matplotlib) plots are desired\n",
"interactive = False\n",
"if not interactive: matplotlib.use('Agg', warn=False)\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Import and Parse airlines data\n",
"\n",
"Parse Progress: [##################################################] 100%\n",
"Rows:43,978 Cols:31\n",
"\n",
"Chunk compression summary: \n"
]
},
{
"data": {
"text/html": [
"chunk_type | \n",
"chunk_name | \n",
"count | \n",
"count_percentage | \n",
"size | \n",
"size_percentage |
\n",
"C0L | \n",
"Constant Integers | \n",
"10 | \n",
"5.376344 | \n",
" 800 B | \n",
"0.0504024 |
\n",
"C0D | \n",
"Constant Reals | \n",
"23 | \n",
"12.365591 | \n",
" 1.8 KB | \n",
"0.1159254 |
\n",
"CBS | \n",
"Bits | \n",
"2 | \n",
"1.0752689 | \n",
" 2.0 KB | \n",
"0.1272030 |
\n",
"CX0 | \n",
"Sparse Bits | \n",
"10 | \n",
"5.376344 | \n",
" 1.9 KB | \n",
"0.1247459 |
\n",
"C1 | \n",
"1-Byte Integers | \n",
"40 | \n",
"21.505377 | \n",
" 287.8 KB | \n",
"18.564957 |
\n",
"C1N | \n",
"1-Byte Integers (w/o NAs) | \n",
"19 | \n",
"10.215054 | \n",
" 133.1 KB | \n",
"8.58617 |
\n",
"C1S | \n",
"1-Byte Fractions | \n",
"6 | \n",
"3.2258065 | \n",
" 43.4 KB | \n",
"2.8024976 |
\n",
"C2 | \n",
"2-Byte Integers | \n",
"76 | \n",
"40.860214 | \n",
" 1.1 MB | \n",
"69.628105 |
"
],
"text/plain": [
"chunk_type chunk_name count count_percentage size size_percentage\n",
"------------ ------------------------- ------- ------------------ -------- -----------------\n",
"C0L Constant Integers 10 5.37634 800 B 0.0504024\n",
"C0D Constant Reals 23 12.3656 1.8 KB 0.115925\n",
"CBS Bits 2 1.07527 2.0 KB 0.127203\n",
"CX0 Sparse Bits 10 5.37634 1.9 KB 0.124746\n",
"C1 1-Byte Integers 40 21.5054 287.8 KB 18.565\n",
"C1N 1-Byte Integers (w/o NAs) 19 10.2151 133.1 KB 8.58617\n",
"C1S 1-Byte Fractions 6 3.22581 43.4 KB 2.8025\n",
"C2 2-Byte Integers 76 40.8602 1.1 MB 69.6281"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Frame distribution summary: \n"
]
},
{
"data": {
"text/html": [
" | \n",
"size | \n",
"number_of_rows | \n",
"number_of_chunks_per_column | \n",
"number_of_chunks |
\n",
"172.16.2.84:54321 | \n",
" 1.5 MB | \n",
"43978.0 | \n",
"6.0 | \n",
"186.0 |
\n",
"mean | \n",
" 1.5 MB | \n",
"43978.0 | \n",
"6.0 | \n",
"186.0 |
\n",
"min | \n",
" 1.5 MB | \n",
"43978.0 | \n",
"6.0 | \n",
"186.0 |
\n",
"max | \n",
" 1.5 MB | \n",
"43978.0 | \n",
"6.0 | \n",
"186.0 |
\n",
"stddev | \n",
" 0 B | \n",
"0.0 | \n",
"0.0 | \n",
"0.0 |
\n",
"total | \n",
" 1.5 MB | \n",
"43978.0 | \n",
"6.0 | \n",
"186.0 |
"
],
"text/plain": [
" size number_of_rows number_of_chunks_per_column number_of_chunks\n",
"----------------- ------ ---------------- ----------------------------- ------------------\n",
"172.16.2.84:54321 1.5 MB 43978 6 186\n",
"mean 1.5 MB 43978 6 186\n",
"min 1.5 MB 43978 6 186\n",
"max 1.5 MB 43978 6 186\n",
"stddev 0 B 0 0 0\n",
"total 1.5 MB 43978 6 186"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"text/html": [
"\n",
" | Year | Month | DayofMonth | DayOfWeek | DepTime | CRSDepTime | ArrTime | CRSArrTime | UniqueCarrier | FlightNum | TailNum | ActualElapsedTime | CRSElapsedTime | AirTime | ArrDelay | DepDelay | Origin | Dest | Distance | TaxiIn | TaxiOut | Cancelled | CancellationCode | Diverted | CarrierDelay | WeatherDelay | NASDelay | SecurityDelay | LateAircraftDelay | IsArrDelayed | IsDepDelayed |
\n",
"type | int | int | int | int | int | int | int | int | enum | int | enum | int | int | int | int | int | enum | enum | int | int | int | int | enum | int | int | int | int | int | int | enum | enum |
\n",
"mins | 1987.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 16.0 | 17.0 | 14.0 | -63.0 | -16.0 | 0.0 | 0.0 | 11.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
\n",
"mean | 1997.5 | 1.409090909090909 | 14.601073263904679 | 3.820614852880991 | 1345.8466613820763 | 1313.2228614307164 | 1504.6341303788884 | 1485.289167310927 | NaN | 818.8429896766577 | NaN | 124.8145291354043 | 125.02156260661899 | 114.31611109078277 | 9.317111936984313 | 10.0073906556001 | NaN | NaN | 730.1821905650501 | 5.381368059530628 | 14.168634184732056 | 0.024694165264450407 | NaN | 0.0024785119832643593 | 4.047800291055627 | 0.2893764692712417 | 4.855031904175534 | 0.017015560282100096 | 7.620060450016789 | 0.555755150302424 | 0.5250579835372226 |
\n",
"maxs | 2008.0 | 10.0 | 31.0 | 7.0 | 2400.0 | 2359.0 | 2400.0 | 2359.0 | 9.0 | 3949.0 | 3500.0 | 475.0 | 437.0 | 402.0 | 475.0 | 473.0 | 131.0 | 133.0 | 3365.0 | 128.0 | 254.0 | 1.0 | 3.0 | 1.0 | 369.0 | 201.0 | 323.0 | 14.0 | 373.0 | 1.0 | 1.0 |
\n",
"sigma | 6.344360901711177 | 1.874711371343963 | 9.175790425861443 | 1.9050131191328936 | 465.340899124234 | 476.25113999259946 | 484.34748790351614 | 492.75043412270094 | NaN | 777.4043691636349 | NaN | 73.97444166059017 | 73.4015946300093 | 69.63632951506109 | 29.840221962414848 | 26.438809042916454 | NaN | NaN | 578.438008230424 | 4.201979939864828 | 9.905085747204327 | 0.15519314135784237 | NaN | 0.049723487218862286 | 16.20572990448423 | 4.416779898734124 | 18.619776221475682 | 0.40394018210151184 | 23.487565874106213 | 0.4968872883428837 | 0.49937738031758017 |
\n",
"zeros | 0 | 0 | 0 | 0 | 0 | 569 | 0 | 569 | 724 | 0 | 2 | 0 | 0 | -8878 | 1514 | 6393 | 59 | 172 | 0 | -8255 | -8321 | 42892 | 81 | 43869 | -23296 | -21800 | -23252 | -21726 | -23500 | 19537 | 20887 |
\n",
"missing | 0 | 0 | 0 | 0 | 1086 | 0 | 1195 | 0 | 0 | 0 | 32 | 1195 | 13 | 16649 | 1195 | 1086 | 0 | 0 | 35 | 16026 | 16024 | 0 | 9774 | 0 | 35045 | 35045 | 35045 | 35045 | 35045 | 0 | 0 |
\n",
"0 | 1987.0 | 10.0 | 14.0 | 3.0 | 741.0 | 730.0 | 912.0 | 849.0 | PS | 1451.0 | NA | 91.0 | 79.0 | nan | 23.0 | 11.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | YES |
\n",
"1 | 1987.0 | 10.0 | 15.0 | 4.0 | 729.0 | 730.0 | 903.0 | 849.0 | PS | 1451.0 | NA | 94.0 | 79.0 | nan | 14.0 | -1.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | NO |
\n",
"2 | 1987.0 | 10.0 | 17.0 | 6.0 | 741.0 | 730.0 | 918.0 | 849.0 | PS | 1451.0 | NA | 97.0 | 79.0 | nan | 29.0 | 11.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | YES |
\n",
"3 | 1987.0 | 10.0 | 18.0 | 7.0 | 729.0 | 730.0 | 847.0 | 849.0 | PS | 1451.0 | NA | 78.0 | 79.0 | nan | -2.0 | -1.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | NO | NO |
\n",
"4 | 1987.0 | 10.0 | 19.0 | 1.0 | 749.0 | 730.0 | 922.0 | 849.0 | PS | 1451.0 | NA | 93.0 | 79.0 | nan | 33.0 | 19.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | YES |
\n",
"5 | 1987.0 | 10.0 | 21.0 | 3.0 | 728.0 | 730.0 | 848.0 | 849.0 | PS | 1451.0 | NA | 80.0 | 79.0 | nan | -1.0 | -2.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | NO | NO |
\n",
"6 | 1987.0 | 10.0 | 22.0 | 4.0 | 728.0 | 730.0 | 852.0 | 849.0 | PS | 1451.0 | NA | 84.0 | 79.0 | nan | 3.0 | -2.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | NO |
\n",
"7 | 1987.0 | 10.0 | 23.0 | 5.0 | 731.0 | 730.0 | 902.0 | 849.0 | PS | 1451.0 | NA | 91.0 | 79.0 | nan | 13.0 | 1.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | YES |
\n",
"8 | 1987.0 | 10.0 | 24.0 | 6.0 | 744.0 | 730.0 | 908.0 | 849.0 | PS | 1451.0 | NA | 84.0 | 79.0 | nan | 19.0 | 14.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | YES |
\n",
"9 | 1987.0 | 10.0 | 25.0 | 7.0 | 729.0 | 730.0 | 851.0 | 849.0 | PS | 1451.0 | NA | 82.0 | 79.0 | nan | 2.0 | -1.0 | SAN | SFO | 447.0 | nan | nan | 0.0 | NA | 0.0 | nan | nan | nan | nan | nan | YES | NO |
\n",
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.\n",
"# air_path = [_locate(\"bigdata/laptop/airlines_all.05p.csv\")]\n",
"# air_path = [_locate(\"bigdata/laptop/flights-nyc/flights14.csv.zip\")]\n",
"air_path = [_locate(\"smalldata/airlines/allyears2k_headers.zip\")]\n",
"\n",
"# ----------\n",
"\n",
"# 1- Load data - 1 row per flight. Has columns showing the origin,\n",
"# destination, departure and arrival time, carrier information, and\n",
"# whether the flight was delayed.\n",
"print(\"Import and Parse airlines data\")\n",
"data = h2o.import_file(path=air_path)\n",
"data.describe()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"glm Model Build Progress: [##################################################] 100%\n"
]
}
],
"source": [
"# ----------\n",
"\n",
"# 2- Data exploration and munging. Generate scatter plots \n",
"# of various columns and plot fitted GLM model.\n",
"\n",
"# Function to fit a GLM model and plot the fitted (x,y) values\n",
"def scatter_plot(data, x, y, max_points = 1000, fit = True):\n",
" if(fit):\n",
" lr = H2OGeneralizedLinearEstimator(family = \"gaussian\")\n",
" lr.train(x=x, y=y, training_frame=data)\n",
" coeff = lr.coef()\n",
" df = data[[x,y]]\n",
" runif = df[y].runif()\n",
" df_subset = df[runif < float(max_points)/data.nrow]\n",
" df_py = h2o.as_list(df_subset)\n",
" \n",
" if(fit): h2o.remove(lr._id)\n",
"\n",
" # If x variable is string, generate box-and-whisker plot\n",
" if(df_py[x].dtype == \"object\"):\n",
" if interactive: df_py.boxplot(column = y, by = x)\n",
" # Otherwise, generate a scatter plot\n",
" else:\n",
" if interactive: df_py.plot(x = x, y = y, kind = \"scatter\")\n",
" \n",
" if(fit):\n",
" x_min = min(df_py[x])\n",
" x_max = max(df_py[x])\n",
" y_min = coeff[\"Intercept\"] + coeff[x]*x_min\n",
" y_max = coeff[\"Intercept\"] + coeff[x]*x_max\n",
" plt.plot([x_min, x_max], [y_min, y_max], \"k-\")\n",
" if interactive: plt.show()\n",
"\n",
"scatter_plot(data, \"Distance\", \"AirTime\", fit = True)\n",
"scatter_plot(data, \"UniqueCarrier\", \"ArrDelay\", max_points = 5000, fit = False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" Month | sum_Cancelled | nrow_Year |
\n",
" 1 | 1067 | 41979 |
\n",
" 10 | 19 | 1999 |
\n",
"
"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rows:2 Cols:3\n",
"\n",
"Chunk compression summary: \n"
]
},
{
"data": {
"text/html": [
"chunk_type | \n",
"chunk_name | \n",
"count | \n",
"count_percentage | \n",
"size | \n",
"size_percentage |
\n",
"C1N | \n",
"1-Byte Integers (w/o NAs) | \n",
"1 | \n",
"33.333336 | \n",
" 70 B | \n",
"30.434782 |
\n",
"C2 | \n",
"2-Byte Integers | \n",
"1 | \n",
"33.333336 | \n",
" 72 B | \n",
"31.304348 |
\n",
"C2S | \n",
"2-Byte Fractions | \n",
"1 | \n",
"33.333336 | \n",
" 88 B | \n",
"38.260868 |
"
],
"text/plain": [
"chunk_type chunk_name count count_percentage size size_percentage\n",
"------------ ------------------------- ------- ------------------ ------ -----------------\n",
"C1N 1-Byte Integers (w/o NAs) 1 33.3333 70 B 30.4348\n",
"C2 2-Byte Integers 1 33.3333 72 B 31.3043\n",
"C2S 2-Byte Fractions 1 33.3333 88 B 38.2609"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Frame distribution summary: \n"
]
},
{
"data": {
"text/html": [
" | \n",
"size | \n",
"number_of_rows | \n",
"number_of_chunks_per_column | \n",
"number_of_chunks |
\n",
"172.16.2.84:54321 | \n",
" 230 B | \n",
"2.0 | \n",
"1.0 | \n",
"3.0 |
\n",
"mean | \n",
" 230 B | \n",
"2.0 | \n",
"1.0 | \n",
"3.0 |
\n",
"min | \n",
" 230 B | \n",
"2.0 | \n",
"1.0 | \n",
"3.0 |
\n",
"max | \n",
" 230 B | \n",
"2.0 | \n",
"1.0 | \n",
"3.0 |
\n",
"stddev | \n",
" 0 B | \n",
"0.0 | \n",
"0.0 | \n",
"0.0 |
\n",
"total | \n",
" 230 B | \n",
"2.0 | \n",
"1.0 | \n",
"3.0 |
"
],
"text/plain": [
" size number_of_rows number_of_chunks_per_column number_of_chunks\n",
"----------------- ------ ---------------- ----------------------------- ------------------\n",
"172.16.2.84:54321 230 B 2 1 3\n",
"mean 230 B 2 1 3\n",
"min 230 B 2 1 3\n",
"max 230 B 2 1 3\n",
"stddev 0 B 0 0 0\n",
"total 230 B 2 1 3"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"text/html": [
"\n",
" | Month | sum_Cancelled | nrow_Year |
\n",
"type | int | int | int |
\n",
"mins | 1.0 | 19.0 | 1999.0 |
\n",
"mean | 5.5 | 543.0 | 21989.0 |
\n",
"maxs | 10.0 | 1067.0 | 41979.0 |
\n",
"sigma | 6.363961030678928 | 741.0479066835018 | 28270.12911183817 |
\n",
"zeros | 0 | 0 | 0 |
\n",
"missing | 0 | 0 | 0 |
\n",
"0 | 1.0 | 1067.0 | 41979.0 |
\n",
"1 | 10.0 | 19.0 | 1999.0 |
\n",
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Group flights by month\n",
"grouped = data.group_by(\"Month\")\n",
"bpd = grouped.count().sum(\"Cancelled\").frame\n",
"bpd.show()\n",
"bpd.describe()\n",
"bpd.dim\n",
"\n",
"# Convert columns to factors\n",
"data[\"Year\"] = data[\"Year\"] .asfactor()\n",
"data[\"Month\"] = data[\"Month\"] .asfactor()\n",
"data[\"DayOfWeek\"] = data[\"DayOfWeek\"].asfactor()\n",
"data[\"Cancelled\"] = data[\"Cancelled\"].asfactor()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Parse Progress: [##################################################] 100%\n",
"\n",
"glm Model Build Progress: [##################################################] 100%\n"
]
}
],
"source": [
"# Calculate and plot travel time\n",
"hour1 = data[\"CRSArrTime\"] / 100\n",
"mins1 = data[\"CRSArrTime\"] % 100\n",
"arrTime = hour1*60 + mins1\n",
"\n",
"hour2 = data[\"CRSDepTime\"] / 100\n",
"mins2 = data[\"CRSDepTime\"] % 100\n",
"depTime = hour2*60 + mins2\n",
"\n",
"# TODO: Replace this once list comprehension is supported. See PUBDEV-1286.\n",
"# data[\"TravelTime\"] = [x if x > 0 else None for x in (arrTime - depTime)]\n",
"data[\"TravelTime\"] = (arrTime-depTime > 0).ifelse((arrTime-depTime), h2o.H2OFrame([[None]] * data.nrow))\n",
"scatter_plot(data, \"Distance\", \"TravelTime\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"glm Model Build Progress: [##################################################] 100%\n"
]
}
],
"source": [
"# Impute missing travel times and re-plot\n",
"data.impute(column = \"Distance\", by = [\"Origin\", \"Dest\"])\n",
"scatter_plot(data, \"Distance\", \"TravelTime\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"glm Model Build Progress: [##################################################] 100%\n",
"\n",
"gbm Model Build Progress: [##################################################] 100%\n",
"\n",
"gbm Model Build Progress: [##################################################] 100%\n",
"\n",
"drf Model Build Progress: [##################################################] 100%\n",
"\n",
"drf Model Build Progress: [##################################################] 100%\n",
"\n",
"deeplearning Model Build Progress: [##################################################] 100%\n"
]
}
],
"source": [
"# ----------\n",
"# 3- Fit a model on train; using test as validation\n",
"\n",
"# Create test/train split\n",
"s = data[\"Year\"].runif()\n",
"train = data[s <= 0.75]\n",
"test = data[s > 0.75]\n",
"\n",
"# Set predictor and response variables\n",
"myY = \"IsDepDelayed\"\n",
"myX = [\"Origin\", \"Dest\", \"Year\", \"UniqueCarrier\", \"DayOfWeek\", \"Month\", \"Distance\", \"FlightNum\"]\n",
"\n",
"# Simple GLM - Predict Delays\n",
"data_glm = H2OGeneralizedLinearEstimator(family=\"binomial\", standardize=True)\n",
"data_glm.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)\n",
"\n",
"# Simple GBM\n",
"data_gbm = H2OGradientBoostingEstimator(balance_classes=True,\n",
" ntrees =3,\n",
" max_depth =1,\n",
" distribution =\"bernoulli\",\n",
" learn_rate =0.1,\n",
" min_rows =2)\n",
"\n",
"data_gbm.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)\n",
"\n",
"# Complex GBM\n",
"data_gbm2 = H2OGradientBoostingEstimator(balance_classes=True,\n",
" ntrees =50,\n",
" max_depth =5,\n",
" distribution =\"bernoulli\",\n",
" learn_rate =0.1,\n",
" min_rows =2)\n",
"\n",
"data_gbm2.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)\n",
"\n",
"# Simple Random Forest\n",
"data_rf = H2ORandomForestEstimator(ntrees =5,\n",
" max_depth =2,\n",
" balance_classes=True)\n",
"\n",
"data_rf.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)\n",
"\n",
"# Complex Random Forest\n",
"data_rf2 = H2ORandomForestEstimator(ntrees =10,\n",
" max_depth =5,\n",
" balance_classes=True)\n",
"\n",
"data_rf2.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)\n",
"\n",
"# Deep Learning with 5 epochs\n",
"data_dl = H2ODeepLearningEstimator(hidden =[10,10],\n",
" epochs =5,\n",
" variable_importances=True,\n",
" balance_classes =True,\n",
" loss =\"Automatic\")\n",
"\n",
"data_dl.train(x =myX,\n",
" y =myY,\n",
" training_frame =train,\n",
" validation_frame=test)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable Importances:\n",
"\n",
"| Predictor | Normalized Coefficient |\n",
"|------------------+--------------------------|\n",
"| Year.2008 | 2.1663 |\n",
"| Dest.HTS | 1.59911 |\n",
"| Year.2003 | 1.59565 |\n",
"| Origin.MDW | 1.58362 |\n",
"| Year.2007 | 1.37479 |\n",
"| Origin.HPN | 1.34354 |\n",
"| Origin.LIH | 1.32598 |\n",
"| Dest.LYH | 1.29275 |\n",
"| Origin.LBB | 1.21984 |\n",
"| Origin.LEX | 1.21291 |\n",
"| Origin.ERI | 1.20959 |\n",
"| Origin.TLH | 1.17343 |\n",
"| Origin.CAE | 1.15044 |\n",
"| UniqueCarrier.HP | 1.12944 |\n",
"| Origin.PSP | 1.11685 |\n",
"| Origin.HNL | 1.11194 |\n",
"| Origin.TRI | 1.02187 |\n",
"| UniqueCarrier.TW | 1.0169 |\n",
"| Year.2001 | 0.979973 |\n",
"| Year.2002 | 0.944374 |\n",
"| Origin.SDF | 0.939753 |\n",
"| Origin.ATL | 0.935832 |\n",
"| Origin.GRR | 0.884671 |\n",
"| Origin.PBI | 0.882257 |\n",
"| Origin.CHO | 0.878584 |\n",
"| Origin.OGG | 0.864754 |\n",
"| Origin.SRQ | 0.856535 |\n",
"| Year.2004 | 0.846669 |\n",
"| Origin.MYR | 0.835173 |\n",
"| Origin.ACY | 0.804102 |\n",
"| Origin.ORD | 0.787865 |\n",
"| Year.1994 | 0.781128 |\n",
"| Origin.MAF | 0.766548 |\n",
"| Origin.TUL | 0.765077 |\n",
"| Origin.MRY | 0.759124 |\n",
"| Year.2006 | 0.749834 |\n",
"| Origin.STL | 0.737706 |\n",
"| Origin.LYH | 0.728328 |\n",
"| Dest.CHO | 0.728328 |\n",
"| Origin.CMH | 0.703809 |\n",
"| Dest.GSO | 0.694797 |\n",
"| Origin.BTV | 0.678703 |\n",
"| Origin.ROA | 0.672739 |\n",
"| Dest.ISP | 0.666122 |\n",
"| Dest.LIH | 0.647256 |\n",
"| Origin.AUS | 0.646233 |\n",
"| Origin.IAH | 0.637049 |\n",
"| Dest.FLL | 0.624057 |\n",
"| Origin.MLB | 0.611271 |\n",
"| Dest.PBI | 0.609092 |\n",
"| Origin.PIT | 0.604604 |\n",
"| Origin.PWM | 0.603332 |\n",
"| Dest.ICT | 0.601697 |\n",
"| Year.1996 | 0.601507 |\n",
"| Origin.TYS | 0.590041 |\n",
"| Origin.MSY | 0.587653 |\n",
"| Year.1990 | 0.564752 |\n",
"| Dest.DAY | 0.564026 |\n",
"| Origin.SYR | 0.560879 |\n",
"| Dest.IAH | 0.553572 |\n",
"| Dest.EUG | 0.54793 |\n",
"| Origin.JAX | 0.542031 |\n",
"| Origin.BOI | 0.541044 |\n",
"| Dest.TOL | 0.528751 |\n",
"| Dest.TPA | 0.51248 |\n",
"| Dest.BUF | 0.512192 |\n",
"| Dest.PSP | 0.508527 |\n",
"| Origin.ALB | 0.506946 |\n",
"| Origin.SAV | 0.50483 |\n",
"| Origin.CRW | 0.504431 |\n",
"| Dest.PNS | 0.503218 |\n",
"| UniqueCarrier.CO | 0.499991 |\n",
"| Dest.SFO | 0.499403 |\n",
"| Origin.PHL | 0.498516 |\n",
"| Year.1997 | 0.492557 |\n",
"| Origin.OKC | 0.491762 |\n",
"| Origin.LGA | 0.488253 |\n",
"| Origin.MIA | 0.480325 |\n",
"| Origin.OMA | 0.477082 |\n",
"| Dest.CHS | 0.475901 |\n",
"| Dest.CAK | 0.473522 |\n",
"| Origin.FLL | 0.469294 |\n",
"| Origin.ICT | 0.464117 |\n",
"| Dest.GEG | 0.461246 |\n",
"| Origin.EGE | 0.461207 |\n",
"| Dest.ABQ | 0.461191 |\n",
"| Dest.EYW | 0.452089 |\n",
"| Year.2005 | 0.45045 |\n",
"| Dest.IND | 0.449927 |\n",
"| UniqueCarrier.WN | 0.446792 |\n",
"| Origin.IND | 0.446311 |\n",
"| Origin.GSO | 0.442529 |\n",
"| Origin.MCO | 0.434966 |\n",
"| Origin.LAX | 0.433672 |\n",
"| Origin.BDL | 0.418545 |\n",
"| Dest.CAE | 0.414453 |\n",
"| Dest.SMF | 0.409427 |\n",
"| Origin.CRP | 0.403216 |\n",
"| Origin.DFW | 0.399445 |\n",
"| Dest.BDL | 0.395146 |\n",
"| Dest.CVG | 0.391672 |\n",
"| Dest.UCA | 0.39075 |\n",
"| Origin.DSM | 0.387103 |\n",
"| Origin.MEM | 0.383554 |\n",
"| Origin.EYW | 0.375727 |\n",
"| Dest.CLE | 0.372843 |\n",
"| Dest.FAT | 0.369287 |\n",
"| UniqueCarrier.PI | 0.366404 |\n",
"| Origin.SLC | 0.354344 |\n",
"| Origin.JFK | 0.34159 |\n",
"| Origin.BWI | 0.339737 |\n",
"| Dest.MIA | 0.338326 |\n",
"| Origin.ROC | 0.328992 |\n",
"| Origin.OAK | 0.327167 |\n",
"| Dest.BGM | 0.323214 |\n",
"| Origin.IAD | 0.320497 |\n",
"| Dest.JAX | 0.319508 |\n",
"| Dest.MKE | 0.31828 |\n",
"| Year.1992 | 0.31714 |\n",
"| Dest.MCO | 0.315641 |\n",
"| Dest.FAY | 0.315447 |\n",
"| Dest.COS | 0.314929 |\n",
"| Origin.RNO | 0.314859 |\n",
"| Origin.MCI | 0.313843 |\n",
"| Dest.SAT | 0.305571 |\n",
"| Year.1995 | 0.29602 |\n",
"| Origin.SAN | 0.292782 |\n",
"| Dest.OGG | 0.281564 |\n",
"| Year.1991 | 0.274708 |\n",
"| Dest.BUR | 0.270584 |\n",
"| Dest.ALB | 0.268558 |\n",
"| Dest.TUL | 0.26762 |\n",
"| Origin.DAY | 0.264843 |\n",
"| Origin.BUR | 0.264689 |\n",
"| Origin.CLT | 0.256984 |\n",
"| Origin.ONT | 0.256321 |\n",
"| Origin.MKE | 0.254529 |\n",
"| Origin.HRL | 0.253809 |\n",
"| DayOfWeek.5 | 0.244342 |\n",
"| UniqueCarrier.US | 0.239344 |\n",
"| Dest.BTV | 0.23824 |\n",
"| Origin.ABE | 0.234584 |\n",
"| Origin.TPA | 0.22891 |\n",
"| Dest.STT | 0.225113 |\n",
"| Origin.STX | 0.223986 |\n",
"| Dest.GSP | 0.221914 |\n",
"| Origin.BHM | 0.219408 |\n",
"| Dest.IAD | 0.219399 |\n",
"| Origin.BOS | 0.21936 |\n",
"| Origin.MDT | 0.217089 |\n",
"| Dest.PVD | 0.21636 |\n",
"| Dest.RSW | 0.208373 |\n",
"| Origin.ELP | 0.207048 |\n",
"| Origin.DEN | 0.205402 |\n",
"| Dest.LIT | 0.204071 |\n",
"| Month.10 | 0.203185 |\n",
"| Year.1987 | 0.203185 |\n",
"| Dest.BWI | 0.202309 |\n",
"| Origin.MSP | 0.201702 |\n",
"| Dest.PDX | 0.201547 |\n",
"| Dest.ROC | 0.199012 |\n",
"| Origin.TUS | 0.197624 |\n",
"| Dest.KOA | 0.197388 |\n",
"| Dest.CLT | 0.191233 |\n",
"| Dest.OAJ | 0.188976 |\n",
"| Year.1999 | 0.186221 |\n",
"| Origin.SJC | 0.182876 |\n",
"| Dest.DAL | 0.179589 |\n",
"| Origin.BUF | 0.178246 |\n",
"| DayOfWeek.2 | 0.17761 |\n",
"| Origin.DAL | 0.175027 |\n",
"| Origin.CLE | 0.173502 |\n",
"| Dest.GRR | 0.169856 |\n",
"| Dest.PWM | 0.16768 |\n",
"| UniqueCarrier.AA | 0.167342 |\n",
"| Year.1993 | 0.166087 |\n",
"| Dest.RNO | 0.165744 |\n",
"| Distance | 0.163211 |\n",
"| Dest.LBB | 0.157175 |\n",
"| Dest.HRL | 0.156284 |\n",
"| Dest.ABE | 0.155532 |\n",
"| Dest.CMH | 0.154857 |\n",
"| Dest.CRP | 0.151555 |\n",
"| Dest.SNA | 0.151435 |\n",
"| Origin.SFO | 0.150441 |\n",
"| Dest.SEA | 0.149936 |\n",
"| Dest.ROA | 0.148303 |\n",
"| Year.2000 | 0.146046 |\n",
"| Dest.ORF | 0.134053 |\n",
"| Dest.SAN | 0.133593 |\n",
"| DayOfWeek.6 | 0.132748 |\n",
"| Dest.MSP | 0.132271 |\n",
"| Origin.COS | 0.128671 |\n",
"| Dest.HOU | 0.127342 |\n",
"| Dest.TUS | 0.120346 |\n",
"| DayOfWeek.4 | 0.119748 |\n",
"| Dest.DSM | 0.116603 |\n",
"| Dest.LAX | 0.11609 |\n",
"| Dest.SLC | 0.114966 |\n",
"| Dest.AVP | 0.112227 |\n",
"| Dest.STL | 0.110793 |\n",
"| Origin.ORF | 0.108536 |\n",
"| Dest.BHM | 0.108348 |\n",
"| UniqueCarrier.UA | 0.107298 |\n",
"| Origin.DTW | 0.105773 |\n",
"| Dest.MDW | 0.10405 |\n",
"| Dest.DFW | 0.0989164 |\n",
"| Origin.CVG | 0.0967693 |\n",
"| Origin.SMF | 0.0959796 |\n",
"| Origin.RSW | 0.0934595 |\n",
"| Origin.SWF | 0.0927228 |\n",
"| Month.1 | 0.092347 |\n",
"| Dest.PHL | 0.0848795 |\n",
"| Dest.PHX | 0.0848389 |\n",
"| Origin.RDU | 0.0839633 |\n",
"| Origin.DCA | 0.0832363 |\n",
"| Dest.OAK | 0.0818515 |\n",
"| Dest.MCI | 0.0815358 |\n",
"| Dest.EWR | 0.0785491 |\n",
"| Dest.DEN | 0.0783454 |\n",
"| Dest.DTW | 0.0774459 |\n",
"| Year.1989 | 0.0762646 |\n",
"| Dest.LAS | 0.0743316 |\n",
"| Dest.MDT | 0.0731147 |\n",
"| Dest.RIC | 0.0723303 |\n",
"| Dest.OMA | 0.0661859 |\n",
"| UniqueCarrier.PS | 0.0645156 |\n",
"| Year.1998 | 0.05845 |\n",
"| Dest.MHT | 0.0576363 |\n",
"| Origin.BNA | 0.0553462 |\n",
"| Origin.PHX | 0.0522407 |\n",
"| Origin.GNV | 0.0504304 |\n",
"| Dest.MSY | 0.0501866 |\n",
"| Origin.PVD | 0.0490418 |\n",
"| Origin.MFR | 0.0437977 |\n",
"| Origin.SNA | 0.0421396 |\n",
"| FlightNum | 0.0376186 |\n",
"| Origin.SEA | 0.0372322 |\n",
"| Dest.BNA | 0.0347007 |\n",
"| Origin.PHF | 0.029703 |\n",
"| Dest.LGA | 0.0291171 |\n",
"| Intercept | 0.026855 |\n",
"| Dest.ORD | 0.0244753 |\n",
"| DayOfWeek.7 | 0.0234737 |\n",
"| Dest.SJC | 0.0177833 |\n",
"| Dest.AVL | 0.0172911 |\n",
"| Dest.BOS | 0.0162872 |\n",
"| DayOfWeek.1 | 0.0153713 |\n",
"| Origin.PDX | 0.0112833 |\n",
"| Origin.RIC | 0.011192 |\n",
"| Origin.SAT | 0.0110852 |\n",
"| Year.1988 | 0.00996483 |\n",
"| Origin.BGM | 0.00952641 |\n",
"| Dest.PIT | 0.00935131 |\n",
"| Dest.ATL | 0.00882664 |\n",
"| Origin.CHS | 0.00818887 |\n",
"| Origin.ABQ | 0.00803383 |\n",
"| Dest.ILM | 0.00255637 |\n",
"| UniqueCarrier.DL | 0.00110988 |\n",
"| Origin.GEG | 0 |\n",
"| Origin.SBN | 0 |\n",
"| Origin.STT | 0 |\n",
"| Origin.ANC | 0 |\n",
"| Dest.AMA | 0 |\n",
"| Dest.RDU | 0 |\n",
"| Dest.FNT | 0 |\n",
"| Dest.LEX | 0 |\n",
"| Origin.HOU | 0 |\n",
"| Origin.LAS | 0 |\n",
"| Dest.ACY | 0 |\n",
"| Dest.AUS | 0 |\n",
"| Dest.SDF | 0 |\n",
"| Dest.DCA | 0 |\n",
"| Dest.MRY | 0 |\n",
"| Dest.SCK | 0 |\n",
"| Origin.EWR | 0 |\n",
"| Dest.PHF | 0 |\n",
"| Dest.BOI | 0 |\n",
"| Origin.AVP | 0 |\n",
"| Origin.LAN | 0 |\n",
"| Dest.SBN | 0 |\n",
"| Dest.JFK | 0 |\n",
"| Dest.SJU | 0 |\n",
"| Origin.UCA | 0 |\n",
"| DayOfWeek.3 | 0 |\n",
"| Dest.SYR | 0 |\n",
"| Origin.KOA | 0 |\n",
"| Origin.MHT | 0 |\n",
"| Origin.LIT | 0 |\n",
"| Dest.JAN | 0 |\n",
"| Origin.SCK | 0 |\n",
"| Dest.ERI | 0 |\n",
"| Dest.ELM | 0 |\n",
"| Dest.HNL | 0 |\n",
"| Dest.OKC | 0 |\n",
"| Dest.HPN | 0 |\n",
"| Origin.BIL | 0 |\n",
"| Dest.ORH | 0 |\n",
"| Dest.MYR | 0 |\n",
"| Dest.SRQ | 0 |\n",
"| Dest.ANC | 0 |\n",
"| Dest.CHA | 0 |\n",
"| Dest.SWF | 0 |\n",
"| Origin.JAN | 0 |\n",
"| Origin.AMA | 0 |\n",
"| Dest.ONT | 0 |\n",
"| Dest.ELP | 0 |\n",
"| Origin.ISP | 0 |\n",
"| Dest.MAF | 0 |\n",
"| Origin.SJU | 0 |\n"
]
},
{
"data": {
"text/plain": [
"[('Year', 860.6602783203125, 1.0, 0.5018886676744018),\n",
" ('Dest', 593.151123046875, 0.6891814784394192, 0.3458923739998345),\n",
" ('UniqueCarrier', 87.23373413085938, 0.1013567563511901, 0.05086980740489776),\n",
" ('DayOfWeek', 80.93794250488281, 0.09404168467358974, 0.04719845582668416),\n",
" ('Distance', 65.31503295898438, 0.07588944744429815, 0.03808805366836533),\n",
" ('FlightNum', 27.54490852355957, 0.032004391532181486, 0.01606264142581647),\n",
" ('Month', 0.0, 0.0, 0.0),\n",
" ('Origin', 0.0, 0.0, 0.0)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Variable importances from each algorithm\n",
"# Calculate magnitude of normalized GLM coefficients\n",
"from six import iteritems\n",
"glm_varimp = data_glm.coef_norm()\n",
"for k,v in iteritems(glm_varimp):\n",
" glm_varimp[k] = abs(glm_varimp[k])\n",
" \n",
"# Sort in descending order by magnitude\n",
"glm_sorted = sorted(glm_varimp.items(), key = operator.itemgetter(1), reverse = True)\n",
"table = tabulate(glm_sorted, headers = [\"Predictor\", \"Normalized Coefficient\"], tablefmt = \"orgtbl\")\n",
"print(\"Variable Importances:\\n\\n\" + table)\n",
"\n",
"data_gbm.varimp()\n",
"data_rf.varimp()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"ModelMetricsBinomial: gbm\n",
"** Reported on test data. **\n",
"\n",
"MSE: 0.20407778554922562\n",
"R^2: 0.18116065189707653\n",
"LogLoss: 0.5945117554029998\n",
"AUC: 0.7467255149856272\n",
"Gini: 0.49345102997125445\n",
"\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.3514986726263641: \n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"1968.0 | \n",
"3199.0 | \n",
"0.6191 | \n",
" (3199.0/5167.0) |
\n",
"YES | \n",
"657.0 | \n",
"5118.0 | \n",
"0.1138 | \n",
" (657.0/5775.0) |
\n",
"Total | \n",
"2625.0 | \n",
"8317.0 | \n",
"0.3524 | \n",
" (3856.0/10942.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 1968 3199 0.6191 (3199.0/5167.0)\n",
"YES 657 5118 0.1138 (657.0/5775.0)\n",
"Total 2625 8317 0.3524 (3856.0/10942.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric | \n",
"threshold | \n",
"value | \n",
"idx |
\n",
"max f1 | \n",
"0.3514987 | \n",
"0.7263696 | \n",
"287.0 |
\n",
"max f2 | \n",
"0.1882069 | \n",
"0.8505254 | \n",
"372.0 |
\n",
"max f0point5 | \n",
"0.5203289 | \n",
"0.7060683 | \n",
"199.0 |
\n",
"max accuracy | \n",
"0.4815086 | \n",
"0.6868945 | \n",
"220.0 |
\n",
"max precision | \n",
"0.9607084 | \n",
"1.0 | \n",
"0.0 |
\n",
"max absolute_MCC | \n",
"0.5011300 | \n",
"0.3721374 | \n",
"209.0 |
\n",
"max min_per_class_accuracy | \n",
"0.5067588 | \n",
"0.6851171 | \n",
"206.0 |
"
],
"text/plain": [
"metric threshold value idx\n",
"-------------------------- ----------- -------- -----\n",
"max f1 0.351499 0.72637 287\n",
"max f2 0.188207 0.850525 372\n",
"max f0point5 0.520329 0.706068 199\n",
"max accuracy 0.481509 0.686895 220\n",
"max precision 0.960708 1 0\n",
"max absolute_MCC 0.50113 0.372137 209\n",
"max min_per_class_accuracy 0.506759 0.685117 206"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Gains/Lift Table: Avg response rate: 52.78 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" | \n",
"group | \n",
"lower_threshold | \n",
"cumulative_data_fraction | \n",
"response_rate | \n",
"cumulative_response_rate | \n",
"capture_rate | \n",
"cumulative_capture_rate | \n",
"lift | \n",
"cumulative_lift | \n",
"gain | \n",
"cumulative_gain |
\n",
" | \n",
"1 | \n",
"0.8528015 | \n",
"0.0500823 | \n",
"0.8978102 | \n",
"0.8978102 | \n",
"0.0851948 | \n",
"0.0851948 | \n",
"1.7010977 | \n",
"1.7010977 | \n",
"70.1097734 | \n",
"70.1097734 |
\n",
" | \n",
"2 | \n",
"0.7962821 | \n",
"0.1001645 | \n",
"0.8321168 | \n",
"0.8649635 | \n",
"0.0789610 | \n",
"0.1641558 | \n",
"1.5766272 | \n",
"1.6388625 | \n",
"57.6627168 | \n",
"63.8862451 |
\n",
" | \n",
"3 | \n",
"0.7494101 | \n",
"0.1500640 | \n",
"0.7893773 | \n",
"0.8398295 | \n",
"0.0746320 | \n",
"0.2387879 | \n",
"1.4956478 | \n",
"1.5912405 | \n",
"49.5647844 | \n",
"59.1240542 |
\n",
" | \n",
"4 | \n",
"0.7173250 | \n",
"0.2000548 | \n",
"0.7787934 | \n",
"0.8245774 | \n",
"0.0737662 | \n",
"0.3125541 | \n",
"1.4755944 | \n",
"1.5623422 | \n",
"47.5594387 | \n",
"56.2342211 |
\n",
" | \n",
"5 | \n",
"0.6855407 | \n",
"0.2501371 | \n",
"0.7408759 | \n",
"0.8078188 | \n",
"0.0703030 | \n",
"0.3828571 | \n",
"1.4037514 | \n",
"1.5305893 | \n",
"40.3751382 | \n",
"53.0589279 |
\n",
" | \n",
"6 | \n",
"0.6556998 | \n",
"0.3002193 | \n",
"0.7043796 | \n",
"0.7905632 | \n",
"0.0668398 | \n",
"0.4496970 | \n",
"1.3346011 | \n",
"1.4978947 | \n",
"33.4601068 | \n",
"49.7894747 |
\n",
" | \n",
"7 | \n",
"0.6236469 | \n",
"0.3500274 | \n",
"0.6495413 | \n",
"0.7704961 | \n",
"0.0612987 | \n",
"0.5109957 | \n",
"1.2306980 | \n",
"1.4598733 | \n",
"23.0697963 | \n",
"45.9873272 |
\n",
" | \n",
"8 | \n",
"0.5865343 | \n",
"0.4001097 | \n",
"0.5985401 | \n",
"0.7489721 | \n",
"0.0567965 | \n",
"0.5677922 | \n",
"1.1340652 | \n",
"1.4190914 | \n",
"13.4065156 | \n",
"41.9091443 |
\n",
" | \n",
"9 | \n",
"0.5478266 | \n",
"0.4500091 | \n",
"0.5787546 | \n",
"0.7300975 | \n",
"0.0547186 | \n",
"0.6225108 | \n",
"1.0965771 | \n",
"1.3833293 | \n",
"9.6577074 | \n",
"38.3329289 |
\n",
" | \n",
"10 | \n",
"0.5128311 | \n",
"0.5 | \n",
"0.5557587 | \n",
"0.7126668 | \n",
"0.0526407 | \n",
"0.6751515 | \n",
"1.0530063 | \n",
"1.3503030 | \n",
"5.3006323 | \n",
"35.0303030 |
\n",
" | \n",
"11 | \n",
"0.4815168 | \n",
"0.5508134 | \n",
"0.5161871 | \n",
"0.6945412 | \n",
"0.0496970 | \n",
"0.7248485 | \n",
"0.9780292 | \n",
"1.3159602 | \n",
"-2.1970787 | \n",
"31.5960199 |
\n",
" | \n",
"12 | \n",
"0.4483592 | \n",
"0.6001645 | \n",
"0.45 | \n",
"0.6744328 | \n",
"0.0420779 | \n",
"0.7669264 | \n",
"0.8526234 | \n",
"1.2778603 | \n",
"-14.7376623 | \n",
"27.7860324 |
\n",
" | \n",
"13 | \n",
"0.4159386 | \n",
"0.6501554 | \n",
"0.4223035 | \n",
"0.6550464 | \n",
"0.04 | \n",
"0.8069264 | \n",
"0.8001463 | \n",
"1.2411286 | \n",
"-19.9853748 | \n",
"24.1128584 |
\n",
" | \n",
"14 | \n",
"0.3884948 | \n",
"0.7000548 | \n",
"0.3736264 | \n",
"0.6349869 | \n",
"0.0353247 | \n",
"0.8422511 | \n",
"0.7079168 | \n",
"1.2031216 | \n",
"-29.2083155 | \n",
"20.3121585 |
\n",
" | \n",
"15 | \n",
"0.3570956 | \n",
"0.7499543 | \n",
"0.3864469 | \n",
"0.6184499 | \n",
"0.0365368 | \n",
"0.8787879 | \n",
"0.7322081 | \n",
"1.1717886 | \n",
"-26.7791891 | \n",
"17.1788566 |
\n",
" | \n",
"16 | \n",
"0.3277598 | \n",
"0.8000366 | \n",
"0.3485401 | \n",
"0.6015536 | \n",
"0.0330736 | \n",
"0.9118615 | \n",
"0.6603855 | \n",
"1.1397748 | \n",
"-33.9614497 | \n",
"13.9774757 |
\n",
" | \n",
"17 | \n",
"0.2981756 | \n",
"0.8500274 | \n",
"0.2888483 | \n",
"0.5831631 | \n",
"0.0273593 | \n",
"0.9392208 | \n",
"0.5472862 | \n",
"1.1049300 | \n",
"-45.2713819 | \n",
"10.4929982 |
\n",
" | \n",
"18 | \n",
"0.2699267 | \n",
"0.8999269 | \n",
"0.2637363 | \n",
"0.5654514 | \n",
"0.0249351 | \n",
"0.9641558 | \n",
"0.4997060 | \n",
"1.0713713 | \n",
"-50.0293992 | \n",
"7.1371306 |
\n",
" | \n",
"19 | \n",
"0.2239691 | \n",
"0.9499177 | \n",
"0.2230347 | \n",
"0.5474312 | \n",
"0.0211255 | \n",
"0.9852814 | \n",
"0.4225881 | \n",
"1.0372281 | \n",
"-57.7411936 | \n",
"3.7228104 |
\n",
" | \n",
"20 | \n",
"0.0694869 | \n",
"1.0 | \n",
"0.1551095 | \n",
"0.5277829 | \n",
"0.0147186 | \n",
"1.0 | \n",
"0.2938888 | \n",
"1.0 | \n",
"-70.6111164 | \n",
"0.0 |
"
],
"text/plain": [
" group lower_threshold cumulative_data_fraction response_rate cumulative_response_rate capture_rate cumulative_capture_rate lift cumulative_lift gain cumulative_gain\n",
"-- ------- ----------------- -------------------------- --------------- -------------------------- -------------- ------------------------- -------- ----------------- -------- -----------------\n",
" 1 0.852802 0.0500823 0.89781 0.89781 0.0851948 0.0851948 1.7011 1.7011 70.1098 70.1098\n",
" 2 0.796282 0.100165 0.832117 0.864964 0.078961 0.164156 1.57663 1.63886 57.6627 63.8862\n",
" 3 0.74941 0.150064 0.789377 0.839829 0.074632 0.238788 1.49565 1.59124 49.5648 59.1241\n",
" 4 0.717325 0.200055 0.778793 0.824577 0.0737662 0.312554 1.47559 1.56234 47.5594 56.2342\n",
" 5 0.685541 0.250137 0.740876 0.807819 0.070303 0.382857 1.40375 1.53059 40.3751 53.0589\n",
" 6 0.6557 0.300219 0.70438 0.790563 0.0668398 0.449697 1.3346 1.49789 33.4601 49.7895\n",
" 7 0.623647 0.350027 0.649541 0.770496 0.0612987 0.510996 1.2307 1.45987 23.0698 45.9873\n",
" 8 0.586534 0.40011 0.59854 0.748972 0.0567965 0.567792 1.13407 1.41909 13.4065 41.9091\n",
" 9 0.547827 0.450009 0.578755 0.730097 0.0547186 0.622511 1.09658 1.38333 9.65771 38.3329\n",
" 10 0.512831 0.5 0.555759 0.712667 0.0526407 0.675152 1.05301 1.3503 5.30063 35.0303\n",
" 11 0.481517 0.550813 0.516187 0.694541 0.049697 0.724848 0.978029 1.31596 -2.19708 31.596\n",
" 12 0.448359 0.600165 0.45 0.674433 0.0420779 0.766926 0.852623 1.27786 -14.7377 27.786\n",
" 13 0.415939 0.650155 0.422303 0.655046 0.04 0.806926 0.800146 1.24113 -19.9854 24.1129\n",
" 14 0.388495 0.700055 0.373626 0.634987 0.0353247 0.842251 0.707917 1.20312 -29.2083 20.3122\n",
" 15 0.357096 0.749954 0.386447 0.61845 0.0365368 0.878788 0.732208 1.17179 -26.7792 17.1789\n",
" 16 0.32776 0.800037 0.34854 0.601554 0.0330736 0.911861 0.660386 1.13977 -33.9614 13.9775\n",
" 17 0.298176 0.850027 0.288848 0.583163 0.0273593 0.939221 0.547286 1.10493 -45.2714 10.493\n",
" 18 0.269927 0.899927 0.263736 0.565451 0.0249351 0.964156 0.499706 1.07137 -50.0294 7.13713\n",
" 19 0.223969 0.949918 0.223035 0.547431 0.0211255 0.985281 0.422588 1.03723 -57.7412 3.72281\n",
" 20 0.0694869 1 0.155109 0.527783 0.0147186 1 0.293889 1 -70.6111 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": []
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Model performance of GBM model on test data\n",
"data_gbm2.model_performance(test)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}