{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Partitioning feature space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Make sure to get latest dtreeviz**" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "! pip install -q -U dtreeviz\n", "! pip install -q graphviz==0.17 # 0.18 deletes the `run` func I need" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn.linear_model import LinearRegression, Ridge, Lasso, LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n", "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor\n", "from sklearn.datasets import load_boston, load_iris, load_wine, load_digits, \\\n", " load_breast_cancer, load_diabetes\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error, accuracy_score\n", "\n", "import matplotlib.pyplot as plt\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "from sklearn import tree\n", "from dtreeviz.trees import *\n", "from dtreeviz.models.shadow_decision_tree import ShadowDecTree" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def show_mse_leaves(X,y,max_depth):\n", " t = DecisionTreeRegressor(max_depth=max_depth)\n", " t.fit(X,y)\n", " shadow = ShadowDecTree.get_shadow_tree(t, X, y, feature_names=['sqfeet'], target_name='rent')\n", " root, leaves, internal = shadow._get_tree_nodes()\n", " # node2samples = shadow._get_tree_nodes()_samples()\n", " # isleaf = shadow.get_node_type(t)\n", " n_node_samples = t.tree_.n_node_samples\n", "\n", " mse = 99.9#mean_squared_error(y, [np.mean(y)]*len(y))\n", " print(f\"Root {0:3d} has {n_node_samples[0]:3d} samples with MSE ={mse:6.2f}\")\n", " print(\"-----------------------------------------\")\n", "\n", " avg_mse_per_record = 0.0\n", " node2samples = shadow.get_node_samples()\n", " for node in leaves:\n", " leafy = y[node2samples[node.id]]\n", " n = len(leafy)\n", " mse = mean_squared_error(leafy, [np.mean(leafy)]*n)\n", " avg_mse_per_record += mse * n\n", " print(f\"Node {node.id:3d} has {n_node_samples[node.id]:3d} samples with MSE ={mse:6.2f}\")\n", "\n", " avg_mse_per_record /= len(y)\n", " print(f\"Average MSE per record is {avg_mse_per_record:.1f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "pycharm": { "is_executing": false } }, "outputs": [ { "data": { "text/html": [ "
\n", " | MPG | \n", "CYL | \n", "ENG | \n", "WGT | \n", "
---|---|---|---|---|
0 | \n", "18.0 | \n", "8 | \n", "307.0 | \n", "3504 | \n", "
1 | \n", "15.0 | \n", "8 | \n", "350.0 | \n", "3693 | \n", "
2 | \n", "18.0 | \n", "8 | \n", "318.0 | \n", "3436 | \n", "
\n", "mean_squared_error(y, [np.mean(y)]*len(y)) # about 60.76\n", "\n", "
\n", "rtreeviz_univar(dt, X, y,\n", " feature_names='Horsepower',\n", " markersize=5,\n", " mean_linewidth=1,\n", " target_name='MPG',\n", " fontsize=9,\n", " show={'splits'})\n", "\n", "
\n", "
\n", "lefty = y[X['ENG']\n", "=split]\n", "mleft = np.mean(lefty)\n", "mright = np.mean(righty)\n", "\n", "mse_left = mean_squared_error(lefty, [mleft]\\*len(lefty))\n", "mse_right = mean_squared_error(righty, [mright]\\*len(righty))\n", "
\n", " | alcohol | \n", "malic_acid | \n", "ash | \n", "alcalinity_of_ash | \n", "magnesium | \n", "total_phenols | \n", "flavanoids | \n", "nonflavanoid_phenols | \n", "proanthocyanins | \n", "color_intensity | \n", "hue | \n", "od280/od315_of_diluted_wines | \n", "proline | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "14.23 | \n", "1.71 | \n", "2.43 | \n", "15.6 | \n", "127.0 | \n", "2.80 | \n", "3.06 | \n", "0.28 | \n", "2.29 | \n", "5.64 | \n", "1.04 | \n", "3.92 | \n", "1065.0 | \n", "
1 | \n", "13.20 | \n", "1.78 | \n", "2.14 | \n", "11.2 | \n", "100.0 | \n", "2.65 | \n", "2.76 | \n", "0.26 | \n", "1.28 | \n", "4.38 | \n", "1.05 | \n", "3.40 | \n", "1050.0 | \n", "
2 | \n", "13.16 | \n", "2.36 | \n", "2.67 | \n", "18.6 | \n", "101.0 | \n", "2.80 | \n", "3.24 | \n", "0.30 | \n", "2.81 | \n", "5.68 | \n", "1.03 | \n", "3.17 | \n", "1185.0 | \n", "