{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import load_iris, load_breast_cancer\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.inspection import permutation_importance as skpermutation_importance" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def accuracy_score(y_true, y_pred):\n", " return np.mean(y_true == y_pred)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def permutation_importance(estimator, X, y, n_repeats=5, random_state=0):\n", " baseline_score = accuracy_score(y, estimator.predict(X))\n", " importances = np.zeros((X.shape[1], n_repeats))\n", " rng = np.random.RandomState(random_state)\n", " for i in range(X.shape[1]):\n", " original_features = X[:, i].copy()\n", " for j in range(n_repeats):\n", " rng.shuffle(X[:, i])\n", " importances[i, j] = accuracy_score(y, estimator.predict(X))\n", " X[:, i] = original_features\n", " importances = baseline_score - importances\n", " return importances, importances.mean(axis=1), importances.std(axis=1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X, y = load_iris(return_X_y=True)\n", "clf = RandomForestClassifier(random_state=0).fit(X, y)\n", "importance1, mean1, std1 = permutation_importance(clf, X, y, random_state=0)\n", "importance2 = skpermutation_importance(clf, X, y, random_state=0)\n", "assert np.allclose(importance1, importance2.importances)\n", "assert np.allclose(mean1, importance2.importances_mean)\n", "assert np.allclose(std1, importance2.importances_std)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X, y = load_breast_cancer(return_X_y=True)\n", "clf = RandomForestClassifier(random_state=0).fit(X, y)\n", "importance1, mean1, std1 = permutation_importance(clf, X, y, random_state=0)\n", "importance2 = skpermutation_importance(clf, X, y, random_state=0)\n", "assert np.allclose(importance1, importance2.importances)\n", "assert np.allclose(mean1, importance2.importances_mean)\n", "assert np.allclose(std1, importance2.importances_std)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }