{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import h2o\n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
H2O cluster uptime: 5 seconds 730 milliseconds
H2O cluster version: 3.7.0.99999
H2O cluster name: spIdea
H2O cluster total nodes: 1
H2O cluster total free memory: 12.44 GB
H2O cluster total cores: 8
H2O cluster allowed cores: 8
H2O cluster healthy: True
H2O Connection ip: 127.0.0.1
H2O Connection port: 54321
H2O Connection proxy: None
Python Version: 3.5.0
" ], "text/plain": [ "------------------------------ --------------------------\n", "H2O cluster uptime: 5 seconds 730 milliseconds\n", "H2O cluster version: 3.7.0.99999\n", "H2O cluster name: spIdea\n", "H2O cluster total nodes: 1\n", "H2O cluster total free memory: 12.44 GB\n", "H2O cluster total cores: 8\n", "H2O cluster allowed cores: 8\n", "H2O cluster healthy: True\n", "H2O Connection ip: 127.0.0.1\n", "H2O Connection port: 54321\n", "H2O Connection proxy:\n", "Python Version: 3.5.0\n", "------------------------------ --------------------------" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Connect to a pre-existing cluster\n", "h2o.init()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Parse Progress: [##################################################] 100%\n" ] } ], "source": [ "from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.\n", "\n", "df = h2o.import_file(path=_locate(\"smalldata/logreg/prostate.csv\"))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rows:380 Cols:9\n", "\n", "Chunk compression summary: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
chunk_typechunk_namecountcount_percentagesizesize_percentage
CBSBits111.111112 118 B2.4210093
C1N1-Byte Integers (w/o NAs)555.555557 2.2 KB45.958145
C22-Byte Integers111.111112 828 B16.9881
C2S2-Byte Fractions222.222223 1.6 KB34.632744
" ], "text/plain": [ "chunk_type chunk_name count count_percentage size size_percentage\n", "------------ ------------------------- ------- ------------------ ------ -----------------\n", "CBS Bits 1 11.1111 118 B 2.42101\n", "C1N 1-Byte Integers (w/o NAs) 5 55.5556 2.2 KB 45.9581\n", "C2 2-Byte Integers 1 11.1111 828 B 16.9881\n", "C2S 2-Byte Fractions 2 22.2222 1.6 KB 34.6327" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Frame distribution summary: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
sizenumber_of_rowsnumber_of_chunks_per_columnnumber_of_chunks
172.16.2.84:54321 4.8 KB380.01.09.0
mean 4.8 KB380.01.09.0
min 4.8 KB380.01.09.0
max 4.8 KB380.01.09.0
stddev 0 B0.00.00.0
total 4.8 KB380.01.09.0
" ], "text/plain": [ " size number_of_rows number_of_chunks_per_column number_of_chunks\n", "----------------- ------ ---------------- ----------------------------- ------------------\n", "172.16.2.84:54321 4.8 KB 380 1 9\n", "mean 4.8 KB 380 1 9\n", "min 4.8 KB 380 1 9\n", "max 4.8 KB 380 1 9\n", "stddev 0 B 0 0 0\n", "total 4.8 KB 380 1 9" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
ID CAPSULE AGE RACE DPROS DCAPS PSA VOL GLEASON
type int int int int int int real real int
mins 1.0 0.0 43.0 0.0 1.0 1.0 0.3 0.0 0.0
mean 190.5 0.402631578947368466.039473684210491.08684210526315722.27105263157894881.107894736842104815.40863157894737515.8129210526315736.3842105263157904
maxs 380.0 1.0 79.0 2.0 4.0 2.0 139.7000000000000297.60000000000001 9.0
sigma 109.840793879141270.49107433896305526.5270712691733110.30877325802527931.00010761815028610.310656449351493919.99757266856046 18.3476199672711751.0919533744261092
zeros 0 227 0 3 0 0 0 167 2
missing0 0 0 0 0 0 0 0 0
0 1.0 0.0 65.0 1.0 2.0 1.0 1.40000000000000010.0 6.0
1 2.0 0.0 72.0 1.0 3.0 2.0 6.7 0.0 7.0
2 3.0 0.0 70.0 1.0 1.0 2.0 4.9 0.0 6.0
3 4.0 0.0 76.0 2.0 2.0 1.0 51.2 20.0 7.0
4 5.0 0.0 69.0 1.0 1.0 1.0 12.3 55.9 6.0
5 6.0 1.0 71.0 1.0 3.0 2.0 3.30000000000000030.0 8.0
6 7.0 0.0 68.0 2.0 4.0 2.0 31.9000000000000020.0 7.0
7 8.0 0.0 61.0 2.0 4.0 2.0 66.7 27.2 7.0
8 9.0 0.0 69.0 1.0 1.0 1.0 3.9 24.0 7.0
9 10.0 0.0 68.0 2.0 1.0 2.0 13.0 0.0 6.0
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df.describe()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Remove ID from training frame\n", "train = df.drop(\"ID\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For VOL & GLEASON, a zero really means \"missing\"\n", "vol = train['VOL']\n", "vol[vol == 0] = None\n", "gle = train['GLEASON']\n", "gle[gle == 0] = None" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Convert CAPSULE to a logical factor\n", "train['CAPSULE'] = train['CAPSULE'].asfactor()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rows:380 Cols:8\n", "\n", "Chunk compression summary: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
chunk_typechunk_namecountcount_percentagesizesize_percentage
CBSBits112.5 118 B2.9164608
C1N1-Byte Integers (w/o NAs)562.5 2.2 KB55.363323
C2S2-Byte Fractions225.0 1.6 KB41.72022
" ], "text/plain": [ "chunk_type chunk_name count count_percentage size size_percentage\n", "------------ ------------------------- ------- ------------------ ------ -----------------\n", "CBS Bits 1 12.5 118 B 2.91646\n", "C1N 1-Byte Integers (w/o NAs) 5 62.5 2.2 KB 55.3633\n", "C2S 2-Byte Fractions 2 25 1.6 KB 41.7202" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Frame distribution summary: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
sizenumber_of_rowsnumber_of_chunks_per_columnnumber_of_chunks
172.16.2.84:54321 4.0 KB380.01.08.0
mean 4.0 KB380.01.08.0
min 4.0 KB380.01.08.0
max 4.0 KB380.01.08.0
stddev 0 B0.00.00.0
total 4.0 KB380.01.08.0
" ], "text/plain": [ " size number_of_rows number_of_chunks_per_column number_of_chunks\n", "----------------- ------ ---------------- ----------------------------- ------------------\n", "172.16.2.84:54321 4.0 KB 380 1 8\n", "mean 4.0 KB 380 1 8\n", "min 4.0 KB 380 1 8\n", "max 4.0 KB 380 1 8\n", "stddev 0 B 0 0 0\n", "total 4.0 KB 380 1 8" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
CAPSULE AGE RACE DPROS DCAPS PSA VOL GLEASON
type enum int int int int real real int
mins 0.0 43.0 0.0 1.0 1.0 0.3 0.0 0.0
mean 0.402631578947368466.039473684210491.08684210526315722.27105263157894881.107894736842104815.40863157894737515.8129210526315736.3842105263157904
maxs 1.0 79.0 2.0 4.0 2.0 139.7000000000000297.60000000000001 9.0
sigma 0.49107433896305526.5270712691733110.30877325802527931.00010761815028610.310656449351493919.99757266856046 18.3476199672711751.0919533744261092
zeros 227 0 3 0 0 0 167 2
missing0 0 0 0 0 0 0 0
0 0 65.0 1.0 2.0 1.0 1.40000000000000010.0 6.0
1 0 72.0 1.0 3.0 2.0 6.7 0.0 7.0
2 0 70.0 1.0 1.0 2.0 4.9 0.0 6.0
3 0 76.0 2.0 2.0 1.0 51.2 20.0 7.0
4 0 69.0 1.0 1.0 1.0 12.3 55.9 6.0
5 1 71.0 1.0 3.0 2.0 3.30000000000000030.0 8.0
6 0 68.0 2.0 4.0 2.0 31.9000000000000020.0 7.0
7 0 61.0 2.0 4.0 2.0 66.7 27.2 7.0
8 0 69.0 1.0 1.0 1.0 3.9 24.0 7.0
9 0 68.0 2.0 1.0 2.0 13.0 0.0 6.0
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# See that the data is ready\n", "train.describe()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "gbm Model Build Progress: [##################################################] 100%\n" ] } ], "source": [ "# Run GBM\n", "my_gbm = H2OGradientBoostingEstimator(distribution = \"bernoulli\", ntrees=50, learn_rate=0.1)\n", "\n", "my_gbm.train(x=list(range(1,train.ncol)), y=\"CAPSULE\", training_frame=train, validation_frame=train)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "ModelMetricsBinomial: gbm\n", "** Reported on test data. **\n", "\n", "MSE: 0.07584147467507414\n", "R^2: 0.6846762562816877\n", "LogLoss: 0.2744668128481441\n", "AUC: 0.9780311537243385\n", "Gini: 0.9560623074486769\n", "\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.4549496668047897: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
0216.011.00.0485 (11.0/227.0)
114.0139.00.0915 (14.0/153.0)
Total230.0150.00.0658 (25.0/380.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- --- --- ------- ------------\n", "0 216 11 0.0485 (11.0/227.0)\n", "1 14 139 0.0915 (14.0/153.0)\n", "Total 230 150 0.0658 (25.0/380.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.45494970.9174917149.0
max f20.30320100.9394314196.0
max f0point50.47283130.9244265146.0
max accuracy0.45494970.9342105149.0
max precision0.97479381.00.0
max absolute_MCC0.45494970.8629130149.0
max min_per_class_accuracy0.43739950.9215686156.0
" ], "text/plain": [ "metric threshold value idx\n", "-------------------------- ----------- -------- -----\n", "max f1 0.45495 0.917492 149\n", "max f2 0.303201 0.939431 196\n", "max f0point5 0.472831 0.924426 146\n", "max accuracy 0.45495 0.934211 149\n", "max precision 0.974794 1 0\n", "max absolute_MCC 0.45495 0.862913 149\n", "max min_per_class_accuracy 0.437399 0.921569 156" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Gains/Lift Table: Avg response rate: 40.26 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
grouplower_thresholdcumulative_data_fractionresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rateliftcumulative_liftgaincumulative_gain
10.94057500.051.01.00.12418300.12418302.48366012.4836601148.3660131148.3660131
20.89219800.11.01.00.12418300.24836602.48366012.4836601148.3660131148.3660131
30.82636950.151.01.00.12418300.37254902.48366012.4836601148.3660131148.3660131
40.75954600.20.94736840.98684210.11764710.49019612.35294122.4509804135.2941176145.0980392
50.70819260.251.00.98947370.12418300.61437912.48366012.4575163148.3660131145.7516340
60.63643120.30.89473680.97368420.11111110.72549022.22222222.4183007122.2222222141.8300654
70.54786510.350.68421050.93233080.08496730.81045751.69934642.315592969.9346405131.5592904
80.44998270.40.78947370.91447370.09803920.90849671.96078432.271241896.0784314127.1241830
90.39278700.450.21052630.83625730.02614380.93464050.52287582.0769789-47.7124183107.6978940
100.32076570.50.31578950.78421050.03921570.97385620.78431371.9477124-21.568627594.7712418
110.24257440.550.15789470.72727270.01960780.99346410.39215691.8062983-60.784313780.6298277
120.19776160.60.00.66666670.00.99346410.01.6557734-100.065.5773420
130.15869410.650.05263160.61943320.00653591.00.13071901.5384615-86.928104653.8461538
140.13535910.70.00.57518800.01.00.01.4285714-100.042.8571429
150.10941010.750.00.53684210.01.00.01.3333333-100.033.3333333
160.09238280.80.00.50328950.01.00.01.25-100.025.0
170.06659330.850.00.47368420.01.00.01.1764706-100.017.6470588
180.04779680.90.00.44736840.01.00.01.1111111-100.011.1111111
190.02769730.950.00.42382270.01.00.01.0526316-100.05.2631579
200.01255661.00.00.40263160.01.00.01.0-100.00.0
" ], "text/plain": [ " group lower_threshold cumulative_data_fraction response_rate cumulative_response_rate capture_rate cumulative_capture_rate lift cumulative_lift gain cumulative_gain\n", "-- ------- ----------------- -------------------------- --------------- -------------------------- -------------- ------------------------- -------- ----------------- -------- -----------------\n", " 1 0.940575 0.05 1 1 0.124183 0.124183 2.48366 2.48366 148.366 148.366\n", " 2 0.892198 0.1 1 1 0.124183 0.248366 2.48366 2.48366 148.366 148.366\n", " 3 0.82637 0.15 1 1 0.124183 0.372549 2.48366 2.48366 148.366 148.366\n", " 4 0.759546 0.2 0.947368 0.986842 0.117647 0.490196 2.35294 2.45098 135.294 145.098\n", " 5 0.708193 0.25 1 0.989474 0.124183 0.614379 2.48366 2.45752 148.366 145.752\n", " 6 0.636431 0.3 0.894737 0.973684 0.111111 0.72549 2.22222 2.4183 122.222 141.83\n", " 7 0.547865 0.35 0.684211 0.932331 0.0849673 0.810458 1.69935 2.31559 69.9346 131.559\n", " 8 0.449983 0.4 0.789474 0.914474 0.0980392 0.908497 1.96078 2.27124 96.0784 127.124\n", " 9 0.392787 0.45 0.210526 0.836257 0.0261438 0.934641 0.522876 2.07698 -47.7124 107.698\n", " 10 0.320766 0.5 0.315789 0.784211 0.0392157 0.973856 0.784314 1.94771 -21.5686 94.7712\n", " 11 0.242574 0.55 0.157895 0.727273 0.0196078 0.993464 0.392157 1.8063 -60.7843 80.6298\n", " 12 0.197762 0.6 0 0.666667 0 0.993464 0 1.65577 -100 65.5773\n", " 13 0.158694 0.65 0.0526316 0.619433 0.00653595 1 0.130719 1.53846 -86.9281 53.8462\n", " 14 0.135359 0.7 0 0.575188 0 1 0 1.42857 -100 42.8571\n", " 15 0.10941 0.75 0 0.536842 0 1 0 1.33333 -100 33.3333\n", " 16 0.0923828 0.8 0 0.503289 0 1 0 1.25 -100 25\n", " 17 0.0665933 0.85 0 0.473684 0 1 0 1.17647 -100 17.6471\n", " 18 0.0477968 0.9 0 0.447368 0 1 0 1.11111 -100 11.1111\n", " 19 0.0276973 0.95 0 0.423823 0 1 0 1.05263 -100 5.26316\n", " 20 0.0125566 1 0 0.402632 0 1 0 1 -100 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "my_gbm_metrics = my_gbm.model_performance(train)\n", "my_gbm_metrics.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.0" } }, "nbformat": 4, "nbformat_minor": 0 }