{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.原理介绍\n", "这一节将树模型的预测与概率分布相结合,我们假设树模型的输出服从某一分布,而我们的目标是使得该输出的概率尽可能的高,如下图所示\n", "![avatar](./source/10_集成学习_极大似然估计.png)\n", "\n", "而概率值最高的点通常由分布中的某一个参数(通常是均值)反映,所以我们将树模型的输出打造为分布中的该参数项,然后让树模型的输出去逼近极大似然估计的结果即可,即: \n", "\n", "$$\n", "\\hat{y}\\rightarrow \\mu_{ML}\n", "$$ \n", "\n", "下面分别介绍possion回归,gamma回归,tweedie回归,负二项回归的具体求解" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.泊松回归\n", "泊松分布的表达式如下: \n", "\n", "$$\n", "P(y\\mid\\lambda)=\\frac{\\lambda^y}{y!}e^{-\\lambda}\n", "$$ \n", "\n", "其中,$y$是我们的目标输出,$\\lambda$为模型参数,且$\\lambda$恰为该分布的均值,由于泊松分布要求$y>0$,所以我们对$\\hat{y}$取指数去拟合$\\lambda$,即令: \n", "\n", "$$\n", "\\lambda=e^{\\hat{y}}\n", "$$ \n", "\n", "对于$N$个样本,其似然函数可以表示如下: \n", "\n", "$$\n", "\\prod_{i=1}^N\\frac{e^{y_i\\hat{y_i}}e^{-e^{\\hat{y_i}}}}{y_i!}\n", "$$ \n", "\n", "由于$y_i!$是常数,可以去掉,并对上式取负对数,转换为求极小值的问题: \n", "\n", "$$\n", "L(y,\\hat{y})=\\sum_{i=1}^N(e^{\\hat{y_i}}-y_i\\hat{y_i})\n", "$$ \n", "\n", "所以,一阶导和二阶导分别为: \n", "\n", "$$\n", "\\frac{\\partial L(y,\\hat{y})}{\\partial \\hat{y}}=e^{\\hat{y}}-y\\\\\n", "\\frac{\\partial^2 L(y,\\hat{y})}{{\\partial \\hat{y}}^2}=e^{\\hat{y}}\\\\\n", "$$ " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.gamma回归\n", "gamma分布如下: \n", "\n", "$$\n", "p(y\\mid\\alpha,\\lambda)=\\frac{1}{\\Gamma(\\alpha)\\lambda^\\alpha}y^{\\alpha-1}e^{-y/\\lambda}\n", "$$ \n", "\n", "其中,$y>0$为我们的目标输出,$\\alpha$为形状参数,$\\lambda$为尺度参数,$\\Gamma(\\cdot)$为Gamma函数(后续推导这里会被省略,所以就不列出来了),而Gamma分布的均值为$\\alpha\\lambda$,这里不好直接变换,我们令$\\alpha=1/\\phi,\\lambda=\\phi\\mu$,所以现在Gamma分布的均值可以表示为$\\mu$,此时的Gamma分布为: \n", "\n", "$$\n", "p(y\\mid\\mu,\\phi)=\\frac{1}{y\\Gamma(1/\\phi)}(\\frac{y}{\\mu\\phi})^{1/\\phi}exp[-\\frac{y}{\\mu\\phi}]\n", "$$ \n", "\n", "此时,$\\mu$看做Gamma分布的均值参数,而$\\phi$为它的离散参数,在均值给定的情况下,若离散参数越大,Gamma分布的离散程度越大,接下来对上面的表达式进一步变换: \n", "\n", "$$\n", "p(y\\mid\\mu,\\phi)=exp[\\frac{-y/\\mu-ln\\mu}{\\phi}+\\frac{1-\\phi}{\\phi}lny-\\frac{ln\\phi}{\\phi}-ln\\Gamma(\\frac{1}{\\phi})]\n", "$$ \n", "\n", "同泊松分布一样,我们可以令: \n", "\n", "$$\n", "\\mu=e^{\\hat{y}}\n", "$$ \n", "\n", "又由于$\\mu$与$\\phi$无关,所以做极大似然估计时可以将$\\phi$看做常数,我们将对数似然函数的负数看做损失函数,可以写作如下: \n", "\n", "$$\n", "L(y,\\hat{y})=\\sum_{i=1}^N(\\frac{y_i}{e^{\\hat{y_i}}}+\\hat{y_i})\n", "$$ \n", "\n", "所以,一阶导和二阶导就可以写出来啦: \n", "\n", "$$\n", "\\frac{\\partial L(y,\\hat{y})}{\\partial \\hat{y}}=1-ye^{-\\hat{y}}\\\\\n", "\\frac{\\partial^2 L(y,\\hat{y})}{{\\partial \\hat{y}}^2}=ye^{-\\hat{y}}\\\\\n", "$$ \n", "\n", "注意:上面的两个向量是按元素相乘" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四.tweedie回归\n", "tweedie回归是多个分布的组合体,包括gamma分布,泊松分布,高斯分布等,tweedie回归由一个超参数$p$控制,$p$不同,则其对应的对数似然函数也不同: \n", "\n", "$$\n", "g(y,\\phi)+\\left\\{\\begin{matrix}\n", "\\frac{1}{\\phi}(ylog(\\mu)-\\frac{\\mu^{2-p}}{2-p}) & p=1\\\\ \n", "\\frac{1}{\\phi}(y\\frac{\\mu^{1-p}}{1-p}-log\\mu) & p=2 \\\\ \n", "\\frac{1}{\\phi}(y\\frac{\\mu^{1-p}}{1-p}-\\frac{\\mu^{2-p}}{2-p}) & p\\neq 1,p\\neq 2\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "同样的,我们可以令: \n", "\n", "$$\n", "\\mu=e^{\\hat{y}}\n", "$$ \n", "\n", "由于除开$\\mu$以外的都可以视作常数项,所以损失函数可以简化为: \n", "\n", "$$\n", "L(y,\\hat{y})=\\left\\{\\begin{matrix}\n", "\\sum_{i=1}^n(\\frac{e^{\\hat{y_i}(2-p)}}{2-p}-y_i\\hat{y_i})=\\sum_{i=1}^n(e^{\\hat{y_i}}-y_i\\hat{y_i}) & p=1\\\\ \n", "\\sum_{i=1}^n(\\hat{y_i}+y_ie^{-\\hat{y_i}}) & p=2 \\\\ \n", "\\sum_{i=1}^n(\\frac{exp[\\hat{y_i}(2-p)]}{2-p}-y_i\\frac{exp[\\hat{y_i}(1-p)]}{1-p}) & p\\neq 1,p\\neq 2\n", "\\end{matrix}\\right.\n", "$$ \n", "\n", "所以,一阶导: \n", "\n", "$$\n", "\\frac{\\partial L(y,\\hat{y})}{\\partial \\hat{y}}=\\left\\{\\begin{matrix}\n", "e^{\\hat{y}}-y & p=1\\\\ \n", "1-ye^{-\\hat{y}} & p=2 \\\\ \n", "e^{\\hat{y}(2-p)}-ye^{\\hat{y}(1-p)} & p\\neq 1,p\\neq 2\n", "\\end{matrix}\\right.\n", "$$\n", "\n", "二阶导: \n", "\n", "$$\n", "\\frac{\\partial^2 L(y,\\hat{y})}{{\\partial \\hat{y}}^2}=\\left\\{\\begin{matrix}\n", "e^{\\hat{y}} & p=1\\\\ \n", "ye^{-\\hat{y}} & p=2 \\\\ \n", "(2-p)e^{\\hat{y}(2-p)}-(1-p)ye^{\\hat{y}(1-p)} & p\\neq 1,p\\neq 2\n", "\\end{matrix}\\right.\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 五.代码实现\n", "基于上一节的计算框架,略作调整即可实现...." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('../')\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "from ml_models.ensemble import XGBoostBaseTree\n", "from ml_models import utils\n", "import copy\n", "import numpy as np\n", "\n", "\"\"\"\n", "xgboost回归树的实现,封装到ml_models.ensemble\n", "\"\"\"\n", "\n", "class XGBoostRegressor(object):\n", " def __init__(self, base_estimator=None, n_estimators=10, learning_rate=1.0, loss='squarederror', p=2.5):\n", " \"\"\"\n", " :param base_estimator: 基学习器\n", " :param n_estimators: 基学习器迭代数量\n", " :param learning_rate: 学习率,降低后续基学习器的权重,避免过拟合\n", " :param loss:损失函数,支持squarederror、logistic、poisson,gamma,tweedie\n", " :param p:对tweedie回归生效\n", " \"\"\"\n", " self.base_estimator = base_estimator\n", " self.n_estimators = n_estimators\n", " self.learning_rate = learning_rate\n", " if self.base_estimator is None:\n", " # 默认使用决策树桩\n", " self.base_estimator = XGBoostBaseTree()\n", " # 同质分类器\n", " if type(base_estimator) != list:\n", " estimator = self.base_estimator\n", " self.base_estimator = [copy.deepcopy(estimator) for _ in range(0, self.n_estimators)]\n", " # 异质分类器\n", " else:\n", " self.n_estimators = len(self.base_estimator)\n", " self.loss = loss\n", " self.p = p\n", "\n", " def _get_gradient_hess(self, y, y_pred):\n", " \"\"\"\n", " 获取一阶、二阶导数信息\n", " :param y:真实值\n", " :param y_pred:预测值\n", " :return:\n", " \"\"\"\n", " if self.loss == 'squarederror':\n", " return y_pred - y, np.ones_like(y)\n", " elif self.loss == 'logistic':\n", " return utils.sigmoid(y_pred) - utils.sigmoid(y), utils.sigmoid(y_pred) * (1 - utils.sigmoid(y_pred))\n", " elif self.loss == 'poisson':\n", " return np.exp(y_pred) - y, np.exp(y_pred)\n", " elif self.loss == 'gamma':\n", " return 1.0 - y * np.exp(-1.0 * y_pred), y * np.exp(-1.0 * y_pred)\n", " elif self.loss == 'tweedie':\n", " if self.p == 1:\n", " return np.exp(y_pred) - y, np.exp(y_pred)\n", " elif self.p == 2:\n", " return 1.0 - y * np.exp(-1.0 * y_pred), y * np.exp(-1.0 * y_pred)\n", " else:\n", " return np.exp(y_pred * (2.0 - self.p)) - y * np.exp(y_pred * (1.0 - self.p)), (2.0 - self.p) * np.exp(\n", " y_pred * (2.0 - self.p)) - (1.0 - self.p) * y * np.exp(y_pred * (1.0 - self.p))\n", "\n", " def fit(self, x, y):\n", " y_pred = np.zeros_like(y)\n", " g, h = self._get_gradient_hess(y, y_pred)\n", " for index in range(0, self.n_estimators):\n", " self.base_estimator[index].fit(x, g, h)\n", " y_pred += self.base_estimator[index].predict(x) * self.learning_rate\n", " g, h = self._get_gradient_hess(y, y_pred)\n", "\n", " def predict(self, x):\n", " rst_np = np.sum(\n", " [self.base_estimator[0].predict(x)] +\n", " [self.learning_rate * self.base_estimator[i].predict(x) for i in\n", " range(1, self.n_estimators - 1)] +\n", " [self.base_estimator[self.n_estimators - 1].predict(x)]\n", " , axis=0)\n", " if self.loss in [\"poisson\", \"gamma\", \"tweedie\"]:\n", " return np.exp(rst_np)\n", " else:\n", " return rst_np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对泊松、gamma、tweedie回归做测试" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = np.linspace(1, 10, num=100)\n", "target = np.sin(data) + np.random.random(size=100) + 1 # 添加噪声\n", "data = data.reshape((-1, 1))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='poisson')\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='gamma')\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XuUXWV9//H3N8lM5pIhI7mRTBISMI1y0aQE5NJaGluDoJKFF8D+bP0t+ktRUelq81sBW1FXW2Lpgh/WaxSqVgu0SANqLFZivaCCkUQil2gUhEwCCcJkkswkmUm+vz/OOZMzJ+ey9zl7n33OPp/XWrPCnLPn7IczZ7772d/n+zyPuTsiIpIuE5JugIiIRE/BXUQkhRTcRURSSMFdRCSFFNxFRFJIwV1EJIUU3EVEUkjBXUQkhRTcRURSaFJSJ54+fbovWLAgqdOLiDSln/70py+4+4xKxyUW3BcsWMCmTZuSOr2ISFMys98EOa5iWsbMOszsYTP7mZk9ZmYfKXLMZDO7y8y2m9lDZrYgfJNFRCQqQXLuh4Dl7v5qYAlwkZmdW3DMVcBL7v5y4BbgY9E2U0REwqgY3D1jf/bbtuxX4VKSlwJfzP733cDrzMwia6WIiIQSqFrGzCaa2RZgN/Df7v5QwSF9wLMA7j4K7AWmRdlQEREJLlBwd/cj7r4EmAucY2ZnFBxSrJd+3ELxZrbKzDaZ2aY9e/aEb62IiAQSqs7d3QeA/wEuKnhqBzAPwMwmAVOBF4v8/Dp3X+buy2bMqFjJIyIiVapYCmlmM4ARdx8ws07gjzh+wPQ+4M+AHwFvBTa6tngSaUnrN/dz0/3b2DkwzJzeTlavWMzKpX1JN6vlBKlznw180cwmkunp/7u7f93MPgpscvf7gNuAfzWz7WR67FfE1mIRaVjrN/dz3T1bGR45AkD/wDDX3bMVQAG+zioGd3d/FFha5PEP5f33QeBt0TZNRJJSbe/7pvu3jQX2nOGRI9x0/zYF9zpLbIaqiDSmWnrfOweGQz0u8dHCYSIyTrnedyVzejtDPS7xUXAXkXFq6X2vXrGYzraJ4x7rbJvI6hWLI2mbBKe0jIiMM6e3k/4igTxI7zuXtqm2WkaVNtGxpCoWly1b5loVUqTxFObcIdP7vvGyM2MNtMXOa2RmQ/Yp0I8xs5+6+7JKxzVdzz3IlV1Xf5Hq1dr7rlaxXH+u66mSyvCaKrgHGcVXna1I7VYu7av730ulnL5KKsNpqgHVIKP4tYz0i0hw6zf3c8HajSxc8w0uWLuR9Zv7a3q9IDl9lVQG11TBPcgofqlj+geGI/kAisixO+T+gWGcY3fItfx9Fau0KaSSyuCaKrgHqaEt98uP4gNYTNQ9GJFGF8cd8sqlfdx42Zn0Zf+GC5eaVUllOE0V3IPU0Fa6+kedoomjByPS6OKaibpyaR8PrlnO02sv4ZbLl9DX24mRqZaJu1onbZpqQDXIKH7+McVqdSHavJ3W0pBWVEstfFBJDOqmSVMFdwj2C88dc8HajbF/ALWWhrSi1SsWF62FV9qkcTRdcC+lWG17PT6A9ejBiDSapGrhJbhUzFAtN6MO4v0AJjWbTyQJmiCYvNTOUC2mXN77wTXLY/3wqQcjrUITBJtLKoJ70nlvDfxIK1DxQHNpqlLIUrSGtEj8ku5ESTipCO5aQ1okvLCT79SJai6pCO75M9s04UGksmom31XTidLs7eSkIucOynuLhFFN/jxs8YAGYJOVmuAuIsFVmz8P04nSAGyyFNwToFphSVo9Jt9pADZZqci511steUQtNCaNoB5FCBqATZZ67iHVmkfUrarUQ6W7w6gn3yW1/IeUVnH5ATObB3wJOAk4Cqxz91sLjrkQuBd4KvvQPe7+0XKv26wbZJdajKyvt5MH1yyv+PML13yDYu+4AU+tvaT2BkrLq/eSGEGX/5ja2YYZDAyNKB1Zg6DLDwRJy4wCf+XurwTOBd5rZqcVOe777r4k+1U2sDezWvOIulWVuNV7q8lKd6MPrlnOLZcv4dDoUV4aGlE6sk4qpmXcfRewK/vf+8zsCaAPeDzmtjWkSgNRlW6HdasqUSr2eav3QGaQ8ykdWX+hBlTNbAGwFHioyNPnmdnPzOybZnZ6iZ9fZWabzGzTnj17Qje2EZQbiAoyWKoJVxKVUp+33q62osfHdXcY5G40zguOJkoVF3hA1cymAF8FrnX3wYKnHwFOdvf9ZnYxsB5YVPga7r4OWAeZnHvVrU5QuYGoC9ZuDNQ70YQriUKp3vDkSRPobJtYt7vDIHejcZVeaqJUaYGCu5m1kQnsX3H3ewqfzw/27r7BzD5lZtPd/YXomto4SgVn1fVKPZX6XO0dHuGWy5eEqoSpZe5FkMqbuNKRSveUVjG4m5kBtwFPuPvNJY45CXje3d3MziGT7vltpC1tAtqVSeqp3OctzN1hFL3fSueLa98DdahKC9JzvwB4J7DVzLZkH7semA/g7p8B3gq828xGgWHgCk9qi6cEabBU6imqz1u9er9xpCPVoSotSLXMD8iUYZc75hPAJ6JqVLPSrkxST1F93pq596sOVWmaoRoxDZZKPUXxeRvr/bpz3jNb6Rw5CMC07nb4esw34IsXw6Ljai8CU4eqtFRskC0i1cvl3Jdsf4Q77vxgfU8+dy488wxY2eSA5GmpDbJFpHq5Xu7O678KwHtW3cLbX/s7XLh4Zrwn3rABbrgBnngCTis26V1qoeAuIpkAP3UfnHwyn/rstfU56bRpmeD+wAMK7jFoqeCuddRFytiyBZYurd/5Fi6EhQvZdc83eOuB0/V3GbGWWc9d66iLlLF/P/ziF7BkSV1P+/SSc+n+0Q/Y9eJ+/V1GrGWCe71XyhNpKo8+Cu717bkD/9J+CiccOsAZz/9q7DH9XUaj+dIyR4/CtvC/+M7tv+DUIo/bC8ATs8v/8IwZMH166HPmU0pIGtqW7PzEOvfcN8x4BR8BLvjNz3h09u+MPd4MNfaNrvmC++HDVQ2+fLvck7dV+OHubti0CV7xitDnBS1uJE1g82Y48USYN6+up23vm8OT00/m/Kd/xqfPfdvY45phWrvmC+5tbXDnnaF/7CdPvcgdDz/D4SNHxx5rnziBK8+Zz9kLTyz9g6OjcPXV8Dd/A3ffXU2LtbiRJCrQXeOWLZlee53rzVevWMzD/7mEtz/yTSaPHubQpHbNMI1I8wX3iRPh8stD/9jZQH+RD/nZQYLr9u3w4Q/zf97zSb59woLQaZVmnt4tzS3QXePICGzdCtdcU/f2rVzax4+vXEnHw/fyu/1P8syrX6OUZUQ0QzWAr/9gG+eteA2/mD6fK6/4BzALtSdlrfuuilSr1GcPMp+/1SsWs7LtJTjzTPjyl+FP/qTOLQQGBzMpoTVr4O/+rv7nbzKaoRqhG3/Qzx+dfwUf+fZn+YOnHuG7p5wVKq2ixY0kKeXuDnO9+Hnt2zgL6j6YOuaEE+Dss+ErX4Ghofqcc8IEuOoqeOUr63O+BCi4B7BzYJh/W3IRV/1kPR/ceBunvLjj2JOdWyv+/ErglIMDbHxyN3uHR5ja2cbyV8zkVd/bBt+r8MOvfW3dy9MkPUotiZszPHKEX/zPDziroyOziFdS/vRPMz33z3++Pufbtw+OHIFbbqnP+RLQssE9TGli5g8E/n75Vfzzvf/IDQ987tiTDwQ736uyX2O+HrCh550HP/xhwINFxit211jo5Gd+kUnLTEowHLz73Zmveunry6SDUqwlg3vY0sTcH8j9v3M+Sz5wB5OOHqFz0gQ+9KbTuORVc0qe5xuP7uSfN25n196DzJ7awfuWv7zs8cd597vhxz8O9z8nkid/SdyiPXh3ztjzFFx0ZZ1blrCenkzvPcVaMriHLU0cv2Y0Yz39SyrsSXndt59leGQSdExh7yH4628/y8gJvcErAebOhd27MzMHtSRqS1i/uZ+Pf+1nHHxuN7NO6ODqPziVFWecVNNrrpwOK/9kEff//DnWfvNJDo4e++zPO7iXE4b3lc23p3ICXk9PZsmFFGvJ4F5NaWLYTREiqW2fNQuGh+HAAZgyJfC5pTnl7ii//qlVnPpidm2Vj0X3+iuyX0WddVbZNqVuAt6UKeq5p1E99l2MpLZ9ZnY97eefV3BvATfdv43hw6Oc/NIuHjj1bO5fdB4AL+tq47qLY6zq6O3NVKuUalMDTMCL/O6hpyezSUiKtWRwr0dpYiQXkFmzAPju9x/j+v/4DTsHhpna2YYZDAyNpOcWWYDMhb9j9BCT/Cg/mXs6//7q1wOZDYyvu+qSxNoU5vE4xHL30AI595ZZFTLfyqV93HjZmfT1dmJkJnMEnZAU1OoVi+lsmzjusdAXkGzP/T82bBpbqnhgeISXhka0PGoKzentpOdQJmjun9w17vGklDp3PdsUy4quyrmnV9wbWUeycW+2594z+FLJQ7RGTXqsXrGYz3z+KQD2t2eCZ9KT3RphAl4sdw8t0HNv2eCeL65qgJovIDNmADD9QOngDlqjJi1WLu1j6vlz4FNwoL3r2PIACV64I+mk1CiWMbIpUzLFCqOjydb3xyid/1chFMvn/eVdW7j2ri30Jp3fbm9nsGMK04b2lj1My6Omxx/O6QDgc+9bDhdemGxjsuK+y60klruHnp7Mv/v3ZwaUU6jlg3uxfF5uKbWB4ZGxx5IqAbOTZjHrYOmZdEnftkvEcrMmTzgh2XY0kDB3D4HvwhXcwczmAV8CTgKOAuvc/daCYwy4FbgYGALe5e6PRN/c6IVJaSSR3+6ZN4ezhw7T19upaplWkMsD54KPAMHuHkJV1eTe3xTn3YP03EeBv3L3R8ysB/ipmf23uz+ed8wbgEXZr9cAn87+G6socuWVFlYqVPf89qxZTH/sMS0N3CrUc69aqJr83LyRFAf3iqWQ7r4r1wt3933AE0BhBL0U+JJn/BjoNbMKG5PWJneVzpUIVlsWWKxksZy657dnzsxMYpLW0CLBff3mfi5Yu5GFa77BBWs3RlLOG6qqpgV67qHq3M1sAbAUeKjgqT7g2bzvd3D8BQAzW2Vmm8xs0549e8K1tEBUta/5Ne+QmTBSSiL57Zkz4cUXM7vlSPoNDmaqNzo6km5JbKLqmBUKVZOfn3NPqcDB3cymAF8FrnX3whG+YjHxuC2e3H2duy9z92UzsmV+1Yqy9nXl0j4eXLOcp9dewi2XLxmb3NTb2cbLutpim+gUSLbWnRovhtIk9u3L9NpTvFBcLJOSCDlxsAV67oGqZcysjUxg/4q731PkkB1A/rbpc4GdtTevtLjWh0m67Os4ufVldu+GOSGWC5bmNDiY+sHUuJY0CFpVs35zP7f/xxbuA266exOLTv/Dxvqbj0iQahkDbgOecPebSxx2H3CNmd1JZiB1r7vviq6Zx2uEmXN1keu5796dbDukPgYHU59vj3Phvkqds1xKyA5m7oxG9+5NxyqXRQRJy1wAvBNYbmZbsl8Xm9nVZnZ19pgNwK+B7cDngPfE09xj6rE+TEPIXxlS0q8Fgnsk6y5VKZcSGmrr4ChG1+GDkaSEGlHFnru7/4Dy44y4uwPvjapRQTVcCiUO6rm3lsHBsWUn0irJJQ3GUj9mHGjvYMrh4fGPp0jLz1BteD09MHmyeu6tYt8+OOWUpFsRu1o6ZrXMb8lPCQ21d9KdDe5pXMKjJZf8bSpmmdSMeu5NL1BtdwukZWpRaxllfkpof3snUw4Pp3OsDgX35jBrloJ7kwsclBTcy6q1jDJ/rO5Aeycn+uF0jtWhtEziAt1izpwJzz2XTAMlEoGmxh85ktkvV8G9pCjKKMdSQvfPzbznKQzsoJ57ogL35tRzb3qBglJutqSCe0mR7gyV8g07FNwTFPgWM5dz9+Mm/UqTCBSUcuvKpHwSUy0iLaOcMkXBXeIR+BZz1iw4fBj2lt+0QxpXoKDUIouG1SLS+S0p30dVOfcEBZ6plz+RqcTGAnFtFSjRCFTbreAeSGTzW1KellFwT1DgJRTy15dZfOy5XEDvHxjGOLZSW1K7Rkl5FYNSLtAouNdHTw8MDWUGVScGX/a7WSi4JyjwTL0is1QLd50pzMYnsWuU1Kgg5667segUfS9zG3bs3w9TpybbwBgouCcs0C1mkfVlig3GFkrjlOpUy0vLhNoyTsoq9V6e7KMsBS5d+1886lNSdwHVgGozmD49M1M1r+ceJHCncUp1quUF97jWPG9Fpd7Lu54cAGD/Cy9FumlIo1BwbwaTJsG0aeN67pUCd1qnVKda3ubYca153opKvWe7aQcYW18G0nUBVXBvFgXryxQrrcst3Zna5Y9TpOg6M4OD0NUFkyZFO1mnxZV6z4baM1sZ5gd3SM8FVDn3ZjFrFmzdCp/9LAArgQVHB/jWtucYGB6ht7ON1592EkvmZ0slH/4lPFzD+SZNgre8pWTppVSvVA741b/excLsYGo1m9FoALa4Uu/lhGxVUvfhg+OOT8sFVMG9WZx+OnznO3D11WMPLcl+jVkf8Tn37YNrr434RaVUDnj79p0szAacsGueawC2tFLvZfdvuuGz0H14aOzYNKUzFdybxa23wvXX1+dcR4/C3LkwMFCf87WYUrf9Ew7sh95jNe5hJusEWpisxVS8kzkpk5WeO+kIBqm721FwbxYTJsDs2fU7X0dHZoKHRK7UzORpRw7CCdOrek0NwI4X6E4mW+e++vw5rF59SdnXasZ0lwZUG0igzRzqpbs7s/ysRK7UOjML249UPTtVA7DjBSol7e7O/FtmfZlaNwdJkoJ7g2i4D5GCe2xKLX41dWS46hUhk9x0uhEFupOZMKHiypDNPN9AaZkGEWXONJLbSAX3WBXNp5fZhanS7zTJTacbUeBF+SosHtbM6S4F9wYR1YcosqqJri4F93rbt69ocA/6O41stcQUCFxKWqHnHvgi0YCUlmkQpT4sDqHy76XuAK69a0u4PH53twZU6+nQocya/UWCezOnBpISeN33Cmu6N3O6Sz33BlGsp5ETpvddrqcfqhff3Q179lRotUSmzC5MzZwaSFKgO5kKaZlmTndVDO5mdjvwRmC3u59R5PkLgXuBp7IP3ePuH42yka0g/0NU7DYwaP691G1k2NehuxuefrpiuyUiZTbqaObUQMPr6YH+8nezzZruCpKW+QJwUYVjvu/uS7JfCuxVWrm0jwfXLB9bI6ZQkJ5asdvIal5HA6p1VmajjmZODTS8FO/GVDG4u/v3gBfr0BbJqqVmOT/XGPb1x9GAan2V6blHum+ojDdlSmr3UY0q536emf0M2An8tbs/FtHrtqRqFo3Kl7uNLKyyCPU6GlCNROCy1Ar7pzZraqDhpbjnHkVwfwQ42d33m9nFZJavWlTsQDNbBawCmD9/fgSnTqeoBnFqep3ubhgezqwzM0FFVdUIVZZaZkBVYtTTk7lDTeHnvObg7u6Def+9wcw+ZWbT3f2FIseuA9YBLFu2rHDbT8kTVU+t6tfJTc0eGhpbg0PCCTUxTZtjJyN3Md2/P3Xvfc3B3cxOAp53dzezc8jk8X9bc8skWbngfuCAgnuVQpUwVkjL5GvWhawaUt4m2et/tW/sfZ3a2YYZDAyNNO17HKQU8g7gQmC6me0AbgDaANz9M8BbgXeb2SgwDFzh7uqVN7uursy/GlStWqgSxsHBTFog976XoHXbI5btuX/7oe1c98jQ2Ps6MDwydkizvsdBqmWudPfZ7t7m7nPd/TZ3/0w2sOPun3D309391e5+rrv/MP5mS+zy0zJSlVAljIODmUBjpQphMzRbNWLZ4H7HA48VnUCY04zvsWaoSnH5aRmpSqgB7RLryhTSbNWIZYP70G8HYEr5Io9me48V3KU4BfdIBB7QLrMiZD7NVo1YNuc+r/0IP6pwaK3vcb3HStJV+yPRUXCvStUbrgQM7pqtGrFsz/3ti3vLzuyu9T1OYr8GBfcWVjYQaUA1tJr+gAMGd81WjVg2uC+b1jbufe3tbONlXW2RvcdJjJUoLdOiKlZdaEA1tJo2XBkchHnzAp1Hs1UjlKtz37cv1vc1ibESBfcWVTEQKS1TVrH8aU1/wAEHVCViAfZRjUISYyUK7i2qYiBq9eDuDs89V/Spb27dxc0bnuTw6BGmA4f3w81ffI6XtxkDw6PHHT/7hA7Ytav8+fbuVXBPwoQJmc96zOvL1LpeVDUU3FtUxZ5ER0em5rpVg/sHPwg33lj0qTdkv0Ip/lLjnXhi2FeVKEydCrfeCp/+dGynWAm80Z3RI447fODPb+LiVZfFml5TcG9RFXsSZq297O/27TB7Ntxww3FPXf+fW0v+2NvPmse3Hn+OgeERejvbeP1pJ7Fkfm/l802cCCtX1tJiqdbHPw4PPRTo0F8+v4+Hn36R/QdHmdIxiXMWnMiiWcEWe5vEsYC77j1vggXxjpsouLeoQBNsWnnZ38FBmD8f/uIvjnvquy9tLHrX09fbyT+sWc6SkKcay9//00NNu45JU3vLWzJfFYwVIcwe3yG68bIzgcbbik/BvYWVqg7IBZs7DhmPP/QrDm7uT/yDWndlShOjzJ9qrZjmUaoI4cP3Pcah0aMN9ztUcBfgWEDvHxjGAAeG2jpg6EBDfFDrbnAQ5swp+lSUmybXVD4pdVWqCCF/kbGcRvgdKrjLcb3H3JKew20ddB0+2BAf1Hpav7mfc3fs4XsTZnPr2o1FA3dUNdFaK6Z5VNp8vlDSv0PNUJWivUeA4bbJdI4cApL/oNZL7kLXdfAA+9u7Yp8mXst+uVJfpZZ+eFlXW9Hjk/4dKrhLycA91DaZztFMcE/6g1ovN92/jeHDo0w5NMS+yZklGOKcJq61YppHqaUfbnjT6Q35O1RaRkrebg63ddA5crAhPqj1snNgmK6Rg0zA2d/eNe7xOESZv5f4lUvHNdrvUMFdilZ/GJkB1Z7RQy21MNWc3k5Gns3sEnlgcue4x+OitWKaXyP+DpWWkaK3m7dcvoTL/2AxMyeMNtyHNk6rVyxm+tFMKirXc2+lOxdJD/XcBSjR82jBSUwrl/bRe/5s+Azsn9xFX4PcYkvjafSNyhXcpbTubjh8GEZHYVLpj0qjf8jDunB2BwC3v/918Pu/n3BrpBE1w+QzBXcpLX9lyKlTix7SDB/y0AYHM/8WzFBN20VMqtcMk8+Uc5fSAiz7m8QOM7ErEtyT2CZNGlczTD5TcJfSAmy11wwf8tCKBPdUXsSkakEnn1W9p24EFNyltABb7aVyhmUuuPccW8o1lRcxqVqQyWdJ3+0puEtpAdIyqZxhuW9fZrOS9vaxh1J5EZOqBdmoPOm7vYoDqmZ2O/BGYLe7n1HkeQNuBS4GhoB3ufsjUTdUEhAguKdyhmWR5X6T2CZNGluliUtJ3+0FqZb5AvAJ4Eslnn8DsCj79Rrg09l/pdkF3Ee1EWfn1aRIcE/lRUxilcSm2PkqBnd3/56ZLShzyKXAl9zdgR+bWa+ZzXb3CjsCS8MrM6Ca6rLAEht1pO4iJrFK+m4vijr3PuDZvO93ZB9TcG92JQZUU1nbnq/MLkzFpPpCJ1VL+m4viuBuRR7zIo9hZquAVQDz58+P4NQSqxJpmWaYwFGT3P6pAaT+Qic1SfJuL4rgvgOYl/f9XGBnsQPdfR2wDmDZsmVFLwASn9A9zBLBvdaBonr0dGs6R4iee+ovdNK0ogju9wHXmNmdZAZS9yrf3niq6mG2t2fWlCkI7rUMFEXd0y0WxIHazhEiuCddESFSSsU6dzO7A/gRsNjMdpjZVWZ2tZldnT1kA/BrYDvwOeA9sbVWqlZ1zW1X13HBvZba9ihrf0tNEvnI1x6r7Rwhgrvq36VRBamWubLC8w68N7IWSSyq7mEWWfa3loGiKHu6pS4UxfaDDXyOQ4cyK2EGDO5JV0SIlKJVIVtE1amU7u6ipZDVDhRFWfsb9oIQ6BwlVoQsJemKCJFSFNxbRNU9zBLBve7tKKLUhaK3s41Do0erO0fI4A6qf5fGpLVlWkSQtTCKiji4V92OIkrl/j/85tOrP0cVwV2kEann3kKq6mF2dcH+/cm3I09+hczUzjY62iYwMDRyXEqkqnMouEtKKLhLed3dsHt30q0YU1hKOTA8QmfbRG65fEk0qREFd0kJBXcpLy8t0wjT7GOfNFRkLXeRZqTgLuVlg3ujTLOPfdKQeu6SEhpQlfKywT3pjQdyYp80VBDck9wmTaQWCu5SXnaG6s6Xim+1V67HHEdgjH3np8FBmDgROjsT3yZNpBYK7lJedzccPcrJPcUzeKV6zHEFxihLKYvKLT1g1jB3KyLVUM5dysuuDLn69+by1w/sCDwxKM6Bz1gnDeWtK6NFwaSZqecu5WWD+yWnTg3VY27awJgX3LUomDQz9dylvLw13VcufUXgHnPS+0dWLS+4a1EwaWbquUt5ZfZRLSf2gc+45AX32PP7IjFSz13KK7GPaiVNu1ri4CCceurYt1oUTMJqhMl+oOAulZTYai+IpgyMITfHFsnXKJP9QMFdKqkhuDeKUD0pBXepQSPtqavgLuU1eXAP1ZMaHc2knxTcpUqNVCWmAVUpr8oB1UYRaiLSvn2ZfxXcpUqNVD6r4C7lVTmg2ihK9Zj6B4aPXxJBi4ZJjRqpSkxpGSkvF9z/7d/g8cerfpnf/PYAj/bvZejQEbomT+RVfVM5eVp38YOnTYO//3toa6v6fDml6u2hSIpGwV1q1EhVYgruUt7EifCGN8DWrfCtb1X1EkMjR2gfGuGsvMdsGwx1tdFV0MthdBSeew5+7/fgzW+uvt1ZxSYi5Rs32KXgLhFolCoxBXepbMOGmn78j9duLNp77uvt5ME1y8c/ODICs2fDHXdEEtzze1KlevBjqRsFd0kR5dwldqEqCNra4O1vh3vvjWzv1pVL+3hwzXL6Kg12KbhLiii4S+xCVxC84x0wPJwJ8BGqONil4C4pEii4m9lFZrbNzLab2Zoiz7/LzPaY2Zbs159H31RpVqErCM4/H+bNywziRqjiWjEK7pIiFXPuZjYR+CTwx8AO4Cdmdp+7F5ZO3OXu18TQRmlyoSsIJkyAK6+Em2+GF16A6dMjbUvZ2akAU6ZEdj6RpAQZUD0H2O7uvwYwszuBS4H6ugTlAAAHJElEQVTq6+Kk5YSuIHjHO+Af/xHuvhuuvjq+huUbHISenszFRaTJBfkU9wHP5n2/I/tYobeY2aNmdreZzYukddK6XvUqeOUrI0/NlJUL7iIpEKTnbkUe84Lvvwbc4e6HzOxq4IvA8sIfMrNVwCqA+fPnh2yqJCGx5UvNMr33v/1bWLQo/vNBpr5+7tz6nEskZkGC+w4gvyc+F9iZf4C7/zbv288BHyv2Qu6+DlgHsGzZssILhDSYxJcvXbUKnnoKDh6M/1w5K1bU71wiMQoS3H8CLDKzhUA/cAXwjvwDzGy2u+/Kfvtm4IlIWymJSHz50pkz4bbb4j+PSApVDO7uPmpm1wD3AxOB2939MTP7KLDJ3e8D3m9mbwZGgReBd8XYZqmTRlq+VETCCbT8gLtvADYUPPahvP++Drgu2qZJ0pp2k2sR0QxVKS3J5UvXb+7ngrUbWbjmG8cvzSsiFWnhMCkpqeVLEx/IFUkBBXcpK4nlSxMfyBVJAaVlpOFoIFekduq5S8PITZgqNQFCA7kiwSm4S0MozLMXinsgN7GZuCIxUXCXhlAsz57TF3Ow1QCupJFy7tIQSuXTDXhwzfJYg2y5AVyRZqXgLg0h9G5NEdIArqSRgrs0hCQnTCV5YRGJi3Lu0hCSmDCVG0TtHxjGGL+Odb0uLCJxUXCXhlHPCVOFg6gOYwE+7gFckXpQcJeWVGwQNRfYH1xz3D4zIk1HwV0i1Sz14hpElbRTcJfIJF0vHubCouWMJe1ULSORSbJePHdh6R8Yxjl2YSm1VHCS1Tki9aDgLpFJMtUR9sKycmkfN152Jn29nRiZXPuNl53ZkCkkkWooLSORSTLVUc2FJYnljEXqRT13iYwmIok0DgV3iUySqQ7l0EXGU1pGIpVUqiOpLQFFGpWCuyQqyrp45dBFjlFwl8QkXRcvkmbKuUtitI66SHzUc5fEhClfbJZlDUQaRaCeu5ldZGbbzGy7ma0p8vxkM7sr+/xDZrYg6oZK+gQtXww7+1REAgR3M5sIfBJ4A3AacKWZnVZw2FXAS+7+cuAW4GNRN1TSJ2j5otI3IuEF6bmfA2x391+7+2HgTuDSgmMuBb6Y/e+7gdeZmUXXTEmjoHXxWsFRJLwgOfc+4Nm873cAryl1jLuPmtleYBrwQhSNlPQKUr6oFRxFwgvScy/WA/cqjsHMVpnZJjPbtGfPniDtE9HsU5EqBAnuO4B5ed/PBXaWOsbMJgFTgRcLX8jd17n7MndfNmPGjOpaLC1HKziKhBckLfMTYJGZLQT6gSuAdxQccx/wZ8CPgLcCG939uJ67SLU0+1QknIrBPZtDvwa4H5gI3O7uj5nZR4FN7n4fcBvwr2a2nUyP/Yo4Gy0iIuUFmsTk7huADQWPfSjvvw8Cb4u2aSIiUi0tPyAikkIK7iIiKaTgLiKSQgruIiIpZElVLJrZHuA3iZw8OtPRLNx8ej+O0Xsxnt6PY2p9L05294oThRIL7mlgZpvcfVnS7WgUej+O0Xsxnt6PY+r1XigtIyKSQgruIiIppOBem3VJN6DB6P04Ru/FeHo/jqnLe6Gcu4hICqnnLiKSQgruVTCzeWb2HTN7wsweM7MPJN2mpJnZRDPbbGZfT7otSTOzXjO728yezH5Gzku6TUkxs7/M/o383MzuMLOOpNtUT2Z2u5ntNrOf5z12opn9t5n9Mvvvy+I4t4J7dUaBv3L3VwLnAu8tsq9sq/kA8ETSjWgQtwL/5e6vAF5Ni74vZtYHvB9Y5u5nkFlVttVWjP0CcFHBY2uAB9x9EfBA9vvIKbhXwd13ufsj2f/eR+aPt2UXGzezucAlwOeTbkvSzOwE4LVklsHG3Q+7+0CyrUrUJKAzu4lPF8dv9JNq7v49jt+4KH/P6S8CK+M4t4J7jcxsAbAUeCjZliTq/wH/FziadEMawCnAHuBfsmmqz5tZd9KNSoK79wP/BDwD7AL2uvu3km1VQ5jl7rsg01EEZsZxEgX3GpjZFOCrwLXuPph0e5JgZm8Edrv7T5NuS4OYBPwu8Gl3XwocIKbb7kaXzSVfCiwE5gDdZva/km1V61Bwr5KZtZEJ7F9x93uSbk+CLgDebGZPA3cCy83sy8k2KVE7gB3unruTu5tMsG9FfwQ85e573H0EuAc4P+E2NYLnzWw2QPbf3XGcRMG9CmZmZHKqT7j7zUm3J0nufp27z3X3BWQGyza6e8v2ztz9OeBZM1ucfeh1wOMJNilJzwDnmllX9m/mdbTo4HKB3J7TZP+9N46TBNpmT45zAfBOYKuZbck+dn12O0KR9wFfMbN24NfA/064PYlw94fM7G7gETIVZptpsZmqZnYHcCEw3cx2ADcAa4F/N7OryFwAY9miVDNURURSSGkZEZEUUnAXEUkhBXcRkRRScBcRSSEFdxGRFFJwFxFJIQV3EZEUUnAXEUmh/w+5sqxAUTktwQAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='tweedie',p=2.5)\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='tweedie',p=1.5)\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面的拟合结果,看不出明显区别....,接下来对tweedie分布中`p`取极端值做一个简单探索...,可以发现取值过大或者过小都有可能陷入欠拟合" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='tweedie',p=0.1)\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = XGBoostRegressor(loss='tweedie',p=20)\n", "model.fit(data, target)\n", "plt.scatter(data, target)\n", "plt.plot(data, model.predict(data), color='r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "参考内容: \n", "http://www.doc88.com/p-9029670237688.html \n", "http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/glm.html#families" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }