{ "cells": [ { "cell_type": "markdown", "id": "cde08d59", "metadata": {}, "source": [ "# Interpreting numeric split points in H2O POJO tree based models\n", "This notebook explains how to correctly interpret split points that you might see in POJOs of H2O tree based models.\n", "\n", "*Motivation*: we had seen there are users who are parsing H2O POJO and translating the Java code into another representation (SQL statements, ...). While we do not encourage users to use POJO in this particular use case we want to clarify how to interpret the numerical values correctly." ] }, { "cell_type": "markdown", "id": "7a4cb0c2", "metadata": {}, "source": [ "## Concept of floating point numbers in computers" ] }, { "cell_type": "markdown", "id": "1ee81250", "metadata": {}, "source": [ "Computers and software like H2O use floating-point representation of real numbers. In this representation sequences of bits (0/1) are used to store the number with a limited precision. In H2O we use mainly 32-bit and 64-bit floating point number representation.\n", "\n", "Lets take look at one example of a floating point number - 25.695312 and use 32-bit and 64-bit representation to compare the behavior." ] }, { "cell_type": "code", "execution_count": 247, "id": "8a86cfe9", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 248, "id": "963a3b21", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25.695312" ] }, "execution_count": 248, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f32 = np.float32(\"25.695312\")\n", "f32" ] }, { "cell_type": "code", "execution_count": 249, "id": "963c723b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25.695312" ] }, "execution_count": 249, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f64 = np.float64(\"25.695312\")\n", "f64" ] }, { "cell_type": "markdown", "id": "41bdf3b2", "metadata": {}, "source": [ "If we try to compare the numbers we will see they are not actually the same number" ] }, { "cell_type": "code", "execution_count": 250, "id": "87b1356e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 250, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f32 == f64" ] }, { "cell_type": "markdown", "id": "63724476", "metadata": {}, "source": [ "When two numbers are compared their precion is first adjusted to be the same. This typically means the lower precison number is converted to the higher precision representation. In this case `f32` will be converted to float64 representation. We can do the same thing explicitly:" ] }, { "cell_type": "code", "execution_count": 251, "id": "aee0f3f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 251, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.float64(f32) == f64" ] }, { "cell_type": "markdown", "id": "905da309", "metadata": {}, "source": [ "The comparison failed because the converted number is actually different" ] }, { "cell_type": "code", "execution_count": 252, "id": "95ac967d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25.6953125" ] }, "execution_count": 252, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.float64(f32)" ] }, { "cell_type": "markdown", "id": "e3cc6315", "metadata": {}, "source": [ "Notice the 7th decimal digit after the conversion." ] }, { "cell_type": "code", "execution_count": 253, "id": "3fa8c70a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.999999987376214e-07" ] }, "execution_count": 253, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.float64(f32) - f64" ] }, { "cell_type": "code", "execution_count": 254, "id": "fbdd1ed8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 254, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.float64(f32) > f64" ] }, { "cell_type": "markdown", "id": "004ffe3a", "metadata": {}, "source": [ "## Examining GBM POJO" ] }, { "cell_type": "markdown", "id": "317a0819", "metadata": {}, "source": [ "Understanding how computers compare numbers of different precision is critical for correctly interpretting split points in tree-based POJOs. Lets now train a simple GBM model." ] }, { "cell_type": "code", "execution_count": 255, "id": "b99fd407", "metadata": {}, "outputs": [], "source": [ "import h2o\n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator" ] }, { "cell_type": "code", "execution_count": 256, "id": "5648b643", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking whether there is an H2O instance running at http://localhost:54321 . connected.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
H2O_cluster_uptime:09 secs
H2O_cluster_timezone:America/New_York
H2O_data_parsing_timezone:UTC
H2O_cluster_version:3.35.0.99999
H2O_cluster_version_age:2 hours and 53 minutes
H2O_cluster_name:mkurka
H2O_cluster_total_nodes:1
H2O_cluster_free_memory:7.094 Gb
H2O_cluster_total_cores:16
H2O_cluster_allowed_cores:16
H2O_cluster_status:locked, healthy
H2O_connection_url:http://localhost:54321
H2O_connection_proxy:{\"http\": null, \"https\": null}
H2O_internal_security:False
H2O_API_Extensions:Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:3.8.2 final
" ], "text/plain": [ "-------------------------- ------------------------------------------------------------------\n", "H2O_cluster_uptime: 09 secs\n", "H2O_cluster_timezone: America/New_York\n", "H2O_data_parsing_timezone: UTC\n", "H2O_cluster_version: 3.35.0.99999\n", "H2O_cluster_version_age: 2 hours and 53 minutes\n", "H2O_cluster_name: mkurka\n", "H2O_cluster_total_nodes: 1\n", "H2O_cluster_free_memory: 7.094 Gb\n", "H2O_cluster_total_cores: 16\n", "H2O_cluster_allowed_cores: 16\n", "H2O_cluster_status: locked, healthy\n", "H2O_connection_url: http://localhost:54321\n", "H2O_connection_proxy: {\"http\": null, \"https\": null}\n", "H2O_internal_security: False\n", "H2O_API_Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4\n", "Python_version: 3.8.2 final\n", "-------------------------- ------------------------------------------------------------------" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Connect to a pre-existing cluster\n", "h2o.init()" ] }, { "cell_type": "code", "execution_count": 257, "id": "75b09b65", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%\n" ] } ], "source": [ "from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.\n", "\n", "df = h2o.upload_file(path=_locate(\"smalldata/logreg/prostate.csv\"))" ] }, { "cell_type": "code", "execution_count": 258, "id": "e1135ab7", "metadata": {}, "outputs": [], "source": [ "# Remove ID from training frame\n", "train = df.drop(\"ID\")" ] }, { "cell_type": "code", "execution_count": 259, "id": "12b4a681", "metadata": {}, "outputs": [], "source": [ "# For VOL & GLEASON, a zero really means \"missing\"\n", "vol = train['VOL']\n", "vol[vol == 0] = None\n", "gle = train['GLEASON']\n", "gle[gle == 0] = None" ] }, { "cell_type": "code", "execution_count": 260, "id": "8a2f7ba7", "metadata": {}, "outputs": [], "source": [ "# Convert CAPSULE to a logical factor\n", "train['CAPSULE'] = train['CAPSULE'].asfactor()" ] }, { "cell_type": "code", "execution_count": 261, "id": "1cd1a4eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Model Build progress: |██████████████████████████████████████████████████████| (done) 100%\n", "Model Details\n", "=============\n", "H2OGradientBoostingEstimator : Gradient Boosting Machine\n", "Model Key: GBM_model_python_1636137917875_1\n", "\n", "\n", "Model Summary: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
number_of_treesnumber_of_internal_treesmodel_size_in_bytesmin_depthmax_depthmean_depthmin_leavesmax_leavesmean_leaves
01.01.0360.05.05.05.024.024.024.0
\n", "
" ], "text/plain": [ " number_of_trees number_of_internal_trees model_size_in_bytes \\\n", "0 1.0 1.0 360.0 \n", "\n", " min_depth max_depth mean_depth min_leaves max_leaves mean_leaves \n", "0 5.0 5.0 5.0 24.0 24.0 24.0 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "ModelMetricsBinomial: gbm\n", "** Reported on train data. **\n", "\n", "MSE: 0.22019689456071448\n", "RMSE: 0.4692514193486414\n", "LogLoss: 0.6319753099030868\n", "Mean Per-Class Error: 0.20582476749877632\n", "AUC: 0.8816907085888687\n", "AUCPR: 0.8515845076604194\n", "Gini: 0.7633814171777373\n", "\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.4008312811161997: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01ErrorRate
00176.051.00.2247(51.0/227.0)
1129.0124.00.1895(29.0/153.0)
2Total205.0175.00.2105(80.0/380.0)
\n", "
" ], "text/plain": [ " 0 1 Error Rate\n", "0 0 176.0 51.0 0.2247 (51.0/227.0)\n", "1 1 29.0 124.0 0.1895 (29.0/153.0)\n", "2 Total 205.0 175.0 0.2105 (80.0/380.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Maximum Metrics: Maximum metrics at their respective thresholds\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
metricthresholdvalueidx
0max f10.4008310.75609810.0
1max f20.3798400.83148616.0
2max f0point50.4292930.7838666.0
3max accuracy0.4292930.8078956.0
4max precision0.4635281.0000000.0
5max recall0.3727741.00000018.0
6max specificity0.4635281.0000000.0
7max absolute_mcc0.4124060.5959587.0
8max min_per_class_accuracy0.4040360.7777789.0
9max mean_per_class_accuracy0.4040360.7941759.0
10max tns0.463528227.0000000.0
11max fns0.463528121.0000000.0
12max fps0.363105227.00000019.0
13max tps0.372774153.00000018.0
14max tnr0.4635281.0000000.0
15max fnr0.4635280.7908500.0
16max fpr0.3631051.00000019.0
17max tpr0.3727741.00000018.0
\n", "
" ], "text/plain": [ " metric threshold value idx\n", "0 max f1 0.400831 0.756098 10.0\n", "1 max f2 0.379840 0.831486 16.0\n", "2 max f0point5 0.429293 0.783866 6.0\n", "3 max accuracy 0.429293 0.807895 6.0\n", "4 max precision 0.463528 1.000000 0.0\n", "5 max recall 0.372774 1.000000 18.0\n", "6 max specificity 0.463528 1.000000 0.0\n", "7 max absolute_mcc 0.412406 0.595958 7.0\n", "8 max min_per_class_accuracy 0.404036 0.777778 9.0\n", "9 max mean_per_class_accuracy 0.404036 0.794175 9.0\n", "10 max tns 0.463528 227.000000 0.0\n", "11 max fns 0.463528 121.000000 0.0\n", "12 max fps 0.363105 227.000000 19.0\n", "13 max tps 0.372774 153.000000 18.0\n", "14 max tnr 0.463528 1.000000 0.0\n", "15 max fnr 0.463528 0.790850 0.0\n", "16 max fpr 0.363105 1.000000 19.0\n", "17 max tpr 0.372774 1.000000 18.0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Gains/Lift Table: Avg response rate: 40.26 %, avg score: 40.30 %\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratescorecumulative_response_ratecumulative_scorecapture_ratecumulative_capture_rategaincumulative_gainkolmogorov_smirnov
010.0842110.4635282.4836602.4836601.0000000.4635281.0000000.4635280.2091500.209150148.366013148.3660130.209150
120.1289470.4574522.3375622.4329730.9411760.4574520.9795920.4614200.1045750.313725133.756248143.2973190.309320
230.1578950.4447912.0320862.3594770.8181820.4447910.9500000.4583720.0588240.372549103.208556135.9477120.359333
340.2184210.4326921.8357492.2143480.7391300.4366930.8915660.4523640.1111110.48366083.574879121.4347590.444013
450.3000000.4296221.6824792.0697170.6774190.4303890.8333330.4463890.1372550.62091568.247944106.9716780.537215
560.4263160.4040361.2418301.8244170.5000000.4124420.7345680.4363300.1568630.77777824.18300782.4417010.588350
670.5210530.3924120.8278871.6432300.3333330.3957280.6616160.4289480.0784310.856209-17.21132964.3229680.561055
780.6605260.3839490.5623381.4149940.2264150.3851450.5697210.4196990.0784310.934641-43.76618641.4993620.458870
890.7631580.3798400.4457851.2846520.1794870.3805330.5172410.4144320.0457520.980392-55.42148528.4651790.363652
9100.8131580.3732850.2614381.2217360.1052630.3732850.4919090.4119020.0130720.993464-73.85620922.1735730.301834
10111.0000000.3631050.0349811.0000000.0140850.3644670.4026320.4030390.0065361.000000-96.5018870.0000000.000000
\n", "
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift \\\n", "0 1 0.084211 0.463528 2.483660 \n", "1 2 0.128947 0.457452 2.337562 \n", "2 3 0.157895 0.444791 2.032086 \n", "3 4 0.218421 0.432692 1.835749 \n", "4 5 0.300000 0.429622 1.682479 \n", "5 6 0.426316 0.404036 1.241830 \n", "6 7 0.521053 0.392412 0.827887 \n", "7 8 0.660526 0.383949 0.562338 \n", "8 9 0.763158 0.379840 0.445785 \n", "9 10 0.813158 0.373285 0.261438 \n", "10 11 1.000000 0.363105 0.034981 \n", "\n", " cumulative_lift response_rate score cumulative_response_rate \\\n", "0 2.483660 1.000000 0.463528 1.000000 \n", "1 2.432973 0.941176 0.457452 0.979592 \n", "2 2.359477 0.818182 0.444791 0.950000 \n", "3 2.214348 0.739130 0.436693 0.891566 \n", "4 2.069717 0.677419 0.430389 0.833333 \n", "5 1.824417 0.500000 0.412442 0.734568 \n", "6 1.643230 0.333333 0.395728 0.661616 \n", "7 1.414994 0.226415 0.385145 0.569721 \n", "8 1.284652 0.179487 0.380533 0.517241 \n", "9 1.221736 0.105263 0.373285 0.491909 \n", "10 1.000000 0.014085 0.364467 0.402632 \n", "\n", " cumulative_score capture_rate cumulative_capture_rate gain \\\n", "0 0.463528 0.209150 0.209150 148.366013 \n", "1 0.461420 0.104575 0.313725 133.756248 \n", "2 0.458372 0.058824 0.372549 103.208556 \n", "3 0.452364 0.111111 0.483660 83.574879 \n", "4 0.446389 0.137255 0.620915 68.247944 \n", "5 0.436330 0.156863 0.777778 24.183007 \n", "6 0.428948 0.078431 0.856209 -17.211329 \n", "7 0.419699 0.078431 0.934641 -43.766186 \n", "8 0.414432 0.045752 0.980392 -55.421485 \n", "9 0.411902 0.013072 0.993464 -73.856209 \n", "10 0.403039 0.006536 1.000000 -96.501887 \n", "\n", " cumulative_gain kolmogorov_smirnov \n", "0 148.366013 0.209150 \n", "1 143.297319 0.309320 \n", "2 135.947712 0.359333 \n", "3 121.434759 0.444013 \n", "4 106.971678 0.537215 \n", "5 82.441701 0.588350 \n", "6 64.322968 0.561055 \n", "7 41.499362 0.458870 \n", "8 28.465179 0.363652 \n", "9 22.173573 0.301834 \n", "10 0.000000 0.000000 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Scoring History: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
timestampdurationnumber_of_treestraining_rmsetraining_loglosstraining_auctraining_pr_auctraining_lifttraining_classification_error
02021-11-05 14:45:280.022 sec0.00.4904280.6740640.5000000.4026321.000000.597368
12021-11-05 14:45:280.182 sec1.00.4692510.6319750.8816910.8515852.483660.210526
\n", "
" ], "text/plain": [ " timestamp duration number_of_trees training_rmse \\\n", "0 2021-11-05 14:45:28 0.022 sec 0.0 0.490428 \n", "1 2021-11-05 14:45:28 0.182 sec 1.0 0.469251 \n", "\n", " training_logloss training_auc training_pr_auc training_lift \\\n", "0 0.674064 0.500000 0.402632 1.00000 \n", "1 0.631975 0.881691 0.851585 2.48366 \n", "\n", " training_classification_error \n", "0 0.597368 \n", "1 0.210526 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Variable Importances: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
variablerelative_importancescaled_importancepercentage
0GLEASON20.1253201.0000000.496931
1PSA8.1381510.4043740.200946
2VOL6.4161120.3188080.158426
3DPROS5.8196490.2891700.143698
4AGE0.0000000.0000000.000000
5RACE0.0000000.0000000.000000
6DCAPS0.0000000.0000000.000000
\n", "
" ], "text/plain": [ " variable relative_importance scaled_importance percentage\n", "0 GLEASON 20.125320 1.000000 0.496931\n", "1 PSA 8.138151 0.404374 0.200946\n", "2 VOL 6.416112 0.318808 0.158426\n", "3 DPROS 5.819649 0.289170 0.143698\n", "4 AGE 0.000000 0.000000 0.000000\n", "5 RACE 0.000000 0.000000 0.000000\n", "6 DCAPS 0.000000 0.000000 0.000000" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 261, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run GBM\n", "my_gbm = H2OGradientBoostingEstimator(ntrees=1, seed=1234)\n", "\n", "my_gbm.train(y=\"CAPSULE\", training_frame=train)" ] }, { "cell_type": "code", "execution_count": 262, "id": "e4828673", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/*\n", " Licensed under the Apache License, Version 2.0\n", " http://www.apache.org/licenses/LICENSE-2.0.html\n", "\n", " AUTOGENERATED BY H2O at 2021-11-05T14:45:28.555-04:00\n", " 3.35.0.99999\n", " \n", " Standalone prediction code with sample test data for GBMModel named GBM_model_python_1636137917875_1\n", "\n", " How to download, compile and execute:\n", " mkdir tmpdir\n", " cd tmpdir\n", " curl http://192.168.86.229:54321/3/h2o-genmodel.jar > h2o-genmodel.jar\n", " curl http://192.168.86.229:54321/3/Models.java/GBM_model_python_1636137917875_1 > GBM_model_python_1636137917875_1.java\n", " javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m GBM_model_python_1636137917875_1.java\n", "\n", " (Note: Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)\n", "*/\n", "import java.util.Map;\n", "import hex.genmodel.GenModel;\n", "import hex.genmodel.annotations.ModelPojo;\n", "\n", "@ModelPojo(name=\"GBM_model_python_1636137917875_1\", algorithm=\"gbm\")\n", "public class GBM_model_python_1636137917875_1 extends GenModel {\n", " public hex.ModelCategory getModelCategory() { return hex.ModelCategory.Binomial; }\n", "\n", " public boolean isSupervised() { return true; }\n", " public int nfeatures() { return 7; }\n", " public int nclasses() { return 2; }\n", "\n", " // Names of columns used by model.\n", " public static final String[] NAMES = NamesHolder_GBM_model_python_1636137917875_1.VALUES;\n", " // Number of output classes included in training data response column.\n", " public static final int NCLASSES = 2;\n", "\n", " // Column domains. The last array contains domain of response column.\n", " public static final String[][] DOMAINS = new String[][] {\n", " /* AGE */ null,\n", " /* RACE */ null,\n", " /* DPROS */ null,\n", " /* DCAPS */ null,\n", " /* PSA */ null,\n", " /* VOL */ null,\n", " /* GLEASON */ null,\n", " /* CAPSULE */ GBM_model_python_1636137917875_1_ColInfo_7.VALUES\n", " };\n", " // Prior class distribution\n", " public static final double[] PRIOR_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", " // Class distribution used for model building\n", " public static final double[] MODEL_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", "\n", " public GBM_model_python_1636137917875_1() { super(NAMES,DOMAINS,\"CAPSULE\"); }\n", " public String getUUID() { return Long.toString(4988040225257658559L); }\n", "\n", " // Pass in data in a double[], pre-aligned to the Model's requirements.\n", " // Jam predictions into the preds[] array; preds[0] is reserved for the\n", " // main prediction (class for classifiers or value for regression),\n", " // and remaining columns hold a probability distribution for classifiers.\n", " public final double[] score0( double[] data, double[] preds ) {\n", " java.util.Arrays.fill(preds,0);\n", " GBM_model_python_1636137917875_1_Forest_0.score0(data,preds);\n", " preds[2] = preds[1] + -0.3945120960889672;\n", " preds[2] = 1./(1. + Math.min(1e19, Math.exp(-(preds[2]))));\n", " preds[1] = 1.0-preds[2];\n", " preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, 0.4008312811161997);\n", " return preds;\n", " }\n", "}\n", "// The class representing training column names\n", "class NamesHolder_GBM_model_python_1636137917875_1 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[7];\n", " static {\n", " NamesHolder_GBM_model_python_1636137917875_1_0.fill(VALUES);\n", " }\n", " static final class NamesHolder_GBM_model_python_1636137917875_1_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"AGE\";\n", " sa[1] = \"RACE\";\n", " sa[2] = \"DPROS\";\n", " sa[3] = \"DCAPS\";\n", " sa[4] = \"PSA\";\n", " sa[5] = \"VOL\";\n", " sa[6] = \"GLEASON\";\n", " }\n", " }\n", "}\n", "// The class representing column CAPSULE\n", "class GBM_model_python_1636137917875_1_ColInfo_7 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[2];\n", " static {\n", " GBM_model_python_1636137917875_1_ColInfo_7_0.fill(VALUES);\n", " }\n", " static final class GBM_model_python_1636137917875_1_ColInfo_7_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"0\";\n", " sa[1] = \"1\";\n", " }\n", " }\n", "}\n", "\n", "class GBM_model_python_1636137917875_1_Forest_0 {\n", " public static void score0(double[] fdata, double[] preds) {\n", " preds[1] += GBM_model_python_1636137917875_1_Tree_0_class_0.score0(fdata);\n", " }\n", "}\n", "class GBM_model_python_1636137917875_1_Tree_0_class_0 {\n", " static final double score0(double[] data) {\n", " double pred = (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 6.5f ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5f ? \n", " (data[6 /* GLEASON */] < 5.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.44375f ? \n", " -0.16740088f : \n", " (data[5 /* VOL */] < 35.319237f ? \n", " -0.0842475f : \n", " -0.16740088f)) : \n", " (data[2 /* DPROS */] < 1.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 25.695312f ? \n", " -0.09571693f : \n", " -0.16740088f) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 23.359375f ? \n", " -0.07830798f : \n", " 0.005835324f))) : \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 6.645508f ? \n", " (data[5 /* VOL */] < 4.4484377f ? \n", " -0.007490538f : \n", " (data[4 /* PSA */] < 3.6390624f ? \n", " -0.16740088f : \n", " -0.1258242f)) : \n", " (data[6 /* GLEASON */] < 5.5f ? \n", " -0.05400991f : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.1625f ? \n", " 0.17277204f : \n", " 0.040482566f)))) : \n", " (data[4 /* PSA */] < 14.730078f ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 7.1585937f ? \n", " (data[4 /* PSA */] < 7.995f ? \n", " 0.10977705f : \n", " -0.039472606f) : \n", " -0.12363595f) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 17.264843f ? \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 8.267187f ? \n", " 0.1524198f : \n", " -0.042670812f) : \n", " 0.22390914f)) : \n", " (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 7.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 24.657812f ? \n", " (data[4 /* PSA */] < 18.55625f ? \n", " 0.24836601f : \n", " 0.11424766f) : \n", " 0.01078493f) : \n", " (data[4 /* PSA */] < 22.60625f ? \n", " 0.12363594f : \n", " 0.24836601f))));\n", " return pred;\n", " } // constant pool size = 94B, number of visited nodes = 23, static init size = 0B\n", "}\n", "\n", "\n", "\n" ] } ], "source": [ "# Get the POJO\n", "my_gbm.download_pojo()" ] }, { "cell_type": "markdown", "id": "b975f47d", "metadata": {}, "source": [ "Please take a close look at the POJO code, you should see statements like this one\n", "```\n", "Double.isNaN(data[5]) || data[5 /* VOL */] < 25.695312f ? -0.09571693f : -0.16740088f\n", "```\n", "This code represents one split decision in a GBM tree. `data` represents a single input row. The split decision is looking a column `VOL` to decide whether the observation should go to the left sub-tree or go right based on the value of element 5 in the `data` array.\n", "\n", "It is important to notice that `data` is defined as a double array:\n", "```\n", "double[] data\n", "```\n", "This means data is represented by 64-bit floating point numbers.\n", "The split point itself is however outputted in 32-bit precision. In java code we capture this fact by using `f` suffix in the number representation, eg.: `25.695312f`.\n", "\n", "This means we have the same scenario as outlined in the beginning of this notebook - we are comparing numbers with two different precisions and we need to pay attention to how the numbers are interpreted." ] }, { "cell_type": "code", "execution_count": 263, "id": "bbd0019f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25.695312" ] }, "execution_count": 263, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = np.array([0, 0, 0, 0, 0, np.float64(25.695312)])\n", "data[5]" ] }, { "cell_type": "markdown", "id": "531f6131", "metadata": {}, "source": [ "The java comparison rewritten to Python would look like this:" ] }, { "cell_type": "code", "execution_count": 264, "id": "876fb69b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 264, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[5] < np.float32(25.695312)" ] }, { "cell_type": "markdown", "id": "13857e4a", "metadata": {}, "source": [ "This means that observation represented by array `data` should got the left subtree of the current node. If we ignored the fact that the split point is using 32-bit precision and considered it as 64-bit precision, we would miclassify the observation to left sub-tree." ] }, { "cell_type": "code", "execution_count": 265, "id": "b82a0fd4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 265, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[5] < np.float64(25.695312)" ] }, { "cell_type": "markdown", "id": "abe6f5ad", "metadata": {}, "source": [ "## Expert options" ] }, { "cell_type": "markdown", "id": "1359db5f", "metadata": {}, "source": [ "### Forcing split point in POJO to be written in 64-bit precision" ] }, { "cell_type": "markdown", "id": "591ca716", "metadata": {}, "source": [ "H2O allows users to modify the POJO output by setting a property `sys.ai.h2o.java.output.doubles`. Setting this property to `true` will cause the POJO generator to output split point in 64-bit precision (doubles) instead of the default 32-bit precision.\n", "\n", "We can set this property even on a running H2O instance by invoking a rapids expression." ] }, { "cell_type": "code", "execution_count": 266, "id": "ef731eed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Old values of sys.ai.h2o.java.output.doubles (per node): null'" ] }, "execution_count": 266, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h2o.rapids(\"(setproperty \\\"{}\\\" \\\"{}\\\")\".format(\"sys.ai.h2o.java.output.doubles\", \"true\"))[\"string\"]" ] }, { "cell_type": "code", "execution_count": 267, "id": "98b0c4bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/*\n", " Licensed under the Apache License, Version 2.0\n", " http://www.apache.org/licenses/LICENSE-2.0.html\n", "\n", " AUTOGENERATED BY H2O at 2021-11-05T14:45:28.619-04:00\n", " 3.35.0.99999\n", " \n", " Standalone prediction code with sample test data for GBMModel named GBM_model_python_1636137917875_1\n", "\n", " How to download, compile and execute:\n", " mkdir tmpdir\n", " cd tmpdir\n", " curl http://192.168.86.229:54321/3/h2o-genmodel.jar > h2o-genmodel.jar\n", " curl http://192.168.86.229:54321/3/Models.java/GBM_model_python_1636137917875_1 > GBM_model_python_1636137917875_1.java\n", " javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m GBM_model_python_1636137917875_1.java\n", "\n", " (Note: Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)\n", "*/\n", "import java.util.Map;\n", "import hex.genmodel.GenModel;\n", "import hex.genmodel.annotations.ModelPojo;\n", "\n", "@ModelPojo(name=\"GBM_model_python_1636137917875_1\", algorithm=\"gbm\")\n", "public class GBM_model_python_1636137917875_1 extends GenModel {\n", " public hex.ModelCategory getModelCategory() { return hex.ModelCategory.Binomial; }\n", "\n", " public boolean isSupervised() { return true; }\n", " public int nfeatures() { return 7; }\n", " public int nclasses() { return 2; }\n", "\n", " // Names of columns used by model.\n", " public static final String[] NAMES = NamesHolder_GBM_model_python_1636137917875_1.VALUES;\n", " // Number of output classes included in training data response column.\n", " public static final int NCLASSES = 2;\n", "\n", " // Column domains. The last array contains domain of response column.\n", " public static final String[][] DOMAINS = new String[][] {\n", " /* AGE */ null,\n", " /* RACE */ null,\n", " /* DPROS */ null,\n", " /* DCAPS */ null,\n", " /* PSA */ null,\n", " /* VOL */ null,\n", " /* GLEASON */ null,\n", " /* CAPSULE */ GBM_model_python_1636137917875_1_ColInfo_7.VALUES\n", " };\n", " // Prior class distribution\n", " public static final double[] PRIOR_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", " // Class distribution used for model building\n", " public static final double[] MODEL_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", "\n", " public GBM_model_python_1636137917875_1() { super(NAMES,DOMAINS,\"CAPSULE\"); }\n", " public String getUUID() { return Long.toString(4988040225257658559L); }\n", "\n", " // Pass in data in a double[], pre-aligned to the Model's requirements.\n", " // Jam predictions into the preds[] array; preds[0] is reserved for the\n", " // main prediction (class for classifiers or value for regression),\n", " // and remaining columns hold a probability distribution for classifiers.\n", " public final double[] score0( double[] data, double[] preds ) {\n", " java.util.Arrays.fill(preds,0);\n", " GBM_model_python_1636137917875_1_Forest_0.score0(data,preds);\n", " preds[2] = preds[1] + -0.3945120960889672;\n", " preds[2] = 1./(1. + Math.min(1e19, Math.exp(-(preds[2]))));\n", " preds[1] = 1.0-preds[2];\n", " preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, 0.4008312811161997);\n", " return preds;\n", " }\n", "}\n", "// The class representing training column names\n", "class NamesHolder_GBM_model_python_1636137917875_1 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[7];\n", " static {\n", " NamesHolder_GBM_model_python_1636137917875_1_0.fill(VALUES);\n", " }\n", " static final class NamesHolder_GBM_model_python_1636137917875_1_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"AGE\";\n", " sa[1] = \"RACE\";\n", " sa[2] = \"DPROS\";\n", " sa[3] = \"DCAPS\";\n", " sa[4] = \"PSA\";\n", " sa[5] = \"VOL\";\n", " sa[6] = \"GLEASON\";\n", " }\n", " }\n", "}\n", "// The class representing column CAPSULE\n", "class GBM_model_python_1636137917875_1_ColInfo_7 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[2];\n", " static {\n", " GBM_model_python_1636137917875_1_ColInfo_7_0.fill(VALUES);\n", " }\n", " static final class GBM_model_python_1636137917875_1_ColInfo_7_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"0\";\n", " sa[1] = \"1\";\n", " }\n", " }\n", "}\n", "\n", "class GBM_model_python_1636137917875_1_Forest_0 {\n", " public static void score0(double[] fdata, double[] preds) {\n", " preds[1] += GBM_model_python_1636137917875_1_Tree_0_class_0.score0(fdata);\n", " }\n", "}\n", "class GBM_model_python_1636137917875_1_Tree_0_class_0 {\n", " static final double score0(double[] data) {\n", " double pred = (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 6.5 ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5 ? \n", " (data[6 /* GLEASON */] < 5.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.443750381469727 ? \n", " -0.16740088164806366 : \n", " (data[5 /* VOL */] < 35.319236755371094 ? \n", " -0.08424749970436096 : \n", " -0.16740088164806366)) : \n", " (data[2 /* DPROS */] < 1.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 25.6953125 ? \n", " -0.0957169309258461 : \n", " -0.16740088164806366) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 23.359375 ? \n", " -0.07830797880887985 : \n", " 0.005835324060171843))) : \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 6.6455078125 ? \n", " (data[5 /* VOL */] < 4.448437690734863 ? \n", " -0.007490538060665131 : \n", " (data[4 /* PSA */] < 3.6390624046325684 ? \n", " -0.16740088164806366 : \n", " -0.1258241981267929)) : \n", " (data[6 /* GLEASON */] < 5.5 ? \n", " -0.05400991067290306 : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.162500381469727 ? \n", " 0.17277203500270844 : \n", " 0.04048256576061249)))) : \n", " (data[4 /* PSA */] < 14.730077743530273 ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 7.158593654632568 ? \n", " (data[4 /* PSA */] < 7.994999885559082 ? \n", " 0.1097770482301712 : \n", " -0.03947260603308678) : \n", " -0.12363594770431519) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 17.264842987060547 ? \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 8.267187118530273 ? \n", " 0.1524198055267334 : \n", " -0.04267081245779991) : \n", " 0.2239091396331787)) : \n", " (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 7.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 24.657812118530273 ? \n", " (data[4 /* PSA */] < 18.556249618530273 ? \n", " 0.24836601316928864 : \n", " 0.11424765735864639) : \n", " 0.010784929618239403) : \n", " (data[4 /* PSA */] < 22.606250762939453 ? \n", " 0.12363594025373459 : \n", " 0.24836601316928864))));\n", " return pred;\n", " } // constant pool size = 94B, number of visited nodes = 23, static init size = 0B\n", "}\n", "\n", "\n", "\n" ] } ], "source": [ "my_gbm.download_pojo()" ] }, { "cell_type": "markdown", "id": "1339df44", "metadata": {}, "source": [ "In the modified POJO output you can now see the original split is coded as\n", "```\n", "Double.isNaN(data[5]) || data[5 /* VOL */] < 25.6953125 ? -0.0957169309258461 : -0.16740088164806366\n", "```\n", "Notice the last decimal place and observer there is now no suffix `f` at the end of the number. Compare it to the original version\n", "```\n", "Double.isNaN(data[5]) || data[5 /* VOL */] < 25.695312f ? -0.09571693f : -0.16740088f\n", "```" ] }, { "cell_type": "markdown", "id": "be18742b", "metadata": {}, "source": [ "The 64-bit precision output might be more natural to users for understanding what the POJO is doing when deciding how should a given observation traverse the tree." ] }, { "cell_type": "markdown", "id": "14a15cc0", "metadata": {}, "source": [ "### Convert existing MOJO into POJO with 64-bit precision number representation" ] }, { "cell_type": "markdown", "id": "00bfe93f", "metadata": {}, "source": [ "Suppose we already have a MOJO model that was created by an older H2O version and we want to see how would the POJO look like with numbers represented in 64-bits.\n", "\n", "For this use case H2O provides a conversion tool `MojoConvertTool` as a part of the `h2o.jar`." ] }, { "cell_type": "code", "execution_count": 268, "id": "9a271b59", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/Users/mkurka/git/h2o/h2o-3/GBM_model_python_1636137917875_1.zip'" ] }, "execution_count": 268, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mojo_path = my_gbm.download_mojo()\n", "mojo_path" ] }, { "cell_type": "code", "execution_count": 269, "id": "972627a8", "metadata": {}, "outputs": [], "source": [ "# Find h2o.jar (this is using internal functions)\n", "from h2o.backend import H2OLocalServer\n", "h2o_jar = H2OLocalServer()._find_jar()" ] }, { "cell_type": "code", "execution_count": 270, "id": "e8e45507", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "java -cp h2o.jar water.tools.MojoConvertTool source_mojo.zip target_pojo.java\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 270, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Invoke MojoConvertTool without arguments to print out usage instructions\n", "import subprocess\n", "subprocess.call([\"java\", \"-cp\", h2o_jar, \"water.tools.MojoConvertTool\"], stderr=subprocess.STDOUT, shell=False)" ] }, { "cell_type": "code", "execution_count": 271, "id": "6c905c1c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Starting local H2O instance to facilitate MOJO to POJO conversion.\n", "\n", "14:45:29.416 [main] INFO hex.tree.xgboost.util.NativeLibrary - Loaded library from lib/osx_64/libxgboost4j_minimal.dylib (/var/folders/v1/fkjmcbkd11v2mrm4dm6345ym0000gn/T/libxgboost4j_minimal6279070988842798503.dylib)\n", "11-05 14:45:29.543 127.0.0.1:54321 79164 main INFO water.default: ----- H2O started -----\n", "11-05 14:45:29.544 127.0.0.1:54321 79164 main INFO water.default: Build git branch: master\n", "11-05 14:45:29.544 127.0.0.1:54321 79164 main INFO water.default: Build git hash: b9ba1af5f07c6dbc6369e41113ea43947109e054\n", "11-05 14:45:29.544 127.0.0.1:54321 79164 main INFO water.default: Build git describe: jenkins-master-5625-7-gb9ba1af5f0\n", "11-05 14:45:29.544 127.0.0.1:54321 79164 main INFO water.default: Build project version: 3.35.0.99999\n", "11-05 14:45:29.544 127.0.0.1:54321 79164 main INFO water.default: Build age: 2 hours and 53 minutes\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Built by: 'mkurka'\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Built on: '2021-11-05 11:51:34'\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Found H2O Core extensions: [XGBoost, KrbStandalone]\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Processed H2O arguments: [-disable_web, -ip, localhost, -disable_net]\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Java availableProcessors: 16\n", "11-05 14:45:29.545 127.0.0.1:54321 79164 main INFO water.default: Java heap totalMemory: 491.0 MB\n", "11-05 14:45:29.546 127.0.0.1:54321 79164 main INFO water.default: Java heap maxMemory: 7.11 GB\n", "11-05 14:45:29.546 127.0.0.1:54321 79164 main INFO water.default: Java version: Java 1.8.0_311 (from Oracle Corporation)\n", "11-05 14:45:29.546 127.0.0.1:54321 79164 main INFO water.default: JVM launch parameters: []\n", "11-05 14:45:29.546 127.0.0.1:54321 79164 main INFO water.default: JVM process id: 79164@michals-mbp.lan\n", "11-05 14:45:29.546 127.0.0.1:54321 79164 main INFO water.default: OS version: Mac OS X 10.16 (x86_64)\n", "11-05 14:45:29.547 127.0.0.1:54321 79164 main INFO water.default: Machine physical memory: 32.00 GB\n", "11-05 14:45:29.548 127.0.0.1:54321 79164 main INFO water.default: Machine locale: en_US\n", "11-05 14:45:29.549 127.0.0.1:54321 79164 main INFO water.default: X-h2o-cluster-id: 1636137928927\n", "11-05 14:45:29.549 127.0.0.1:54321 79164 main INFO water.default: User name: 'mkurka'\n", "11-05 14:45:29.549 127.0.0.1:54321 79164 main INFO water.default: IPv6 stack selected: false\n", "11-05 14:45:29.549 127.0.0.1:54321 79164 main INFO water.default: H2O node running in unencrypted mode.\n", "11-05 14:45:30.081 127.0.0.1:54321 79164 main INFO water.default: Kerberos not configured\n", "11-05 14:45:30.081 127.0.0.1:54321 79164 main INFO water.default: Log dir: '/tmp/h2o-mkurka/h2ologs'\n", "11-05 14:45:30.081 127.0.0.1:54321 79164 main INFO water.default: Cur dir: '/Users/mkurka/git/h2o/h2o-3'\n", "11-05 14:45:30.087 127.0.0.1:54321 79164 main INFO water.default: Subsystem for distributed import from HTTP/HTTPS successfully initialized\n", "11-05 14:45:30.088 127.0.0.1:54321 79164 main INFO water.default: HDFS subsystem successfully initialized\n", "11-05 14:45:30.090 127.0.0.1:54321 79164 main INFO water.default: S3 subsystem successfully initialized\n", "11-05 14:45:30.102 127.0.0.1:54321 79164 main INFO water.default: GCS subsystem successfully initialized\n", "11-05 14:45:30.103 127.0.0.1:54321 79164 main INFO water.default: Flow dir: '/Users/mkurka/h2oflows'\n", "11-05 14:45:30.108 127.0.0.1:54321 79164 main INFO water.default: Cloud of size 1 formed [localhost/127.0.0.1:54321]\n", "11-05 14:45:30.116 127.0.0.1:54321 79164 main INFO water.default: Registered parsers: [GUESS, ARFF, XLS, SVMLight, AVRO, PARQUET, CSV]\n", "11-05 14:45:30.117 127.0.0.1:54321 79164 main INFO water.default: XGBoost extension initialized\n", "11-05 14:45:30.117 127.0.0.1:54321 79164 main INFO water.default: KrbStandalone extension initialized\n", "11-05 14:45:30.117 127.0.0.1:54321 79164 main INFO water.default: Registered 2 core extensions in: 377ms\n", "11-05 14:45:30.117 127.0.0.1:54321 79164 main INFO water.default: Registered H2O core extensions: [XGBoost, KrbStandalone]\n", "11-05 14:45:30.305 127.0.0.1:54321 79164 main INFO hex.tree.xgboost.XGBoostExtension: Found XGBoost backend with library: xgboost4j_minimal\n", "11-05 14:45:30.306 127.0.0.1:54321 79164 main WARN hex.tree.xgboost.XGBoostExtension: Your system supports only minimal version of XGBoost (no GPUs, no multithreading)!\n", "11-05 14:45:30.406 127.0.0.1:54321 79164 main INFO water.default: Registered: 257 REST APIs in: 289ms\n", "11-05 14:45:30.406 127.0.0.1:54321 79164 main INFO water.default: Registered REST API extensions: [Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4]\n", "11-05 14:45:30.506 127.0.0.1:54321 79164 main INFO water.default: Registered: 311 schemas in 100ms\n", "11-05 14:45:30.506 127.0.0.1:54321 79164 main INFO water.default: Locking cloud to new members, because H2O is started in a single node configuration.\n", "\n", "Converting /Users/mkurka/git/h2o/h2o-3/GBM_model_python_1636137917875_1.zip to pojo.java...\n", "11-05 14:45:30.642 127.0.0.1:54321 79164 FJ-1-15 INFO water.default: Starting model Generic_model_1636137928927_1\n", "11-05 14:45:30.747 127.0.0.1:54321 79164 FJ-1-15 INFO water.default: Completing model Generic_model_1636137928927_1\n", "DONE\n" ] }, { "data": { "text/plain": [ "0" ] }, "execution_count": 271, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Add path to MOJO file and write output to \"pojo.java\"\n", "subprocess.call([\"java\", \"-cp\", h2o_jar, \"water.tools.MojoConvertTool\", mojo_path, \"pojo.java\"], stderr=subprocess.STDOUT, shell=False)" ] }, { "cell_type": "code", "execution_count": 272, "id": "3b029b27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/*\n", " Licensed under the Apache License, Version 2.0\n", " http://www.apache.org/licenses/LICENSE-2.0.html\n", "\n", " AUTOGENERATED BY H2O at 2021-11-05T14:45:30.759-04:00\n", " 3.35.0.99999\n", " \n", " Standalone prediction code with sample test data for GBMModel named Generic_model_1636137928927_1\n", "\n", " How to download, compile and execute:\n", " mkdir tmpdir\n", " cd tmpdir\n", " curl http:/localhost/127.0.0.1:54321/3/h2o-genmodel.jar > h2o-genmodel.jar\n", " curl http:/localhost/127.0.0.1:54321/3/Models.java/Generic_model_1636137928927_1 > Generic_model_1636137928927_1.java\n", " javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m Generic_model_1636137928927_1.java\n", "\n", " (Note: Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)\n", "*/\n", "import java.util.Map;\n", "import hex.genmodel.GenModel;\n", "import hex.genmodel.annotations.ModelPojo;\n", "\n", "@ModelPojo(name=\"Generic_model_1636137928927_1\", algorithm=\"gbm\")\n", "public class Generic_model_1636137928927_1 extends GenModel {\n", " public hex.ModelCategory getModelCategory() { return hex.ModelCategory.Binomial; }\n", "\n", " public boolean isSupervised() { return true; }\n", " public int nfeatures() { return 7; }\n", " public int nclasses() { return 2; }\n", "\n", " // Names of columns used by model.\n", " public static final String[] NAMES = NamesHolder_Generic_model_1636137928927_1.VALUES;\n", " // Number of output classes included in training data response column.\n", " public static final int NCLASSES = 2;\n", "\n", " // Column domains. The last array contains domain of response column.\n", " public static final String[][] DOMAINS = new String[][] {\n", " /* AGE */ null,\n", " /* RACE */ null,\n", " /* DPROS */ null,\n", " /* DCAPS */ null,\n", " /* PSA */ null,\n", " /* VOL */ null,\n", " /* GLEASON */ null,\n", " /* CAPSULE */ Generic_model_1636137928927_1_ColInfo_7.VALUES\n", " };\n", " // Prior class distribution\n", " public static final double[] PRIOR_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", " // Class distribution used for model building\n", " public static final double[] MODEL_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", "\n", " public Generic_model_1636137928927_1() { super(NAMES,DOMAINS,\"CAPSULE\"); }\n", " public String getUUID() { return Long.toString(4988040225257658559L); }\n", "\n", " // Pass in data in a double[], pre-aligned to the Model's requirements.\n", " // Jam predictions into the preds[] array; preds[0] is reserved for the\n", " // main prediction (class for classifiers or value for regression),\n", " // and remaining columns hold a probability distribution for classifiers.\n", " public final double[] score0( double[] data, double[] preds ) {\n", " java.util.Arrays.fill(preds,0);\n", " Generic_model_1636137928927_1_Forest_0.score0(data,preds);\n", " preds[2] = preds[1] + -0.3945120960889672;\n", " preds[2] = 1./(1. + Math.min(1e19, Math.exp(-(preds[2]))));\n", " preds[1] = 1.0-preds[2];\n", " preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, 0.4008312811161997);\n", " return preds;\n", " }\n", "}\n", "// The class representing training column names\n", "class NamesHolder_Generic_model_1636137928927_1 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[7];\n", " static {\n", " NamesHolder_Generic_model_1636137928927_1_0.fill(VALUES);\n", " }\n", " static final class NamesHolder_Generic_model_1636137928927_1_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"AGE\";\n", " sa[1] = \"RACE\";\n", " sa[2] = \"DPROS\";\n", " sa[3] = \"DCAPS\";\n", " sa[4] = \"PSA\";\n", " sa[5] = \"VOL\";\n", " sa[6] = \"GLEASON\";\n", " }\n", " }\n", "}\n", "// The class representing column CAPSULE\n", "class Generic_model_1636137928927_1_ColInfo_7 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[2];\n", " static {\n", " Generic_model_1636137928927_1_ColInfo_7_0.fill(VALUES);\n", " }\n", " static final class Generic_model_1636137928927_1_ColInfo_7_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"0\";\n", " sa[1] = \"1\";\n", " }\n", " }\n", "}\n", "\n", "class Generic_model_1636137928927_1_Forest_0 {\n", " public static void score0(double[] fdata, double[] preds) {\n", " preds[1] += Generic_model_1636137928927_1_Tree_0_class_0.score0(fdata);\n", " }\n", "}\n", "class Generic_model_1636137928927_1_Tree_0_class_0 {\n", " static final double score0(double[] data) {\n", " double pred = (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 6.5f ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5f ? \n", " (data[6 /* GLEASON */] < 5.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.44375f ? \n", " -0.16740088f : \n", " (data[5 /* VOL */] < 35.319237f ? \n", " -0.0842475f : \n", " -0.16740088f)) : \n", " (data[2 /* DPROS */] < 1.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 25.695312f ? \n", " -0.09571693f : \n", " -0.16740088f) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 23.359375f ? \n", " -0.07830798f : \n", " 0.005835324f))) : \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 6.645508f ? \n", " (data[5 /* VOL */] < 4.4484377f ? \n", " -0.007490538f : \n", " (data[4 /* PSA */] < 3.6390624f ? \n", " -0.16740088f : \n", " -0.1258242f)) : \n", " (data[6 /* GLEASON */] < 5.5f ? \n", " -0.05400991f : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.1625f ? \n", " 0.17277204f : \n", " 0.040482566f)))) : \n", " (data[4 /* PSA */] < 14.730078f ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 7.1585937f ? \n", " (data[4 /* PSA */] < 7.995f ? \n", " 0.10977705f : \n", " -0.039472606f) : \n", " -0.12363595f) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 17.264843f ? \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 8.267187f ? \n", " 0.1524198f : \n", " -0.042670812f) : \n", " 0.22390914f)) : \n", " (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 7.5f ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 24.657812f ? \n", " (data[4 /* PSA */] < 18.55625f ? \n", " 0.24836601f : \n", " 0.11424766f) : \n", " 0.01078493f) : \n", " (data[4 /* PSA */] < 22.60625f ? \n", " 0.12363594f : \n", " 0.24836601f))));\n", " return pred;\n", " } // constant pool size = 94B, number of visited nodes = 23, static init size = 0B\n", "}\n", "\n", "\n", "\n" ] } ], "source": [ "# Display the content of the POJO\n", "with open('pojo.java', 'r') as f:\n", " print(f.read())" ] }, { "cell_type": "code", "execution_count": 273, "id": "87f33a45", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Starting local H2O instance to facilitate MOJO to POJO conversion.\n", "\n", "14:45:31.502 [main] INFO hex.tree.xgboost.util.NativeLibrary - Loaded library from lib/osx_64/libxgboost4j_minimal.dylib (/var/folders/v1/fkjmcbkd11v2mrm4dm6345ym0000gn/T/libxgboost4j_minimal978915340387551523.dylib)\n", "11-05 14:45:31.628 127.0.0.1:54321 79166 main INFO water.default: ----- H2O started -----\n", "11-05 14:45:31.628 127.0.0.1:54321 79166 main INFO water.default: Build git branch: master\n", "11-05 14:45:31.628 127.0.0.1:54321 79166 main INFO water.default: Build git hash: b9ba1af5f07c6dbc6369e41113ea43947109e054\n", "11-05 14:45:31.628 127.0.0.1:54321 79166 main INFO water.default: Build git describe: jenkins-master-5625-7-gb9ba1af5f0\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Build project version: 3.35.0.99999\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Build age: 2 hours and 53 minutes\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Built by: 'mkurka'\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Built on: '2021-11-05 11:51:34'\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Found H2O Core extensions: [XGBoost, KrbStandalone]\n", "11-05 14:45:31.629 127.0.0.1:54321 79166 main INFO water.default: Processed H2O arguments: [-disable_web, -ip, localhost, -disable_net]\n", "11-05 14:45:31.630 127.0.0.1:54321 79166 main INFO water.default: Java availableProcessors: 16\n", "11-05 14:45:31.630 127.0.0.1:54321 79166 main INFO water.default: Java heap totalMemory: 491.0 MB\n", "11-05 14:45:31.630 127.0.0.1:54321 79166 main INFO water.default: Java heap maxMemory: 7.11 GB\n", "11-05 14:45:31.630 127.0.0.1:54321 79166 main INFO water.default: Java version: Java 1.8.0_311 (from Oracle Corporation)\n", "11-05 14:45:31.630 127.0.0.1:54321 79166 main INFO water.default: JVM launch parameters: [-Dsys.ai.h2o.java.output.doubles=true]\n", "11-05 14:45:31.631 127.0.0.1:54321 79166 main INFO water.default: JVM process id: 79166@michals-mbp.lan\n", "11-05 14:45:31.631 127.0.0.1:54321 79166 main INFO water.default: OS version: Mac OS X 10.16 (x86_64)\n", "11-05 14:45:31.631 127.0.0.1:54321 79166 main INFO water.default: Machine physical memory: 32.00 GB\n", "11-05 14:45:31.631 127.0.0.1:54321 79166 main INFO water.default: Machine locale: en_US\n", "11-05 14:45:31.633 127.0.0.1:54321 79166 main INFO water.default: X-h2o-cluster-id: 1636137931013\n", "11-05 14:45:31.633 127.0.0.1:54321 79166 main INFO water.default: User name: 'mkurka'\n", "11-05 14:45:31.633 127.0.0.1:54321 79166 main INFO water.default: IPv6 stack selected: false\n", "11-05 14:45:31.633 127.0.0.1:54321 79166 main INFO water.default: H2O node running in unencrypted mode.\n", "11-05 14:45:32.130 127.0.0.1:54321 79166 main INFO water.default: Kerberos not configured\n", "11-05 14:45:32.130 127.0.0.1:54321 79166 main INFO water.default: Log dir: '/tmp/h2o-mkurka/h2ologs'\n", "11-05 14:45:32.130 127.0.0.1:54321 79166 main INFO water.default: Cur dir: '/Users/mkurka/git/h2o/h2o-3'\n", "11-05 14:45:32.136 127.0.0.1:54321 79166 main INFO water.default: Subsystem for distributed import from HTTP/HTTPS successfully initialized\n", "11-05 14:45:32.137 127.0.0.1:54321 79166 main INFO water.default: HDFS subsystem successfully initialized\n", "11-05 14:45:32.140 127.0.0.1:54321 79166 main INFO water.default: S3 subsystem successfully initialized\n", "11-05 14:45:32.151 127.0.0.1:54321 79166 main INFO water.default: GCS subsystem successfully initialized\n", "11-05 14:45:32.152 127.0.0.1:54321 79166 main INFO water.default: Flow dir: '/Users/mkurka/h2oflows'\n", "11-05 14:45:32.158 127.0.0.1:54321 79166 main INFO water.default: Cloud of size 1 formed [localhost/127.0.0.1:54321]\n", "11-05 14:45:32.164 127.0.0.1:54321 79166 main INFO water.default: Registered parsers: [GUESS, ARFF, XLS, SVMLight, AVRO, PARQUET, CSV]\n", "11-05 14:45:32.165 127.0.0.1:54321 79166 main INFO water.default: XGBoost extension initialized\n", "11-05 14:45:32.166 127.0.0.1:54321 79166 main INFO water.default: KrbStandalone extension initialized\n", "11-05 14:45:32.166 127.0.0.1:54321 79166 main INFO water.default: Registered 2 core extensions in: 376ms\n", "11-05 14:45:32.166 127.0.0.1:54321 79166 main INFO water.default: Registered H2O core extensions: [XGBoost, KrbStandalone]\n", "11-05 14:45:32.353 127.0.0.1:54321 79166 main INFO hex.tree.xgboost.XGBoostExtension: Found XGBoost backend with library: xgboost4j_minimal\n", "11-05 14:45:32.353 127.0.0.1:54321 79166 main WARN hex.tree.xgboost.XGBoostExtension: Your system supports only minimal version of XGBoost (no GPUs, no multithreading)!\n", "11-05 14:45:32.467 127.0.0.1:54321 79166 main INFO water.default: Registered: 257 REST APIs in: 301ms\n", "11-05 14:45:32.467 127.0.0.1:54321 79166 main INFO water.default: Registered REST API extensions: [Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4]\n", "11-05 14:45:32.564 127.0.0.1:54321 79166 main INFO water.default: Registered: 311 schemas in 96ms\n", "11-05 14:45:32.565 127.0.0.1:54321 79166 main INFO water.default: Locking cloud to new members, because H2O is started in a single node configuration.\n", "\n", "Converting /Users/mkurka/git/h2o/h2o-3/GBM_model_python_1636137917875_1.zip to pojo64.java...\n", "11-05 14:45:32.700 127.0.0.1:54321 79166 FJ-1-15 INFO water.default: Starting model Generic_model_1636137931013_1\n", "11-05 14:45:32.804 127.0.0.1:54321 79166 FJ-1-15 INFO water.default: Completing model Generic_model_1636137931013_1\n", "DONE\n" ] }, { "data": { "text/plain": [ "0" ] }, "execution_count": 273, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now specify system property sys.ai.h2o.java.output.doubles to output numbers in 64-bit precision\n", "subprocess.call([\"java\", \"-Dsys.ai.h2o.java.output.doubles=true\", \"-cp\", h2o_jar, \"water.tools.MojoConvertTool\", mojo_path, \"pojo64.java\"], stderr=subprocess.STDOUT, shell=False)" ] }, { "cell_type": "code", "execution_count": 274, "id": "00759e2e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/*\n", " Licensed under the Apache License, Version 2.0\n", " http://www.apache.org/licenses/LICENSE-2.0.html\n", "\n", " AUTOGENERATED BY H2O at 2021-11-05T14:45:32.815-04:00\n", " 3.35.0.99999\n", " \n", " Standalone prediction code with sample test data for GBMModel named Generic_model_1636137931013_1\n", "\n", " How to download, compile and execute:\n", " mkdir tmpdir\n", " cd tmpdir\n", " curl http:/localhost/127.0.0.1:54321/3/h2o-genmodel.jar > h2o-genmodel.jar\n", " curl http:/localhost/127.0.0.1:54321/3/Models.java/Generic_model_1636137931013_1 > Generic_model_1636137931013_1.java\n", " javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m Generic_model_1636137931013_1.java\n", "\n", " (Note: Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)\n", "*/\n", "import java.util.Map;\n", "import hex.genmodel.GenModel;\n", "import hex.genmodel.annotations.ModelPojo;\n", "\n", "@ModelPojo(name=\"Generic_model_1636137931013_1\", algorithm=\"gbm\")\n", "public class Generic_model_1636137931013_1 extends GenModel {\n", " public hex.ModelCategory getModelCategory() { return hex.ModelCategory.Binomial; }\n", "\n", " public boolean isSupervised() { return true; }\n", " public int nfeatures() { return 7; }\n", " public int nclasses() { return 2; }\n", "\n", " // Names of columns used by model.\n", " public static final String[] NAMES = NamesHolder_Generic_model_1636137931013_1.VALUES;\n", " // Number of output classes included in training data response column.\n", " public static final int NCLASSES = 2;\n", "\n", " // Column domains. The last array contains domain of response column.\n", " public static final String[][] DOMAINS = new String[][] {\n", " /* AGE */ null,\n", " /* RACE */ null,\n", " /* DPROS */ null,\n", " /* DCAPS */ null,\n", " /* PSA */ null,\n", " /* VOL */ null,\n", " /* GLEASON */ null,\n", " /* CAPSULE */ Generic_model_1636137931013_1_ColInfo_7.VALUES\n", " };\n", " // Prior class distribution\n", " public static final double[] PRIOR_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", " // Class distribution used for model building\n", " public static final double[] MODEL_CLASS_DISTRIB = {0.5973684210526315,0.4026315789473684};\n", "\n", " public Generic_model_1636137931013_1() { super(NAMES,DOMAINS,\"CAPSULE\"); }\n", " public String getUUID() { return Long.toString(4988040225257658559L); }\n", "\n", " // Pass in data in a double[], pre-aligned to the Model's requirements.\n", " // Jam predictions into the preds[] array; preds[0] is reserved for the\n", " // main prediction (class for classifiers or value for regression),\n", " // and remaining columns hold a probability distribution for classifiers.\n", " public final double[] score0( double[] data, double[] preds ) {\n", " java.util.Arrays.fill(preds,0);\n", " Generic_model_1636137931013_1_Forest_0.score0(data,preds);\n", " preds[2] = preds[1] + -0.3945120960889672;\n", " preds[2] = 1./(1. + Math.min(1e19, Math.exp(-(preds[2]))));\n", " preds[1] = 1.0-preds[2];\n", " preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, 0.4008312811161997);\n", " return preds;\n", " }\n", "}\n", "// The class representing training column names\n", "class NamesHolder_Generic_model_1636137931013_1 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[7];\n", " static {\n", " NamesHolder_Generic_model_1636137931013_1_0.fill(VALUES);\n", " }\n", " static final class NamesHolder_Generic_model_1636137931013_1_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"AGE\";\n", " sa[1] = \"RACE\";\n", " sa[2] = \"DPROS\";\n", " sa[3] = \"DCAPS\";\n", " sa[4] = \"PSA\";\n", " sa[5] = \"VOL\";\n", " sa[6] = \"GLEASON\";\n", " }\n", " }\n", "}\n", "// The class representing column CAPSULE\n", "class Generic_model_1636137931013_1_ColInfo_7 implements java.io.Serializable {\n", " public static final String[] VALUES = new String[2];\n", " static {\n", " Generic_model_1636137931013_1_ColInfo_7_0.fill(VALUES);\n", " }\n", " static final class Generic_model_1636137931013_1_ColInfo_7_0 implements java.io.Serializable {\n", " static final void fill(String[] sa) {\n", " sa[0] = \"0\";\n", " sa[1] = \"1\";\n", " }\n", " }\n", "}\n", "\n", "class Generic_model_1636137931013_1_Forest_0 {\n", " public static void score0(double[] fdata, double[] preds) {\n", " preds[1] += Generic_model_1636137931013_1_Tree_0_class_0.score0(fdata);\n", " }\n", "}\n", "class Generic_model_1636137931013_1_Tree_0_class_0 {\n", " static final double score0(double[] data) {\n", " double pred = (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 6.5 ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5 ? \n", " (data[6 /* GLEASON */] < 5.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.443750381469727 ? \n", " -0.16740088164806366 : \n", " (data[5 /* VOL */] < 35.319236755371094 ? \n", " -0.08424749970436096 : \n", " -0.16740088164806366)) : \n", " (data[2 /* DPROS */] < 1.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 25.6953125 ? \n", " -0.0957169309258461 : \n", " -0.16740088164806366) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 23.359375 ? \n", " -0.07830797880887985 : \n", " 0.005835324060171843))) : \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 6.6455078125 ? \n", " (data[5 /* VOL */] < 4.448437690734863 ? \n", " -0.007490538060665131 : \n", " (data[4 /* PSA */] < 3.6390624046325684 ? \n", " -0.16740088164806366 : \n", " -0.1258241981267929)) : \n", " (data[6 /* GLEASON */] < 5.5 ? \n", " -0.05400991067290306 : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 19.162500381469727 ? \n", " 0.17277203500270844 : \n", " 0.04048256576061249)))) : \n", " (data[4 /* PSA */] < 14.730077743530273 ? \n", " (Double.isNaN(data[2]) || data[2 /* DPROS */] < 2.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 7.158593654632568 ? \n", " (data[4 /* PSA */] < 7.994999885559082 ? \n", " 0.1097770482301712 : \n", " -0.03947260603308678) : \n", " -0.12363594770431519) : \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 17.264842987060547 ? \n", " (Double.isNaN(data[4]) || data[4 /* PSA */] < 8.267187118530273 ? \n", " 0.1524198055267334 : \n", " -0.04267081245779991) : \n", " 0.2239091396331787)) : \n", " (Double.isNaN(data[6]) || data[6 /* GLEASON */] < 7.5 ? \n", " (Double.isNaN(data[5]) || data[5 /* VOL */] < 24.657812118530273 ? \n", " (data[4 /* PSA */] < 18.556249618530273 ? \n", " 0.24836601316928864 : \n", " 0.11424765735864639) : \n", " 0.010784929618239403) : \n", " (data[4 /* PSA */] < 22.606250762939453 ? \n", " 0.12363594025373459 : \n", " 0.24836601316928864))));\n", " return pred;\n", " } // constant pool size = 94B, number of visited nodes = 23, static init size = 0B\n", "}\n", "\n", "\n", "\n" ] } ], "source": [ "# Display the content of the POJO with 64-bit number representation\n", "with open('pojo64.java', 'r') as f:\n", " print(f.read())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 5 }