{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Decision Tree Regression\nIn this example, we demonstrate the effect of changing the maximum depth of a\ndecision tree on how it fits to the data. We perform this once on a 1D regression\ntask and once on a multi-output regression task.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision Tree on a 1D Regression Task\n\nHere we fit a tree on a 1D regression task.\n\nThe `decision trees ` is\nused to fit a sine curve with addition noisy observation. As a result, it\nlearns local linear regressions approximating the sine curve.\n\nWe can see that if the maximum depth of the tree (controlled by the\n`max_depth` parameter) is set too high, the decision trees learn too fine\ndetails of the training data and learn from the noise, i.e. they overfit.\n\n### Create a random 1D dataset\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\nrng = np.random.RandomState(1)\nX = np.sort(5 * rng.rand(80, 1), axis=0)\ny = np.sin(X).ravel()\ny[::5] += 3 * (0.5 - rng.rand(16))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit regression model\nHere we fit two models with different maximum depths\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeRegressor\n\nregr_1 = DecisionTreeRegressor(max_depth=2)\nregr_2 = DecisionTreeRegressor(max_depth=5)\nregr_1.fit(X, y)\nregr_2.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict\nGet predictions on the test set\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]\ny_1 = regr_1.predict(X_test)\ny_2 = regr_2.predict(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nplt.figure()\nplt.scatter(X, y, s=20, edgecolor=\"black\", c=\"darkorange\", label=\"data\")\nplt.plot(X_test, y_1, color=\"cornflowerblue\", label=\"max_depth=2\", linewidth=2)\nplt.plot(X_test, y_2, color=\"yellowgreen\", label=\"max_depth=5\", linewidth=2)\nplt.xlabel(\"data\")\nplt.ylabel(\"target\")\nplt.title(\"Decision Tree Regression\")\nplt.legend()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the model with a depth of 5 (yellow) learns the details of the\ntraining data to the point that it overfits to the noise. On the other hand,\nthe model with a depth of 2 (blue) learns the major tendencies in the data well\nand does not overfit. In real use cases, you need to make sure that the tree\nis not overfitting the training data, which can be done using cross-validation.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision Tree Regression with Multi-Output Targets\n\nHere the `decision trees `\nis used to predict simultaneously the noisy `x` and `y` observations of a circle\ngiven a single underlying feature. As a result, it learns local linear\nregressions approximating the circle.\n\nWe can see that if the maximum depth of the tree (controlled by the\n`max_depth` parameter) is set too high, the decision trees learn too fine\ndetails of the training data and learn from the noise, i.e. they overfit.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a random dataset\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rng = np.random.RandomState(1)\nX = np.sort(200 * rng.rand(100, 1) - 100, axis=0)\ny = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T\ny[::5, :] += 0.5 - rng.rand(20, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit regression model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "regr_1 = DecisionTreeRegressor(max_depth=2)\nregr_2 = DecisionTreeRegressor(max_depth=5)\nregr_3 = DecisionTreeRegressor(max_depth=8)\nregr_1.fit(X, y)\nregr_2.fit(X, y)\nregr_3.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict\nGet predictions on the test set\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]\ny_1 = regr_1.predict(X_test)\ny_2 = regr_2.predict(X_test)\ny_3 = regr_3.predict(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure()\ns = 25\nplt.scatter(y[:, 0], y[:, 1], c=\"yellow\", s=s, edgecolor=\"black\", label=\"data\")\nplt.scatter(\n y_1[:, 0],\n y_1[:, 1],\n c=\"cornflowerblue\",\n s=s,\n edgecolor=\"black\",\n label=\"max_depth=2\",\n)\nplt.scatter(y_2[:, 0], y_2[:, 1], c=\"red\", s=s, edgecolor=\"black\", label=\"max_depth=5\")\nplt.scatter(y_3[:, 0], y_3[:, 1], c=\"blue\", s=s, edgecolor=\"black\", label=\"max_depth=8\")\nplt.xlim([-6, 6])\nplt.ylim([-6, 6])\nplt.xlabel(\"target 1\")\nplt.ylabel(\"target 2\")\nplt.title(\"Multi-output Decision Tree Regression\")\nplt.legend(loc=\"best\")\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the higher the value of `max_depth`, the more details of the data\nare caught by the model. However, the model also overfits to the data and is\ninfluenced by the noise.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }