{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Permutation Importance with Multicollinear or Correlated Features\n\nIn this example, we compute the\n:func:`~sklearn.inspection.permutation_importance` of the features to a trained\n:class:`~sklearn.ensemble.RandomForestClassifier` using the\n`breast_cancer_dataset`. The model can easily get about 97% accuracy on a\ntest dataset. Because this dataset contains multicollinear features, the\npermutation importance shows that none of the features are important, in\ncontradiction with the high test accuracy.\n\nWe demo a possible approach to handling multicollinearity, which consists of\nhierarchical clustering on the features' Spearman rank-order correlations,\npicking a threshold, and keeping a single feature from each cluster.\n\n

Note

See also\n `sphx_glr_auto_examples_inspection_plot_permutation_importance.py`

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random Forest Feature Importance on Breast Cancer Data\n\nFirst, we define a function to ease the plotting:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib\n\nfrom sklearn.inspection import permutation_importance\nfrom sklearn.utils.fixes import parse_version\n\n\ndef plot_permutation_importance(clf, X, y, ax):\n result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)\n perm_sorted_idx = result.importances_mean.argsort()\n\n # `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been\n # renamed to `tick_labels`. The following code handles this, but as a\n # scikit-learn user you probably can write simpler code by using `labels=...`\n # (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).\n tick_labels_parameter_name = (\n \"tick_labels\"\n if parse_version(matplotlib.__version__) >= parse_version(\"3.9\")\n else \"labels\"\n )\n tick_labels_dict = {tick_labels_parameter_name: X.columns[perm_sorted_idx]}\n ax.boxplot(result.importances[perm_sorted_idx].T, vert=False, **tick_labels_dict)\n ax.axvline(x=0, color=\"k\", linestyle=\"--\")\n return ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then train a :class:`~sklearn.ensemble.RandomForestClassifier` on the\n`breast_cancer_dataset` and evaluate its accuracy on a test set:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import load_breast_cancer\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.model_selection import train_test_split\n\nX, y = load_breast_cancer(return_X_y=True, as_frame=True)\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n\nclf = RandomForestClassifier(n_estimators=100, random_state=42)\nclf.fit(X_train, y_train)\nprint(f\"Baseline accuracy on test data: {clf.score(X_test, y_test):.2}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we plot the tree based feature importance and the permutation\nimportance. The permutation importance is calculated on the training set to\nshow how much the model relies on each feature during training.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nmdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)\ntree_importance_sorted_idx = np.argsort(clf.feature_importances_)\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\nmdi_importances.sort_values().plot.barh(ax=ax1)\nax1.set_xlabel(\"Gini importance\")\nplot_permutation_importance(clf, X_train, y_train, ax2)\nax2.set_xlabel(\"Decrease in accuracy score\")\nfig.suptitle(\n \"Impurity-based vs. permutation importances on multicollinear features (train set)\"\n)\n_ = fig.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The plot on the left shows the Gini importance of the model. As the\nscikit-learn implementation of\n:class:`~sklearn.ensemble.RandomForestClassifier` uses a random subsets of\n$\\sqrt{n_\\text{features}}$ features at each split, it is able to dilute\nthe dominance of any single correlated feature. As a result, the individual\nfeature importance may be distributed more evenly among the correlated\nfeatures. Since the features have large cardinality and the classifier is\nnon-overfitted, we can relatively trust those values.\n\nThe permutation importance on the right plot shows that permuting a feature\ndrops the accuracy by at most `0.012`, which would suggest that none of the\nfeatures are important. This is in contradiction with the high test accuracy\ncomputed as baseline: some feature must be important.\n\nSimilarly, the change in accuracy score computed on the test set appears to be\ndriven by chance:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(7, 6))\nplot_permutation_importance(clf, X_test, y_test, ax)\nax.set_title(\"Permutation Importances on multicollinear features\\n(test set)\")\nax.set_xlabel(\"Decrease in accuracy score\")\n_ = ax.figure.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nevertheless, one can still compute a meaningful permutation importance in the\npresence of correlated features, as demonstrated in the following section.\n\n## Handling Multicollinear Features\nWhen features are collinear, permuting one feature has little effect on the\nmodels performance because it can get the same information from a correlated\nfeature. Note that this is not the case for all predictive models and depends\non their underlying implementation.\n\nOne way to handle multicollinear features is by performing hierarchical\nclustering on the Spearman rank-order correlations, picking a threshold, and\nkeeping a single feature from each cluster. First, we plot a heatmap of the\ncorrelated features:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from scipy.cluster import hierarchy\nfrom scipy.spatial.distance import squareform\nfrom scipy.stats import spearmanr\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\ncorr = spearmanr(X).correlation\n\n# Ensure the correlation matrix is symmetric\ncorr = (corr + corr.T) / 2\nnp.fill_diagonal(corr, 1)\n\n# We convert the correlation matrix to a distance matrix before performing\n# hierarchical clustering using Ward's linkage.\ndistance_matrix = 1 - np.abs(corr)\ndist_linkage = hierarchy.ward(squareform(distance_matrix))\ndendro = hierarchy.dendrogram(\n dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90\n)\ndendro_idx = np.arange(0, len(dendro[\"ivl\"]))\n\nax2.imshow(corr[dendro[\"leaves\"], :][:, dendro[\"leaves\"]])\nax2.set_xticks(dendro_idx)\nax2.set_yticks(dendro_idx)\nax2.set_xticklabels(dendro[\"ivl\"], rotation=\"vertical\")\nax2.set_yticklabels(dendro[\"ivl\"])\n_ = fig.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we manually pick a threshold by visual inspection of the dendrogram to\ngroup our features into clusters and choose a feature from each cluster to\nkeep, select those features from our dataset, and train a new random forest.\nThe test accuracy of the new random forest did not change much compared to the\nrandom forest trained on the complete dataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from collections import defaultdict\n\ncluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion=\"distance\")\ncluster_id_to_feature_ids = defaultdict(list)\nfor idx, cluster_id in enumerate(cluster_ids):\n cluster_id_to_feature_ids[cluster_id].append(idx)\nselected_features = [v[0] for v in cluster_id_to_feature_ids.values()]\nselected_features_names = X.columns[selected_features]\n\nX_train_sel = X_train[selected_features_names]\nX_test_sel = X_test[selected_features_names]\n\nclf_sel = RandomForestClassifier(n_estimators=100, random_state=42)\nclf_sel.fit(X_train_sel, y_train)\nprint(\n \"Baseline accuracy on test data with features removed:\"\n f\" {clf_sel.score(X_test_sel, y_test):.2}\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can finally explore the permutation importance of the selected subset of\nfeatures:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(7, 6))\nplot_permutation_importance(clf_sel, X_test_sel, y_test, ax)\nax.set_title(\"Permutation Importances on selected subset of features\\n(test set)\")\nax.set_xlabel(\"Decrease in accuracy score\")\nax.figure.tight_layout()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }