{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Probability Calibration for 3-class classification\n\nThis example illustrates how sigmoid `calibration ` changes\npredicted probabilities for a 3-class classification problem. Illustrated is\nthe standard 2-simplex, where the three corners correspond to the three\nclasses. Arrows point from the probability vectors predicted by an uncalibrated\nclassifier to the probability vectors predicted by the same classifier after\nsigmoid calibration on a hold-out validation set. Colors indicate the true\nclass of an instance (red: class 1, green: class 2, blue: class 3).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\nBelow, we generate a classification dataset with 2000 samples, 2 features\nand 3 target classes. We then split the data as follows:\n\n* train: 600 samples (for training the classifier)\n* valid: 400 samples (for calibrating predicted probabilities)\n* test: 1000 samples\n\nNote that we also create `X_train_valid` and `y_train_valid`, which consists\nof both the train and valid subsets. This is used when we only want to train\nthe classifier but not calibrate the predicted probabilities.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport numpy as np\n\nfrom sklearn.datasets import make_blobs\n\nnp.random.seed(0)\n\nX, y = make_blobs(\n n_samples=2000, n_features=2, centers=3, random_state=42, cluster_std=5.0\n)\nX_train, y_train = X[:600], y[:600]\nX_valid, y_valid = X[600:1000], y[600:1000]\nX_train_valid, y_train_valid = X[:1000], y[:1000]\nX_test, y_test = X[1000:], y[1000:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fitting and calibration\n\nFirst, we will train a :class:`~sklearn.ensemble.RandomForestClassifier`\nwith 25 base estimators (trees) on the concatenated train and validation\ndata (1000 samples). This is the uncalibrated classifier.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n\nclf = RandomForestClassifier(n_estimators=25)\nclf.fit(X_train_valid, y_train_valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train the calibrated classifier, we start with the same\n:class:`~sklearn.ensemble.RandomForestClassifier` but train it using only\nthe train data subset (600 samples) then calibrate, with `method='sigmoid'`,\nusing the valid data subset (400 samples) in a 2-stage process.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.calibration import CalibratedClassifierCV\nfrom sklearn.frozen import FrozenEstimator\n\nclf = RandomForestClassifier(n_estimators=25)\nclf.fit(X_train, y_train)\ncal_clf = CalibratedClassifierCV(FrozenEstimator(clf), method=\"sigmoid\")\ncal_clf.fit(X_valid, y_valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare probabilities\nBelow we plot a 2-simplex with arrows showing the change in predicted\nprobabilities of the test samples.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nplt.figure(figsize=(10, 10))\ncolors = [\"r\", \"g\", \"b\"]\n\nclf_probs = clf.predict_proba(X_test)\ncal_clf_probs = cal_clf.predict_proba(X_test)\n# Plot arrows\nfor i in range(clf_probs.shape[0]):\n plt.arrow(\n clf_probs[i, 0],\n clf_probs[i, 1],\n cal_clf_probs[i, 0] - clf_probs[i, 0],\n cal_clf_probs[i, 1] - clf_probs[i, 1],\n color=colors[y_test[i]],\n head_width=1e-2,\n )\n\n# Plot perfect predictions, at each vertex\nplt.plot([1.0], [0.0], \"ro\", ms=20, label=\"Class 1\")\nplt.plot([0.0], [1.0], \"go\", ms=20, label=\"Class 2\")\nplt.plot([0.0], [0.0], \"bo\", ms=20, label=\"Class 3\")\n\n# Plot boundaries of unit simplex\nplt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], \"k\", label=\"Simplex\")\n\n# Annotate points 6 points around the simplex, and mid point inside simplex\nplt.annotate(\n r\"($\\frac{1}{3}$, $\\frac{1}{3}$, $\\frac{1}{3}$)\",\n xy=(1.0 / 3, 1.0 / 3),\n xytext=(1.0 / 3, 0.23),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.plot([1.0 / 3], [1.0 / 3], \"ko\", ms=5)\nplt.annotate(\n r\"($\\frac{1}{2}$, $0$, $\\frac{1}{2}$)\",\n xy=(0.5, 0.0),\n xytext=(0.5, 0.1),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.annotate(\n r\"($0$, $\\frac{1}{2}$, $\\frac{1}{2}$)\",\n xy=(0.0, 0.5),\n xytext=(0.1, 0.5),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.annotate(\n r\"($\\frac{1}{2}$, $\\frac{1}{2}$, $0$)\",\n xy=(0.5, 0.5),\n xytext=(0.6, 0.6),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.annotate(\n r\"($0$, $0$, $1$)\",\n xy=(0, 0),\n xytext=(0.1, 0.1),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.annotate(\n r\"($1$, $0$, $0$)\",\n xy=(1, 0),\n xytext=(1, 0.1),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\nplt.annotate(\n r\"($0$, $1$, $0$)\",\n xy=(0, 1),\n xytext=(0.1, 1),\n xycoords=\"data\",\n arrowprops=dict(facecolor=\"black\", shrink=0.05),\n horizontalalignment=\"center\",\n verticalalignment=\"center\",\n)\n# Add grid\nplt.grid(False)\nfor x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:\n plt.plot([0, x], [x, 0], \"k\", alpha=0.2)\n plt.plot([0, 0 + (1 - x) / 2], [x, x + (1 - x) / 2], \"k\", alpha=0.2)\n plt.plot([x, x + (1 - x) / 2], [0, 0 + (1 - x) / 2], \"k\", alpha=0.2)\n\nplt.title(\"Change of predicted probabilities on test samples after sigmoid calibration\")\nplt.xlabel(\"Probability class 1\")\nplt.ylabel(\"Probability class 2\")\nplt.xlim(-0.05, 1.05)\nplt.ylim(-0.05, 1.05)\n_ = plt.legend(loc=\"best\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the figure above, each vertex of the simplex represents\na perfectly predicted class (e.g., 1, 0, 0). The mid point\ninside the simplex represents predicting the three classes with equal\nprobability (i.e., 1/3, 1/3, 1/3). Each arrow starts at the\nuncalibrated probabilities and end with the arrow head at the calibrated\nprobability. The color of the arrow represents the true class of that test\nsample.\n\nThe uncalibrated classifier is overly confident in its predictions and\nincurs a large `log loss `. The calibrated classifier incurs\na lower `log loss ` due to two factors. First, notice in the\nfigure above that the arrows generally point away from the edges of the\nsimplex, where the probability of one class is 0. Second, a large proportion\nof the arrows point towards the true class, e.g., green arrows (samples where\nthe true class is 'green') generally point towards the green vertex. This\nresults in fewer over-confident, 0 predicted probabilities and at the same\ntime an increase in the predicted probabilities of the correct class.\nThus, the calibrated classifier produces more accurate predicted probabilities\nthat incur a lower `log loss `\n\nWe can show this objectively by comparing the `log loss ` of\nthe uncalibrated and calibrated classifiers on the predictions of the 1000\ntest samples. Note that an alternative would have been to increase the number\nof base estimators (trees) of the\n:class:`~sklearn.ensemble.RandomForestClassifier` which would have resulted\nin a similar decrease in `log loss `.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.metrics import log_loss\n\nscore = log_loss(y_test, clf_probs)\ncal_score = log_loss(y_test, cal_clf_probs)\n\nprint(\"Log-loss of\")\nprint(f\" * uncalibrated classifier: {score:.3f}\")\nprint(f\" * calibrated classifier: {cal_score:.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we generate a grid of possible uncalibrated probabilities over\nthe 2-simplex, compute the corresponding calibrated probabilities and\nplot arrows for each. The arrows are colored according the highest\nuncalibrated probability. This illustrates the learned calibration map:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n# Generate grid of probability values\np1d = np.linspace(0, 1, 20)\np0, p1 = np.meshgrid(p1d, p1d)\np2 = 1 - p0 - p1\np = np.c_[p0.ravel(), p1.ravel(), p2.ravel()]\np = p[p[:, 2] >= 0]\n\n# Use the three class-wise calibrators to compute calibrated probabilities\ncalibrated_classifier = cal_clf.calibrated_classifiers_[0]\nprediction = np.vstack(\n [\n calibrator.predict(this_p)\n for calibrator, this_p in zip(calibrated_classifier.calibrators, p.T)\n ]\n).T\n\n# Re-normalize the calibrated predictions to make sure they stay inside the\n# simplex. This same renormalization step is performed internally by the\n# predict method of CalibratedClassifierCV on multiclass problems.\nprediction /= prediction.sum(axis=1)[:, None]\n\n# Plot changes in predicted probabilities induced by the calibrators\nfor i in range(prediction.shape[0]):\n plt.arrow(\n p[i, 0],\n p[i, 1],\n prediction[i, 0] - p[i, 0],\n prediction[i, 1] - p[i, 1],\n head_width=1e-2,\n color=colors[np.argmax(p[i])],\n )\n\n# Plot the boundaries of the unit simplex\nplt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], \"k\", label=\"Simplex\")\n\nplt.grid(False)\nfor x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:\n plt.plot([0, x], [x, 0], \"k\", alpha=0.2)\n plt.plot([0, 0 + (1 - x) / 2], [x, x + (1 - x) / 2], \"k\", alpha=0.2)\n plt.plot([x, x + (1 - x) / 2], [0, 0 + (1 - x) / 2], \"k\", alpha=0.2)\n\nplt.title(\"Learned sigmoid calibration map\")\nplt.xlabel(\"Probability class 1\")\nplt.ylabel(\"Probability class 2\")\nplt.xlim(-0.05, 1.05)\nplt.ylim(-0.05, 1.05)\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }