{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `FeatureBinarizerFromTrees`\n", "\n", "The `FeatureBinarizerFromTrees` transformer binarizes features for BooleanRuleCG (BRCG), LogisticRuleRegression (LogRR), and LinearRuleRegression (LinearRR) models. It generates binary features (i.e. rules) based on the splits in fitted decision trees. This approach naturally creates optimal thresholds and returns only important features. Compared to `FeatureBinarizer`, the `FeatureBinarizerFromTrees` transformer reduces the number of features required to produce an accurate model. Not only does this shorten training times, but more importantly, it often results in simpler rule sets.\n", "\n", "This notebook demonstrates basic `FeatureBinarizerFromTrees`, compares `FeatureBinarizer`, and concludes with a formal performance comparison." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "from feature_binarizer_from_trees_demo import fbt_vs_fb_crime, format_results, get_corr_columns, print_metrics\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from pandas import DataFrame\n", "import pickle\n", "from time import time\n", "\n", "from aix360.algorithms.rbm import BooleanRuleCG, BRCGExplainer, FeatureBinarizer, FeatureBinarizerFromTrees, \\\n", " GLRMExplainer, LogisticRuleRegression\n", "from aix360.datasets.heloc_dataset import HELOCDataset, nan_preprocessing\n", "from aix360.datasets import MEPSDataset\n", "\n", "import shap\n", "from sklearn.datasets import load_breast_cancer\n", "from sklearn.model_selection import train_test_split\n", "\n", "def print_brcg_rules(rules):\n", " print('Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:\\n')\n", " for rule in rules:\n", " print(f' - {rule}')\n", " print()\n", "\n", "def fit_predict_bcrg(X_train_b, y_train, X_test_b, y_test, lambda0=0.001, lambda1=0.001):\n", " bcrg = BooleanRuleCG(lambda0, lambda1, silent=True)\n", " explainer = BRCGExplainer(bcrg)\n", " t = time()\n", " explainer.fit(X_train_b, y_train)\n", " print(f'Model trained in {time() - t:0.1f} seconds\\n')\n", " print_metrics(y_test, explainer.predict(X_test_b))\n", " print_brcg_rules(explainer.explain()['rules'])\n", " \n", "def fit_predict_logrr(X_train_b, X_train_std, y_train, X_test_b, X_test_std, y_test):\n", " logrr = LogisticRuleRegression(lambda0=0.005, lambda1=0.001, useOrd=True, maxSolverIter=1000)\n", " explainer = GLRMExplainer(logrr)\n", " t = time()\n", " explainer.fit(X_train_b, y_train, X_train_std)\n", " print(f'Model trained in {time() - t:0.1f} seconds\\n')\n", " print_metrics(y_test, explainer.predict(X_test_b, X_test_std))\n", " return explainer.explain()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Communities and Crime Data\n", "\n", "Create a binary classification problem to predict the top 25% of violent crimes from a subset of the [UCI Communities and Crime](https://archive.ics.uci.edu/ml/datasets/Communities+and+Crime+Unnormalized) data.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "X, y = shap.datasets.communitiesandcrime()\n", "y = (y >= np.percentile(y, 75)).astype(np.int)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After dropping highly correlated columns, there are 88 ordinal features." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1994, 88)\n" ] } ], "source": [ "X.drop(columns=get_corr_columns(X), inplace=True)\n", "print(X.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the data: 2/3 training, 1/3 test." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using `FeatureBinarizerFromTrees` with BRCG\n", "\n", "The code below initializes the default transformer and transforms the data. The default transformer uses one decision tree with a maximum node depth of 4. This will create up to 30 features. In this case, it generates only 28 features. Perhaps the fitted tree didn't require the maximum number of nodes to fit the data, or the binarizer dropped duplicates." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28 features.\n" ] } ], "source": [ "fbt = FeatureBinarizerFromTrees(randomState=0)\n", "X_train_b = fbt.fit_transform(X_train, y_train)\n", "X_test_b = fbt.transform(X_test)\n", "print(f'{X_train_b.shape[1]} features.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An obvious question might have crossed your mind by now: \"If you can sufficiently describe the data with features from a simple decision tree, why not use the decision tree as the explainable model?\" It is true that a decision tree may be a satisfactory model in some cases. However, rule sets generated by BCRG can be simpler and more accessible than a decision tree. Furthermore, this is just an introductory example. As we will show, we often need more than one simple decision tree to generate features for an accurate model.\n", "\n", "Here are the binarized features for `PctKidsBornNeverMar`. The binarizer selected two thresholds: 2.64 and 4.26. There are two complimentary features for each threshold with operators `<=` and `>`. Additional operators are supported for categorical and binary features, but are not shown in this notebook." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
operation<=>
value2.644.262.644.26
17650011
21641100
16910110
6970110
391100
\n", "
" ], "text/plain": [ "operation <= > \n", "value 2.64 4.26 2.64 4.26\n", "1765 0 0 1 1\n", "2164 1 1 0 0\n", "1691 0 1 1 0\n", "697 0 1 1 0\n", "39 1 1 0 0" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_b['PctKidsBornNeverMar'].head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we fit the model, predict the test set, and display the training time, test metrics, and rule set.\n", "\n", "The model trains in roughly 4 seconds and creates a simple, one-rule model with almost 84% accuracy using default BRCG model parameters. " ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 4.3 seconds\n", "\n", "Accuracy = 0.83567\n", "Precision = 0.74157\n", "Recall = 0.52800\n", "F1 = 0.61682\n", "F1 Weighted = 0.82562\n", "\n", "Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:\n", "\n", " - FemalePctDiv > 10.88 AND PctKidsBornNeverMar > 4.26 AND PctPopUnderPov > 9.80 AND PctSpeakEnglOnly <= 97.82\n", "\n" ] } ], "source": [ "fit_predict_bcrg(X_train_b, y_train, X_test_b, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this data set, we can easily improve the fit by changing a few binarizer parameters. The parameter values used here were manually selected for demonstration purposes. More optimal values are possible, especially if the BRCG parameters are also tuned." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "36 features.\n" ] } ], "source": [ "fbt = FeatureBinarizerFromTrees(treeNum=3, treeDepth=3, treeFeatureSelection=0.5, threshRound=0, randomState=0)\n", "X_train_b = fbt.fit_transform(X_train, y_train)\n", "X_test_b = fbt.transform(X_test)\n", "print(f'{X_train_b.shape[1]} features.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we describe the key parameters used above. (See the `FeatureBinarizerFromTrees` API for a full list of arguments.)\n", "\n", "- `treeNum` - The number of trees to fit. A value greater than one encourages a greater variety of features and thresholds.\n", "- `treeDepth` - The depth of the fitted decision trees. The greater the depth, the more features are generated.\n", "- `treeFeatureSelection` - The proportion of randomly chosen input features to consider at each split in the decision tree. When more than one tree is specified, this encourages a greater variety of features. See the API documentation for a full list of options.\n", "- `threshRound` - Round the threshold values to the given number of decimal places. Rounding the thresholds prevents near duplicate thresholds like 1.01 and 1.0. In the crime data, most of the features are ratios and integers, so rounding to the nearest integer value is acceptable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model trains in around 10 seconds and appears to improve accuracy significantly. Though more features improved the fit in this case, it is important to point out that more features are not always better. For both explainability and accuracy, we suggest starting with a small number of features. From there, increase the number of features incrementally until accuracy plateaus or the explanation is sufficient." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 10.1 seconds\n", "\n", "Accuracy = 0.86774\n", "Precision = 0.81720\n", "Recall = 0.60800\n", "F1 = 0.69725\n", "F1 Weighted = 0.86074\n", "\n", "Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:\n", "\n", " - HousVacant > 2172.00 AND PctKidsBornNeverMar > 3.00\n", " - FemalePctDiv > 12.00 AND PctKidsBornNeverMar > 4.00 AND PctPopUnderPov > 10.00 AND PctSpeakEnglOnly <= 98.00 AND racePctWhite <= 79.00\n", "\n" ] } ], "source": [ "fit_predict_bcrg(X_train_b, y_train, X_test_b, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using `FeatureBinarizerFromTrees` with Linear Models\n", "\n", "To use `FeatureBinarizerFromTrees` with LogRR and LinearRR, set `returnOrd=True`. Like the standard `FeatureBinarizer`, the transformer will return a standardized data frame of ordinal features in addition to the binarized features. The standardized features can then be passed to the linear model to improve accuracy. (Make sure to set `useOrd=True` for the linear model.)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28 features.\n" ] } ], "source": [ "fbt = FeatureBinarizerFromTrees(treeNum=2, treeDepth=4, treeFeatureSelection=None, threshRound=0, returnOrd=True, \n", " randomState=0)\n", "X_train_b, X_train_std = fbt.fit_transform(X_train, y_train)\n", "X_test_b, X_test_std = fbt.transform(X_test)\n", "print(f'{X_train_b.shape[1]} features.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The explanation for the fitted linear model lists the features in descending order by linear coefficient magnitude. For this feature set, the linear model does not appear to improve the accuracy significantly." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 5.5 seconds\n", "\n", "Accuracy = 0.86373\n", "Precision = 0.79381\n", "Recall = 0.61600\n", "F1 = 0.69369\n", "F1 Weighted = 0.85759\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rule/numerical featurecoefficient
0(intercept)-1.9073
1PctKidsBornNeverMar <= 4.00 AND PctSpeakEnglOn...-3.12689
2FemalePctDiv > 11.00 AND PctPopUnderPov > 10.0...2.35018
3PctKidsBornNeverMar <= 4.002.16666
4PctKidsBornNeverMar2.09972
5FemalePctDiv > 11.00 AND OwnOccQrange > 37500....1.70801
6FemalePctDiv1.07142
7HousVacant <= 2172.00-0.868909
8FemalePctDiv > 11.00 AND PctPopUnderPov > 10.0...-0.682629
9HousVacant0.65058
10PctKidsBornNeverMar <= 3.00-0.640423
11FemalePctDiv > 11.00 AND PctEmplManu <= 15.00 ...0.597499
12PctSpeakEnglOnly-0.510433
13pctWInvInc <= 38.000.484888
14pctWInvInc-0.476527
15pctWWage <= 74.000.381191
16PctEmplManu-0.315508
\n", "
" ], "text/plain": [ " rule/numerical feature coefficient\n", "0 (intercept) -1.9073\n", "1 PctKidsBornNeverMar <= 4.00 AND PctSpeakEnglOn... -3.12689\n", "2 FemalePctDiv > 11.00 AND PctPopUnderPov > 10.0... 2.35018\n", "3 PctKidsBornNeverMar <= 4.00 2.16666\n", "4 PctKidsBornNeverMar 2.09972\n", "5 FemalePctDiv > 11.00 AND OwnOccQrange > 37500.... 1.70801\n", "6 FemalePctDiv 1.07142\n", "7 HousVacant <= 2172.00 -0.868909\n", "8 FemalePctDiv > 11.00 AND PctPopUnderPov > 10.0... -0.682629\n", "9 HousVacant 0.65058\n", "10 PctKidsBornNeverMar <= 3.00 -0.640423\n", "11 FemalePctDiv > 11.00 AND PctEmplManu <= 15.00 ... 0.597499\n", "12 PctSpeakEnglOnly -0.510433\n", "13 pctWInvInc <= 38.00 0.484888\n", "14 pctWInvInc -0.476527\n", "15 pctWWage <= 74.00 0.381191\n", "16 PctEmplManu -0.315508" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit_predict_logrr(X_train_b, X_train_std, y_train, X_test_b, X_test_std, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare with `FeatureBinarizer`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The standard `FeatureBinarizer` creates thresholds by binning the data into a user-specified number of quantiles. The default setting of 9 thresholds creates 1,528 features for these data when negations are enabled. This is a very large feature space." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1528 features.\n" ] } ], "source": [ "fb = FeatureBinarizer(negations=True)\n", "X_train_b = fb.fit_transform(X_train)\n", "X_test_b = fb.transform(X_test)\n", "print(f'{X_train_b.shape[1]} features.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are the binary features associated with the `PctKidsBornNeverMar` input feature." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
operation<=>
value0.5700.9301.2401.6702.1302.7703.5284.9427.4320.5700.9301.2401.6702.1302.7703.5284.9427.432
1765000000000111111111
2164001111111110000000
1691000000011111111100
697000000111111111000
39011111111100000000
\n", "
" ], "text/plain": [ "operation <= > \\\n", "value 0.570 0.930 1.240 1.670 2.130 2.770 3.528 4.942 7.432 0.570 0.930 \n", "1765 0 0 0 0 0 0 0 0 0 1 1 \n", "2164 0 0 1 1 1 1 1 1 1 1 1 \n", "1691 0 0 0 0 0 0 0 1 1 1 1 \n", "697 0 0 0 0 0 0 1 1 1 1 1 \n", "39 0 1 1 1 1 1 1 1 1 1 0 \n", "\n", "operation \n", "value 1.240 1.670 2.130 2.770 3.528 4.942 7.432 \n", "1765 1 1 1 1 1 1 1 \n", "2164 0 0 0 0 0 0 0 \n", "1691 1 1 1 1 1 0 0 \n", "697 1 1 1 1 0 0 0 \n", "39 0 0 0 0 0 0 0 " ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_b['PctKidsBornNeverMar'].head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model takes more than 5 minutes to train. The test accuracy also appears to be lower and the rule set is complex." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 322.8 seconds\n", "\n", "Accuracy = 0.83166\n", "Precision = 0.72527\n", "Recall = 0.52800\n", "F1 = 0.61111\n", "F1 Weighted = 0.82207\n", "\n", "Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:\n", "\n", " - FemalePctDiv > 15.19 AND PctKidsBornNeverMar > 7.43\n", " - PctUnemployed > 5.56 AND PctImmigRec5 <= 29.48 AND NumStreet > 14.00\n", " - PctFam2Par <= 60.46 AND PctSpeakEnglOnly <= 96.91 AND PersPerRentOccHous > 2.23\n", " - PctTeen2Par <= 68.20 AND PctKidsBornNeverMar > 2.77 AND MedOwnCostPctIncNoMtg <= 13.10 AND PctSameHouse85 > 42.21\n", " - pctWInvInc <= 38.79 AND blackPerCap > 6280.60 AND PctKidsBornNeverMar > 4.94 AND OwnOccQrange > 31100.00 AND RentQrange > 145.00\n", "\n" ] } ], "source": [ "fit_predict_bcrg(X_train_b, y_train, X_test_b, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A more reasonable number of thresholds for this data set is 4. This setting generates 688 features. The accuracy is now comparable with the previous results, but it still takes approximately ten times longer to train the model (compared to 10 seconds). The rule set is also complex." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "688 features.\n", "Model trained in 102.8 seconds\n", "\n", "Accuracy = 0.85371\n", "Precision = 0.76000\n", "Recall = 0.60800\n", "F1 = 0.67556\n", "F1 Weighted = 0.84795\n", "\n", "Predict Y=1 if ANY of the following rules are satisfied, otherwise Y=0:\n", "\n", " - pctWInvInc <= 38.79 AND blackPerCap > 6280.60 AND PersPerFam > 3.05 AND PctKidsBornNeverMar > 4.94\n", " - racepctblack > 16.73 AND blackPerCap <= 11133.40 AND PctImmigRecent > 19.66 AND PctUsePubTrans > 0.25\n", " - agePct12t29 <= 27.62 AND PctTeen2Par <= 68.20 AND PctKidsBornNeverMar > 2.77 AND PctSameState85 <= 91.12\n", " - racepctblack > 1.72 AND PctEmploy <= 68.48 AND MalePctDivorce > 8.43 AND FemalePctDiv > 13.42 AND PctKidsBornNeverMar > 2.77 AND PctImmigRecent > 5.68 AND PctHousOccup <= 96.34 AND NumInShelters > 5.00\n", "\n" ] } ], "source": [ "fb = FeatureBinarizer(negations=True, numThresh=4)\n", "X_train_b = fb.fit_transform(X_train)\n", "X_test_b = fb.transform(X_test)\n", "print(f'{X_train_b.shape[1]} features.')\n", "fit_predict_bcrg(X_train_b, y_train, X_test_b, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This LogRR model has comparable accuracy to the one trained with `FeatureBinarizerFromTrees`, but it takes 3 minutes to train and it also has a more complex rule set." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [], "source": [ "fb = FeatureBinarizer(negations=True, returnOrd=True, numThresh=4)\n", "X_train_b, X_train_std = fb.fit_transform(X_train)\n", "X_test_b, X_test_std = fb.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 180.1 seconds\n", "\n", "Accuracy = 0.86172\n", "Precision = 0.75000\n", "Recall = 0.67200\n", "F1 = 0.70886\n", "F1 Weighted = 0.85911\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
rule/numerical featurecoefficient
0(intercept)-1.6929
1FemalePctDiv > 9.41 AND PctFam2Par <= 83.38 AN...1.89106
2blackPerCap <= 11133.40 AND FemalePctDiv > 9.411.44801
3agePct12t29-1.37336
4FemalePctDiv > 9.41 AND MedRentPctHousInc > 24.001.24788
5PctKidsBornNeverMar1.06566
6FemalePctDiv > 9.41 AND OwnOccLowQuart > 39380.00-0.972519
7PctHousOccup-0.908306
8MedOwnCostPctIncNoMtg-0.902969
9racepctblack <= 5.28-0.800059
10racepctblack > 1.72 AND FemalePctDiv > 9.410.787698
11racepctblack <= 16.73-0.767632
12PctOccupManu <= 14.70-0.75916
13PctTeen2Par <= 68.200.711791
14NumStreet <= 3.00-0.632624
15PctPopUnderPov <= 7.42-0.592511
16RentQrange <= 192.00-0.589909
17FemalePctDiv > 9.41 AND OwnOccQrange > 31100.000.560425
18PctPersDenseHous <= 1.90-0.53483
19PctImmigRecent <= 10.34-0.518697
20PctVacMore6Mos <= 37.500.480573
21HispPerCap > 6657.60 AND FemalePctDiv > 9.41 A...0.460934
22PctKidsBornNeverMar <= 2.77-0.45671
23pctWPubAsst <= 4.74-0.443499
24PctWOFullPlumb0.430075
25LandArea <= 17.80-0.42746
26population <= 18846.60-0.410339
27PctWOFullPlumb <= 0.64-0.404707
28FemalePctDiv > 9.41 AND PctWorkMomYoungKids <=...0.394237
29MedRentPctHousInc <= 26.90-0.389177
30pctUrban <= 99.47-0.3877
31MalePctDivorce <= 11.51-0.378395
32PctForeignBorn0.37134
33FemalePctDiv > 9.41 AND PctUsePubTrans > 0.250.358929
34HousVacant <= 1573.800.347633
35HousVacant0.347582
36HousVacant <= 761.800.326033
37PctNotHSGrad <= 24.46-0.313271
38FemalePctDiv > 9.41 AND PctUsePubTrans > 0.810.301539
39PctPopUnderPov > 7.42 AND FemalePctDiv > 9.410.2864
40PctUnemployed <= 6.23-0.229864
41FemalePctDiv <= 13.42-0.226307
42FemalePctDiv <= 11.60-0.22487
43LandArea0.20781
44LemasPctOfficDrugUn0.200326
45PersPerRentOccHous <= 2.39-0.196087
46pctWInvInc <= 38.790.190517
47NumInShelters <= 5.00-0.173206
48racePctWhite-0.157085
49NumStreet <= 0.000.156811
50PctPersDenseHous0.109269
51racePctWhite > 84.60-0.0930171
52FemalePctDiv > 9.41 AND MedRentPctHousInc > 26.90-0.0847854
53FemalePctDiv > 9.41 AND RentQrange > 134.00-0.0266475
54PctKidsBornNeverMar <= 4.940.024893
\n", "
" ], "text/plain": [ " rule/numerical feature coefficient\n", "0 (intercept) -1.6929\n", "1 FemalePctDiv > 9.41 AND PctFam2Par <= 83.38 AN... 1.89106\n", "2 blackPerCap <= 11133.40 AND FemalePctDiv > 9.41 1.44801\n", "3 agePct12t29 -1.37336\n", "4 FemalePctDiv > 9.41 AND MedRentPctHousInc > 24.00 1.24788\n", "5 PctKidsBornNeverMar 1.06566\n", "6 FemalePctDiv > 9.41 AND OwnOccLowQuart > 39380.00 -0.972519\n", "7 PctHousOccup -0.908306\n", "8 MedOwnCostPctIncNoMtg -0.902969\n", "9 racepctblack <= 5.28 -0.800059\n", "10 racepctblack > 1.72 AND FemalePctDiv > 9.41 0.787698\n", "11 racepctblack <= 16.73 -0.767632\n", "12 PctOccupManu <= 14.70 -0.75916\n", "13 PctTeen2Par <= 68.20 0.711791\n", "14 NumStreet <= 3.00 -0.632624\n", "15 PctPopUnderPov <= 7.42 -0.592511\n", "16 RentQrange <= 192.00 -0.589909\n", "17 FemalePctDiv > 9.41 AND OwnOccQrange > 31100.00 0.560425\n", "18 PctPersDenseHous <= 1.90 -0.53483\n", "19 PctImmigRecent <= 10.34 -0.518697\n", "20 PctVacMore6Mos <= 37.50 0.480573\n", "21 HispPerCap > 6657.60 AND FemalePctDiv > 9.41 A... 0.460934\n", "22 PctKidsBornNeverMar <= 2.77 -0.45671\n", "23 pctWPubAsst <= 4.74 -0.443499\n", "24 PctWOFullPlumb 0.430075\n", "25 LandArea <= 17.80 -0.42746\n", "26 population <= 18846.60 -0.410339\n", "27 PctWOFullPlumb <= 0.64 -0.404707\n", "28 FemalePctDiv > 9.41 AND PctWorkMomYoungKids <=... 0.394237\n", "29 MedRentPctHousInc <= 26.90 -0.389177\n", "30 pctUrban <= 99.47 -0.3877\n", "31 MalePctDivorce <= 11.51 -0.378395\n", "32 PctForeignBorn 0.37134\n", "33 FemalePctDiv > 9.41 AND PctUsePubTrans > 0.25 0.358929\n", "34 HousVacant <= 1573.80 0.347633\n", "35 HousVacant 0.347582\n", "36 HousVacant <= 761.80 0.326033\n", "37 PctNotHSGrad <= 24.46 -0.313271\n", "38 FemalePctDiv > 9.41 AND PctUsePubTrans > 0.81 0.301539\n", "39 PctPopUnderPov > 7.42 AND FemalePctDiv > 9.41 0.2864\n", "40 PctUnemployed <= 6.23 -0.229864\n", "41 FemalePctDiv <= 13.42 -0.226307\n", "42 FemalePctDiv <= 11.60 -0.22487\n", "43 LandArea 0.20781\n", "44 LemasPctOfficDrugUn 0.200326\n", "45 PersPerRentOccHous <= 2.39 -0.196087\n", "46 pctWInvInc <= 38.79 0.190517\n", "47 NumInShelters <= 5.00 -0.173206\n", "48 racePctWhite -0.157085\n", "49 NumStreet <= 0.00 0.156811\n", "50 PctPersDenseHous 0.109269\n", "51 racePctWhite > 84.60 -0.0930171\n", "52 FemalePctDiv > 9.41 AND MedRentPctHousInc > 26.90 -0.0847854\n", "53 FemalePctDiv > 9.41 AND RentQrange > 134.00 -0.0266475\n", "54 PctKidsBornNeverMar <= 4.94 0.024893" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fit_predict_logrr(X_train_b, X_train_std, y_train, X_test_b, X_test_std, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Formal Performance Comparison" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we provide a formal performance comparison between `FeatureBinarizerFromTrees` and `FeatureBinarizer` over 30 random train-test splits. The settings for binarizers and models are as follows.\n", "\n", "- `FeatureBinarizerFromTrees`: `treeNum=2, treeDepth=4, treeFeatureSelection=None, returnOrd=True`\n", "- `FeatureBinarizer`: `numThresh=4, negations=True, returnOrd=True`\n", "- `BooleanRuleCG`: Defaults\n", "- `LogisticRuleRegression`: `lambda0=0.005, lambda1=0.001, useOrd=True, maxSolverIter=1000`\n", "\n", "This process takes over two hours to run, so we saved the output and loaded it here for display. To re-run the test, uncomment the code below." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Wall time: 2h 17min 35s\n" ] } ], "source": [ "# %%time\n", "# df = fbt_vs_fb_crime(iterations=30, treeNum=2, treeDepth=4, numThresh=4, filename='./data/crime.pkl')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the table below, 'fb' and 'fbt' indicate models fit with `FeatureBinarizer` and `FeatureBinarizerFromTrees`, respectively.\n", "\n", "For these data and settings, the output shows that models trained using `FeatureBinarizerFromTrees` fit, on average, in less than 1/10th of the time and generate rule sets with significantly fewer clauses (i.e., the explanations are significantly less complex). There are no statistically significant differences in the mean scoring metrics." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
timeaccuracyprecisionrecallf1rulesclauses
meanstdmeanstdmeanstdmeanstdmeanstdmeanstdmeanstd
brcgfb103.903(33.13)0.855(0.012)0.768(0.03)0.607(0.057)0.676(0.036)3.233(0.935)13.867(3.213)
fbt7.209(1.803)0.854(0.014)0.756(0.046)0.624(0.06)0.681(0.036)2.533(0.776)7.2(1.901)
logrrfb151.336(48.178)0.866(0.013)0.746(0.033)0.708(0.044)0.726(0.029)45.967(4.491)46.967(4.491)
fbt8.887(3.489)0.863(0.011)0.745(0.031)0.691(0.035)0.717(0.023)15.833(1.802)16.833(1.802)
\n", "
" ], "text/plain": [ " time accuracy precision recall \\\n", " mean std mean std mean std mean \n", "brcg fb 103.903 (33.13) 0.855 (0.012) 0.768 (0.03) 0.607 \n", " fbt 7.209 (1.803) 0.854 (0.014) 0.756 (0.046) 0.624 \n", "logrr fb 151.336 (48.178) 0.866 (0.013) 0.746 (0.033) 0.708 \n", " fbt 8.887 (3.489) 0.863 (0.011) 0.745 (0.031) 0.691 \n", "\n", " f1 rules clauses \n", " std mean std mean std mean std \n", "brcg fb (0.057) 0.676 (0.036) 3.233 (0.935) 13.867 (3.213) \n", " fbt (0.06) 0.681 (0.036) 2.533 (0.776) 7.2 (1.901) \n", "logrr fb (0.044) 0.726 (0.029) 45.967 (4.491) 46.967 (4.491) \n", " fbt (0.035) 0.717 (0.023) 15.833 (1.802) 16.833 (1.802) " ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with open('./data/crime.pkl', 'rb') as fl:\n", " df = pickle.load(fl)\n", " \n", "format_results(df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 4 }