{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Feature importances with a forest of trees\n\nThis example shows the use of a forest of trees to evaluate the importance of\nfeatures on an artificial classification task. The blue bars are the feature\nimportances of the forest, along with their inter-trees variability represented\nby the error bars.\n\nAs expected, the plot suggests that 3 features are informative, while the\nremaining are not.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data generation and model fitting\nWe generate a synthetic dataset with only 3 informative features. We will\nexplicitly not shuffle the dataset to ensure that the informative features\nwill correspond to the three first columns of X. In addition, we will split\nour dataset into training and testing subsets.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import make_classification\nfrom sklearn.model_selection import train_test_split\n\nX, y = make_classification(\n n_samples=1000,\n n_features=10,\n n_informative=3,\n n_redundant=0,\n n_repeated=0,\n n_classes=2,\n random_state=0,\n shuffle=False,\n)\nX_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A random forest classifier will be fitted to compute the feature importances.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n\nfeature_names = [f\"feature {i}\" for i in range(X.shape[1])]\nforest = RandomForestClassifier(random_state=0)\nforest.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature importance based on mean decrease in impurity\nFeature importances are provided by the fitted attribute\n`feature_importances_` and they are computed as the mean and standard\ndeviation of accumulation of the impurity decrease within each tree.\n\n

Warning

Impurity-based feature importances can be misleading for **high\n cardinality** features (many unique values). See\n `permutation_importance` as an alternative below.

\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import time\n\nimport numpy as np\n\nstart_time = time.time()\nimportances = forest.feature_importances_\nstd = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)\nelapsed_time = time.time() - start_time\n\nprint(f\"Elapsed time to compute the importances: {elapsed_time:.3f} seconds\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the impurity-based importance.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import pandas as pd\n\nforest_importances = pd.Series(importances, index=feature_names)\n\nfig, ax = plt.subplots()\nforest_importances.plot.bar(yerr=std, ax=ax)\nax.set_title(\"Feature importances using MDI\")\nax.set_ylabel(\"Mean decrease in impurity\")\nfig.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We observe that, as expected, the three first features are found important.\n\n## Feature importance based on feature permutation\nPermutation feature importance overcomes limitations of the impurity-based\nfeature importance: they do not have a bias toward high-cardinality features\nand can be computed on a left-out test set.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.inspection import permutation_importance\n\nstart_time = time.time()\nresult = permutation_importance(\n forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2\n)\nelapsed_time = time.time() - start_time\nprint(f\"Elapsed time to compute the importances: {elapsed_time:.3f} seconds\")\n\nforest_importances = pd.Series(result.importances_mean, index=feature_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The computation for full permutation importance is more costly. Features are\nshuffled n times and the model refitted to estimate the importance of it.\nPlease see `permutation_importance` for more details. We can now plot\nthe importance ranking.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots()\nforest_importances.plot.bar(yerr=result.importances_std, ax=ax)\nax.set_title(\"Feature importances using permutation on full model\")\nax.set_ylabel(\"Mean accuracy decrease\")\nfig.tight_layout()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same features are detected as most important using both methods. Although\nthe relative importances vary. As seen on the plots, MDI is less likely than\npermutation importance to fully omit a feature.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }