{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# OOB Errors for Random Forests\n\nThe ``RandomForestClassifier`` is trained using *bootstrap aggregation*, where\neach new tree is fit from a bootstrap sample of the training observations\n$z_i = (x_i, y_i)$. The *out-of-bag* (OOB) error is the average error for\neach $z_i$ calculated using predictions from the trees that do not\ncontain $z_i$ in their respective bootstrap sample. This allows the\n``RandomForestClassifier`` to be fit and validated whilst being trained [1]_.\n\nThe example below demonstrates how the OOB error can be measured at the\naddition of each new tree during training. The resulting plot allows a\npractitioner to approximate a suitable value of ``n_estimators`` at which the\nerror stabilizes.\n\n.. [1] T. Hastie, R. Tibshirani and J. Friedman, \"Elements of Statistical\n Learning Ed. 2\", p592-593, Springer, 2009.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom collections import OrderedDict\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn.datasets import make_classification\nfrom sklearn.ensemble import RandomForestClassifier\n\nRANDOM_STATE = 123\n\n# Generate a binary classification dataset.\nX, y = make_classification(\n n_samples=500,\n n_features=25,\n n_clusters_per_class=1,\n n_informative=15,\n random_state=RANDOM_STATE,\n)\n\n# NOTE: Setting the `warm_start` construction parameter to `True` disables\n# support for parallelized ensembles but is necessary for tracking the OOB\n# error trajectory during training.\nensemble_clfs = [\n (\n \"RandomForestClassifier, max_features='sqrt'\",\n RandomForestClassifier(\n warm_start=True,\n oob_score=True,\n max_features=\"sqrt\",\n random_state=RANDOM_STATE,\n ),\n ),\n (\n \"RandomForestClassifier, max_features='log2'\",\n RandomForestClassifier(\n warm_start=True,\n max_features=\"log2\",\n oob_score=True,\n random_state=RANDOM_STATE,\n ),\n ),\n (\n \"RandomForestClassifier, max_features=None\",\n RandomForestClassifier(\n warm_start=True,\n max_features=None,\n oob_score=True,\n random_state=RANDOM_STATE,\n ),\n ),\n]\n\n# Map a classifier name to a list of (, ) pairs.\nerror_rate = OrderedDict((label, []) for label, _ in ensemble_clfs)\n\n# Range of `n_estimators` values to explore.\nmin_estimators = 15\nmax_estimators = 150\n\nfor label, clf in ensemble_clfs:\n for i in range(min_estimators, max_estimators + 1, 5):\n clf.set_params(n_estimators=i)\n clf.fit(X, y)\n\n # Record the OOB error for each `n_estimators=i` setting.\n oob_error = 1 - clf.oob_score_\n error_rate[label].append((i, oob_error))\n\n# Generate the \"OOB error rate\" vs. \"n_estimators\" plot.\nfor label, clf_err in error_rate.items():\n xs, ys = zip(*clf_err)\n plt.plot(xs, ys, label=label)\n\nplt.xlim(min_estimators, max_estimators)\nplt.xlabel(\"n_estimators\")\nplt.ylabel(\"OOB error rate\")\nplt.legend(loc=\"upper right\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }