{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Gradient Boosting regression\n\nThis example demonstrates Gradient Boosting to produce a predictive\nmodel from an ensemble of weak predictive models. Gradient boosting can be used\nfor regression and classification problems. Here, we will train a model to\ntackle a diabetes regression task. We will obtain the results from\n:class:`~sklearn.ensemble.GradientBoostingRegressor` with least squares loss\nand 500 regression trees of depth 4.\n\nNote: For larger datasets (n_samples >= 10000), please refer to\n:class:`~sklearn.ensemble.HistGradientBoostingRegressor`. See\n`sphx_glr_auto_examples_ensemble_plot_hgbt_regression.py` for an example\nshowcasing some other advantages of\n:class:`~ensemble.HistGradientBoostingRegressor`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import datasets, ensemble\nfrom sklearn.inspection import permutation_importance\nfrom sklearn.metrics import mean_squared_error\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.utils.fixes import parse_version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n\nFirst we need to load the data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "diabetes = datasets.load_diabetes()\nX, y = diabetes.data, diabetes.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preprocessing\n\nNext, we will split our dataset to use 90% for training and leave the rest\nfor testing. We will also set the regression model parameters. You can play\nwith these parameters to see how the results change.\n\n`n_estimators` : the number of boosting stages that will be performed.\nLater, we will plot deviance against boosting iterations.\n\n`max_depth` : limits the number of nodes in the tree.\nThe best value depends on the interaction of the input variables.\n\n`min_samples_split` : the minimum number of samples required to split an\ninternal node.\n\n`learning_rate` : how much the contribution of each tree will shrink.\n\n`loss` : loss function to optimize. The least squares function is used in\nthis case however, there are many other options (see\n:class:`~sklearn.ensemble.GradientBoostingRegressor` ).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(\n X, y, test_size=0.1, random_state=13\n)\n\nparams = {\n \"n_estimators\": 500,\n \"max_depth\": 4,\n \"min_samples_split\": 5,\n \"learning_rate\": 0.01,\n \"loss\": \"squared_error\",\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit regression model\n\nNow we will initiate the gradient boosting regressors and fit it with our\ntraining data. Let's also look and the mean squared error on the test data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg = ensemble.GradientBoostingRegressor(**params)\nreg.fit(X_train, y_train)\n\nmse = mean_squared_error(y_test, reg.predict(X_test))\nprint(\"The mean squared error (MSE) on test set: {:.4f}\".format(mse))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot training deviance\n\nFinally, we will visualize the results. To do that we will first compute the\ntest set deviance and then plot it against boosting iterations.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "test_score = np.zeros((params[\"n_estimators\"],), dtype=np.float64)\nfor i, y_pred in enumerate(reg.staged_predict(X_test)):\n test_score[i] = mean_squared_error(y_test, y_pred)\n\nfig = plt.figure(figsize=(6, 6))\nplt.subplot(1, 1, 1)\nplt.title(\"Deviance\")\nplt.plot(\n np.arange(params[\"n_estimators\"]) + 1,\n reg.train_score_,\n \"b-\",\n label=\"Training Set Deviance\",\n)\nplt.plot(\n np.arange(params[\"n_estimators\"]) + 1, test_score, \"r-\", label=\"Test Set Deviance\"\n)\nplt.legend(loc=\"upper right\")\nplt.xlabel(\"Boosting Iterations\")\nplt.ylabel(\"Deviance\")\nfig.tight_layout()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot feature importance\n\n

Warning

Careful, impurity-based feature importances can be misleading for\n **high cardinality** features (many unique values). As an alternative,\n the permutation importances of ``reg`` can be computed on a\n held out test set. See `permutation_importance` for more details.

\n\nFor this example, the impurity-based and permutation methods identify the\nsame 2 strongly predictive features but not in the same order. The third most\npredictive feature, \"bp\", is also the same for the 2 methods. The remaining\nfeatures are less predictive and the error bars of the permutation plot\nshow that they overlap with 0.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "feature_importance = reg.feature_importances_\nsorted_idx = np.argsort(feature_importance)\npos = np.arange(sorted_idx.shape[0]) + 0.5\nfig = plt.figure(figsize=(12, 6))\nplt.subplot(1, 2, 1)\nplt.barh(pos, feature_importance[sorted_idx], align=\"center\")\nplt.yticks(pos, np.array(diabetes.feature_names)[sorted_idx])\nplt.title(\"Feature Importance (MDI)\")\n\nresult = permutation_importance(\n reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2\n)\nsorted_idx = result.importances_mean.argsort()\nplt.subplot(1, 2, 2)\n\n# `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been\n# renamed to `tick_labels`. The following code handles this, but as a\n# scikit-learn user you probably can write simpler code by using `labels=...`\n# (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).\ntick_labels_parameter_name = (\n \"tick_labels\"\n if parse_version(matplotlib.__version__) >= parse_version(\"3.9\")\n else \"labels\"\n)\ntick_labels_dict = {\n tick_labels_parameter_name: np.array(diabetes.feature_names)[sorted_idx]\n}\nplt.boxplot(result.importances[sorted_idx].T, vert=False, **tick_labels_dict)\nplt.title(\"Permutation Importance (test set)\")\nfig.tight_layout()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }