{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_boston\n",
    "from sklearn.tree import DecisionTreeRegressor as skDecisionTreeRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 1\n",
    "- similat to scikit-learn criterion=\"mse\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TreeNode():\n",
    "    def __init__(self):\n",
    "        self.left_child = -1\n",
    "        self.right_child = -1\n",
    "        self.feature = None\n",
    "        self.threshold = None\n",
    "        self.impurity = None\n",
    "        self.n_node = None\n",
    "        self.value = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecisionTreeRegressor():\n",
    "    def __init__(self, max_depth=2):\n",
    "        self.max_depth = max_depth\n",
    "\n",
    "    def _build_tree(self, X, y, cur_depth, parent, is_left):\n",
    "        if cur_depth == self.max_depth:\n",
    "            cur_node = TreeNode()\n",
    "            cur_node.impurity = np.mean(np.square(y)) - np.square(np.mean(y))\n",
    "            cur_node.n_node = X.shape[0]\n",
    "            cur_node.value = np.mean(y)\n",
    "            cur_id = len(self.tree_)\n",
    "            self.tree_.append(cur_node)\n",
    "            if parent is not None:\n",
    "                if is_left:\n",
    "                    self.tree_[parent].left_child = cur_id\n",
    "                else:\n",
    "                    self.tree_[parent].right_child = cur_id\n",
    "            return\n",
    "        best_improvement = -np.inf\n",
    "        best_feature = None\n",
    "        best_threshold = None\n",
    "        best_left_ind = None\n",
    "        best_right_ind = None\n",
    "        sum_total = np.sum(y)\n",
    "        for i in range(X.shape[1]):\n",
    "            sum_left = 0\n",
    "            sum_right = sum_total\n",
    "            n_left = 0\n",
    "            n_right = X.shape[0]\n",
    "            ind = np.argsort(X[:, i])\n",
    "            for j in range(ind.shape[0] - 1):\n",
    "                sum_left += y[ind[j]]\n",
    "                sum_right -= y[ind[j]]\n",
    "                n_left += 1\n",
    "                n_right -= 1\n",
    "                if j + 1 < ind.shape[0] - 1 and np.isclose(X[ind[j], i], X[ind[j + 1], i]):\n",
    "                    continue\n",
    "                cur_improvement = sum_left * sum_left / n_left + sum_right * sum_right / n_right\n",
    "                if cur_improvement > best_improvement:\n",
    "                    best_improvement = cur_improvement\n",
    "                    best_feature = i\n",
    "                    best_threshold = X[ind[j], i]\n",
    "                    best_left_ind = ind[:j + 1]\n",
    "                    best_right_ind = ind[j + 1:]\n",
    "        cur_node = TreeNode()\n",
    "        cur_node.feature = best_feature\n",
    "        cur_node.threshold = best_threshold\n",
    "        cur_node.impurity = np.mean(np.square(y)) - np.square(np.mean(y))\n",
    "        cur_node.n_node = X.shape[0]\n",
    "        cur_node.value = np.mean(y)\n",
    "        cur_id = len(self.tree_)\n",
    "        self.tree_.append(cur_node)\n",
    "        if parent is not None:\n",
    "            if is_left:\n",
    "                self.tree_[parent].left_child = cur_id\n",
    "            else:\n",
    "                self.tree_[parent].right_child = cur_id\n",
    "        if cur_depth < self.max_depth:\n",
    "            self._build_tree(X[best_left_ind], y[best_left_ind], cur_depth + 1, cur_id, True)\n",
    "            self._build_tree(X[best_right_ind], y[best_right_ind], cur_depth + 1, cur_id, False)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.n_features = X.shape[1]\n",
    "        self.tree_ = []\n",
    "        self._build_tree(X, y, 0, None, None)\n",
    "        return self\n",
    "\n",
    "    def apply(self, X):\n",
    "        pred = np.zeros(X.shape[0], dtype=np.int)\n",
    "        for i in range(X.shape[0]):\n",
    "            cur_node = 0\n",
    "            while self.tree_[cur_node].left_child != -1:\n",
    "                if X[i][self.tree_[cur_node].feature] <= self.tree_[cur_node].threshold:\n",
    "                    cur_node = self.tree_[cur_node].left_child\n",
    "                else:\n",
    "                    cur_node = self.tree_[cur_node].right_child\n",
    "            pred[i] = cur_node\n",
    "        return pred\n",
    "\n",
    "    def predict(self, X):\n",
    "        pred = self.apply(X)\n",
    "        return np.array([self.tree_[p].value for p in pred])\n",
    "\n",
    "    @property\n",
    "    def feature_importances_(self):\n",
    "        importances = np.zeros(self.n_features)\n",
    "        for node in self.tree_:\n",
    "            if node.left_child != -1:\n",
    "                left_child = self.tree_[node.left_child]\n",
    "                right_child = self.tree_[node.right_child]\n",
    "                importances[node.feature] += (node.n_node * node.impurity\n",
    "                                              - left_child.n_node * left_child.impurity\n",
    "                                              - right_child.n_node * right_child.impurity)\n",
    "        # feature importance used in GBDT\n",
    "        # importances /=  self.tree_[0].n_node\n",
    "        return importances / np.sum(importances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_boston(return_X_y=True)\n",
    "clf1 = DecisionTreeRegressor(max_depth=2).fit(X, y)\n",
    "clf2 = skDecisionTreeRegressor(max_depth=2, random_state=0).fit(X, y)\n",
    "assert np.array_equal([node.left_child for node in clf1.tree_],\n",
    "                      clf2.tree_.children_left)\n",
    "assert np.array_equal([node.right_child for node in clf1.tree_],\n",
    "                      clf2.tree_.children_right)\n",
    "clf1_leaf = [node.left_child != -1 for node in clf1.tree_]\n",
    "clf2_leaf = clf2.tree_.children_left != -1\n",
    "assert np.array_equal(clf1_leaf, clf2_leaf)\n",
    "assert np.array_equal(np.array([node.feature for node in clf1.tree_])[clf1_leaf],\n",
    "                      clf2.tree_.feature[clf2_leaf])\n",
    "# threshold is slightly different\n",
    "# because scikit-learn users the average value between current point and the next point\n",
    "assert np.allclose([node.impurity for node in clf1.tree_],\n",
    "                    clf2.tree_.impurity)\n",
    "assert np.array_equal([node.n_node for node in clf1.tree_],\n",
    "                      clf2.tree_.n_node_samples)\n",
    "assert np.allclose([node.value for node in clf1.tree_],\n",
    "                   clf2.tree_.value.ravel())\n",
    "assert np.allclose(clf1.feature_importances_, clf2.feature_importances_)\n",
    "pred1 = clf1.apply(X)\n",
    "pred2 = clf2.apply(X)\n",
    "assert np.array_equal(pred1, pred2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.allclose(pred1, pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 2\n",
    "- similat to scikit-learn criterion=\"friedman_mse\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecisionTreeRegressor():\n",
    "    def __init__(self, max_depth=2):\n",
    "        self.max_depth = max_depth\n",
    "\n",
    "    def _build_tree(self, X, y, cur_depth, parent, is_left):\n",
    "        if cur_depth == self.max_depth:\n",
    "            cur_node = TreeNode()\n",
    "            cur_node.impurity = np.mean(np.square(y)) - np.square(np.mean(y))\n",
    "            cur_node.n_node = X.shape[0]\n",
    "            cur_node.value = np.mean(y)\n",
    "            cur_id = len(self.tree_)\n",
    "            self.tree_.append(cur_node)\n",
    "            if parent is not None:\n",
    "                if is_left:\n",
    "                    self.tree_[parent].left_child = cur_id\n",
    "                else:\n",
    "                    self.tree_[parent].right_child = cur_id\n",
    "            return\n",
    "        best_improvement = -np.inf\n",
    "        best_feature = None\n",
    "        best_threshold = None\n",
    "        best_left_ind = None\n",
    "        best_right_ind = None\n",
    "        sum_total = np.sum(y)\n",
    "        for i in range(X.shape[1]):\n",
    "            sum_left = 0\n",
    "            sum_right = sum_total\n",
    "            n_left = 0\n",
    "            n_right = X.shape[0]\n",
    "            ind = np.argsort(X[:, i])\n",
    "            for j in range(ind.shape[0] - 1):\n",
    "                sum_left += y[ind[j]]\n",
    "                sum_right -= y[ind[j]]\n",
    "                n_left += 1\n",
    "                n_right -= 1\n",
    "                if j + 1 < ind.shape[0] - 1 and np.isclose(X[ind[j], i], X[ind[j + 1], i]):\n",
    "                    continue\n",
    "                cur_improvement = n_left * n_right * np.square(sum_left / n_left - sum_right / n_right)\n",
    "                if cur_improvement > best_improvement:\n",
    "                    best_improvement = cur_improvement\n",
    "                    best_feature = i\n",
    "                    best_threshold = X[ind[j], i]\n",
    "                    best_left_ind = ind[:j + 1]\n",
    "                    best_right_ind = ind[j + 1:]\n",
    "        cur_node = TreeNode()\n",
    "        cur_node.feature = best_feature\n",
    "        cur_node.threshold = best_threshold\n",
    "        cur_node.impurity = np.mean(np.square(y)) - np.square(np.mean(y))\n",
    "        cur_node.n_node = X.shape[0]\n",
    "        cur_node.value = np.mean(y)\n",
    "        cur_id = len(self.tree_)\n",
    "        self.tree_.append(cur_node)\n",
    "        if parent is not None:\n",
    "            if is_left:\n",
    "                self.tree_[parent].left_child = cur_id\n",
    "            else:\n",
    "                self.tree_[parent].right_child = cur_id\n",
    "        if cur_depth < self.max_depth:\n",
    "            self._build_tree(X[best_left_ind], y[best_left_ind], cur_depth + 1, cur_id, True)\n",
    "            self._build_tree(X[best_right_ind], y[best_right_ind], cur_depth + 1, cur_id, False)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.n_features = X.shape[1]\n",
    "        self.tree_ = []\n",
    "        self._build_tree(X, y, 0, None, None)\n",
    "        return self\n",
    "\n",
    "    def apply(self, X):\n",
    "        pred = np.zeros(X.shape[0], dtype=np.int)\n",
    "        for i in range(X.shape[0]):\n",
    "            cur_node = 0\n",
    "            while self.tree_[cur_node].left_child != -1:\n",
    "                if X[i][self.tree_[cur_node].feature] <= self.tree_[cur_node].threshold:\n",
    "                    cur_node = self.tree_[cur_node].left_child\n",
    "                else:\n",
    "                    cur_node = self.tree_[cur_node].right_child\n",
    "            pred[i] = cur_node\n",
    "        return pred\n",
    "\n",
    "    def predict(self, X):\n",
    "        pred = self.apply(X)\n",
    "        return np.array([self.tree_[p].value for p in pred])\n",
    "\n",
    "    @property\n",
    "    def feature_importances_(self):\n",
    "        importances = np.zeros(self.n_features)\n",
    "        for node in self.tree_:\n",
    "            if node.left_child != -1:\n",
    "                left_child = self.tree_[node.left_child]\n",
    "                right_child = self.tree_[node.right_child]\n",
    "                importances[node.feature] += (node.n_node * node.impurity\n",
    "                                              - left_child.n_node * left_child.impurity\n",
    "                                              - right_child.n_node * right_child.impurity)\n",
    "        # feature importance used in GBDT\n",
    "        # importances /=  self.tree_[0].n_node\n",
    "        return importances / np.sum(importances)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = load_boston(return_X_y=True)\n",
    "clf1 = DecisionTreeRegressor(max_depth=2).fit(X, y)\n",
    "clf2 = skDecisionTreeRegressor(max_depth=2, criterion=\"friedman_mse\", random_state=0).fit(X, y)\n",
    "assert np.array_equal([node.left_child for node in clf1.tree_],\n",
    "                      clf2.tree_.children_left)\n",
    "assert np.array_equal([node.right_child for node in clf1.tree_],\n",
    "                      clf2.tree_.children_right)\n",
    "clf1_leaf = [node.left_child != -1 for node in clf1.tree_]\n",
    "clf2_leaf = clf2.tree_.children_left != -1\n",
    "assert np.array_equal(clf1_leaf, clf2_leaf)\n",
    "assert np.array_equal(np.array([node.feature for node in clf1.tree_])[clf1_leaf],\n",
    "                      clf2.tree_.feature[clf2_leaf])\n",
    "# threshold is slightly different\n",
    "# because scikit-learn users the average value between current point and the next point\n",
    "assert np.allclose([node.impurity for node in clf1.tree_],\n",
    "                    clf2.tree_.impurity)\n",
    "assert np.array_equal([node.n_node for node in clf1.tree_],\n",
    "                      clf2.tree_.n_node_samples)\n",
    "assert np.allclose([node.value for node in clf1.tree_],\n",
    "                   clf2.tree_.value.ravel())\n",
    "assert np.allclose(clf1.feature_importances_, clf2.feature_importances_)\n",
    "pred1 = clf1.apply(X)\n",
    "pred2 = clf2.apply(X)\n",
    "assert np.array_equal(pred1, pred2)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.allclose(pred1, pred2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}