{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "上一节主要基于PRML中10.3节的内容,利用变分推断实现了贝叶斯线性回归,并且效果还不错,但是呢,还留下了一个超参数$\\beta$需要人工设定,这一节我们假设$\\beta$也来源于一个Gamma分布,然后利用变分推断来求解,所以接下来的推导会和上一节有很多相同的地方,但为了完整性还是将这部分内容copy过来" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.联合概率分布\n", "贝叶斯线性回归的图模型如下 \n", "![avatar](./source/15_VI_贝叶斯线性回归图模型.png)\n", "我们的似然函数和共轭先验可以假设如下: \n", "\n", "$$\n", "p(t\\mid w,\\beta)=\\prod_{n=1}^N N(t_n\\mid w^T\\phi(x_n),\\beta^{-1})\\\\\n", "p(w\\mid \\alpha)=N(w\\mid 0,\\alpha^{-1}I)\\\\\n", "p(\\alpha)=Gam(\\alpha\\mid a_0,b_0)\\\\\n", "p(\\beta)=Gam(\\beta\\mid a_1,b_1)\n", "$$ \n", "这里,$X=\\{x_1,x_2,...,x_N\\}$,$I$为单位矩阵,所以所有变量上的联合概率分布为: \n", "\n", "$$\n", "p(t,w,\\alpha,\\beta)=p(t\\mid w,\\beta)p(w\\mid\\alpha)p(\\alpha)p(\\beta)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.变分分布\n", "接下来,我们就是要构建一个变分分布$q(w,\\alpha,\\beta)$去近似后验概率分布$p(w,\\alpha,\\beta\\mid t)$,利用平均场假设,令: \n", "\n", "$$\n", "q(w,\\alpha,\\beta)=q(w)q(\\alpha)q(\\beta)\n", "$$ \n", "\n", "根据VI第一节最后推导的公式,我们可以直接写出最优解: \n", "\n", "$$\n", "ln\\ q^*(\\alpha)=ln\\ p(\\alpha)+E_w[ln\\ p(w\\mid\\alpha)]+const\\\\\n", "=(a_0-1)ln\\ \\alpha-b_0\\alpha+\\frac{M}{2}ln\\ \\alpha-\\frac{\\alpha}{2}E[w^Tw]+const\n", "$$ \n", "\n", "可以发现这个Gamma分布取对数的形式,通过观察系数项,我们有 \n", "\n", "$$\n", "q^*(\\alpha)=Gam(a\\mid a_N,b_N)\n", "$$ \n", "\n", "其中 \n", "\n", "$$\n", "a_N=a_0+\\frac{M}{2}\\\\\n", "b_N=b_0+\\frac{1}{2}E[w^Tw]\n", "$$ \n", "\n", "同样的方式,我们可以求解$q^*(w)$: \n", "\n", "$$\n", "ln\\ q^*(w)=E_\\beta[ln\\ p(t\\mid w,\\beta)]+E_\\alpha[ln\\ p(w\\mid \\alpha)]+const\\\\\n", "=-\\frac{E[\\beta]}{2}\\sum_{n=1}^N\\left \\{w^T\\phi(x_n)-t_n\\right \\}^2-\\frac{1}{2}E[\\alpha]w^Tw+const\\\\\n", "=-\\frac{1}{2}w^T(E[\\alpha]I+E[\\beta]\\Phi^T\\Phi)w+E[\\beta] w^T\\Phi^Tt+const\n", "$$ \n", "\n", "可以发现这是关于$w$的二次型,所以$q^*(w)$是一个高斯分布,通过配方可以得到: \n", "\n", "$$\n", "q^*(w)=N(w\\mid m_N,S_N)\n", "$$ \n", "\n", "其中 \n", "\n", "$$\n", "m_N=E[\\beta] S_N\\Phi^Tt\\\\\n", "S_N=(E[\\alpha]I+E[\\beta]\\Phi^T\\Phi)^{-1}\n", "$$ \n", "\n", "可以发现,相比于上一节,$\\alpha,\\beta$分别变为了$E[\\alpha],E[\\beta]$ \n", "\n", "接下来就是$q^*(\\beta)$:\n", "\n", "$$\n", "ln\\ q^*(\\beta)=ln\\ p(\\beta)+E_w[ln\\ p(t\\mid w,\\beta)]+const\\\\\n", "=(a_1-1)ln\\ \\beta-b_1\\beta+\\frac{N}{2}ln\\ \\beta-\\frac{\\beta}{2}E_w[(t-\\Phi w^T)^T(t-\\Phi w^T)]\n", "$$ \n", "\n", "显然,这也是一个Gamma分布取对数的形式,它的形式如下: \n", "$$\n", "q^*(\\beta)=Gam(\\beta\\mid a_M,b_M)\n", "$$ \n", "其中 \n", "\n", "$$\n", "a_M=a_1+\\frac{N}{2}\\\\\n", "b_M=b_1+\\frac{1}{2}E_w[(t-\\Phi w^T)^T(t-\\Phi w^T)]\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.迭代求解\n", "\n", "同样地,$q^*(w),q^*(\\alpha),q^*(\\beta)$之间互相耦合,$q^*(w)$中需要用到$E[\\alpha],E[\\beta]$,$q^*(\\alpha)$中需要用到$E[w^Tw]$,而$q^*(\\beta)$需要用到$E_w[(t-\\Phi w^T)^T(t-\\Phi w^T)]$,根据Gamma分布和高斯分布的性质,我们可以写出这几个期望的求解公式 \n", "\n", "$$\n", "E[\\alpha]=\\frac{a_N}{b_N}\\\\\n", "E[\\beta]=\\frac{a_M}{b_M}\\\\\n", "E[w^Tw]=m_N^Tm_N+Tr(S_N)\\\\\n", "E_w[(t-\\Phi w^T)^T(t-\\Phi w^T)]=E_w[w^T\\Phi^T\\Phi w]-2E_w[w^T\\Phi^Tt]+t^Tt\\\\\n", "=m_N^T\\Phi^T\\Phi m_N+Tr(\\Phi S_N\\Phi^T)-2m_N^T\\Phi^Tt+t^Tt\n", "$$ \n", "\n", "那么,我们初始随意为$E[\\alpha],E[\\beta]$赋予一个大于0的值,就可以迭代起来了:$E_0[\\alpha]\\rightarrow E_0[\\beta]\\rightarrow E_0[w^Tw]\\rightarrow E_0[(t-\\Phi w^T)^T(t-\\Phi w^T)] \\rightarrow \\cdots$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四.预测分布\n", "\n", "该过程与上一节一样 \n", "\n", "$$\n", "p(\\hat{t}\\mid \\hat{x},t)=\\int p(\\hat{t}\\mid\\hat{x},w)p(w\\mid t)dw\\\\\n", "\\simeq \\int p(\\hat{t}\\mid\\hat{x},w)q(w)dw\\\\\n", "=\\int N(\\hat{t}\\mid w^T\\phi(\\hat{x}),\\beta^{-1})N(w\\mid m_N,S_N)dw\\\\\n", "=N(\\hat{t}\\mid m_N^T\\phi(\\hat{x}),\\sigma^2(\\hat{x}))\n", "$$ \n", "\n", "这里$\\sigma^2(\\hat{x})=\\frac{1}{\\beta}+\\phi(\\hat{x})^TS_N\\phi(\\hat{x})$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 五.代码实现" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "代码同样封装到ml_models.vi中,上一vi版本的类名改为LinearRegression1\n", "\"\"\"\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "class LinearRegression(object):\n", " def __init__(self, basis_func=None, tol=1e-5, epochs=100):\n", " \"\"\"\n", " :param basis_func: list,基函数列表,包括rbf,sigmoid,poly_{num},linear,fm,默认None为linear,其中poly_{num}中的{num}表示多项式的最高阶数,fm表示构建交叉因子\n", " :param tol: 两次迭代参数平均绝对值变化小于tol则停止\n", " :param epochs: 默认迭代次数\n", " \"\"\"\n", " if basis_func is None:\n", " self.basis_func = ['linear']\n", " else:\n", " self.basis_func = basis_func\n", " self.tol = tol\n", " self.epochs = epochs\n", " # 特征均值、标准差\n", " self.feature_mean = None\n", " self.feature_std = None\n", " # 训练参数,也就是m_N\n", " self.w = None\n", "\n", " def _map_basis(self, X):\n", " \"\"\"\n", " 将X进行基函数映射\n", " :param X:\n", " :return:\n", " \"\"\"\n", " n_sample, n_feature = X.shape\n", " x_list = []\n", " for basis_func in self.basis_func:\n", " if basis_func == \"linear\":\n", " x_list.append(X)\n", " elif basis_func == \"rbf\":\n", " x_list.append(np.exp(-0.5 * X * X))\n", " elif basis_func == \"sigmoid\":\n", " x_list.append(1 / (1 + np.exp(-1 * X)))\n", " elif basis_func.startswith(\"poly\"):\n", " p = int(basis_func.split(\"_\")[1])\n", " for pow in range(1, p + 1):\n", " x_list.append(np.power(X, pow))\n", " elif basis_func == 'fm':\n", " X_fm = np.zeros(shape=(n_sample, int(n_feature * (n_feature - 1) / 2)))\n", " c = 0\n", " for i in range(0, n_feature - 1):\n", " for j in range(i + 1, n_feature):\n", " X_fm[:, c] = X[:, i] * X[:, j]\n", " c += 1\n", " x_list.append(X_fm)\n", " return np.concatenate(x_list, axis=1)\n", "\n", " def fit(self, X, y):\n", " self.feature_mean = np.mean(X, axis=0)\n", " self.feature_std = np.std(X, axis=0) + 1e-8\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " n_sample, n_feature = X_.shape\n", "\n", " E_alpha = 1.0\n", " E_beta = 1.0\n", " current_w = None\n", " for _ in range(0, self.epochs):\n", " S_N = np.linalg.inv(E_alpha * np.eye(n_feature) + E_beta * X_.T @ X_)\n", " self.w = E_beta * S_N @ X_.T @ y.reshape((-1, 1)) # 即m_N\n", " if current_w is not None and np.mean(np.abs(current_w - self.w)) < self.tol:\n", " break\n", " current_w = self.w\n", " # 更新 E_alph,E_beta\n", " E_w = (self.w.T @ self.w)[0][0] + np.trace(S_N)\n", " E_alpha = (n_feature - 1) / E_w # 这里假设a_0,b_0都为0\n", " E_w_phi = np.dot(y.reshape(-1), y.reshape(-1)) + np.sum(X_ @ S_N * X_) + (self.w.T @ X_.T @ X_ @ self.w)[0][\n", " 0] - 2 * (self.w.T @ X_.T @ y.reshape((-1, 1)))[0][0]\n", " E_beta = (n_sample - 1) / E_w_phi # 这里假设a_1,b_1都为0\n", " \n", " # ml_models.vi包中,没有这一部分\n", " return E_alpha, E_beta\n", "\n", " def predict(self, X):\n", " X_ = (X - self.feature_mean) / self.feature_std\n", " X_ = self._map_basis(X_)\n", " X_ = np.c_[np.ones(X_.shape[0]), X_]\n", " return (self.w.T @ X_.T).reshape(-1)\n", "\n", " def plot_fit_boundary(self, x, y):\n", " \"\"\"\n", " 绘制拟合结果\n", " :param x:\n", " :param y:\n", " :return:\n", " \"\"\"\n", " plt.scatter(x[:, 0], y)\n", " plt.plot(x[:, 0], self.predict(x), 'r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 六.测试" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#造伪样本\n", "X=np.linspace(0,100,100)\n", "X=np.c_[X,np.ones(100)]\n", "w=np.asarray([3,2])\n", "Y=X.dot(w)\n", "X=X.astype('float')\n", "Y=Y.astype('float')\n", "X[:,0]+=np.random.normal(size=(X[:,0].shape))*3#添加噪声\n", "Y=Y.reshape(100,1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#加噪声\n", "X=np.concatenate([X,np.asanyarray([[100,1],[101,1],[102,1],[103,1],[104,1]])])\n", "Y=np.concatenate([Y,np.asanyarray([[3000],[3300],[3600],[3800],[3900]])])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAG61JREFUeJzt3X+UVPV9//Hne5cRFlQWcP3BAi4qByNi1KxKJU0tMYI/IlRNYqNfaWO/xNRzNDahgdQTNT+xtFWTJvZoTEtaj6KGg6htkaAxif0KLqL8EFYI/toFXSwsqCy4wPv7x72Dd3dndmbY+bVzX49z9uzMZz5353O9+HnN/Xw+9465OyIiEj9VpW6AiIiUhgJARCSmFAAiIjGlABARiSkFgIhITCkARERiSgEgIhJTCgARkZhSAIiIxNSAUjegN8ccc4w3NDSUuhkiIv3KqlWr3nP3ukz1yjoAGhoaaGpqKnUzRET6FTN7M5t6WQ8BmVm1ma02syfD52PNbIWZbTKzhWZ2RFg+MHy+OXy9IfI35oblzWY2NbddEhGRfMplDuBmYEPk+Z3AXe4+DtgJXB+WXw/sdPdTgLvCepjZacDVwARgGvAzM6vuW/NFRORwZRUAZjYKuBT4efjcgCnAY2GVBcCM8PH08Dnh658N608HHnb3fe7+OrAZODcfOyEiIrnL9gzgbuBvgYPh8xFAu7vvD5+3APXh43rgbYDw9V1h/UPlKbYREZEiyxgAZnYZ0Obuq6LFKap6htd62yb6frPMrMnMmrZv356peSIicpiyWQU0GbjczC4BBgFHE5wR1JrZgPBT/ihga1i/BRgNtJjZAGAosCNSnhTd5hB3vw+4D6CxsVHfViMisbJ4dSvzlzaztb2DkbU1zJ46nhlnFWawJOMZgLvPdfdR7t5AMIn7jLtfAzwLXBVWmwk8Hj5eEj4nfP0ZD752bAlwdbhKaCwwDliZtz0REennFq9uZe6itbS2d+BAa3sHcxetZfHq1oK8X1+uBP4W8DdmtplgjP+BsPwBYERY/jfAHAB3Xw88ArwK/Ddwo7sf6MP7i4hUlPlLm+no7NotdnQeYP7S5oK8X04Xgrn7b4DfhI+3kGIVj7vvBb6QZvsfAD/ItZEiInGwtb0jp/K+0r2ARETKxMjampzK+6qsbwUhIhIHyYnf1vYOjK7LI2sS1cyeOr4g76sAEBEpoeTEb3LsP7lm3oH6Aq8CUgCIiJRQqonfZOf//JwpBX1vzQGIiJRQsSd+oxQAIiIlVOyJ3ygFgIhICc2eOp6aRNcbIxdy4jdKcwAiIiWUnOAt1u0fohQAIiIlNuOs+qJ0+N1pCEhEJKZ0BiAiUgLFvOtnOgoAEZEi637xV/Kun0BRQ0BDQCIiRVbsu36mowAQESmyUl78FaUAEBEpslJe/BWlABARKbJSXvwVpUlgEZEiK+XFX1EKABGREijVxV9RGgISEYkpBYCISEwpAEREYkoBICISUwoAEZGYUgCIiMSUAkBEJKZ0HYCISJGUwy2goxQAIiJFUC63gI7SEJCISBGUyy2goxQAIiJFUC63gI5SAIiIFEG53AI6SgEgIlIE5XIL6ChNAouIFEG53AI6SgEgIlIk5XAL6CgNAYmIxJQCQEQkphQAIiIxpQAQEYkpBYCISExlDAAzG2RmK83sFTNbb2Z3hOVjzWyFmW0ys4VmdkRYPjB8vjl8vSHyt+aG5c1mNrVQOyUiIpllcwawD5ji7p8EzgSmmdkk4E7gLncfB+wErg/rXw/sdPdTgLvCepjZacDVwARgGvAzM+t6VYSISIVZvLqVyfOeYeycp5g87xkWr24tdZMOyRgAHvggfJoIfxyYAjwWli8AZoSPp4fPCV//rJlZWP6wu+9z99eBzcC5edkLEZEylLwDaGt7B87HdwAtlxDIag7AzKrN7GWgDVgG/AFod/f9YZUWIHl1Qz3wNkD4+i5gRLQ8xTYiIhWnHO8AGpVVALj7AXc/ExhF8Kn9E6mqhb8tzWvpyrsws1lm1mRmTdu3b8+meSIiZakc7wAaldMqIHdvB34DTAJqzSx5K4lRwNbwcQswGiB8fSiwI1qeYpvoe9zn7o3u3lhXV5dL80REyko53gE0KptVQHVmVhs+rgEuBDYAzwJXhdVmAo+Hj5eEzwlff8bdPSy/OlwlNBYYB6zM146IiJSbcrwDaFQ2N4M7AVgQrtipAh5x9yfN7FXgYTP7PrAaeCCs/wDw72a2meCT/9UA7r7ezB4BXgX2Aze6+wFERCpUOd4BNMqCD+flqbGx0ZuamkrdDBGRfsXMVrl7Y6Z6uhJYRCSmFAAiIjGlABARiSkFgIhITCkARERiSgEgIhJTCgARkZhSAIiIxJQCQEQkphQAIiIxpQAQEYkpBYCISEwpAEREYkoBICISUwoAEZGYUgCIiMSUAkBEJKYUACIiMaUAEBGJKQWAiEhMKQBERGJKASAiElMKABGRmFIAiIjElAJARCSmFAAiIjGlABARiSkFgIhITCkARERiSgEgIhJTCgARkZhSAIiIxJQCQEQkphQAIiIxpQAQEYkpBYCISEwpAEREYkoBICISUxkDwMxGm9mzZrbBzNab2c1h+XAzW2Zmm8Lfw8JyM7Mfm9lmM1tjZmdH/tbMsP4mM5tZuN0SEZFMsjkD2A98w90/AUwCbjSz04A5wHJ3HwcsD58DXAyMC39mAfdCEBjAbcB5wLnAbcnQEBGR4ssYAO6+zd1fCh+/D2wA6oHpwIKw2gJgRvh4OvBLD7wA1JrZCcBUYJm773D3ncAyYFpe90ZERLKW0xyAmTUAZwErgOPcfRsEIQEcG1arB96ObNYSlqUrFxGREsg6AMzsSOBXwNfdfXdvVVOUeS/l3d9nlpk1mVnT9u3bs22eiIjkKKsAMLMEQef/oLsvCovfDYd2CH+3heUtwOjI5qOArb2Ud+Hu97l7o7s31tXV5bIvIiKSg2xWARnwALDB3f8p8tISILmSZybweKT8unA10CRgVzhEtBS4yMyGhZO/F4VlIiJSAgOyqDMZ+D/AWjN7OSz7NjAPeMTMrgfeAr4QvvafwCXAZmAP8JcA7r7DzL4HvBjW+66778jLXoiISM7MvccwfNlobGz0pqamUjdDRKRfMbNV7t6YqZ6uBBYRiSkFgIhITCkARERiSgEgIhJTCgARkZhSAIiIxJQCQEQkphQAIiIxpQAQEYkpBYCISEwpAEREYkoBICISUwoAEZGYUgCIiMSUAkBEJKYUACIiMaUAEBGJKQWAiEhMKQBERGJKASAiElMKABGRmFIAiIjElAJARCSmFAAiIjGlABARiSkFgIhITCkARERiSgEgIhJTCgARkZhSAIiIxJQCQEQkphQAIiIxpQAQEYkpBYCISEwpAEREYkoBICISUwoAEZGYUgCIiMSUAkBEJKYyBoCZ/cLM2sxsXaRsuJktM7NN4e9hYbmZ2Y/NbLOZrTGzsyPbzAzrbzKzmYXZHRERyVY2ZwD/BkzrVjYHWO7u44Dl4XOAi4Fx4c8s4F4IAgO4DTgPOBe4LRkaIiJSGhkDwN1/C+zoVjwdWBA+XgDMiJT/0gMvALVmdgIwFVjm7jvcfSewjJ6hIiIiRXS4cwDHufs2gPD3sWF5PfB2pF5LWJauvAczm2VmTWbWtH379sNsnoiIZJLvSWBLUea9lPcsdL/P3RvdvbGuri6vjRMRkY8dbgC8Gw7tEP5uC8tbgNGReqOArb2Ui4hIiRxuACwBkit5ZgKPR8qvC1cDTQJ2hUNES4GLzGxYOPl7UVgmIiIlMiBTBTN7CLgAOMbMWghW88wDHjGz64G3gC+E1f8TuATYDOwB/hLA3XeY2feAF8N633X37hPLIiJSROaecii+LDQ2NnpTU1OpmyEi0q+Y2Sp3b8xUL+MZgIiIFJg7bNsGa9cGP2vWQGMj3HRTQd9WASAiUkwffADr1n3c0Sc7/R2RUfGRI2HMmII3RQEgIlIIe/bAT34Cc+cGn/BTGTIETj8drrwSJk78+GfEiKI0UQEgItIX7vDd78L990Nra+b6558Ps2fDGWdAQwNUle6enAoAEZFsPfQQfPnLuW1z441w661w/PGFaVMfKABERLrbtw/uvBNuuy237X73O/j0pwvTpgJQAIhIvLW1wfLluX+y37ULjj66MG0qEn0hjIjEw9atcMMNYNb157jjeu/8ly8Pxvm7//Tzzh90BiAilaazM5iQnTMH3n8/c/158+Dss+Goo2DSpMK3r4woAESk/3rvvWAN/eOPwz33ZK6fSAQd/l//NQwaVPj2lTkFgIiUv7Y2eOstePXVjy+eWrMG3nkn/TYzZ8LttwdLLSUlBYCIlI/OTliwILh46r33UtcZOBAmTICpU4OLps44I/h93HHBmL5kTQEgIqXR1ASXXALZfvPfrFlwyy1wyikwQF1XPui/oogU1htvwHnnBcM42br66uDq2nHjCtYsUQCISL64w8aNwUqa3buz22bMmGCs/o47NHxTAgoAEcnd0qVw993B0snhw4NJ2XXreu/4J04M1tTru77LhgJARNJrbYXJk+HNN9PXqa0NOvdrrw1W3DQ3B8sszz67aM2Uw6MAEBE4cCCYYP3JT3Lb7tln4U/+RMM3/ZQCQCRunn46WEKZi1NOgeeeC76oRCqG7gUkUqna2oIvG+l+75tMnf+SJT3ve7Npkzr/CqQAEOnv3OHb3059k7P169Nv99WvBhdede/sP//54rVdSkpDQCL9yXPPwRVXdP3+2EymTIFHHina1wxK/6EAEClHO3YEQzVNTbltt3AhfPGLhWmTVBwNAYmUkjs89VSwjDI6fDNiRO+d//33B99a1X34Rp2/5EBnACLFsnZt8DWDDz6Y/Tbnnw+LFgXj+VJUi1e3Mn9pM1vbOxhZW8PsqeOZcVZ9qZuVV+bupW5DWo2Njd6U6ymwSKnt3g2XXx6M12frM5+BH/4wuOhKiipVR9/05g4efOEteusd68O6QFZBUcxAMbNV7t6YsZ4CQKQP7r47uIAqF/feC1/5ChxxRGHaFAPRznRoTQIzaN/T2aVj7a3DTb7W2t6BQZeOPlFldB7Mrl9MVBkHgQOR+olqY/5Vn+zSuS9e3crcRWvp6DxwqKwmUc2PrphYkBBQAIjk0yuvwJe+FNzmIFvV1fDyy8Fa/JjI9Cm3t467YUQNL2zZyQF3qs348/NG8/0ZE7l18VoeWvH2ofJJJw3jpbd2delMo2oS1Vz5qXp+tao1ZYcL9OiM823Y4ASrv3PRoeeT5z1Da3tHj3r1tTU8P2dK3t9fASByOPbsgauugv/6r9y2+9nP4GtfK0yb8qiQwxC3Ll7Lf7zwVsrXhg1OcOkZJ7DwxbfpPJB9nzPu2CFsavsw57ZUm3EgRd9WX1sDkLIzzrc35l166PHYOU+lHE4y4PVIvXzJNgA0CSzxdf/9wZeM5OLuu+GGG4JvpSqibIY8Mm3z0f4D7Ok8eOj11vYO5i5aC5BxyKT73719yXraOzqBjzv3dJ0/wM49nb2+ns7hdP5Ays4fYGsROv5URtbWpAydkWEglYrOAKTy7d0Lv/89fO5z2W9TVQVbtsCJJxauXRGpOl8g7Th1KvXdxr9nP/ZKVp+2k9tlM0a9eHUrsx99Jesx8lLpyxlAosrAyOlMpbvamgQv3/bxEFC5zgHoDEAqx8GDwbdPJb80PPnz2mvBa+n8/d/DN7+Z1R0tM3XUyY4n2xUi3T9NQ9A5fePRV7pMLGbTFUU/0d/xxPqsO7Ct7R3MX9rcY0y8o/MA85c2d2nv/KXNZd/59zYHkDwm3TvjZMCmO277Dxzg3fc/6vFe444dwhvv7eny3yRRZdx++YQu9ZL/DcttWanOAKR/WrkS5swJbkd86aXBF4ivWwcfRoYMTjoJJk5k47ENLNxVQ2f7bv5n8qWcf8Zont24/dD/iH96al2P50++sq3HEEe0LClRbeCk7BTTfZKsrUkc6iAKMRlZn2a4obf6W9s7shqjTjeWXSjp5gAmnzycN/63o8+rgHLpjLtPRicnqcvxegFNAku/k+rT8Oj2d3hiwdep3ftB7xtfcEFwNe3EiXDGGTBhAos37erx98pFTaKaQYkqdu7Jf9uyGS6KtuNHV0w8dAbTXfdVKulWs2Sjusrwg06qc7H6HFYBJcslPQ0BSVFEP/3UDk7gDrs6Ont8su7+WvfT7KGDBnDUu6185ZWnuen/LczqvXefPJ7vT76WR48/k5HDBvf4hFfopX590dF5oGBtG1lbw4f79mcMvmGDE9z2+QmH/pulGqNOHqek2VPHp5wDSFQbXzpnNItWtXSZaO7+XpD7MMj3Z0xUh18gOgOIiXRj191XcyQ7hHRXR0Y/iZ1UN5jNbR/mPCRw9N4PmPC/bzK+7Q3Gtb3O+O1vMn77Gxz1UfpPlr878Uy+esXfseeImqwmLfvySbU/i651TzdZW5+m4+3LKqBokEjpaQioTKXriO94Yv2h4YDkGHEu/0N1/7t/emodT63ZlvMQQ/KTXPcJtCqDXOf+Egc6mbhtM4senJ22TvugI2mua2Bj3Yk01zWwacRo3hsyjNeHp993I/2yuuSQRbHGqg9nDiCptibBvv0He5wJDDmimo/2H+zyNw04/+ThvLrt/S7HdHCiioGJ6pRLQstxbFqKQwGQJ7l8Kkq3Tjs5FJKqw6oyMLMuKz6i0n1ai7p18dqM9y3JRboldGm5c+nG3zP3N//KqN1tAGw85kRO3tFC4mDXzq35mDEsOn0Kzcc0sLGugXeOGpHz98lmM2lZjDOA7sMa6VYBRcM9KfpJPdW/L3Xe0hdlGwBmNg24B6gGfu7u89LVPdwAOJxOO9WSvlS6r91NNXGZb72tF168upVbFr5ctJUZZ7ds4MdPzD/U0afz65PPobmugS3DRzF07wc8esaFvD9wSJ/f34C7vnRmxknLdHMAyRU90bmJXXs6u0xMVgGDElWHxrK7b3M4HbI6dCmmspwENrNq4KfA54AW4EUzW+Lur+brPbr/j9/9asdknVRrr7++8OWMQx3RtdHFmmhMtR47af7S5rx3/tVm1Oz9kP+7chE3/8/DWW/39LhJ/MMfX8trdQ0Z66YaHklUGYlqSzmJCEHnf82kMVlNWuay7roYnfOMs+rV4UvZKfYqoHOBze6+BcDMHgamA3kLgEwXtGTqtLMZ505eTp7qvQol3SXsfb603Z0/emsNP1z6z4zduS3rzW76/Gye+MQf49b1O4UGJ6qwzoM5rwLKdOOw7q9n08Fn2+mqc5a4KnYA1ANvR563AOfl8w0ydZT56LST9+8o5n1F0t0zJN1kaCpjd7Tyzd/+kkubnwdgxejTObXtdYbuS3+/lXvPu4rHL7qGG/7snB6rgM4/adihi3EO95Nzb/UzdczquEX6ptgBkGq2r8tnbjObBcwCGDNmTM5vkOmmS33ttA0OfXrNpfPtzbWTxvS6YifVeuykVMshB3+0l9u3PM0XH7+v1/c90vez47I/Y2j7NhgwAH70IzjrrC51vhb+QNDhaj22SOUodgC0AKMjz0cBW6MV3P0+4D4IJoFzfYN068Pz1WlHx6BTvRf0XJqXbklmcky7+4UuWY9JuzOjdTVTHvomR/8hy/vUn3BC0NFfey0Tqqtz2ncRqSxFXQVkZgOA14DPAq3Ai8CX3X19qvqFWAWUaQ4gUWUcOWgAO/d0drmkPt3FLrlMIPZpsnHLFvjOd3L7PtlbboGbby7aHS1FpDyU8zLQS4C7CZaB/sLdf5CubqGuAzice6sXza5dwc3Nnn8++20uuCD4VD9pUsGaJSL9R9kGQC7K4UKwgnGHf/xHmJ3+Ktkehg2DO++Ev/gLSCQK1jQR6d/K8jqA2HrpJfjUp3LbZuDA4PtkTz21MG0SkdirylxFsrZvH9xzD9TUBLc3SP5k6vz/5V+CM4Loz9696vxFpKB0BnC4nnsOvvUtWLEi+20uuwwWLoTBgwvXLhGRLCkAMmlrC75m8Lwcrlf7q78KVuyMHp25rohIiWgIKKmzE+69F4YO7Tp8c9xx6Tv/xkZYtqzn8M3996vzF5GyF88zgNdegxtuCL5PNhs33BAszbzwQhg0qLBtExEpksoOgN27YdGi4OKpX/86u22uuw5uvx3Gji1o00RESq0yA+CDD+Coo9K/fvzxwZeH33ILTJuW8xeSiIhUgsoMgIED4Zxz4MUX4cor4Zpr4OKLNXwjIhJRmQGQSMDKlaVuhYhIWdMqIBGRmFIAiIjElAJARCSmFAAiIjGlABARiSkFgIhITCkARERiSgEgIhJTZf2VkGa2HXizAH/6GOC9AvzdcqJ9rBxx2E/tY36d6O51mSqVdQAUipk1ZfN9mf2Z9rFyxGE/tY+loSEgEZGYUgCIiMRUXAPgvlI3oAi0j5UjDvupfSyBWM4BiIhIfM8ARERiL1YBYGbTzKzZzDab2ZxStycfzGy0mT1rZhvMbL2Z3RyWDzezZWa2Kfw9rNRt7Sszqzaz1Wb2ZPh8rJmtCPdxoZkdUeo29pWZ1ZrZY2a2MTymf1Rpx9LMbgn/ra4zs4fMbFAlHEsz+4WZtZnZukhZymNngR+HfdEaMzu7FG2OTQCYWTXwU+Bi4DTgz83stNK2Ki/2A99w908Ak4Abw/2aAyx393HA8vB5f3czsCHy/E7grnAfdwLXl6RV+XUP8N/ufirwSYL9rZhjaWb1wE1Ao7ufDlQDV1MZx/LfgGndytIdu4uBceHPLODeIrWxi9gEAHAusNndt7j7R8DDwPQSt6nP3H2bu78UPn6foMOoJ9i3BWG1BcCM0rQwP8xsFHAp8PPwuQFTgMfCKpWwj0cDnwEeAHD3j9y9nQo7lgTfRFhjZgOAwcA2KuBYuvtvgR3ditMdu+nALz3wAlBrZicUp6Ufi1MA1ANvR563hGUVw8wagLOAFcBx7r4NgpAAji1dy/LibuBvgYPh8xFAu7vvD59XwvE8CdgO/Gs41PVzMxtCBR1Ld28F/gF4i6Dj3wWsovKOZVK6Y1cW/VGcAsBSlFXMEigzOxL4FfB1d99d6vbkk5ldBrS5+6pocYqq/f14DgDOBu5197OAD+nHwz2phGPg04GxwEhgCMFwSHf9/VhmUhb/fuMUAC3A6MjzUcDWErUlr8wsQdD5P+jui8Lid5OnlOHvtlK1Lw8mA5eb2RsEQ3dTCM4IasNhBKiM49kCtLj7ivD5YwSBUEnH8kLgdXff7u6dwCLgfCrvWCalO3Zl0R/FKQBeBMaFqw2OIJh4WlLiNvVZOBb+ALDB3f8p8tISYGb4eCbweLHbli/uPtfdR7l7A8Fxe8bdrwGeBa4Kq/XrfQRw93eAt81sfFj0WeBVKuhYEgz9TDKzweG/3eQ+VtSxjEh37JYA14WrgSYBu5JDRUXl7rH5AS4BXgP+APxdqduTp336NMGp4xrg5fDnEoIx8uXApvD38FK3NU/7ewHwZPj4JGAlsBl4FBhY6vblYf/OBJrC47kYGFZpxxK4A9gIrAP+HRhYCccSeIhgXqOT4BP+9emOHcEQ0E/DvmgtwaqoordZVwKLiMRUnIaAREQkQgEgIhJTCgARkZhSAIiIxJQCQEQkphQAIiIxpQAQEYkpBYCISEz9f5K5HGskKF9UAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "lr=LinearRegression()\n", "alpha,beta=lr.fit(X[:,:-1],Y)\n", "lr.plot_fit_boundary(X[:,:-1],Y)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.0081038570209815" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "alpha/beta" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现L2正则化项的约束还没有证据近似的结果紧...,有些忧伤....,该问题留到后面再优化..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }