{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.证据近似的理论推导\n", "上一节,由于$\\alpha,\\beta$需要人工设置,使得我们“定制版”的贝叶斯线性回归的表现大打折扣,所以理想情况下,我们需要将超参数$\\alpha,\\beta$的后验分布也考虑进来,那么此时的预测分布如下: \n", "\n", "$$\n", "p(\\hat{t}\\mid t)=\\int\\int\\int p(\\hat{t}\\mid w,\\beta)p(w\\mid t,\\alpha,\\beta)p(\\alpha,\\beta\\mid t)\\ dw\\ d\\alpha\\ d\\beta\n", "$$ \n", "\n", "这里,$t=\\{t_1,t_2,...,t_n \\}$表示训练数据的标签,$\\hat{t}$表示需要预测的标签,为了表达的简洁,省略了训练数据$X=\\{x_1,x_2,...,x_n\\}$以及预测输入$\\hat{x}$,另外积分项中的几个分布的意义与上一节一样,关于超参后验分布$p(\\alpha,\\beta\\mid t)$,由[《14_03_概率分布:高斯分布(正态分布)及其共轭先验》](https://nbviewer.jupyter.org/github/zhulei227/ML_Notes/blob/master/notebooks/14_03_%E6%A6%82%E7%8E%87%E5%88%86%E5%B8%83%EF%BC%9A%E9%AB%98%E6%96%AF%E5%88%86%E5%B8%83%EF%BC%88%E6%AD%A3%E6%80%81%E5%88%86%E5%B8%83%EF%BC%89%E5%8F%8A%E5%85%B6%E5%85%B1%E8%BD%AD%E5%85%88%E9%AA%8C.ipynb)中关于后验分布的推导,我们可以有这样一个一般的结论:如果训练数据越多,那么后验分布就会越尖,所以我们不妨假设后验分布$p(\\alpha,\\beta\\mid t)$在$\\hat{\\alpha},\\hat{\\beta}$处有尖峰,那么,预测分布的求解我们就可以近似为: \n", "\n", "$$\n", "p(\\hat{t}\\mid t)\\simeq p(\\hat{t}\\mid t,\\hat{\\alpha},\\hat{\\beta})=\\int p(\\hat{t}\\mid w,\\hat{\\beta})p(w\\mid t,\\hat{\\alpha},\\hat{\\beta})dw\\int\\int p(\\alpha,\\beta\\mid t)\\ d\\alpha\\ d\\beta=\\int p(\\hat{t}\\mid w,\\hat{\\beta})p(w\\mid t,\\hat{\\alpha},\\hat{\\beta})dw\n", "$$ \n", "\n", "所以,我们现在的目标转为求解极大后验估计$p(\\alpha,\\beta\\mid t)$,根据贝叶斯定理,我们知道: \n", "\n", "$$\n", "p(\\alpha,\\beta \\mid t)\\propto p(t\\mid \\alpha,\\beta)p(\\alpha,\\beta)\n", "$$ \n", "\n", "我们不妨假设先验分布$p(\\alpha,\\beta)$不能提供过多的背景信息,即它的分布很平,那么这时的$\\hat{\\alpha},\\hat{\\beta}$可以通过边缘似然函数$p(t\\mid\\alpha,\\beta)$的极大似然估计获得,而$ln\\ p(t\\mid\\alpha,\\beta)$被称为证据函数,接下来我们推导其解析形式" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.求证据函数\n", "\n", "边缘似然函数$p(t\\mid \\alpha,\\beta)$可以通过对参数$w$求积分得到: \n", "\n", "$$\n", "p(t\\mid\\alpha,\\beta)=\\int\\ p(t\\mid w,\\beta)p(w\\mid\\alpha)dw\\\\\n", "=\\left(\\frac{\\beta}{2\\pi}\\right)^{\\frac{N}{2}}\\left(\\frac{\\alpha}{2\\pi}\\right)^{\\frac{M}{2}}\\int exp\\{-E(w)\\}dw\n", "$$ \n", "\n", "这里,$M$是$w$的维度,且: \n", "\n", "$$\n", "E(w)=\\beta E_D(w)+\\alpha E_W(w)\\\\\n", "=\\frac{\\beta}{2}\\left\\|t-\\Phi w\\right\\|^2+\\frac{\\alpha}{2}w^Tw\n", "$$ \n", "\n", "接下来,我们对$w$配平方项: \n", "\n", "$$\n", "E(w)=E(m_N)+\\frac{1}{2}(w-m_N)^TA(w-m_N)\n", "$$ \n", "\n", "其中: \n", "\n", "$$\n", "A=\\alpha I+\\beta\\Phi^T\\Phi\\\\\n", "m_N=\\beta A^{-1}\\Phi^Tt\\\\\n", "E(m_N)=\\frac{\\beta}{2}\\left\\|t-\\Phi m_N\\right\\|^2+\\frac{\\beta}{2}m_N^Tm_N\n", "$$ \n", "\n", "通过对$w$的配方,接下来可以方便的求解出积分项: \n", "\n", "$$\n", "\\int exp\\{-E(w)\\}dw\\\\\n", "=exp\\{-E(m_N)\\}\\int exp\\left\\{-\\frac{1}{2}(w-m_N)^TA(w-m_N)\\right\\}dw\\\\\n", "=exp\\left\\{-E(m_N)\\right\\}(2\\pi)^{\\frac{M}{2}}\\left|A\\right|^{-\\frac{1}{2}}\n", "$$ \n", "\n", "所以,证据函数就可以写出来了 \n", "\n", "$$\n", "ln\\ p(t\\mid\\alpha,\\beta)=\\frac{M}{2}ln\\ \\alpha+\\frac{N}{2}ln\\ \\beta-E(m_N)-\\frac{1}{2}ln\\ |A|-\\frac{N}{2}ln(2\\pi)\n", "$$ \n", "\n", "接下来就是优化的问题了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.极大化证据函数\n", "\n", "首先考虑$p(t\\mid\\alpha,\\beta)$关于$\\alpha$的最大化,我们首先定义下面的特征向量方程: \n", "\n", "$$\n", "(\\beta\\Phi^T\\Phi)\\mu_i=\\lambda_i\\mu_i,i=1,2,...,M\n", "$$ \n", "\n", "由于$A=\\alpha\\ I+\\beta\\Phi^T\\Phi$,所以A的特征值为$\\alpha+\\lambda_i$,所以$ln|A|$对$\\alpha$的导数: \n", "\n", "$$\n", "\\frac{d}{d\\alpha}ln\\ |A|=\\frac{d}{d\\ln\\alpha}\\prod_i(\\lambda_i+\\alpha)=\\frac{d}{d\\ln\\alpha}\\sum_i ln(\\lambda_i+\\alpha)=\\sum_i\\frac{1}{\\lambda_i+\\alpha}\n", "$$ \n", "因此,$ln\\ p(t\\mid\\alpha,\\beta)$关于$\\alpha$的驻点满足: \n", "\n", "$$\n", "0=\\frac{M}{2\\alpha}-\\frac{1}{2}m_N^Tm_N-\\frac{1}{2}\\sum_i\\frac{1}{\\lambda_i+\\alpha}\n", "$$ \n", "\n", "我们令: \n", "\n", "$$\n", "\\gamma=\\sum_i\\frac{\\lambda_i}{\\lambda_i+\\alpha}\n", "$$ \n", "\n", "整理可能$\\alpha$满足: \n", "\n", "$$\n", "\\alpha=\\frac{\\gamma}{m_N^Tm_N}\n", "$$ \n", "\n", "注意,这里是个迭代公式,因为$\\gamma,m_N$均与$\\alpha$相关,接下来,我们看看$\\beta$的求解,由于特征值$\\lambda_i$正比于$\\beta$,所以$\\frac{d\\lambda_i}{d\\beta}=\\frac{1}{\\beta}$,所以: \n", "\n", "$$\n", "\\frac{d}{d\\beta}ln\\ |A|=\\frac{d}{d\\beta}\\sum_i ln(\\lambda_i+\\alpha)=\\frac{1}{\\beta}\\sum_i\\frac{\\lambda_i}{\\lambda_i+\\alpha}=\\frac{\\gamma}{\\beta}\n", "$$ \n", "\n", "类似地,在驻点处满足如下关系: \n", "\n", "$$\n", "0=\\frac{N}{2\\beta}-\\frac{1}{2}\\sum_{n=1}^N\\left\\{t_n-m_N^T\\phi(x_n)\\right\\}^2-\\frac{\\gamma}{2\\beta}\n", "$$ \n", "\n", "整理后可得: \n", "\n", "$$\n", "\\frac{1}{\\beta}=\\frac{1}{N-r}\\sum_{n=1}^N\\left\\{t_n-m_N^T\\phi(x_n)\\right\\}^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四.代码实现\n", "\n", "代码实现主要就是反复迭代上面推导出来的两个公式: \n", "\n", "$$\n", "\\alpha=\\frac{\\gamma}{m_N^Tm_N}\\\\\n", "\\frac{1}{\\beta}=\\frac{1}{N-r}\\sum_{n=1}^N\\left\\{t_n-m_N^T\\phi(x_n)\\right\\}^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "线性回归的bayes估计,封装到ml_models.bayes中\n", "\"\"\"\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "class LinearRegression(object):\n", " def __init__(self, basis_func=None, tol=1e-7, epochs=100, normalized=True):\n", " \"\"\"\n", " :param basis_func: list,基函数列表,包括rbf,sigmoid,poly_{num},linear,fm,默认None为linear,其中poly_{num}中的{num}表示多项式的最高阶数,fm表示构建交叉因子\n", " :param tol: 两次迭代参数平均绝对值变化小于tol则停止\n", " :param epochs: 默认迭代次数\n", " :param normalized:是否归一化\n", " \"\"\"\n", " if basis_func is None:\n", " self.basis_func = ['linear']\n", " else:\n", " self.basis_func = basis_func\n", " self.tol = tol\n", " self.epochs = epochs\n", " self.normalized = normalized\n", " # 特征均值、标准差\n", " self.feature_mean = None\n", " self.feature_std = None\n", " # 训练参数\n", " self.w = None\n", "\n", " def _map_basis(self, X):\n", " \"\"\"\n", " 将X进行基函数映射\n", " :param X:\n", " :return:\n", " \"\"\"\n", " n_sample, n_feature = X.shape\n", " x_list = []\n", " for basis_func in self.basis_func:\n", " if basis_func == \"linear\":\n", " x_list.append(X)\n", " elif basis_func == \"rbf\":\n", " x_list.append(np.exp(-0.5 * X * X))\n", " elif basis_func == \"sigmoid\":\n", " x_list.append(1 / (1 + np.exp(-1 * X)))\n", " elif basis_func.startswith(\"poly\"):\n", " p = int(basis_func.split(\"_\")[1])\n", " for pow in range(1, p + 1):\n", " x_list.append(np.power(X, pow))\n", " elif basis_func == 'fm':\n", " X_fm = np.zeros(shape=(n_sample, int(n_feature * (n_feature - 1) / 2)))\n", " c = 0\n", " for i in range(0, n_feature - 1):\n", " for j in range(i + 1, n_feature):\n", " X_fm[:, c] = X[:, i] * X[:, j]\n", " c += 1\n", " x_list.append(X_fm)\n", " return np.concatenate(x_list, axis=1)\n", "\n", " def fit(self, X, y):\n", " if self.normalized:\n", " self.feature_mean = np.mean(X, axis=0)\n", " self.feature_std = np.std(X, axis=0) + 1e-8\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " else:\n", " X_ = X\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " n_sample, n_feature = X_.shape\n", " alpha = 1\n", " beta = 1\n", " current_w = None\n", " for _ in range(0, self.epochs):\n", " A = alpha * np.eye(n_feature) + beta * X_.T @ X_\n", " self.w = beta * np.linalg.inv(A) @ X_.T @ y.reshape((-1, 1)) # 即m_N\n", " if current_w is not None and np.mean(np.abs(current_w - self.w)) < self.tol:\n", " break\n", " current_w = self.w\n", " # 更新alpha,beta\n", " if n_sample // n_feature >= 100:\n", " # 利用prml中的公式3.98与3.99进行简化计算,避免求特征值的开销\n", " alpha = n_feature / np.dot(self.w.reshape(-1), self.w.reshape(-1))\n", " beta = n_sample / np.sum(np.power(y.reshape(-1) - self.predict(X).reshape(-1), 2))\n", " else:\n", " gamma = 0.0\n", " for lamb in np.linalg.eig(beta * X_.T @ X_)[0]:\n", " gamma += lamb / (lamb + alpha)\n", " alpha = gamma.real / np.dot(self.w.reshape(-1), self.w.reshape(-1))\n", " beta = 1.0 / (\n", " 1.0 / (n_sample - gamma) * np.sum(np.power(y.reshape(-1) - self.predict(X).reshape(-1), 2)))\n", " #ml_models包中不会return,这里用作分析\n", " return alpha,beta\n", "\n", " def predict(self, X):\n", " if self.normalized:\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " else:\n", " X_ = X\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " return (self.w.T @ X_.T).reshape(-1)\n", "\n", " def plot_fit_boundary(self, x, y):\n", " \"\"\"\n", " 绘制拟合结果\n", " :param x:\n", " :param y:\n", " :return:\n", " \"\"\"\n", " plt.scatter(x[:, 0], y)\n", " plt.plot(x[:, 0], self.predict(x), 'r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 测试\n", "再在上一节的例子中测试一下" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#造伪样本\n", "X=np.linspace(0,100,100)\n", "X=np.c_[X,np.ones(100)]\n", "w=np.asarray([3,2])\n", "Y=X.dot(w)\n", "X=X.astype('float')\n", "Y=Y.astype('float')\n", "X[:,0]+=np.random.normal(size=(X[:,0].shape))*3#添加噪声\n", "Y=Y.reshape(100,1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "#添加噪声\n", "X=np.concatenate([X,np.asanyarray([[100,1],[101,1],[102,1],[103,1],[104,1]])])\n", "Y=np.concatenate([Y,np.asanyarray([[3000],[3300],[3600],[3800],[3900]])])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAG+1JREFUeJzt3X98VPWd7/HXJyFCsNqgoouBGrQU0eLPVKh2+4O2gugVblu3qHvL3Yd90N2H17pV8YKyq63rFstubd1H62OptqW19RdSZK276KLd9tpVCYIgKII/SUCJQqhK0BA+949zJjlJZpKZMJmZzPf9fDzyyJzvfGfme5jwfZ/5fr/njLk7IiISnopiN0BERIpDASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARqSLEb0JujjjrK6+rqit0MEZFBZc2aNW+5+8i+6pV0ANTV1dHQ0FDsZoiIDCpm9lo29bIeAjKzSjNba2YPxdtjzewpM9tiZvea2SFx+dB4e2t8f13iOebH5ZvNbGpuuyQiIvmUyxzAlcDzie1bgFvdfRywG7gsLr8M2O3uHwVujethZicBs4CTgWnAj82s8uCaLyIi/ZVVAJjZaOB84I5424ApwNK4yhJgZnx7RrxNfP/n4/ozgHvc/X13fwXYCpyVj50QEZHcZfsJ4AfAtcCBePtIoMXd98fbjUBtfLsW2AYQ378nrt9RnuYxIiJSYH0GgJldAOx09zXJ4jRVvY/7entM8vXmmFmDmTU0Nzf31TwREemnbFYBnQNcaGbTgWHA4USfCGrMbEh8lD8a2B7XbwTGAI1mNgT4MLArUZ6SfEwHd18MLAaor6/Xt9WISFCWr21i0crNbG9p5diaauZOHc/M0wdmsKTPTwDuPt/dR7t7HdEk7mPufinwOPCVuNps4MH49op4m/j+xzz62rEVwKx4ldBYYBzwdN72RERkkFu+ton5yzbQ1NKKA00trcxftoHla5sG5PUO5kzg/wtcZWZbicb474zL7wSOjMuvAuYBuPtG4D5gE/AfwOXu3n4Qry8iUlYWrdxMa1vXbrG1rZ1FKzcPyOvldCKYu/8O+F18+2XSrOJx933ARRkefzNwc66NFBEJwfaW1pzKD5auBSQiUiKOranOqfxglfSlIEREQpCa+G1qacXoujyyuqqSuVPHD8jrKgBERIooNfGbGvtPrZl3oHaAVwEpAEREiijdxG+q839i3pQBfW3NAYiIFFGhJ36TFAAiIkVU6InfJAWAiEgRzZ06nuqqrhdGHsiJ3yTNAYiIFFFqgrdQl39IUgCIiBTZzNNrC9Lhd6chIBGRQOkTgIhIERTyqp+ZKABERAqs+8lfqat+AgUNAQ0BiYgUWKGv+pmJAkBEpMCKefJXkgJARKTAinnyV5ICQESkwIp58leSJoFFRAqsmCd/JSkARESKoFgnfyVpCEhEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpfMAREQKpBQuAZ2kABARKYBSuQR0koaAREQKoFQuAZ2kABARKYBSuQR0kgJARKQASuUS0EkKABGRAiiVS0AnaRJYRKQASuUS0EkKABGRAimFS0AnaQhIRCRQCgARkUApAEREAqUAEBEJlAJARCRQfQaAmQ0zs6fN7Fkz22hm347Lx5rZU2a2xczuNbND4vKh8fbW+P66xHPNj8s3m9nUgdopERHpWzafAN4Hprj7qcBpwDQzmwzcAtzq7uOA3cBlcf3LgN3u/lHg1rgeZnYSMAs4GZgG/NjMup4VISJSZpavbeKchY8xdt5vOWfhYyxf21TsJnXoMwA88m68WRX/ODAFWBqXLwFmxrdnxNvE93/ezCwuv8fd33f3V4CtwFl52QsRkRKUugJoU0srTucVQEslBLKaAzCzSjNbB+wEHgVeAlrcfX9cpRFInd1QC2wDiO/fAxyZLE/zGBGRslOKVwBNyioA3L3d3U8DRhMdtU9IVy3+bRnuy1TehZnNMbMGM2tobm7OpnkiIiWpFK8AmpTTKiB3bwF+B0wGaswsdSmJ0cD2+HYjMAYgvv/DwK5keZrHJF9jsbvXu3v9yJEjc2meiEhJKcUrgCZlswpopJnVxLergS8AzwOPA1+Jq80GHoxvr4i3ie9/zN09Lp8VrxIaC4wDns7XjoiIlJpSvAJoUjYXgxsFLIlX7FQA97n7Q2a2CbjHzP4BWAvcGde/E/ilmW0lOvKfBeDuG83sPmATsB+43N3bEREpU6V4BdAkiw7OS1N9fb03NDQUuxkiIoOKma1x9/q+6ulMYBGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQC1WcAmNkYM3vczJ43s41mdmVcfoSZPWpmW+LfI+JyM7PbzGyrma03szMSzzU7rr/FzGYP3G6JiEhfsvkEsB+42t0nAJOBy83sJGAesMrdxwGr4m2A84Bx8c8c4HaIAgO4AZgEnAXckAoNEREpvD4DwN13uPsz8e13gOeBWmAGsCSutgSYGd+eAfzCI08CNWY2CpgKPOruu9x9N/AoMC2veyMiIlnLaQ7AzOqA04GngGPcfQdEIQEcHVerBbYlHtYYl2UqFxGRIsg6AMzsQ8ADwN+6+596q5qmzHsp7/46c8yswcwampubs22eiIjkKKsAMLMqos7/V+6+LC5+Mx7aIf69My5vBMYkHj4a2N5LeRfuvtjd6929fuTIkbnsi4iI5CCbVUAG3Ak87+7fT9y1Akit5JkNPJgo/1q8GmgysCceIloJnGtmI+LJ33PjMhERKYIhWdQ5B/hfwAYzWxeXXQcsBO4zs8uA14GL4vseBqYDW4G9wF8BuPsuM7sJWB3X+46778rLXoiISM7MvccwfMmor6/3hoaGYjdDRGRQMbM17l7fVz2dCSwiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKBUgCIiARKASAiEigFgIhIoBQAIiKB6jMAzOynZrbTzJ5LlB1hZo+a2Zb494i43MzsNjPbambrzeyMxGNmx/W3mNnsgdkdERHJVjafAH4OTOtWNg9Y5e7jgFXxNsB5wLj4Zw5wO0SBAdwATALOAm5IhYaIiBRHnwHg7r8HdnUrngEsiW8vAWYmyn/hkSeBGjMbBUwFHnX3Xe6+G3iUnqEiIiIF1N85gGPcfQdA/PvouLwW2Jao1xiXZSrvwczmmFmDmTU0Nzf3s3kiItKXfE8CW5oy76W8Z6H7Ynevd/f6kSNH5rVxIiLSqb8B8GY8tEP8e2dc3giMSdQbDWzvpVxERIqkvwGwAkit5JkNPJgo/1q8GmgysCceIloJnGtmI+LJ33PjMhERKZIhfVUws7uBzwJHmVkj0WqehcB9ZnYZ8DpwUVz9YWA6sBXYC/wVgLvvMrObgNVxve+4e/eJZRERKSBzTzsUXxLq6+u9oaGh2M0QERlUzGyNu9f3Va/PTwAiIlIge/fCxo2wYQPU1sLUqQP6cgoAEZFCa2+HrVujjj7189xzUVlqVOYv/kIBICIyaLnDI4/A/Pmwdm1n+bBhsG9fdLuiAj76UTjlFLj0Upg4Mbp9/PED3jwFgIhIPjz9NEyalF3dv/mbqKOfOBFOOgmGDx/YtmWgABARycXevfAv/xId1We7iObss+Ef/xE+85mBbVuOFAAiIum4w44d0fj8z34G996b3eOuvRZuvBGqqwe0efmgABAReeedaPXN+vVdJ2Z39XK60uWXw/XXw6hRhWtnnikARCQcb78NixbBLbdkrnPoofDxj8OXvtQ5Tj9xIhx1VOHaWSAKABEpP+6wYEE07p6Nm27q7Ojr6qKVOQFQAIjI4PbP/wzXXJPbY668Eq67Do4+uu+6ZUwBICKDw+uvw3HH5f64X/8aLr44/+0pAwoAESkt7e0wpJ9d03vvFW1N/WAUxkCXiJSmO+8Es64/2XT+zz4bjfN3/1HnnxMFgIgMvB074GMf69nZf/3rvT/uW99K39Gfckph2l3mNAQkIvnjDtu2wfe/Dz/8YfaPmzAB1q2DQw4ZuLZJDwoAEemflpauJ02lrmi5Z0/mx4wZA088Ef2WolMAiEjvdu+G730PFi7sLBs9GhobO7cPPzxaQ3/JJZ1r6Y89Fk49teDNlewpAEQksn9/dHmDxYv7rvvpT3c9S3bMmGhMXwYVBYBIiH77W7jggtwes2ABXH011NQMTJuk4BQAIuXs/ffhhRe6jtOvXAkHDvT+uH//d5g2rTBtlKJRAIiUA3d47bXOTj51VcsXX4yGdiBaYTNhQvRVg/fcE5XNnRuN7Qdy7RvpSgEgMtisXh19GcmqVdH2Jz8Zrb55553OOnV10dj8zJmdXzE4bhxUVUX33313wZstpUcBIFKqdu6Em2+G227rvd4hh8Ds2Z0TsiefHK3KEemDAkCk2Pbvj75xav786Hr12ZgwAb77XbjwQq2+kX5TAIgU0ltvdZ2QfeCBaJ19b26+Ga64Ag47rDBtlH5bvraJRSs3s72llWNrqpk7dTwzT68tdrMyMs/2S42LoL6+3hsaGordDJHcvfFG9CUjP/5xZ9mwYbBvX+f2kUdG6+fXrYu2Z82C73wnGquXkrB8bRM3rthIS2sbACOGV3HD/ziZmafXsmD5Bu5+ahvt7lSacfzI4Wzd+R7JHtUAB2prqvnciSN5/IXmgoSDma1x9/o+6ykARA7CBx/A3/99718xmDJ9OkyZ0jlW/2d/puGbbnI5gu5eN9XBNrW0UmlGuzu1iefIVL+37Xuf3kbbga59ZFWlcVbdCJ54qZfvC85CdVUl3/3SxAEJAQWASL41N0fDNnfckd0qGrNonP6KK4K6THGqo+3eEWfV4a7eRlt7Z59UVWks+sqpPTrJ5WubmL9sA61t7X22p7qqki+fWcsDa5qyqp+SOnofSLU11Twxb0ren1cBINJfe/fCpk09L3T25pu9P+6Pf4yWZJa4vo6EMx119zYckqyTbceci7+c/BH+YebEju1zFj5GU0tr1o9PBVGpMeCVhefn/3mzDABNAku42tvh5Zejzv2uu+A3v+lZZ9iwaFnleed1vfbNMceUzPBNtsMm3TtwgKaWVu568vUu23OXPsuNKzayp7WND1dXYQa797b1eL7de9uYu/RZgI7XW7Ryc947f6CjjakQ2J5D5w+UZOcPcGxNdVFfXwEgYZgzB37yk+zrL10adfQnnACVlXlpQrZH3sl6H66uoq39AO990NmpjhhexfmnjOoY704OVTS1tDJ/2QaAfh+Zt7V7R0gkwyJT3UUrN3e8Vq4dcy7ufmpbRwAcW1NdtE8A+ZwDmDt1fF7a1F8KACkvzzwDZ56Zff1vfpPffeJcbnoZXt5L3iYek+UVBt3mEdMeec9ftoGG13Z1GatO1wHv3tvW5bHdu7XWtvYunTIM3JE5dO30c+2Yc5HswOdOHZ91oBlw8aQxOc8BpOYOHnp2R1argCYfP4JX327NepK5FJaIag5ABqd9++BHP+LAvHlUpK51k4VfnzqNvzvvio7/tO3uaSf7DLi027jzQI1vJ+XrSLX72PLYeb8dsAnN5ETmwfwbVVUa7e1OpsvUVZrx0nend2wnw7hmeFXaYSronD/IdRVQKXTQ/aU5ABn0lq9t4vpl66nfvJol99+Qtk6mS5i98akpTD/7CnZZVc874w421dGm6xidnuPOA3kUnZKvYYruY8sDdWReVWldhjGScwHdPxXVHVnNH1/alXGdfOp5rv/Nhi5DXikXT+r6LWIzT6/tMczV2yR19/qiTwCSpUwTjd0/Bl88aQz1xx3R0QEkVRhcMqnrUXUX774LGzfChg289Nh/0/zEaia/vqHXdv3qtGncdvYs3jzsqI6y2rjzy0eHlzzqHMij6OTrHWwIpFtfnsuReQVkPApPSrcKqC/ZTlin+7vK+HcjPWgZqPSQy0k2yf+A6VRXVTJ6xDC27Hwv53b8ed3h/PKzI3teuvjllzvqtFYNY/NRH+G0HS92lF039XJ+fdp5fT5/am1Ovv6yX42HUnJdepir/q5Xh55H0plWAXU/Mjfr+EBETXUVN154MkCXSWgzaNnbNuiHRUKiABhkclmb3VvdmuFV7Gtrp7UtOoZLTUDWVFfx3gf7u5xkk+lMxAXLN3SZZOw3d/7Pf9/LNX+4C4AHJ3yGE3Y1Mu6t1xnaHo/XVlRElz5ILrE85RSOX7yJA9a/a9QP1CeAfM4BJFfyHMwqoHIYr5b8UwAMgOVrm/j2v23smGxKHTGl/tNmui+b5+2rY6msMA4bOoSW1ra8nqGY7kzEE+Y/nPMwxJmNm3jgV9f2We/3dafzwsg6dh43jgXzvxpd1bK651ro/h5tp0INyEtn3f0EpINdBZTL34VIf5XsJLCZTQN+CFQCd7j7wny/Rl+dcbqhECBtWfcTZ5JaWtu46r51zF+2vuOIO3nf3Pu7niSTSTaTi+0HOtdm5zOy063b7q3zr2n9E7+89++Y+OZLOb3ON2Zex8rxZ3dsG7DgjDMy1p87dTxX3/8s7d3XT3ZzzglHdFl61/1IOPl3kElqxQ/Q57izJhKlnBT0E4CZVQIvAl8EGoHVwMXuvild/f58Ali+tom5S5/tMtQBUFVhLLroVIAe91dYdITd/RokvS1Jy1Y21/ooxORiJhk/ARw4wKdeXccNqxYz7u1tWT/fv53451x9/lV8MCTN6ps+Xre75WubeqwISX36yWVisL+XPhAZrEr1E8BZwFZ3fxnAzO4BZgBpA6A/Fq3c3KPzB2g7EJ2xuLfbODhEH88PdCtL9xz9kc2ZkQN58kxvOs5EfOed6CsF40nZ/3z8ScZu7D1491Ufyv/83z/g+cNH9f91+5Cvo20dtYukV+gAqAWSh5ONwKR8vkBvHe72ltaCH2lnc62PXM5qPBjV3s5J773JqG1b+cSfGjn3QDOj7toCr77aWelDH2LsxIk0fOFL1P/nMp447hQWfXo2644d32WFyTDgGzkMpQ2mL8kQCUWhAyDd1bO69MlmNgeYA/CRj3wk5xfo7Wi60EfaVRWW9ZEuZDdenZLpiyZqhlex74P9fGbj/+Pa/1rC8bu3A7Bl1Akc/9Y2Kts+iJ6gshLGj4dJk+DrX+9cgXPccVBRQeqz4znxT6Z2p+vIsy0TkeIqdAA0AsnT+UYD25MV3H0xsBiiOYBcX2Du1PEZ5wDmTh3f66Rul/pZzAEMr6qgrd17fGEE5L7aI9WZZnst9Y6j6E2b4Lrr4MEHe33+cad9DCZ+ubOjP/FEGDo0q7aJSHkqdACsBsaZ2VigCZgFXJLPF0h3NN29M557/7NdOu2qCuOrZ43p0cECWV3/PJ/DGxnHq994A34yDx5+ONqen8WTTZsWfZ9sL6ttRCRcBT8PwMymAz8gWgb6U3e/OVPdgToPoKS/uLm9HT73OfjDH7J/TG1t9ClgzhwYoss7iYROJ4INBuvXw/XXw0MPZf+YuXNh3jw44oiBa5eIDGqlugw0TLt2RV8a/r3v5fa4226Lvk9WRGQAKADyyR3uuw+uuQYaG7N7TF0dLFwIF10UXRdHRKRA1OP01zPPwNSp0ffCmkVfHVhRAbNmZe78FyyA3bujoEj9vPIKfPWr6vxFpODU6/Tl1Vfh4os7O/rUz5lnwiOPdNZ7+22YPRsOPRRmzIjOqk129O5w001QU1O0XRERSdIQUEp7O/zrv8Lll2f/mBNOiIZvvvzlKBQAfv7zAWmeiEi+hRkAK1fCBRdADt8lyw03wFVXweGHD1y7REQKqLwDYM8euOSSzpOnsvXII/DFLw5Mm0RESkR5BsC778Jhh/Vd7/bb4Rvf6By+EREJSHlOAg8dCp/4ROf2pZf2XH3jDn/91+r8RSRY5fkJoKoKnn662K0QESlp5fkJQERE+qQAEBEJlAJARCRQCgARkUApAEREAqUAEBEJlAJARCRQCgARkUCV9FdCmlkz8Foen/Io4K08Pl+pCmE/Q9hH0H6Wk0Lu43HuPrKvSiUdAPlmZg3ZfE/mYBfCfoawj6D9LCeluI8aAhIRCZQCQEQkUKEFwOJiN6BAQtjPEPYRtJ/lpOT2Mag5ABER6RTaJwAREYkFEwBmNs3MNpvZVjObV+z25IOZjTGzx83seTPbaGZXxuVHmNmjZrYl/j2i2G3NBzOrNLO1ZvZQvD3WzJ6K9/NeMzuk2G08WGZWY2ZLzeyF+H39ZLm9n2b2rfjv9Tkzu9vMhpXDe2lmPzWznWb2XKIs7Xtnkdvi/mi9mZ1RjDYHEQBmVgn8CDgPOAm42MxOKm6r8mI/cLW7TwAmA5fH+zUPWOXu44BV8XY5uBJ4PrF9C3BrvJ+7gcuK0qr8+iHwH+5+InAq0f6WzftpZrXAN4F6d/84UAnMojzey58D07qVZXrvzgPGxT9zgNsL1MYugggA4Cxgq7u/7O4fAPcAM4rcpoPm7jvc/Zn49jtEnUUt0b4tiastAWYWp4X5Y2ajgfOBO+JtA6YAS+Mqg34/zexw4NPAnQDu/oG7t1B+7+cQoNrMhgDDgR2UwXvp7r8HdnUrzvTezQB+4ZEngRozG1WYlnYKJQBqgW2J7ca4rGyYWR1wOvAUcIy774AoJICji9eyvPkBcC1wIN4+Emhx9/3xdjm8p8cDzcDP4qGuO8zsUMro/XT3JuCfgNeJOv49wBrK771MyfTelUSfFEoApPvm97JZ/mRmHwIeAP7W3f9U7Pbkm5ldAOx09zXJ4jRVB/t7OgQ4A7jd3U8H3mMQD/ekE4+BzwDGAscChxINh3Q32N/LvpTE328oAdAIjElsjwa2F6kteWVmVUSd/6/cfVlc/Gbq42T8e2ex2pcn5wAXmtmrRMN3U4g+EdTEwwhQHu9pI9Do7k/F20uJAqGc3s8vAK+4e7O7twHLgLMpv/cyJdN7VxJ9UigBsBoYF680OIRo0mlFkdt00OJx8DuB5939+4m7VgCz49uzgQcL3bZ8cvf57j7a3euI3rvH3P1S4HHgK3G1ctjPN4BtZjY+Lvo8sInyej9fByab2fD47ze1j2X1XiZkeu9WAF+LVwNNBvakhooKyt2D+AGmAy8CLwHXF7s9edqnTxF9bFwPrIt/phONj68CtsS/jyh2W/O4z58FHopvHw88DWwF7geGFrt9edi/04CG+D1dDowot/cT+DbwAvAc8EtgaDm8l8DdRPMabURH+Jdleu+IhoB+FPdHG4hWRRW8zToTWEQkUKEMAYmISDcKABGRQCkAREQCpQAQEQmUAkBEJFAKABGRQCkAREQCpQAQEQnU/wcvX8gBqPvrlgAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr=LinearRegression()\n", "alpha,beta=lr.fit(X[:,:-1],Y)\n", "lr.plot_fit_boundary(X[:,:-1],Y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.080800093752806" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alpha/beta" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现证据近似最终选择的L2正则化系数为4,而从上一节的测试可以发现最优的L2正则化系数应该在100左右,可以发现证据近似与我们的理想情况还有一定的距离,接下来我们看看利用变分推断拟合的效果会怎样" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }