{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "97873840", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import warnings\n", "# This notebook won't work on Python 2.x\n", "if sys.version_info < (3, 0):\n", " warnings.warn(\"Notebook not executed - please use Python 3.x to run\")\n", " exit(0)" ] }, { "cell_type": "markdown", "id": "d76219dc", "metadata": {}, "source": [ "## Using Isotonic Regression to calibrate a classification model" ] }, { "cell_type": "markdown", "id": "1b044b23", "metadata": {}, "source": [ "In many classification use cases we are interested not only in predicting class labels but also in outputting probabilities that can be interpreted as confidence levels. In this notebook we will demonstrate how Isotonic Regression can be used to calibrate a GBM classifier." ] }, { "cell_type": "markdown", "id": "b220b04d", "metadata": {}, "source": [ "We will show how the calibration method would look like in scikit-learn with use of `CalibratedClassifierCV` and how the same can be accomplished in H2O." ] }, { "cell_type": "markdown", "id": "1fdd4bde", "metadata": {}, "source": [ "Please refer to https://scikit-learn.org/stable/modules/calibration.html for the theoretical background of calibrating probabilities." ] }, { "cell_type": "code", "execution_count": 1, "id": "fc562a4a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "versionFromGradle='3.37.0',projectVersion='3.37.0.99999',branch='master',lastCommitHash='a1c95a407aec53a6cbc551484bd02d7d80b3bcb6',gitDescribe='jenkins-master-5950-dirty',compiledOn='2022-09-13 10:48:53',compiledBy='kurkami'\n" ] } ], "source": [ "import h2o" ] }, { "cell_type": "code", "execution_count": 2, "id": "d2895e39", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking whether there is an H2O instance running at http://localhost:54321 ..... not found.\n", "Attempting to start a local H2O server...\n", " Java Version: openjdk version \"1.8.0_342\"; OpenJDK Runtime Environment (build 1.8.0_342-8u342-b07-0ubuntu1~22.04-b07); OpenJDK 64-Bit Server VM (build 25.342-b07, mixed mode)\n", " Starting server from /home/kurkami/git/h2o/h2o-3/build/h2o.jar\n", " Ice root: /tmp/tmp_k1gozye\n", " JVM stdout: /tmp/tmp_k1gozye/h2o_kurkami_started_from_python.out\n", " JVM stderr: /tmp/tmp_k1gozye/h2o_kurkami_started_from_python.err\n", " Server is running at http://127.0.0.1:54321\n", "Connecting to H2O server at http://127.0.0.1:54321 ... successful.\n" ] }, { "data": { "text/html": [ "\n", " \n", "
\n", " \n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
H2O_cluster_uptime:01 secs
H2O_cluster_timezone:America/New_York
H2O_data_parsing_timezone:UTC
H2O_cluster_version:3.37.0.99999
H2O_cluster_version_age:1 hour and 4 minutes
H2O_cluster_name:H2O_from_python_kurkami_9vbn5o
H2O_cluster_total_nodes:1
H2O_cluster_free_memory:3.409 Gb
H2O_cluster_total_cores:12
H2O_cluster_allowed_cores:12
H2O_cluster_status:locked, healthy
H2O_connection_url:http://127.0.0.1:54321
H2O_connection_proxy:{\"http\": null, \"https\": null}
H2O_internal_security:False
Python_version:3.10.4 final
\n", "
\n" ], "text/plain": [ "-------------------------- ------------------------------\n", "H2O_cluster_uptime: 01 secs\n", "H2O_cluster_timezone: America/New_York\n", "H2O_data_parsing_timezone: UTC\n", "H2O_cluster_version: 3.37.0.99999\n", "H2O_cluster_version_age: 1 hour and 4 minutes\n", "H2O_cluster_name: H2O_from_python_kurkami_9vbn5o\n", "H2O_cluster_total_nodes: 1\n", "H2O_cluster_free_memory: 3.409 Gb\n", "H2O_cluster_total_cores: 12\n", "H2O_cluster_allowed_cores: 12\n", "H2O_cluster_status: locked, healthy\n", "H2O_connection_url: http://127.0.0.1:54321\n", "H2O_connection_proxy: {\"http\": null, \"https\": null}\n", "H2O_internal_security: False\n", "Python_version: 3.10.4 final\n", "-------------------------- ------------------------------" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "h2o.init(strict_version_check=False)" ] }, { "cell_type": "markdown", "id": "7b1a88a9", "metadata": {}, "source": [ "#### Create synthetic data" ] }, { "cell_type": "code", "execution_count": 3, "id": "54a12b7a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%\n", "Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
x1 x2 y
0.78399 0.399454 0
0.401748-0.23744 0
-1.72528 -1.79556 0
1.34722 1.05784 1
-3.55901 -3.23764 0
0.575518 0.424405 1
-0.580976 0.639303 0
1.30574 -1.27541 1
-0.770629-1.00661 0
-2.65608 -2.49828 0
[5000 rows x 3 columns]
" ], "text/plain": [ " x1 x2 y\n", "--------- --------- ---\n", " 0.78399 0.399454 0\n", " 0.401748 -0.23744 0\n", "-1.72528 -1.79556 0\n", " 1.34722 1.05784 1\n", "-3.55901 -3.23764 0\n", " 0.575518 0.424405 1\n", "-0.580976 0.639303 0\n", " 1.30574 -1.27541 1\n", "-0.770629 -1.00661 0\n", "-2.65608 -2.49828 0\n", "[5000 rows x 3 columns]\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import make_classification\n", "X, y = make_classification(n_samples=5000, n_features=2, n_redundant=0, random_state=42)\n", "X_df = h2o.H2OFrame(X, column_names=[\"x1\", \"x2\"])\n", "y_df = h2o.H2OFrame(y, column_names=[\"y\"]).asfactor()\n", "df = X_df.cbind(y_df)\n", "df" ] }, { "cell_type": "markdown", "id": "28c3b9b8", "metadata": {}, "source": [ "### Method 1: Use a separate set of observations for calibration" ] }, { "cell_type": "markdown", "id": "4bf4f08b", "metadata": {}, "source": [ "The simplest way of calibrating a classifier is to set aside a subset of the training set and use it for model calibration. In the code bellow we will split the dataset into training set and calibration set." ] }, { "cell_type": "markdown", "id": "56d2eda3", "metadata": {}, "source": [ "##### scikit-learn" ] }, { "cell_type": "code", "execution_count": 4, "id": "6b859959", "metadata": {}, "outputs": [], "source": [ "# split data\n", "from sklearn.model_selection import train_test_split\n", "X_train, X_calib, y_train, y_calib = train_test_split(X, y, random_state=42)" ] }, { "cell_type": "code", "execution_count": 5, "id": "b9f28ccc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
CalibratedClassifierCV(base_estimator=GradientBoostingClassifier(learning_rate=1.0,\n",
       "                                                                 max_depth=1,\n",
       "                                                                 random_state=0),\n",
       "                       cv='prefit', method='isotonic')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "CalibratedClassifierCV(base_estimator=GradientBoostingClassifier(learning_rate=1.0,\n", " max_depth=1,\n", " random_state=0),\n", " cv='prefit', method='isotonic')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train a calibrated classifier\n", "from sklearn.ensemble import GradientBoostingClassifier\n", "from sklearn.calibration import CalibratedClassifierCV\n", "base_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0,\n", " max_depth=1, random_state=0).fit(X_train, y_train)\n", "calibrated_clf = CalibratedClassifierCV(base_estimator=base_clf, cv=\"prefit\", method=\"isotonic\")\n", "calibrated_clf.fit(X_calib, y_calib)" ] }, { "cell_type": "code", "execution_count": 6, "id": "75eaeaf3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1. , 0. ],\n", " [0.09793814, 0.90206186],\n", " [0.72972973, 0.27027027],\n", " ...,\n", " [0.25641026, 0.74358974],\n", " [0.95652174, 0.04347826],\n", " [0.02836879, 0.97163121]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# predict calibrated probabilities\n", "calibrated_clf.predict_proba(X_calib)" ] }, { "cell_type": "markdown", "id": "113a2636", "metadata": {}, "source": [ "##### H2O" ] }, { "cell_type": "code", "execution_count": 7, "id": "d0c5e81b", "metadata": {}, "outputs": [], "source": [ "# split data\n", "df_train, df_calib = df.split_frame(ratios=[.8], destination_frames=[\"df_train\", \"df_calib\"], seed=42)" ] }, { "cell_type": "code", "execution_count": 8, "id": "1a216272", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Model Build progress: |██████████████████████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "
H2OGradientBoostingEstimator : Gradient Boosting Machine\n",
       "Model Key: GBM_model_python_1663084324496_1\n",
       "
\n", "
\n", " \n", "
\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Model Summary:
number_of_treesnumber_of_internal_treesmodel_size_in_bytesmin_depthmax_depthmean_depthmin_leavesmax_leavesmean_leaves
50.050.019176.05.05.05.018.032.025.82
\n", "
\n", "
\n",
       "\n",
       "[tips]\n",
       "Use `model.show()` for more details.\n",
       "Use `model.explain()` to inspect the model.\n",
       "--\n",
       "Use `h2o.display.toggle_user_tips()` to switch on/off this section.
" ], "text/plain": [ "H2OGradientBoostingEstimator : Gradient Boosting Machine\n", "Model Key: GBM_model_python_1663084324496_1\n", "\n", "\n", "Model Summary: \n", " number_of_trees number_of_internal_trees model_size_in_bytes min_depth max_depth mean_depth min_leaves max_leaves mean_leaves\n", "-- ----------------- -------------------------- --------------------- ----------- ----------- ------------ ------------ ------------ -------------\n", " 50 50 19176 5 5 5 18 32 25.82\n", "\n", "[tips]\n", "Use `model.show()` for more details.\n", "Use `model.explain()` to inspect the model.\n", "--\n", "Use `h2o.display.toggle_user_tips()` to switch on/off this section." ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train a calibrated classifier\n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator\n", "model = H2OGradientBoostingEstimator(\n", " calibrate_model=True, calibration_frame=df_calib, calibration_method=\"IsotonicRegression\"\n", ")\n", "model.train(\n", " y=\"y\", training_frame=df_train\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "id": "38d135e0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm prediction progress: |███████████████████████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
predict p0 p1 cal_p0 cal_p1
00.990326 0.009674151 0
00.964438 0.0355615 0.987288 0.0127119
00.975859 0.0241413 0.987288 0.0127119
00.96125 0.0387503 0.987288 0.0127119
00.990921 0.009079171 0
00.987689 0.0123109 0.987288 0.0127119
10.02551590.974484 0.02777780.972222
10.03178690.968213 0.02777780.972222
00.976713 0.0232871 0.987288 0.0127119
00.990921 0.009079171 0
[985 rows x 5 columns]
" ], "text/plain": [ " predict p0 p1 cal_p0 cal_p1\n", "--------- --------- ---------- --------- ---------\n", " 0 0.990326 0.00967415 1 0\n", " 0 0.964438 0.0355615 0.987288 0.0127119\n", " 0 0.975859 0.0241413 0.987288 0.0127119\n", " 0 0.96125 0.0387503 0.987288 0.0127119\n", " 0 0.990921 0.00907917 1 0\n", " 0 0.987689 0.0123109 0.987288 0.0127119\n", " 1 0.0255159 0.974484 0.0277778 0.972222\n", " 1 0.0317869 0.968213 0.0277778 0.972222\n", " 0 0.976713 0.0232871 0.987288 0.0127119\n", " 0 0.990921 0.00907917 1 0\n", "[985 rows x 5 columns]\n" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# calibrated probabilities are predicted alongside the original probabilities p0 vs cal_p0 (calibrated)\n", "model.predict(df_calib)" ] }, { "cell_type": "markdown", "id": "5e2d181b", "metadata": {}, "source": [ "### Method 2: Use CV holdout predictions to calibrate the classifier" ] }, { "cell_type": "markdown", "id": "aa5953f8", "metadata": {}, "source": [ "In this method we use the full training set and cross-validation to get unbiased predictions. Then we train Isotonic Regression model on the CV holdout predictions. In H2O this is done by first training (and possibly tuning) the base classifier, training the Isotonic Regression model and injecting it into the original classifier." ] }, { "cell_type": "markdown", "id": "0101dc15", "metadata": {}, "source": [ "##### scikit-learn" ] }, { "cell_type": "code", "execution_count": 10, "id": "de1cfe70", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
CalibratedClassifierCV(base_estimator=GradientBoostingClassifier(learning_rate=1.0,\n",
       "                                                                 max_depth=1,\n",
       "                                                                 random_state=0),\n",
       "                       cv=5, ensemble=False, method='isotonic')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "CalibratedClassifierCV(base_estimator=GradientBoostingClassifier(learning_rate=1.0,\n", " max_depth=1,\n", " random_state=0),\n", " cv=5, ensemble=False, method='isotonic')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train a calibrated classifier\n", "from sklearn.ensemble import GradientBoostingClassifier\n", "from sklearn.calibration import CalibratedClassifierCV\n", "base_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0,\n", " max_depth=1, random_state=0)\n", "calibrated_clf = CalibratedClassifierCV(base_estimator=base_clf, ensemble=False, cv=5, method=\"isotonic\")\n", "calibrated_clf.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 11, "id": "61efd795", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(calibrated_clf.calibrated_classifiers_)" ] }, { "cell_type": "code", "execution_count": 12, "id": "51ee53e7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.25446429, 0.74553571],\n", " [0.68292683, 0.31707317],\n", " [0.95238095, 0.04761905],\n", " ...,\n", " [0.91246871, 0.08753129],\n", " [0.02298851, 0.97701149],\n", " [0.95238095, 0.04761905]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# predict calibrated probabilities\n", "calibrated_clf.predict_proba(X)" ] }, { "cell_type": "markdown", "id": "a1d2c58a", "metadata": {}, "source": [ "##### h2o" ] }, { "cell_type": "code", "execution_count": 13, "id": "5ac0eba5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Model Build progress: |██████████████████████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "
H2OGradientBoostingEstimator : Gradient Boosting Machine\n",
       "Model Key: GBM_model_python_1663084324496_53\n",
       "
\n", "
\n", " \n", "
\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Model Summary:
number_of_treesnumber_of_internal_treesmodel_size_in_bytesmin_depthmax_depthmean_depthmin_leavesmax_leavesmean_leaves
50.050.019611.05.05.05.016.032.026.52
\n", "
\n", "
\n",
       "\n",
       "[tips]\n",
       "Use `model.show()` for more details.\n",
       "Use `model.explain()` to inspect the model.\n",
       "--\n",
       "Use `h2o.display.toggle_user_tips()` to switch on/off this section.
" ], "text/plain": [ "H2OGradientBoostingEstimator : Gradient Boosting Machine\n", "Model Key: GBM_model_python_1663084324496_53\n", "\n", "\n", "Model Summary: \n", " number_of_trees number_of_internal_trees model_size_in_bytes min_depth max_depth mean_depth min_leaves max_leaves mean_leaves\n", "-- ----------------- -------------------------- --------------------- ----------- ----------- ------------ ------------ ------------ -------------\n", " 50 50 19611 5 5 5 16 32 26.52\n", "\n", "[tips]\n", "Use `model.show()` for more details.\n", "Use `model.explain()` to inspect the model.\n", "--\n", "Use `h2o.display.toggle_user_tips()` to switch on/off this section." ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train a classifier using 5-fold CV, make sure you keep the CV holdout predictions\n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator\n", "model = H2OGradientBoostingEstimator(\n", " nfolds=5, keep_cross_validation_predictions=True\n", ")\n", "model.train(\n", " y=\"y\", training_frame=df\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "2c30d91f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
predict p0 p1 y
00.51785 0.48215 0
00.896655 0.103345 0
00.988225 0.0117751 0
10.04684750.953152 1
00.988428 0.0115717 0
10.378415 0.621585 1
00.987002 0.0129979 0
10.01437610.985624 1
00.969988 0.0300118 0
00.988784 0.0112155 0
[5000 rows x 4 columns]
" ], "text/plain": [ " predict p0 p1 y\n", "--------- --------- --------- ---\n", " 0 0.51785 0.48215 0\n", " 0 0.896655 0.103345 0\n", " 0 0.988225 0.0117751 0\n", " 1 0.0468475 0.953152 1\n", " 0 0.988428 0.0115717 0\n", " 1 0.378415 0.621585 1\n", " 0 0.987002 0.0129979 0\n", " 1 0.0143761 0.985624 1\n", " 0 0.969988 0.0300118 0\n", " 0 0.988784 0.0112155 0\n", "[5000 rows x 4 columns]\n" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# CV holdout predictions will serve as the training frame for Isotonic Regression calibrator\n", "xval_calib = model.cross_validation_holdout_predictions().cbind(df[[\"y\"]])\n", "xval_calib" ] }, { "cell_type": "code", "execution_count": 15, "id": "85a14ecf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "isotonicregression Model Build progress: |███████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "
H2OIsotonicRegressionEstimator : Isotonic Regression\n",
       "Model Key: IsotonicRegression_model_python_1663084324496_621\n",
       "
\n", "
\n", " \n", "
\n", " \n", " \n", " \n", "\n", "\n", " \n", "\n", "\n", "
Isotonic Regression Model: summary
number_of_observationsnumber_of_thresholds
5000.060.0
\n", "
\n", "
\n",
       "\n",
       "[tips]\n",
       "Use `model.show()` for more details.\n",
       "Use `model.explain()` to inspect the model.\n",
       "--\n",
       "Use `h2o.display.toggle_user_tips()` to switch on/off this section.
" ], "text/plain": [ "H2OIsotonicRegressionEstimator : Isotonic Regression\n", "Model Key: IsotonicRegression_model_python_1663084324496_621\n", "\n", "\n", "Isotonic Regression Model: summary\n", " number_of_observations number_of_thresholds\n", "-- ------------------------ ----------------------\n", " 5000 60\n", "\n", "[tips]\n", "Use `model.show()` for more details.\n", "Use `model.explain()` to inspect the model.\n", "--\n", "Use `h2o.display.toggle_user_tips()` to switch on/off this section." ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train Isotonic Regression model with actual labels as the target and holdout p1 predictions as a (single) feature\n", "from h2o.estimators.isotonicregression import H2OIsotonicRegressionEstimator\n", "h2o_calibrator = H2OIsotonicRegressionEstimator()\n", "h2o_calibrator.train(training_frame=xval_calib, x=\"p1\", y=\"y\")" ] }, { "cell_type": "code", "execution_count": 16, "id": "8120c71b", "metadata": {}, "outputs": [], "source": [ "# inject the calibrator model into the original GBM model\n", "model.calibrate(h2o_calibrator)" ] }, { "cell_type": "code", "execution_count": 17, "id": "02a560f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm prediction progress: |███████████████████████████████████████████████████████| (done) 100%\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
predict p0 p1 cal_p0 cal_p1
00.569895 0.430105 0.609756 0.390244
00.908601 0.09139870.900958 0.0990415
00.987997 0.012003 0.997354 0.0026455
10.05766220.942338 0.06875 0.93125
00.988754 0.01124550.997354 0.0026455
10.374216 0.625784 0.352941 0.647059
00.980122 0.01987770.986861 0.0131387
10.01301210.986988 0.006896550.993103
00.949659 0.05034070.942308 0.0576923
00.988754 0.01124550.997354 0.0026455
[5000 rows x 5 columns]
" ], "text/plain": [ " predict p0 p1 cal_p0 cal_p1\n", "--------- --------- --------- ---------- ---------\n", " 0 0.569895 0.430105 0.609756 0.390244\n", " 0 0.908601 0.0913987 0.900958 0.0990415\n", " 0 0.987997 0.012003 0.997354 0.0026455\n", " 1 0.0576622 0.942338 0.06875 0.93125\n", " 0 0.988754 0.0112455 0.997354 0.0026455\n", " 1 0.374216 0.625784 0.352941 0.647059\n", " 0 0.980122 0.0198777 0.986861 0.0131387\n", " 1 0.0130121 0.986988 0.00689655 0.993103\n", " 0 0.949659 0.0503407 0.942308 0.0576923\n", " 0 0.988754 0.0112455 0.997354 0.0026455\n", "[5000 rows x 5 columns]\n" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# predict with calibrated probabilities\n", "model.predict(df)" ] }, { "cell_type": "code", "execution_count": null, "id": "bb649f9f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }