{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import h2o\n",
"from h2o.estimators.gbm import H2OGradientBoostingEstimator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Version mismatch. H2O is version 3.5.0.99999, but the python package is version UNKNOWN.\n"
]
},
{
"data": {
"text/html": [
"
H2O cluster uptime: | \n",
"52 minutes 26 seconds 170 milliseconds |
\n",
"H2O cluster version: | \n",
"3.5.0.99999 |
\n",
"H2O cluster name: | \n",
"ludirehak |
\n",
"H2O cluster total nodes: | \n",
"1 |
\n",
"H2O cluster total memory: | \n",
"4.44 GB |
\n",
"H2O cluster total cores: | \n",
"8 |
\n",
"H2O cluster allowed cores: | \n",
"8 |
\n",
"H2O cluster healthy: | \n",
"True |
\n",
"H2O Connection ip: | \n",
"127.0.0.1 |
\n",
"H2O Connection port: | \n",
"54321 |
"
],
"text/plain": [
"-------------------------- --------------------------------------\n",
"H2O cluster uptime: 52 minutes 26 seconds 170 milliseconds\n",
"H2O cluster version: 3.5.0.99999\n",
"H2O cluster name: ludirehak\n",
"H2O cluster total nodes: 1\n",
"H2O cluster total memory: 4.44 GB\n",
"H2O cluster total cores: 8\n",
"H2O cluster allowed cores: 8\n",
"H2O cluster healthy: True\n",
"H2O Connection ip: 127.0.0.1\n",
"H2O Connection port: 54321\n",
"-------------------------- --------------------------------------"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"h2o.init()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Parse Progress: [##################################################] 100%\n",
"Imported /Users/ludirehak/h2o-3/smalldata/airlines/AirlinesTrain.csv.zip. Parsed 24,421 rows and 12 cols\n"
]
}
],
"source": [
"from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.\n",
"\n",
"# Airlines dataset\n",
"air = h2o.import_file(path=_locate(\"smalldata/airlines/AirlinesTrain.csv.zip\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Construct validation and training datasets by sampling (20/80)\n",
"r = air[0].runif()\n",
"air_train = air[r < 0.8]\n",
"air_valid = air[r >= 0.8]\n",
"\n",
"myX = [\"Origin\", \"Dest\", \"Distance\", \"UniqueCarrier\", \"fMonth\", \"fDayofMonth\", \"fDayOfWeek\"]\n",
"myY = \"IsDepDelayed\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"gbm Model Build Progress: [##################################################] 100%\n"
]
}
],
"source": [
"# Build gbm\n",
"gbm = H2OGradientBoostingEstimator(distribution=\"bernoulli\", \n",
" ntrees=100, \n",
" max_depth=3, \n",
" learn_rate=0.01)\n",
"\n",
"gbm.train(x =myX, \n",
" y =myY, \n",
" training_frame =air_train,\n",
" validation_frame=air_valid)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.438866890551:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"2172.0 | \n",
"6695.0 | \n",
"0.755 | \n",
" (6695.0/8867.0) |
\n",
"YES | \n",
"790.0 | \n",
"9867.0 | \n",
"0.0741 | \n",
" (790.0/10657.0) |
\n",
"Total | \n",
"2962.0 | \n",
"16562.0 | \n",
"0.3834 | \n",
" (7485.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 2172 6695 0.755 (6695.0/8867.0)\n",
"YES 790 9867 0.0741 (790.0/10657.0)\n",
"Total 2962 16562 0.3834 (7485.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f2 @ threshold = 0.381490353472:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"172.0 | \n",
"8695.0 | \n",
"0.9806 | \n",
" (8695.0/8867.0) |
\n",
"YES | \n",
"23.0 | \n",
"10634.0 | \n",
"0.0022 | \n",
" (23.0/10657.0) |
\n",
"Total | \n",
"195.0 | \n",
"19329.0 | \n",
"0.4465 | \n",
" (8718.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 172 8695 0.9806 (8695.0/8867.0)\n",
"YES 23 10634 0.0022 (23.0/10657.0)\n",
"Total 195 19329 0.4465 (8718.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max precision @ threshold = 0.685762034833:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"8866.0 | \n",
"1.0 | \n",
"0.0001 | \n",
" (1.0/8867.0) |
\n",
"YES | \n",
"10630.0 | \n",
"27.0 | \n",
"0.9975 | \n",
" (10630.0/10657.0) |
\n",
"Total | \n",
"19496.0 | \n",
"28.0 | \n",
"0.5445 | \n",
" (10631.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ----- ----- ------- -----------------\n",
"NO 8866 1 0.0001 (1.0/8867.0)\n",
"YES 10630 27 0.9975 (10630.0/10657.0)\n",
"Total 19496 28 0.5445 (10631.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max accuracy @ threshold = 0.509389999822:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"4671.0 | \n",
"4196.0 | \n",
"0.4732 | \n",
" (4196.0/8867.0) |
\n",
"YES | \n",
"2557.0 | \n",
"8100.0 | \n",
"0.2399 | \n",
" (2557.0/10657.0) |
\n",
"Total | \n",
"7228.0 | \n",
"12296.0 | \n",
"0.3459 | \n",
" (6753.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 4671 4196 0.4732 (4196.0/8867.0)\n",
"YES 2557 8100 0.2399 (2557.0/10657.0)\n",
"Total 7228 12296 0.3459 (6753.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f0point5 @ threshold = 0.54046757144:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"5378.0 | \n",
"3489.0 | \n",
"0.3935 | \n",
" (3489.0/8867.0) |
\n",
"YES | \n",
"3297.0 | \n",
"7360.0 | \n",
"0.3094 | \n",
" (3297.0/10657.0) |
\n",
"Total | \n",
"8675.0 | \n",
"10849.0 | \n",
"0.3476 | \n",
" (6786.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 5378 3489 0.3935 (3489.0/8867.0)\n",
"YES 3297 7360 0.3094 (3297.0/10657.0)\n",
"Total 8675 10849 0.3476 (6786.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for training dataset (based on metric(s))\n",
"print(gbm.confusion_matrix()) # maximum f1 threshold chosen by default\n",
"\n",
"print(gbm.confusion_matrix(metrics=\"f2\"))\n",
"\n",
"print(gbm.confusion_matrix(metrics=\"precision\"))\n",
"\n",
"cms = gbm.confusion_matrix(metrics=[\"accuracy\", \"f0point5\"])\n",
"print(cms[0])\n",
"print(cms[1])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Could not find exact threshold 0.77; using closest threshold found 0.685762034833.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.685762034833:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"8866.0 | \n",
"1.0 | \n",
"0.0001 | \n",
" (1.0/8867.0) |
\n",
"YES | \n",
"10630.0 | \n",
"27.0 | \n",
"0.9975 | \n",
" (10630.0/10657.0) |
\n",
"Total | \n",
"19496.0 | \n",
"28.0 | \n",
"0.5445 | \n",
" (10631.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ----- ----- ------- -----------------\n",
"NO 8866 1 0.0001 (1.0/8867.0)\n",
"YES 10630 27 0.9975 (10630.0/10657.0)\n",
"Total 19496 28 0.5445 (10631.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Could not find exact threshold 0.1; using closest threshold found 0.373879538649.\n",
"Could not find exact threshold 0.5; using closest threshold found 0.49962104911.\n",
"Could not find exact threshold 0.99; using closest threshold found 0.685762034833.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.373879538649:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"8867.0 | \n",
"1.0 | \n",
" (8867.0/8867.0) |
\n",
"YES | \n",
"0.0 | \n",
"10657.0 | \n",
"0.0 | \n",
" (0.0/10657.0) |
\n",
"Total | \n",
"0.0 | \n",
"19524.0 | \n",
"0.4542 | \n",
" (8867.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 0 8867 1 (8867.0/8867.0)\n",
"YES 0 10657 0 (0.0/10657.0)\n",
"Total 0 19524 0.4542 (8867.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.49962104911:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"4463.0 | \n",
"4404.0 | \n",
"0.4967 | \n",
" (4404.0/8867.0) |
\n",
"YES | \n",
"2400.0 | \n",
"8257.0 | \n",
"0.2252 | \n",
" (2400.0/10657.0) |
\n",
"Total | \n",
"6863.0 | \n",
"12661.0 | \n",
"0.3485 | \n",
" (6804.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 4463 4404 0.4967 (4404.0/8867.0)\n",
"YES 2400 8257 0.2252 (2400.0/10657.0)\n",
"Total 6863 12661 0.3485 (6804.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.685762034833:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"8866.0 | \n",
"1.0 | \n",
"0.0001 | \n",
" (1.0/8867.0) |
\n",
"YES | \n",
"10630.0 | \n",
"27.0 | \n",
"0.9975 | \n",
" (10630.0/10657.0) |
\n",
"Total | \n",
"19496.0 | \n",
"28.0 | \n",
"0.5445 | \n",
" (10631.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ----- ----- ------- -----------------\n",
"NO 8866 1 0.0001 (1.0/8867.0)\n",
"YES 10630 27 0.9975 (10630.0/10657.0)\n",
"Total 19496 28 0.5445 (10631.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for training dataset (based on threshold(s))\n",
"print(gbm.confusion_matrix(thresholds=0.77))\n",
"\n",
"cms = gbm.confusion_matrix(thresholds=[0.1, 0.5, 0.99])\n",
"print(cms[0])\n",
"print(cms[1])\n",
"print(cms[2])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Confusion Matrix (Act/Pred) for max f2 @ threshold = 0.385734623697:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"80.0 | \n",
"2119.0 | \n",
"0.9636 | \n",
" (2119.0/2199.0) |
\n",
"YES | \n",
"13.0 | \n",
"2685.0 | \n",
"0.0048 | \n",
" (13.0/2698.0) |
\n",
"Total | \n",
"93.0 | \n",
"4804.0 | \n",
"0.4354 | \n",
" (2132.0/4897.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 80 2119 0.9636 (2119.0/2199.0)\n",
"YES 13 2685 0.0048 (13.0/2698.0)\n",
"Total 93 4804 0.4354 (2132.0/4897.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max precision @ threshold = 0.683022938978:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"2191.0 | \n",
"8.0 | \n",
"0.0036 | \n",
" (8.0/2199.0) |
\n",
"YES | \n",
"2632.0 | \n",
"66.0 | \n",
"0.9755 | \n",
" (2632.0/2698.0) |
\n",
"Total | \n",
"4823.0 | \n",
"74.0 | \n",
"0.5391 | \n",
" (2640.0/4897.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 2191 8 0.0036 (8.0/2199.0)\n",
"YES 2632 66 0.9755 (2632.0/2698.0)\n",
"Total 4823 74 0.5391 (2640.0/4897.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max accuracy @ threshold = 0.518825062343:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"1188.0 | \n",
"1011.0 | \n",
"0.4598 | \n",
" (1011.0/2199.0) |
\n",
"YES | \n",
"684.0 | \n",
"2014.0 | \n",
"0.2535 | \n",
" (684.0/2698.0) |
\n",
"Total | \n",
"1872.0 | \n",
"3025.0 | \n",
"0.3461 | \n",
" (1695.0/4897.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 1188 1011 0.4598 (1011.0/2199.0)\n",
"YES 684 2014 0.2535 (684.0/2698.0)\n",
"Total 1872 3025 0.3461 (1695.0/4897.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f0point5 @ threshold = 0.540424490283:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"1316.0 | \n",
"883.0 | \n",
"0.4015 | \n",
" (883.0/2199.0) |
\n",
"YES | \n",
"818.0 | \n",
"1880.0 | \n",
"0.3032 | \n",
" (818.0/2698.0) |
\n",
"Total | \n",
"2134.0 | \n",
"2763.0 | \n",
"0.3474 | \n",
" (1701.0/4897.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 1316 883 0.4015 (883.0/2199.0)\n",
"YES 818 1880 0.3032 (818.0/2698.0)\n",
"Total 2134 2763 0.3474 (1701.0/4897.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for validation dataset (based on metric(s))\n",
"print(gbm.confusion_matrix(metrics=\"f2\", valid=True))\n",
"\n",
"print(gbm.confusion_matrix(metrics=\"precision\", valid=True))\n",
"\n",
"cms = gbm.confusion_matrix(metrics=[\"accuracy\", \"f0point5\"], valid=True)\n",
"print(cms[0])\n",
"print(cms[1])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Could not find exact threshold 0.77; using closest threshold found 0.685762034833.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.685762034833:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"8866.0 | \n",
"1.0 | \n",
"0.0001 | \n",
" (1.0/8867.0) |
\n",
"YES | \n",
"10630.0 | \n",
"27.0 | \n",
"0.9975 | \n",
" (10630.0/10657.0) |
\n",
"Total | \n",
"19496.0 | \n",
"28.0 | \n",
"0.5445 | \n",
" (10631.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ----- ----- ------- -----------------\n",
"NO 8866 1 0.0001 (1.0/8867.0)\n",
"YES 10630 27 0.9975 (10630.0/10657.0)\n",
"Total 19496 28 0.5445 (10631.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Could not find exact threshold 0.25; using closest threshold found 0.373879538649.\n",
"Could not find exact threshold 0.33; using closest threshold found 0.373879538649.\n",
"Could not find exact threshold 0.44; using closest threshold found 0.44006560762.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.373879538649:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"8867.0 | \n",
"1.0 | \n",
" (8867.0/8867.0) |
\n",
"YES | \n",
"0.0 | \n",
"10657.0 | \n",
"0.0 | \n",
" (0.0/10657.0) |
\n",
"Total | \n",
"0.0 | \n",
"19524.0 | \n",
"0.4542 | \n",
" (8867.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 0 8867 1 (8867.0/8867.0)\n",
"YES 0 10657 0 (0.0/10657.0)\n",
"Total 0 19524 0.4542 (8867.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.373879538649:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"8867.0 | \n",
"1.0 | \n",
" (8867.0/8867.0) |
\n",
"YES | \n",
"0.0 | \n",
"10657.0 | \n",
"0.0 | \n",
" (0.0/10657.0) |
\n",
"Total | \n",
"0.0 | \n",
"19524.0 | \n",
"0.4542 | \n",
" (8867.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 0 8867 1 (8867.0/8867.0)\n",
"YES 0 10657 0 (0.0/10657.0)\n",
"Total 0 19524 0.4542 (8867.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.44006560762:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"2235.0 | \n",
"6632.0 | \n",
"0.7479 | \n",
" (6632.0/8867.0) |
\n",
"YES | \n",
"856.0 | \n",
"9801.0 | \n",
"0.0803 | \n",
" (856.0/10657.0) |
\n",
"Total | \n",
"3091.0 | \n",
"16433.0 | \n",
"0.3835 | \n",
" (7488.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 2235 6632 0.7479 (6632.0/8867.0)\n",
"YES 856 9801 0.0803 (856.0/10657.0)\n",
"Total 3091 16433 0.3835 (7488.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for validation dataset (based on threshold(s))\n",
"print(gbm.confusion_matrix(thresholds=0.77))\n",
"\n",
"cms = gbm.confusion_matrix(thresholds=[0.25, 0.33, 0.44])\n",
"print(cms[0])\n",
"print(cms[1])\n",
"print(cms[2])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Could not find exact threshold 0.77; using closest threshold found 0.685762034833.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.685762034833:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"8866.0 | \n",
"1.0 | \n",
"0.0001 | \n",
" (1.0/8867.0) |
\n",
"YES | \n",
"10630.0 | \n",
"27.0 | \n",
"0.9975 | \n",
" (10630.0/10657.0) |
\n",
"Total | \n",
"19496.0 | \n",
"28.0 | \n",
"0.5445 | \n",
" (10631.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ----- ----- ------- -----------------\n",
"NO 8866 1 0.0001 (1.0/8867.0)\n",
"YES 10630 27 0.9975 (10630.0/10657.0)\n",
"Total 19496 28 0.5445 (10631.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.438866890551:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"2172.0 | \n",
"6695.0 | \n",
"0.755 | \n",
" (6695.0/8867.0) |
\n",
"YES | \n",
"790.0 | \n",
"9867.0 | \n",
"0.0741 | \n",
" (790.0/10657.0) |
\n",
"Total | \n",
"2962.0 | \n",
"16562.0 | \n",
"0.3834 | \n",
" (7485.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 2172 6695 0.755 (6695.0/8867.0)\n",
"YES 790 9867 0.0741 (790.0/10657.0)\n",
"Total 2962 16562 0.3834 (7485.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Could not find exact threshold 0.25; using closest threshold found 0.373879538649.\n",
"Could not find exact threshold 0.33; using closest threshold found 0.373879538649.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.373879538649:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"8867.0 | \n",
"1.0 | \n",
" (8867.0/8867.0) |
\n",
"YES | \n",
"0.0 | \n",
"10657.0 | \n",
"0.0 | \n",
" (0.0/10657.0) |
\n",
"Total | \n",
"0.0 | \n",
"19524.0 | \n",
"0.4542 | \n",
" (8867.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 0 8867 1 (8867.0/8867.0)\n",
"YES 0 10657 0 (0.0/10657.0)\n",
"Total 0 19524 0.4542 (8867.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.373879538649:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"8867.0 | \n",
"1.0 | \n",
" (8867.0/8867.0) |
\n",
"YES | \n",
"0.0 | \n",
"10657.0 | \n",
"0.0 | \n",
" (0.0/10657.0) |
\n",
"Total | \n",
"0.0 | \n",
"19524.0 | \n",
"0.4542 | \n",
" (8867.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 0 8867 1 (8867.0/8867.0)\n",
"YES 0 10657 0 (0.0/10657.0)\n",
"Total 0 19524 0.4542 (8867.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f2 @ threshold = 0.381490353472:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"172.0 | \n",
"8695.0 | \n",
"0.9806 | \n",
" (8695.0/8867.0) |
\n",
"YES | \n",
"23.0 | \n",
"10634.0 | \n",
"0.0022 | \n",
" (23.0/10657.0) |
\n",
"Total | \n",
"195.0 | \n",
"19329.0 | \n",
"0.4465 | \n",
" (8718.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 172 8695 0.9806 (8695.0/8867.0)\n",
"YES 23 10634 0.0022 (23.0/10657.0)\n",
"Total 195 19329 0.4465 (8718.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max f0point5 @ threshold = 0.54046757144:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"5378.0 | \n",
"3489.0 | \n",
"0.3935 | \n",
" (3489.0/8867.0) |
\n",
"YES | \n",
"3297.0 | \n",
"7360.0 | \n",
"0.3094 | \n",
" (3297.0/10657.0) |
\n",
"Total | \n",
"8675.0 | \n",
"10849.0 | \n",
"0.3476 | \n",
" (6786.0/19524.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ----------------\n",
"NO 5378 3489 0.3935 (3489.0/8867.0)\n",
"YES 3297 7360 0.3094 (3297.0/10657.0)\n",
"Total 8675 10849 0.3476 (6786.0/19524.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for validation dataset (based on metric(s) AND threshold(s))\n",
"cms = gbm.confusion_matrix(thresholds=0.77, metrics=\"f1\") \n",
"print(cms[0])\n",
"print(cms[1])\n",
"\n",
"cms = gbm.confusion_matrix(thresholds=[0.25, 0.33], metrics=[\"f2\", \"f0point5\"])\n",
"print(cms[0])\n",
"print(cms[1])\n",
"print(cms[2])\n",
"print(cms[3])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Parse Progress: [##################################################] 100%\n",
"Imported /Users/ludirehak/h2o-3/smalldata/airlines/AirlinesTest.csv.zip. Parsed 2,691 rows and 12 cols\n"
]
}
],
"source": [
"# Test dataset\n",
"air_test = h2o.import_file(path=_locate(\"smalldata/airlines/AirlinesTest.csv.zip\"))\n",
"\n",
"# Test performance\n",
"gbm_perf = gbm.model_performance(air_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Confusion Matrix (Act/Pred) for max f0point5 @ threshold = 0.532641218074:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"694.0 | \n",
"523.0 | \n",
"0.4297 | \n",
" (523.0/1217.0) |
\n",
"YES | \n",
"398.0 | \n",
"1076.0 | \n",
"0.27 | \n",
" (398.0/1474.0) |
\n",
"Total | \n",
"1092.0 | \n",
"1599.0 | \n",
"0.3423 | \n",
" (921.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- --------------\n",
"NO 694 523 0.4297 (523.0/1217.0)\n",
"YES 398 1076 0.27 (398.0/1474.0)\n",
"Total 1092 1599 0.3423 (921.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max min_per_class_accuracy @ threshold = 0.550904005776:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"779.0 | \n",
"438.0 | \n",
"0.3599 | \n",
" (438.0/1217.0) |
\n",
"YES | \n",
"530.0 | \n",
"944.0 | \n",
"0.3596 | \n",
" (530.0/1474.0) |
\n",
"Total | \n",
"1309.0 | \n",
"1382.0 | \n",
"0.3597 | \n",
" (968.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- --------------\n",
"NO 779 438 0.3599 (438.0/1217.0)\n",
"YES 530 944 0.3596 (530.0/1474.0)\n",
"Total 1309 1382 0.3597 (968.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max accuracy @ threshold = 0.532641218074:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"694.0 | \n",
"523.0 | \n",
"0.4297 | \n",
" (523.0/1217.0) |
\n",
"YES | \n",
"398.0 | \n",
"1076.0 | \n",
"0.27 | \n",
" (398.0/1474.0) |
\n",
"Total | \n",
"1092.0 | \n",
"1599.0 | \n",
"0.3423 | \n",
" (921.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- --------------\n",
"NO 694 523 0.4297 (523.0/1217.0)\n",
"YES 398 1076 0.27 (398.0/1474.0)\n",
"Total 1092 1599 0.3423 (921.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) for max accuracy @ threshold = 0.532641218074:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"694.0 | \n",
"523.0 | \n",
"0.4297 | \n",
" (523.0/1217.0) |
\n",
"YES | \n",
"398.0 | \n",
"1076.0 | \n",
"0.27 | \n",
" (398.0/1474.0) |
\n",
"Total | \n",
"1092.0 | \n",
"1599.0 | \n",
"0.3423 | \n",
" (921.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- --------------\n",
"NO 694 523 0.4297 (523.0/1217.0)\n",
"YES 398 1076 0.27 (398.0/1474.0)\n",
"Total 1092 1599 0.3423 (921.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for test dataset (based on metric(s))\n",
"print(gbm_perf.confusion_matrix(metrics=\"f0point5\"))\n",
"\n",
"print(gbm_perf.confusion_matrix(metrics=\"min_per_class_accuracy\"))\n",
"\n",
"cms = gbm_perf.confusion_matrix(metrics=[\"accuracy\", \"f0point5\"])\n",
"print(cms[0])\n",
"print(cms[1])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Could not find exact threshold 0.5; using closest threshold found 0.499551746996.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.499551746996:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"576.0 | \n",
"641.0 | \n",
"0.5267 | \n",
" (641.0/1217.0) |
\n",
"YES | \n",
"311.0 | \n",
"1163.0 | \n",
"0.211 | \n",
" (311.0/1474.0) |
\n",
"Total | \n",
"887.0 | \n",
"1804.0 | \n",
"0.3538 | \n",
" (952.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- --------------\n",
"NO 576 641 0.5267 (641.0/1217.0)\n",
"YES 311 1163 0.211 (311.0/1474.0)\n",
"Total 887 1804 0.3538 (952.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Could not find exact threshold 0.01; using closest threshold found 0.37382486349.\n",
"Could not find exact threshold 0.75; using closest threshold found 0.6857620914.\n",
"Could not find exact threshold 0.88; using closest threshold found 0.6857620914.\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.37382486349:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"0.0 | \n",
"1217.0 | \n",
"1.0 | \n",
" (1217.0/1217.0) |
\n",
"YES | \n",
"0.0 | \n",
"1474.0 | \n",
"0.0 | \n",
" (0.0/1474.0) |
\n",
"Total | \n",
"0.0 | \n",
"2691.0 | \n",
"0.4522 | \n",
" (1217.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 0 1217 1 (1217.0/1217.0)\n",
"YES 0 1474 0 (0.0/1474.0)\n",
"Total 0 2691 0.4522 (1217.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.6857620914:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"1216.0 | \n",
"1.0 | \n",
"0.0008 | \n",
" (1.0/1217.0) |
\n",
"YES | \n",
"1473.0 | \n",
"1.0 | \n",
"0.9993 | \n",
" (1473.0/1474.0) |
\n",
"Total | \n",
"2689.0 | \n",
"2.0 | \n",
"0.5478 | \n",
" (1474.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 1216 1 0.0008 (1.0/1217.0)\n",
"YES 1473 1 0.9993 (1473.0/1474.0)\n",
"Total 2689 2 0.5478 (1474.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Confusion Matrix (Act/Pred) @ threshold = 0.6857620914:\n"
]
},
{
"data": {
"text/html": [
" | \n",
"NO | \n",
"YES | \n",
"Error | \n",
"Rate |
\n",
"NO | \n",
"1216.0 | \n",
"1.0 | \n",
"0.0008 | \n",
" (1.0/1217.0) |
\n",
"YES | \n",
"1473.0 | \n",
"1.0 | \n",
"0.9993 | \n",
" (1473.0/1474.0) |
\n",
"Total | \n",
"2689.0 | \n",
"2.0 | \n",
"0.5478 | \n",
" (1474.0/2691.0) |
"
],
"text/plain": [
" NO YES Error Rate\n",
"----- ---- ----- ------- ---------------\n",
"NO 1216 1 0.0008 (1.0/1217.0)\n",
"YES 1473 1 0.9993 (1473.0/1474.0)\n",
"Total 2689 2 0.5478 (1474.0/2691.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Show various confusion matrices for test dataset (based on threshold(s))\n",
"print(gbm_perf.confusion_matrix(thresholds=0.5))\n",
"\n",
"cms = gbm_perf.confusion_matrix(thresholds=[0.01, 0.75, .88])\n",
"print(cms[0])\n",
"print(cms[1])\n",
"print(cms[2])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2172, 6695], [790, 9867]]\n",
"[[389, 828], [172, 1302]]\n"
]
}
],
"source": [
"# Convert a ConfusionMatrix to a python list of lists: [ [tns,fps], [fns,tps] ]\n",
"cm = gbm.confusion_matrix()\n",
"print(cm.to_list())\n",
"\n",
"cm = gbm_perf.confusion_matrix()\n",
"print(cm.to_list())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}