{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.联合概率分布\n", "前两节我们分别介绍了线性回归的贝叶斯估计以及它的一种近似计算方法:证据近似,贝叶斯估计的图模型结构如下 \n", "![avatar](./source/15_VI_贝叶斯线性回归图模型.png)\n", "为了方便,我们将$\\beta$看做一个常数,回忆一下上一节,我们的似然函数和共轭先验如下: \n", "\n", "$$\n", "p(t\\mid w)=\\prod_{n=1}^N N(t_n\\mid w^T\\phi(x_n),\\beta^{-1})\\\\\n", "p(w\\mid \\alpha)=N(w\\mid 0,\\alpha^{-1}I)\n", "$$ \n", "这里,$X=\\{x_1,x_2,...,x_N\\}$,$I$为单位矩阵,与上一节给$\\alpha$取定值不同,我们这里将$\\alpha$看做来源于一个Gamma分布,即 \n", "$$\n", "p(\\alpha)=Gam(\\alpha\\mid a_0,b_0)\n", "$$ \n", "所以,所有变量上的联合概率分布就有了 \n", "\n", "$$\n", "p(t,w,\\alpha)=p(t\\mid w)p(w\\mid\\alpha)p(\\alpha)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.变分分布\n", "接下来,我们就是要构建一个变分分布$q(w,\\alpha)$去近似后验概率分布$p(w,\\alpha\\mid t)$,利用平均场假设,令: \n", "\n", "$$\n", "q(w,\\alpha)=q(w)q(\\alpha)\n", "$$ \n", "\n", "根据VI第一节最后推导的公式,我们可以直接写出最优解: \n", "\n", "$$\n", "ln\\ q^*(\\alpha)=ln\\ p(\\alpha)+E_w[ln\\ p(w\\mid\\alpha)]+const\\\\\n", "=(a_0-1)ln\\ \\alpha-b_0\\alpha+\\frac{M}{2}ln\\ \\alpha-\\frac{\\alpha}{2}E[w^Tw]+const\n", "$$ \n", "\n", "可以发现这个Gamma分布取对数的形式,通过观察系数项,我们有 \n", "\n", "$$\n", "q^*(\\alpha)=Gam(a\\mid a_N,b_N)\n", "$$ \n", "\n", "其中 \n", "\n", "$$\n", "a_N=a_0+\\frac{M}{2}\\\\\n", "b_N=b_0+\\frac{1}{2}E[w^Tw]\n", "$$ \n", "\n", "同样的方式,我们可以求解$q^*(w)$: \n", "\n", "$$\n", "ln\\ q^*(w)=ln\\ p(t\\mid w)+E_\\alpha[ln\\ p(w\\mid \\alpha)]+const\\\\\n", "=-\\frac{\\beta}{2}\\sum_{n=1}^N\\left \\{w^T\\phi(x_n)-t_n\\right \\}^2-\\frac{1}{2}E[\\alpha]w^Tw+const\\\\\n", "=-\\frac{1}{2}w^T(E[\\alpha]I+\\beta\\Phi^T\\Phi)w+\\beta w^T\\Phi^Tt+const\n", "$$ \n", "\n", "可以发现这是关于$w$的二次型,所以$q^*(w)$是一个高斯分布,通过配方可以得到: \n", "\n", "$$\n", "q^*(w)=N(w\\mid m_N,S_N)\n", "$$ \n", "\n", "其中 \n", "\n", "$$\n", "m_N=\\beta S_N\\Phi^Tt\\\\\n", "S_N=(E[\\alpha]I+\\beta\\Phi^T\\Phi)^{-1}\n", "$$ \n", "\n", "细心的同学已经发现了,把$E[\\alpha]$换成$\\alpha$就是上一节推导出来的结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.迭代求解\n", "\n", "可以发现$q^*(w)$与$q^*(\\alpha)$之间是相互耦合的,$q^*(w)$中需要用到$E[\\alpha]$,而$q^*(\\alpha)$中需要用到$E[w^Tw]$,而根据Gamma分布和高斯分布的性质,我们可以很方便的写出这两个期望的求解公式 \n", "\n", "$$\n", "E[\\alpha]=\\frac{a_N}{b_N}\\\\\n", "E[w^Tw]=m_N^Tm_N+Tr(S_N)\n", "$$ \n", "\n", "那么,我们初始随意为$E[\\alpha]$赋予一个大于0的值,就可以迭代起来了:$E_0[\\alpha]\\rightarrow E_0[w^Tw]\\rightarrow E_1[\\alpha]\\rightarrow E_1[w^Tw]\\rightarrow \\cdots$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四.应用:预测分布\n", "对于新的样本点$\\hat{x}$,我们需要给出预测$\\hat{t}$,其概率预测同样使用上一节的贝叶斯积分,只不过将后验分布替换成变分分布 \n", "\n", "$$\n", "p(\\hat{t}\\mid \\hat{x},t)=\\int p(\\hat{t}\\mid\\hat{x},w)p(w\\mid t)dw\\\\\n", "\\simeq \\int p(\\hat{t}\\mid\\hat{x},w)q(w)dw\\\\\n", "=\\int N(\\hat{t}\\mid w^T\\phi(\\hat{x}),\\beta^{-1})N(w\\mid m_N,S_N)dw\\\\\n", "=N(\\hat{t}\\mid m_N^T\\phi(\\hat{x}),\\sigma^2(\\hat{x}))\n", "$$ \n", "\n", "这里$\\sigma^2(\\hat{x})=\\frac{1}{\\beta}+\\phi(\\hat{x})^TS_N\\phi(\\hat{x})$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ### 五.代码实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "线性回归的vi实现,代码封装到ml_models.vi中\n", "\"\"\"\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "class LinearRegression(object):\n", " def __init__(self, basis_func=None, beta=1e-12, tol=1e-5, epochs=100):\n", " \"\"\"\n", " :param basis_func: list,基函数列表,包括rbf,sigmoid,poly_{num},linear,默认None为linear,其中poly_{num}中的{num}表示多项式的最高阶数\n", " :param beta: 生成t标签的高斯噪声,这里可以设置低一点\n", " :param tol: 两次迭代参数平均绝对值变化小于tol则停止\n", " :param epochs: 默认迭代次数\n", " \"\"\"\n", " if basis_func is None:\n", " self.basis_func = ['linear']\n", " else:\n", " self.basis_func = basis_func\n", " self.beta = beta\n", " self.tol = tol\n", " self.epochs = epochs\n", " # 特征均值、标准差\n", " self.feature_mean = None\n", " self.feature_std = None\n", " # 训练参数,也就是m_N\n", " self.w = None\n", "\n", " def _map_basis(self, X):\n", " \"\"\"\n", " 将X进行基函数映射\n", " :param X:\n", " :return:\n", " \"\"\"\n", " x_list = []\n", " for basis_func in self.basis_func:\n", " if basis_func == \"linear\":\n", " x_list.append(X)\n", " elif basis_func == \"rbf\":\n", " x_list.append(np.exp(-0.5 * X * X))\n", " elif basis_func == \"sigmoid\":\n", " x_list.append(1 / (1 + np.exp(-1 * X)))\n", " elif basis_func.startswith(\"poly\"):\n", " p = int(basis_func.split(\"_\")[1])\n", " for pow in range(1, p + 1):\n", " x_list.append(np.power(X, pow))\n", " return np.concatenate(x_list, axis=1)\n", "\n", " def fit(self, X, y):\n", " self.feature_mean = np.mean(X, axis=0)\n", " self.feature_std = np.std(X, axis=0) + 1e-8\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " n_sample, n_feature = X_.shape\n", "\n", " E_alpha = self.beta # 初始就和beta设置一样,让它自动去调节(这里设置任意大于0的值都是可以的)\n", " current_w = None\n", " for _ in range(0, self.epochs):\n", " S_N = np.linalg.inv(E_alpha * np.eye(n_feature) + self.beta * X_.T @ X_)\n", " self.w = self.beta * S_N @ X_.T @ y.reshape((-1, 1)) # 即m_N\n", " if current_w is not None and np.mean(np.abs(current_w - self.w)) < self.tol:\n", " break\n", " current_w = self.w\n", " E_w = (self.w.T @ self.w)[0][0] + np.trace(S_N)\n", " E_alpha = (n_feature - 1) / E_w # 这里假设a_0,b_0都为0\n", " #ml_models包中不会return,这里用作分析\n", " return E_alpha,self.beta\n", "\n", " def predict(self, X):\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " return (self.w.T @ X_.T).reshape(-1)\n", "\n", " def plot_fit_boundary(self, x, y):\n", " \"\"\"\n", " 绘制拟合结果\n", " :param x:\n", " :param y:\n", " :return:\n", " \"\"\"\n", " plt.scatter(x[:, 0], y)\n", " plt.plot(x[:, 0], self.predict(x), 'r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 六.测试" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#造伪样本\n", "X=np.linspace(0,100,100)\n", "X=np.c_[X,np.ones(100)]\n", "w=np.asarray([3,2])\n", "Y=X.dot(w)\n", "X=X.astype('float')\n", "Y=Y.astype('float')\n", "X[:,0]+=np.random.normal(size=(X[:,0].shape))*3#添加噪声\n", "Y=Y.reshape(100,1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#加噪声\n", "X=np.concatenate([X,np.asanyarray([[100,1],[101,1],[102,1],[103,1],[104,1]])])\n", "Y=np.concatenate([Y,np.asanyarray([[3000],[3300],[3600],[3800],[3900]])])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAH6hJREFUeJzt3X+UVeV97/H3h2HAQauDZjQ6YCFKNf6oYOYq0asxxgiSNJA0aU17E1bqvfSHWTU/FrnYrtb8MI2JbbVZ13qXjUaSm2isMUhSG0LRJE2zVAZBEJEwEZQZVMbCECsEYfjeP/Zz4DCcmTlnOMw5zP681jrrnP2cZ+/9bDbzfPd5nmc/WxGBmZnlz6haF8DMzGrDAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcmp0rQswkDe96U0xadKkWhfDzOyosmLFilcjomWwfHUdACZNmkR7e3uti2FmdlSR9EI5+cpuApLUIGmlpB+k5cmSnpC0QdJ3JI1J6WPTckf6flLRNm5M6eslzajskMzMrJoq6QO4AVhXtPxl4LaImAJsB65L6dcB2yPiTOC2lA9J5wDXAucCM4F/lNRweMU3M7OhKisASJoAvAf4WloWcCXwYMqyEJiTPs9Oy6Tv35Xyzwbuj4jdEbER6AAuqsZBmJlZ5cr9BXA78BlgX1o+CeiJiL1puRNoTZ9bgc0A6fsdKf/+9BLr7CdpnqR2Se3d3d0VHIqZmVVi0AAg6b3A1ohYUZxcImsM8t1A6xxIiLgrItoioq2lZdBObDMzG6JyRgFdCrxP0izgGOB4sl8EzZJGp6v8CcCWlL8TmAh0ShoNnABsK0ovKF7HzMyARSu7uHXJerb07OK05ibmzziLOdMOaSypikF/AUTEjRExISImkXXiPhoRfwg8BnwwZZsLPJw+L07LpO8fjeyxY4uBa9MoocnAFODJqh2JmdlRbtHKLm58aA1dPbsIoKtnFzc+tIZFK7uOyP4O507g/w18SlIHWRv/3Sn9buCklP4pYAFARKwFHgCeBX4IXB8RvYexfzOzEeXWJevZtefganHXnl5uXbL+iOyvohvBIuLHwI/T5+cpMYonIn4NfKif9b8IfLHSQpqZ5cGWnl0VpR8uzwVkZlYnTmtuqij9cNX1VBBmZnlQ6Pjt6tmFOHh4ZFNjA/NnnHVE9usAYGZWQ4WO30Lbf2HMfACtR3gUkAOAmVkNler4LVT+/7HgyiO6b/cBmJnV0HB3/BZzADAzq6Hh7vgt5gBgZlZD82ecRVPjwRMjH8mO32LuAzAzq6FCB+9wTf9QzAHAzKzG5kxrHZYKvy83AZmZ5ZQDgJlZTrkJyMysBoZz2uf+OACYmQ2zvnf/FqZ9BoY1CLgJyMxsmA33tM/9cQAwMxtmtbz7t1g5zwQ+RtKTkp6WtFbS51L6vZI2SlqVXlNTuiR9VVKHpNWSLiza1lxJG9Jrbn/7NDMbyWp592+xcn4B7AaujIgLgKnATEnT03fzI2Jqeq1KadeQPe5xCjAPuBNA0onATcDFZA+SuUnS+OodipnZ0aGWd/8WK+eZwBER/5UWG9MrBlhlNvCNtN7jZA+PPxWYASyNiG0RsR1YCsw8vOKbmR195kxr5UsfOJ/W5iZENvPnlz5wfn2OApLUAKwAzgTuiIgnJP0p8EVJfw0sAxZExG6gFdhctHpnSusv3cwsd2p192+xsjqBI6I3IqYCE4CLJJ0H3AicDfw34ESyh8RD9iyDQzYxQPpBJM2T1C6pvbu7u5zimZnZEFQ0CigiesgeCj8zIl5KzTy7ga9z4AHxncDEotUmAFsGSO+7j7sioi0i2lpaWiopnpmZVaCcUUAtkprT5ybgKuC51K6PJAFzgGfSKouBj6bRQNOBHRHxErAEuFrS+NT5e3VKMzOzGiinD+BUYGHqBxgFPBARP5D0qKQWsqadVcCfpPyPALOADmAn8DGAiNgm6QvA8pTv8xGxrXqHYmZmlVDEQAN6aqutrS3a29trXQwzs6OKpBUR0TZYPt8JbGaWUw4AZmY55dlAzcyGST1MAV3MAcDMbBjUyxTQxdwEZGY2DOplCuhiDgBmZsOgXqaALuYAYGY2DOplCuhiDgBmZsOgXqaALuZOYDOzYVDo6PUoIDOzHKqHKaCLuQnIzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyqpxHQh4j6UlJT0taK+lzKX2ypCckbZD0HUljUvrYtNyRvp9UtK0bU/p6STOO1EGZmdWLRSu7uPSWR5m84F+49JZHWbSyq9ZF2q+cXwC7gSsj4gJgKjAzPev3y8BtETEF2A5cl/JfB2yPiDOB21I+JJ0DXAucC8wE/jE9ZtLMbEQqzADa1bOL4MAMoPUSBAYNAJH5r7TYmF4BXAk8mNIXkj0YHmB2WiZ9/6704PjZwP0RsTsiNpI9M/iiqhyFmVkdqscZQIuV1QcgqUHSKmArsBT4JdATEXtTlk6gcHtbK7AZIH2/AzipOL3EOsX7miepXVJ7d3d35UdkZlYn6nEG0GJlBYCI6I2IqcAEsqv2t5bKlt7Vz3f9pffd110R0RYRbS0tLeUUz8ysLtXjDKDFKhoFFBE9wI+B6UCzpMJcQhOALelzJzARIH1/ArCtOL3EOmZmI049zgBarJxRQC2SmtPnJuAqYB3wGPDBlG0u8HD6vDgtk75/NCIipV+bRglNBqYAT1brQMzM6s2caa186QPn09rchIDW5ia+9IHz62ZCuHJmAz0VWJhG7IwCHoiIH0h6Frhf0s3ASuDulP9u4JuSOsiu/K8FiIi1kh4AngX2AtdHRC9mZiNYvc0AWkzZxXl9amtri/b29loXw8zsqCJpRUS0DZbPdwKbmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5VQ5TwSbKOkxSeskrZV0Q0r/rKQuSavSa1bROjdK6pC0XtKMovSZKa1D0oIjc0hmZlaOcp4Ithf4dEQ8Jek3gBWSlqbvbouIvy3OLOkcsqeAnQucBvybpN9KX98BvJvs+cDLJS2OiGercSBmZlaZQQNARLwEvJQ+vyZpHTDQ881mA/dHxG5gY3o05EXpu46IeB5A0v0prwOAmVkNVNQHIGkSMA14IiV9XNJqSfdIGp/SWoHNRat1prT+0s3MrAbKDgCSjgO+C3wiIn4F3AmcAUwl+4Xwd4WsJVaPAdL77meepHZJ7d3d3eUWz8zMKlRWAJDUSFb5fysiHgKIiFciojci9gH/xIFmnk5gYtHqE4AtA6QfJCLuioi2iGhraWmp9HjMzKxM5YwCEnA3sC4i/r4o/dSibO8HnkmfFwPXShoraTIwBXgSWA5MkTRZ0hiyjuLF1TkMMzOrVDmjgC4FPgKskbQqpf0F8GFJU8macTYBfwwQEWslPUDWubsXuD4iegEkfRxYAjQA90TE2ioei5mZVUARhzTD1422trZob2+vdTHMzI4qklZERNtg+XwnsJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeVUOY+EnCjpMUnrJK2VdENKP1HSUkkb0vv4lC5JX5XUIWm1pAuLtjU35d8gae6ROywzMxtMOb8A9gKfjoi3AtOB6yWdAywAlkXEFGBZWga4huw5wFOAecCdkAUM4CbgYrIHyN9UCBpmZjb8Bg0AEfFSRDyVPr8GrANagdnAwpRtITAnfZ4NfCMyjwPN6QHyM4ClEbEtIrYDS4GZVT0aMzMrW0V9AJImAdOAJ4BTIuIlyIIEcHLK1gpsLlqtM6X1l953H/MktUtq7+7urqR4ZmZWgbIDgKTjgO8Cn4iIXw2UtURaDJB+cELEXRHRFhFtLS0t5RbPzMwqVFYAkNRIVvl/KyIeSsmvpKYd0vvWlN4JTCxafQKwZYB0MzOrgXJGAQm4G1gXEX9f9NVioDCSZy7wcFH6R9NooOnAjtREtAS4WtL41Pl7dUozM7MaGF1GnkuBjwBrJK1KaX8B3AI8IOk64EXgQ+m7R4BZQAewE/gYQERsk/QFYHnK9/mI2FaVozAzs4op4pBm+LrR1tYW7e3ttS6GmdlRRdKKiGgbLJ/vBDYzyykHADOznHIAMDPLKQcAM7OccgAwM8spBwAzs5xyADAzyykHADOznHIAMDPLKQcAM7OccgAwM8spBwAzs5xyADAzyykHADOznHIAMDPLqXKeCHaPpK2SnilK+6ykLkmr0mtW0Xc3SuqQtF7SjKL0mSmtQ9KC6h+KmZlVopxfAPcCM0uk3xYRU9PrEQBJ5wDXAuemdf5RUoOkBuAO4BrgHODDKa+ZmdXIoI+EjIifSppU5vZmA/dHxG5go6QO4KL0XUdEPA8g6f6U99mKS2xmZlVxOH0AH5e0OjURjU9prcDmojydKa2/dDMzq5GhBoA7gTOAqcBLwN+ldJXIGwOkH0LSPEntktq7u7uHWDwzMxvMkAJARLwSEb0RsQ/4Jw4083QCE4uyTgC2DJBeatt3RURbRLS1tLQMpXhmZlaGIQUASacWLb4fKIwQWgxcK2mspMnAFOBJYDkwRdJkSWPIOooXD73YZmZ2uAbtBJZ0H3AF8CZJncBNwBWSppI142wC/hggItZKeoCsc3cvcH1E9KbtfBxYAjQA90TE2qofjZmZlU0RJZvi60JbW1u0t7fXuhhmZkcVSSsiom2wfL4T2MwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHJq0AAg6R5JWyU9U5R2oqSlkjak9/EpXZK+KqlD0mpJFxatMzfl3yBp7pE5HDMzK1c5vwDuBWb2SVsALIuIKcCytAxwDdlzgKcA84A7IQsYZI+SvJjsAfI3FYKGmZnVxqABICJ+CmzrkzwbWJg+LwTmFKV/IzKPA83pAfIzgKURsS0itgNLOTSomJnZMBpqH8ApEfESQHo/OaW3ApuL8nWmtP7SDyFpnqR2Se3d3d1DLJ6ZmQ2m2p3AKpEWA6QfmhhxV0S0RURbS0tLVQtnZmYHDDUAvJKadkjvW1N6JzCxKN8EYMsA6WZmViNDDQCLgcJInrnAw0XpH02jgaYDO1IT0RLgaknjU+fv1SnNzMxqZPRgGSTdB1wBvElSJ9lonluAByRdB7wIfChlfwSYBXQAO4GPAUTENklfAJanfJ+PiL4dy2ZmNowUUbIpvi60tbVFe3t7rYthZnZUkbQiItoGy+c7gc3McsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxy6rACgKRNktZIWiWpPaWdKGmppA3pfXxKl6SvSuqQtFrShdU4ADMzG5pq/AJ4Z0RMLXr4wAJgWURMAZalZYBrgCnpNQ+4swr7NjOzIToSTUCzgYXp80JgTlH6NyLzONBceLC8mZkNv8MNAAH8SNIKSfNS2inpQfCk95NTeiuwuWjdzpRmZmY1MOhD4QdxaURskXQysFTScwPkVYm0Qx5InALJPIDTTz/9MItnZmb9OaxfABGxJb1vBb4HXAS8UmjaSe9bU/ZOYGLR6hOALSW2eVdEtEVEW0tLy+EUz8zMBjDkXwCSjgVGRcRr6fPVwOeBxcBc4Jb0/nBaZTHwcUn3AxcDOwpNRWZmudPbC93d8Mor2evllw/+PGkS3HzzES3C4TQBnQJ8T1JhO9+OiB9KWg48IOk64EXgQyn/I8AsoAPYCXzsMPZtZlZ/Xn8d7rkH/vzPB8538snw6quwb9+h3zU1wZvfDGPGHJkyFhlyAIiI54ELSqT/J/CuEukBXD/U/ZmZ1cTu3fDNb8Ktt8IvflGdbc6enVXyb34znHJK9ip8Pu44UKku0+o73E5gM7OjT28vLF+eXakvX37k9nPzzTBnDpxzzrBV6pVwADCzkWHfPli1CrZuhWOOydrS/+qvYMOG6u5n+nSYPz+7im9oqO62h5kDgJnVrwj413+Fr3wFfvKTI7efqVPhM5+BD34QGhsHzLpoZRe3LlnPlr/8ISc0NSJBz849nNbcxDvPbuGx57rp6tlFg0RvBK3NTcyfcRbtL2zjvic20xtBg8SHL57IzXPOP3LHVAZlTfP1qa2tLdrb22tdDDOrpoisMv/KV7LKvdrOOw9uv51l28XtK7exbs8YfuO4Y4iAHbsOrqi39OzitFRBz5l26H2phcq+uEIXJW5gGsQooER3L/9j+ulHJAhIWlE0PU//+RwAzKxc+69+S1ScP/r+z3n6nn/mt9c+zuWbVtG059dV3//rp03k/7bN4e4z3sH4k8cftP/isp3Q1Mjrb+xlT2959VtTYwNf+sD5BwWBRSu7uPGhNeza01v14yhokPjll2ZVfbvlBgA3AZlZSYtWdnH/XQ/z9tX/zpn/uZnLN61kx/lXccPunfzemn/LMt14IP/V6VWpX5/Uwv952/u595yr+K+x4/an962UF63sYv6DT++v1Hf27GL+g0/vz19cWffs2lNRGXbt6eXWJesPCgC3Lll/RCt/gN4aX4A7AJiNAANdmff1k3sXc+an/4TWbQPfhzmHAzM5FsxtXzxoWTpOnMBPJ1/ITydP48mJ57FzTFO/ecePa+TXe/aVrGj7Vsqf+/7aQ67o9/QGn/v+WsaNGX3YlfWWnl0DLh8JDTUeGeQAYFZjlVTepdbr6tnF6T0v8ydPPsRHVj6SfXlj/+u9YwhlfPGEU/jJW97GvRf+Di+MP5W9DdWpOrbvHPhKvbgS7i/v9p176BlkO+U4rbnpkOWuKgWB/voAPnzxxBKpw8cBwKxIuZVxJZX2opVdfO77aw+qwMaPa+Sm3zkXOLjpoqtnF5/8ziraX9jGzZe3wrJl8KMfZa8XXzxou6Wu0Idir0bx/o/8HWtOnVKFrVVX30p5oHyHU1k3NTYwf8ZZB6XNn3FWRX0Ax45pYNcbvYdU9IVzXY+jgBwALNf6dhy+tnsvvfuyZoauPm3Mhavtvorz9Q0Cfdutj929k0teXM1lG1cy7fan+M3tL1WlEu/r2ZMnM/+aG9jylrey8q+zlvlLb3m0ale01dDU2MDY0aP6ba/vWyk3NzWWzNvc1Fiysu7vqnv8uEbe89unDjoKqLDcdxRQ6wCjiAa6MJgzrbXmFX5fHgVkdaVvhfzG3l527sn+jAtXUpU0jxT+EIv/YPtudzDHjmlgX9DvlWDTG79m+uY1vHvz0/xBzzpYv778Ay7HSSfBjBlw9dVw1VXQ2srkBf9S9lDETbe8B6CidcohwWiJPfsq32p/v4D6ft93VM78f376oP01jhK3fuiCfitfYEjNa0c7DwO1qio1Hrq1Cn9QxdstV3NTI++9oP8ruGoM3xu79w3aOp/lso1P8Y6NT/HW7k1D3lYpO8Yey79Pmsa/T85eW44/+ZA8AjamyruvSq7mCwGgmr8ACiN04NAKtm/aYGPuK21Oy2OFXikHABvUYH9M5VbOzU2NfPZ9/V+ZL1rZxWcXr93/832wq7+hKB4y2F9Fd8yeX/N7q5cy78nvMeFX2WMqvv6232F3QyMtr2+n5fUeWn/VzRnbOodUhp2NYxk3a+aBq/Uzzhi00m1uamTHrj0lr8xbm5v4jwVXllyv3CDX3NTIqpuurmidUhpHwXHHNO6/49UVb31zADgKlVshl3tlNdA+unp2lbyjsVCZt7+wjW89/mLZTQalbqQp7K/vz3aAxgZx3NjRg44CGcyYvXuY/exj/PETD3HmECvu3Q2NdB/bzKvHjufVcSdw1S8PTA72xqjR2VX6pGn8bNJUOk6a2O+kXsWVbUHfPoBijQ3i1g9eUPLfur9/z77bLvx/aB7XyI6dew5q8y5uHum7Tn9t2tX+hWe14QBQR8ppm3zn2S18Z/nmQyqKxlFw64emAoNfLRcqlIGuxMu5AmxsUNl3UBYrdcU6pGaHCM5/uYMrnm/nsk0ruajz2YrLUo4vvPM6Hjz/KnYc0//0u42jxO9fNJHvruga+N++RGVbMNAooFJ3sQ71CtvNI1bgAFAlff+oJp3UxOPPbz/oDr7WEqMABpo7pHGUQJRdyY4SnNDUWNbV8vhxjftHffR1pEeBlGqzLnQ8KvZxztaNXLZxJZdteopLX1h9ZAoxfTpP/O4f8UfbT+P1Cls6BhodMlBzmK+Urd7UbQCQNBP4B6AB+FpE3NJf3qEGgP7anAe6Mv6Lh1bvHxUiwR9enD2QvtxmkKbGBn73ba2DXikOh039dBxWexQIEZz16gtctvEpLt+4kss3razKZv/1ty7hml/8HID2Cefyk0lTWdH6Vp6ceF6/NyCVmjag3FFAlYwuMjsa1OVcQJIagDuAd5M9JH65pMURUbXf+KXanLfv3DPgOO1PPbCK4ibqCPh/jx98081gdu3p3X+TR70q62aZCCZv38LlG5/iso1P0fJ6Dxe8XN351FeeehY/O+NCLvij32P7uVP5yqPPZ30Syv7t4eCO5Tay8/RCBf0fc6a1ukI3G8Rw3wh2EdCRHidJekD8bKBqAeDWJetLjkve0xuHTPZUyD+EYcwlHcnKv7mpkd17S8+Z0jdfSRs3csdrT9K96Ptc+vwKxu3ZXdXy9Zx9Hs1z3puNfrnkEhg7dv93/f0iuzydi9kXTRp0+67QzapvuANAK7C5aLkTuLg4g6R5wDyA008/veIdDDSBU6nvqjnhU6HNfzBD6QP47PuyYZPFoz5O3bSe//XEQ5z5n5s5vedlTtj9erbCZ0tvZ2pZeztgw5vfwtcvmMXO1olc+T8/wPvefma/eZsH2I4rb7P6NNwBoNRQi4NqwYi4C7gLsj6ASncwUDNHqXlFqjXh02B9AIWO4NaiUUB9R4YAnPqrbj7xs2/z+2uWHryB1FNy2NMGnH32gXHql1+ePYC6H1OAvznc/ZlZ3RruANAJFE9/NwHYUs0dzJ9xVr/jzvtO9lTI37cPoKBhlPbPCwNZJX7mycfyfPfOfkcBtf3miYeMAjp39G6+/NoKzvtqqk7TTI1VnQPmz/4sq9Tf+U44/vhqbtnMRqjhDgDLgSmSJgNdwLXAH1RzB4WmhnJHARXSSo0CKlTmJTsat2+HhQvhk5/Mlosq9apV7G9/e7b9K6/M5oMxM6uiWgwDnQXcTjYM9J6I+GJ/eYf9PoBt2+BnP8um3l2yBDo6svRLLoGf/7x6+/nTP4VPfQrO7L9N3cxsqOpyGChARDwCPDJsO3z9dfjxjw/Mqf7cc5Vvo5zK/2Mfyyr1886rfPtmZjUwMp8HsG0bXHEFrFkztPVPOCFrT58xA979bhjCaCQzs3o3MgPA8cfDGWdkV+P33ZeljRuXVeqF1xln1LaMZmY1NjIDwOjR8L3vZZ+//e3alsXMrE6NqnUBzMysNhwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyqq4fCi+pG3jhMDfzJuDVKhSn3uXlOCE/x+rjHFmG8zh/MyJaBstU1wGgGiS1lzMr3tEuL8cJ+TlWH+fIUo/H6SYgM7OccgAwM8upPASAu2pdgGGSl+OE/Byrj3NkqbvjHPF9AGZmVloefgGYmVkJIzoASJopab2kDkkLal2eapE0UdJjktZJWivphpR+oqSlkjak9/G1Lms1SGqQtFLSD9LyZElPpOP8jqQxtS7j4ZLULOlBSc+l8/r2EXw+P5n+3z4j6T5Jx4yEcyrpHklbJT1TlFbyHCrz1VQ3rZZ0YS3KPGIDgKQG4A7gGuAc4MOSzqltqapmL/DpiHgrMB24Ph3bAmBZREwBlqXlkeAGYF3R8peB29Jxbgeuq0mpqusfgB9GxNnABWTHO+LOp6RW4M+Btog4D2gArmVknNN7gZl90vo7h9cAU9JrHnDnMJXxICM2AAAXAR0R8XxEvAHcD8yucZmqIiJeioin0ufXyCqLVrLjW5iyLQTm1KaE1SNpAvAe4GtpWcCVwIMpy1F/nJKOBy4H7gaIiDcioocReD6T0UCTpNHAOOAlRsA5jYifAtv6JPd3DmcD34jM40CzpFOHp6QHjOQA0ApsLlruTGkjiqRJwDTgCeCUiHgJsiABnFy7klXN7cBngH1p+SSgJyL2puWRcF7fAnQDX09NXV+TdCwj8HxGRBfwt8CLZBX/DmAFI++cFvR3DuuifhrJAUAl0kbUkCdJxwHfBT4REb+qdXmqTdJ7ga0RsaI4uUTWo/28jgYuBO6MiGnA64yA5p5SUhv4bGAycBpwLFlzSF9H+zkdTF38Px7JAaATmFi0PAHYUqOyVJ2kRrLK/1sR8VBKfqXwMzK9b61V+arkUuB9kjaRNeFdSfaLoDk1H8DIOK+dQGdEPJGWHyQLCCPtfAJcBWyMiO6I2AM8BFzCyDunBf2dw7qon0ZyAFgOTEmjC8aQdTQtrnGZqiK1g98NrIuIvy/6ajEwN32eCzw83GWrpoi4MSImRMQksvP3aET8IfAY8MGUbSQc58vAZklnpaR3Ac8yws5n8iIwXdK49P+4cKwj6pwW6e8cLgY+mkYDTQd2FJqKhlVEjNgXMAv4BfBL4C9rXZ4qHtd/J/u5uBpYlV6zyNrHlwEb0vuJtS5rFY/5CuAH6fNbgCeBDuCfgbG1Ll8Vjm8q0J7O6SJg/Eg9n8DngOeAZ4BvAmNHwjkF7iPr19hDdoV/XX/nkKwJ6I5UN60hGxU17GX2ncBmZjk1kpuAzMxsAA4AZmY55QBgZpZTDgBmZjnlAGBmllMOAGZmOeUAYGaWUw4AZmY59f8BXALME0CKeFoAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "lr=LinearRegression()\n", "alpha,beta=lr.fit(X[:,:-1],Y)\n", "lr.plot_fit_boundary(X[:,:-1],Y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现成功省去了调节超参`alpha`的烦恼,本节由于将`alpha`也视作随机变量,并假设它来源于一个Gamma分布,最后通过求期望的方式得到一个更优的`alpha`,如下,可以发现最终迭代求解的较优的L2正则化系数为104,当然我们也可以将$\\beta$假设为一个Gamma分布一并放到变分框架中求解,这样可以免去对`beta`的设置" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "104.99882616406987" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alpha/beta" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }