{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 参考\n", "* [机器学习经典算法之-----最小二乘法](http://www.cnblogs.com/armysheng/p/3422923.html)\n", "* [Autograd mechanics](https://pytorch.org/docs/stable/notes/autograd.html)\n", "* [计算图(computational graph)角度看BP(back propagation)算法](https://blog.csdn.net/u013527419/article/details/70184690)\n", "* [PyTorch学习总结(七)——自动求导机制](https://blog.csdn.net/manong_wxd/article/details/78734358)\n", "* [自动求导机制](https://pytorch-cn.readthedocs.io/zh/latest/notes/autograd/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 计算图" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "启发于 cs231n 的 [Deep Learning Hardware and Software](http://cs231n.stanford.edu/syllabus.html),课程在里面介绍了 pytorch 的自动求导机制和动态计算图。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们都知道梯度下降和反向传播需要求解每一层的梯度,以更新权重。[Caffe Layers](https://github.com/BVLC/caffe/tree/master/src/caffe/layers) 采用的机制是对每一层都定义一个 `backward` 和 `forward` 操作,然后在这两个函数中前馈、计算梯度等。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Computational Graph 指的是一系列的操作,包括输出的数据。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "例如课程中的例子,利用 PyTorch 自动求导机制求 x, y 的梯度和 numpy 对比,中间则是计算图,实现的是 $c=\\sum{(x*y+z)}$。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-25-123659.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorFlow 的数据流图的例子,数据经过节点被处理,然后输出,就像水一样流动:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-25-tensors_flowing.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pytorch 是动态图,每一次训练,都会销毁图并重新创建,这样做花销很大,但是更加灵活。** 而 Tensorflow 是静态图,一旦定义训练时就不能修改。Pytorch 合并 caffe2 发布 1.0 版本之后引入静态图,而 Tensorflow 已经发布 [Eager Execution](https://www.tensorflow.org/guide/eager) 引入动态图。但相对来说还是 Pytorch 更加灵活。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Autograd mechanics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先要知道,Pytorch 的数据和参数(权重、偏置等)都是以 Tensor 存储的,Pytorch 的 Tensor 相当于 caffe 的 Blob,Tensorflow 的 Tensor。Pytorch 已经定义了 Tensor 的很多操作,可以在 GPU 上运算加速。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pytorch 会自动跟踪计算图的操作,在计算图执行完成后,调用 `backward` 计算梯度。很久以前数据需要使用 `Variable` 封装,像这样的:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-25-68960-7084a4be66464e40.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "data 存储数据,grad 存储梯度,creator 指向创建者。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在不用这么麻烦了,直接合并 Tensor 和 Variable,只需要 `requires_grad=True` 表明需要计算梯度。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "需要注意: **只有当任意一个输入的 Tensor 不需要计算梯度时,输出才不需要计算梯度;如果有一个需要计算,输出就需要。**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "x = torch.randn(5, 5)\n", "y = torch.randn(5, 5)\n", "z = x + y\n", "z.requires_grad" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.randn(5, 5, requires_grad=True)\n", "y = torch.randn(5, 5)\n", "z = x + y\n", "z.requires_grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "实现自定义的 Tensor 自动求导需要实现 `forward` 和 `backward`,例如 cs231n 的例子:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://tuchuang-1252747889.cosgz.myqcloud.com/2018-11-25-%E5%B1%8F%E5%B9%95%E5%BF%AB%E7%85%A7%202018-11-25%20%E4%B8%8B%E5%8D%889.10.56.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`backward` 之后更新完参数需要清除梯度,`torch.no_grad` 是指在此的计算不创建图的节点。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 实现线性拟合" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "线性拟合是初中的知识,很简单,就是用一条直线拟合一些点,并使得点到直线的距离之和最短。常用最小二乘法求解,这里使用梯度下降(虽然线性拟合有公式直接求解)。点到直线的距离公式:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "d = \\left| \\frac { A x _ { 0 } + B y _ { 0 } + C } { \\sqrt { A ^ { 2 } + B ^ { 2 } } } \\right|\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来我就生成一些点,并假设直线为 $y=ax+b$,那么点到直线的距离为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "d = \\left| \\frac { a x _ { 0 } - y _ { 0 } + b } { \\sqrt { a ^ { 2 } + 1 } } \\right|\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因此定义损失函数为:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "L = \\frac{1}{2N}\\sum_i \\frac{(a x _ { i } - y _ { i } + b)^2}{a^2+1}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "显然可以对 $L$ 直接求导,但是这个公式直接对 $a$ 和 $b$ 求偏导太麻烦,最后的结果很复杂,可能这就是距离公式常用 MSE 和 MAE 的原因。但是我可以用 Pytorch 的计算图自动求导解决。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "生成一些随机点:" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XeYVNX9x/H3dwtd6SA1FBGVjhtgBQXFglhQ0UQTg1Ejxp89ogLGHgTBHguiYiQxNkBBLAjICuhaAKWDoCIsvYPA9vP7494Nw7JlWHbYnTuf1/PwzM6dO8MZx+ezh+98z7nmnENERIIrrqwHICIikaWgFxEJOAW9iEjAKehFRAJOQS8iEnAKehGRgFPQi4gEnIJeRCTgFPQiIgGXUNYDAKhTp45r1qxZWQ9DRCSqzJs3b6tzrm5x55WLoG/WrBlz584t62GIiEQVM/slnPNUuhERCTgFvYhIwCnoRUQCTkEvIhJwCnoRkYBT0IuIBFyxQW9mTcxsppktM7MlZnabf/xBM1tnZt/7f/qGPGeIma0ysxVmdm4k34CIiBQtnD76bOBO59x8MzsGmGdm0/zHnnLOPR56spmdDFwBtAEaAtPN7ATnXE5pDlxEpLxzzmFmZT2M4mf0zrkNzrn5/s97gGVAoyKe0g94yzmX4Zz7GVgFdCmNwYqIRIvZk2Zx/r0T2Dzzi0MfTE2F4cO926PgsGr0ZtYM6AR87R+62cwWmtlYM6vpH2sErA15WhpF/2IQEQmMPelZDHlxOn9K3UP6lm1sH3DdwYGemgq9e8N993m3RyHsww56M6sGTABud87tBl4EWgIdgQ3AE3mnFvB0V8DrDTSzuWY2d8uWLYc9cBGR8mb2pFn0eWgKb69O54ZvJvDRv27lxA2rICXlwEkpKZCZCTk53m3oYxES1l43ZpaIF/JvOOcmAjjnNoU8/jIwxb+bBjQJeXpjYH3+13TOjQHGACQlJR3yi0BEpEykpnrh26sXJCeH9ZQ96Vk8+trnvPlLBi23bWL8p8/TedNKcDlQoYL3Wnl69fKOZWYe+liEFBv05n2T8CqwzDn3ZMjxBs65Df7dS4DF/s+Tgf+a2ZN4X8a2Ar4p1VGLiERCXlklL4RnzCg27Gf9sIXBExaycVc6N3wzkTtm/YdK5ML110PTpof+wkhO9l73MH+ZHIlwZvTdgT8Bi8zse//YUOBKM+uIV5ZZDdwA4JxbYmbvAEvxOnZuUseNiESFgsoqhQTxnvQshn24jLe+XUvLulUZf1p1Oj/7FpDr/ZIYMKDwEE9OPioBn6fYoHfOzaHguvtHRTxnGDDsCMYlInL0hVlW+fyHLQyZsJCNu9O5oWcL7jjrBColxh/1mXq4ysV+9CIi5UIxZZXd6VkMm7KMt+eu5fh61Zhw46l0alrz4OeXo4DPo6AXEQlVSFh/7tfiN+1O5689W3L7Wa28WXx+JfgyN9IU9CIiRSh2Fh+qBF/mHg3a1ExEpBApKzZz7lOzeHfeWm7s1ZIpXRLp9Mbowhc5lUGPfDg0oxeR2FVImWV3ehb/mLKUd+am0apeNV78v+50TFsGZ5x5YLY+c+ahs/Uy6JEPh4JeRGJPaiqMGwevvQbZ2QeVWWau2MzQiYvYtDudG3u15Lbefi3+0XGQkeE9PyPDe37+oC+DHvlwKOhFJLbk1dHT08H5i/IzM9k1cxbD1lU5eBbfpMbhv3457LxRjV5EYkteHT0v5M2Y2aoL52a2Y/y8NG7s1ZIPbulxaMgPGODN/M0OLIiKEprRi0hsCamj76pyDMP+8ijvVGhKq2OrMvryDoXP4pOTvV8S5awsEw4FvYjEFr+OPnPaPIbktmBLJtzUswW39m5FxYQC+uLzPzeKAj6Pgl5EYsqu/Vk8klaF8fuac0L9qrx0WQc6lKQWH0UU9CIS/cJcjTpz+WYGT1zI1l8zuemMltxadTsV//Ni1JViDpeCXkSiWxirUXftz+KRKUsZPy+NE+pX4+UBSbRfuwx6n1XuVrFGgrpuRCS6FbMadebyzZzz1Oe89906bjrD66hp37hGuV3FGgma0YtIdOvVCxISIDfXu+3VC1JT2TVzNg9X78SEtZm0rn+MN4tvXOPg55XDVayRoKAXkeiX1xPvHCxaxGdPj2PIGQPZunM/t5xYlZsHdD+0o6acrmKNBAW9iES3lBSv/OIcuxIq8dC3u5l40RBab1nNK+8Po93Nf4aE3gU/N0rbJQ+Xgl5EoptfgpnRuD1DzrmJbdVqcss347n5izepGB8X6JJMuBT0IhLVdnVI4qFRk5i4NpPWx8TzaueqtNtVCVpfU/R1W2OIgl5EotaMZZsYMnER2/ZmccuZx3NL1W1UODukZTKK9qOJJLVXikjU2bUvi7+98z3XvT6XWlUrMOmm7tx5TmsqzPr8QMtkerq3lbBoRi8i5UgYK1ynL93E0PcWsX1vJrf2bsXNZxxPhQR/zprXaul/OcvYsSrfoKAXkfKimBWuO/dl8vAHS5n43TpOPO4Yxv75t7RtVP3g10hOhmuugZde8oI+J8f7xaGgFxEpBwpaqeoH9PT3PmfINzvYQeKhs/j8BgyA11+PiYVQ4VLQi0j5UMBK1Z37Mnlo7Oe8l5bJiZvTeG3GC7S96FVIOOHA8/KXe2JoIVS4FPQiUj7kC+hp1Vsw9KlZ7NiTwa1fvs3NX7xFBXMHl2IKK/fEyEKocKnrRkTKj+Rkdt52J7f/Uonrx82ldtUKvN+zOn+bO8EL+fylmBjamOxIaEYvIuXGp0s2MvS9xezcl8ltrStx04YvqbC2Flx9tXdC/g6aGNqY7Ego6EWkzO3Ym8mDHyxh0vfrOanBsbyeVJE2l54LGRnerpRxcVCx4qELoFSPD4uCXkTK1NQlG7k3bxbfuxU3nXE8FUY95s3Sc3O9k3JzD+nE+R/V44tVbI3ezJqY2UwzW2ZmS8zsNv94LTObZmYr/dua/nEzs2fNbJWZLTSzzpF+EyISfXbszeTWN7/jhn/Po94xFZl8cw/uOPsEr20yryQT50dUXJxKM0cgnBl9NnCnc26+mR0DzDOzacCfgRnOuRFmNhgYDNwDnAe08v90BV70b0VEgLxZ/CJ27svijrNO4P/OaElifMi8M7QkU7s2bNum0swRKDbonXMbgA3+z3vMbBnQCOgH9PJPex1IwQv6fsA455wDvjKzGmbWwH8dEQm6IrYxCK3Fn9zgWMZd25WTGx5b8OuoJFNqDqtGb2bNgE7A10D9vPB2zm0ws3r+aY2AtSFPS/OPKehFgq6IbQw+mfA5f5+7g50kFjyLl4gJ+7+ymVUDJgC3O+d2F3VqAcdcAa830MzmmtncLVu2hDsMESnPQvva/d0jt+/N5JZ/TuOv3/5KvY1rmfzGIG6rulUhfxSF9V/azBLxQv4N59xE//AmM2vgP94A2OwfTwOahDy9MbA+/2s658Y455Kcc0l169Yt6fhFpDzJ2z0SwDk+mbWEc0ZO55N16fxtzhtMGncHJ69fqYVNR1k4XTcGvAosc849GfLQZMBfxcDVwKSQ4wP87ptuwC7V50VihL975PYq1bnlwrv464X3UD9rH5N71eDW+e+TaKh7pgyEU6PvDvwJWGRm3/vHhgIjgHfM7DpgDXC5/9hHQF9gFbAPuKZURywi5donZ13B3xN7sKtiVf6W+hY3Pnojid1P1cKmMmRec0zZSkpKcnPnzi3rYYjIEdi+N5P7Jy1mysINtK0ez6icZZzUu5tCPYLMbJ5zLqm487QyVkSO2MeLNvD39xezOz2LO88+gb/2aklifJ+yHpb4FPQiUmL5Z/FvVF7BiT+mwXj/6zxdxq9cUNCLSIl8tGgD9/mz+EEnVuaGm/uRuH/fgf1pwLtmqy7lV+bUyCoiB6SmwvDh3m0htv2awU3/nc//vTGfhjUqM+WW07h5w9ckpu8/OOQBsrLUSlkOaEYvIp5iLs4N+Wbx55zADT391a15m5DlbSucJzFRrZTlgIJeRDxFXJx7268Z3D95CR8u3EC7RtX57+XdaH3cMQeem38Tsu++846rRl8uKOhFxFPI1Zo+XLiB+yYtZk/+WXx+2oSs3FLQi8S60N0mQxY1bWvXmfvfmM+Hi7xZ/OP5Z/ESNRT0IrGsoLr8kCHeLP6pWfyans1d57bmhtNbkKBNyKKWgl4kluWry2/9bDYP/FSRDxdtoH3j6oy6rINm8QGgX9EiQRBGW2SB8ury8fFMadOTczLaMm3pJu46tzUTbzyV1j8vLtnrSrmiGb1ItAujLbJQycls/Wga909fzUfZNehQ71hGXd6BE+ofc2SvK+WKZvQi0a6gtsgwOOf4YMF6zp61l+muFnf3ac2EG0/1Qv4IXlfKH83oRaJdIW2RB8l3Hdetv2Zw3/uL+XjxRjo0rn5gFh96/po1By4ioj3ko5qCXiTahS5WKmiv95ASjEtIYMrAodxfI4m9OcbdfVoz8LR8HTWhJZv4eLj+ei18inIKepEgKGqxkl+C2VLxGO4/50Y+rnIKHdb8wONXJtGq1/GFnk9Ojne/aVOFfJRT0ItEk3wlmHC4nj35oE0vHuh5HXsrVOaelNe4ft5kElo/BOf2OPQJ4ZSCJKoo6EWiRQm6YLbsyeC+HyvyyXl30CF7B4+/cSettvxSdIAXVwqSqKOgF4kWRWw6lp9zjskL1vPA5CXsy8zhnj4ncv1pzUno3zK8ANe+NYGioBeJFmGWVLbsyeDv7y9i6pJNdGhSgycub8/x9fyOGgV4TFLQi0SLYkoq+Wfxg887kb/0aK49akRBLxJV8s/I/S9nN3c7nfvSKjJ1ySY6NqnB46GzeIl5CnqRaJWaiuvdm8ktuvHAxubsq3oMQ06uzF/SZhH/YxzUU4lGPAp6kSi1+bM53HvenUxr1Y1O65czKnsDx//1Be1NI4dQ8U4kyjjneP+7dZyd3obPm3dmaMprjJ/wAMfn7tXeNFIgzehFosjmPenc+95ipi3dRKemNRh1vOP4vXWh05+gUyctdJICKehFooBzjknfex01+7NyGNr3RK7r0YL4V16GV1/1ZvEVK8Itt8D330P//irbyP8o6EXKuc2707n3fW8W37lpDUZd3oGWdat5HTc33QTZ2d6JGRnw5JPgHMyeDe3aKewFUNCLlFvOOd7/fh0PTl5KelYO9/Y9iWt7NCc+zrwTUlIgN/fgJ+Xmen+KWTkrsUVfxopEWgku87d5dzrXj5vHHW8voGXdqnx022lcf3qLAyEPXg2+YkWIi/P2jR80yLsfH68avRyk2Bm9mY0FLgA2O+fa+sceBK4HtvinDXXOfeQ/NgS4DsgBbnXOTY3AuEWiw2FuROa+/JL3pi3gwezfkOHs0Fl8qIJWyl58sTYjk0OEU7r5F/AcMC7f8aecc4+HHjCzk4ErgDZAQ2C6mZ3gnMsphbGKRJ/D2Ihs88wvGPridKa3SOKU9csYeVVXWp7eoujXz79SVnvZSAGKLd0452YB28N8vX7AW865DOfcz8AqoMsRjE8kuuVtRFZEOcU5x8T5aZw1dSuzm7Tj75+9wjtvDqHlvNlHfbgSTEfyZezNZjYAmAvc6ZzbATQCvgo5J80/JhKbitmIbPPudIa+t4jpyzZzSq1KjHz6JlpuXq0au5Sqkgb9i8AjgPNvnwCuBQooJOIKegEzGwgMBGjatGkJhyESBQoop+Svxf/9/JO4pntz4k/7t2rsUupKFPTOuU15P5vZy8AU/24a0CTk1MbA+kJeYwwwBiApKanAXwYiQbTpsy+4d7RXi09av4yRf+pKi9P8Wrxq7BIBJWqvNLMGIXcvARb7P08GrjCzimbWHGgFfHNkQxSJYiGtlc45JsxL4+xPtzKnSTvumzGGt98cQou5qsVLZIXTXvkm0AuoY2ZpwANALzPriFeWWQ3cAOCcW2Jm7wBLgWzgJnXcSMwKaa3cVLM+Q4e+xoxNWSTVqsSop26k+ZY1qsXLUVFs0Dvnrizg8KtFnD8MGHYkgxIJhJQUXGYmE07qycO9B5K5KYP7LmjLn09tRvxpb6gWL0eNtkAQiZCNXU9naP/7+az5Kfx23TJGDuhG8x7NvQdVi5ejSEEvUsqcc0yYv46HZv1K1vGncF/ieq658VTiTj21rIcmMUpBL1Ia/Gu3bux6OkN+SWTmii10aVaLkZe1p1mdqmE9V2UciRQFvUhJ5QV07dq4229nfKsePLy5JVlVqnL/BSfz51ObEVfQHjX5X+Mw9sIRKQkFvUhJ5AV0RgYbq9ViyPl3M7Plb+mydgkjT4qnWY/zw3udw9gLR6SkFPQiJZGSgsvI4N02Z/LImX8hKz6BB2aM4eol04m7eXr4r5O3F44u/ycRpKAXKYGNXU9nSP/7mdkiiS5rFzPy42dp9tu2MGP64c3Ii9kLR6Q0KOhFDoNzjnfnpfHIrF/JbtmZB2a8zNVzPyAuIb7k12lVq6VEmIJeJEwbdu1n8IRFfP7DFro29zpqftN8A3z3oVdjv/12XadVyiVdSlCkGM453vl2Lec8OYtvft7OQxe14c3ru/Gb2lVh2zbvYty5ud7FuR988LAuGShyNGhGL1KE9Tv3M2Rivll87ZC++LwvUzMyvLCfPh1mz1abpJQrmtGLFMA5x9vfruHcp/LN4n9YePCFvvO+TD3rLO8i3bm5B9okRcoJzehF8lm/cz+DJy5ilj+LH3VZB5rWrlL44qbkZK9kM3u22iSlXFLQS3Ad5tYCzjnembuWf0xZRo5zPNyvDVd1/c2B1a1FLW5Sm6SUYwp6CabD3Fpg3c79DJ6wkNkrt9KtRS1G9vdn8aGKW9ykNkkppxT0Ekxhbi3g1eLX8o8Pl5HrHI/0a8MfQ2fxoTRrlyiloJdgCmNrgXUz5jB4Zhqzs4+hWwuvFt+kVpVDzjuIZu0ShRT0EkxFzL6dc7z1VgrDvtlKriXwyJwx/PGie4grLuRFopSCXoKrgNn3gVr8PrptXMWoj5+hya5N8O+GoAuDSEAp6CUmOOd469u1DMurxe9fxB/fupc4XFkPTSTiFPQSeKEdNcktajPysvY0WVEdRidCVhYkJsKAAWU9TJGIUdBLYB0yi29fhT+unkncCueVdFJS1EEjMUFBL8GRmgrjxgGQdtkfGfxzInNWbeXUlrV5rEU2TS48p+BVrSIBp6CXYEhNhZ49cVlZ/LdDHx6tvBEqV2bYJW35Q5em2IgRumSfxCwFvQTDuHGkVa7B4EtvZU6zTnRf/T0j2lWkSVf/2q26ZJ/EMAW9RD3nHP+t0JRHr30egGGfPMcfFk/Dbpt94CStapUYpqCXqJa2Yx/3TFjIF5Xb033NAkZ89CxNft0KL7xwaJirJi8xSkEvUck5x3+/WcOjHy4D4NFL2nFlTk2snWnGLpKPgl7KjzC3FV67fR/3jJ3Nl1uz6VE3gRHXnkbjmlWAplrdKlIABb2UD2FsK5yb63jjmzUM/2Axtn8/j6aM5crlKdipumyfSFGKvZSgmY01s81mtjjkWC0zm2ZmK/3bmv5xM7NnzWyVmS00s86RHLwESP5thceNO+iSfWu37+OqV7/mvvcX09n2MPVft/CH7z7GdNk+kWKFM6P/F/AcMC7k2GBghnNuhJkN9u/fA5wHtPL/dAVe9G9Fihba/piQAGPHQk4OuRUq8sYrUxi+PIM4M4Zf2o4rstOwJ3dBfLxaJUXCUGzQO+dmmVmzfIf7Ab38n18HUvCCvh8wzjnngK/MrIaZNXDObSitAUtAhbY/rlkDL7/M2mp1uLvvbaQu3Mdpreowon97GtWoDDRVq6TIYShpjb5+Xng75zaYWT3/eCNgbch5af4xBb0U/2Wr3/6Y++WXvDF/A8O7X0UcjuEdqnDFFV0ws0POFZHilfaXsQVcf63gfWDNbCAwEKBp06alPAwpd8K8huva7fu4e3EcqWdcz2kJexhxRmMa9e5RBgMWCY6SBv2mvJKMmTUANvvH04AmIec1BtYX9ALOuTHAGICkpCRtCh50oV+2pqd7X7aGBH1uruM/b6YwYtEe4hLiGXFpO37/2yYHz+JFpESK7bopxGTgav/nq4FJIccH+N033YBdqs8L4JVrEvx5hXPel61+R82abfv4w5Ofcv+ifZzy80KmjrmRK3LWHRryqakHdeKISHiKndGb2Zt4X7zWMbM04AFgBPCOmV0HrAEu90//COgLrAL2AddEYMwSjZKT4Zpr4KWXvKDPySF3Zgr/oQEjPl5OfGYmj33yAr9bMBWLjz90d8kwSz8icqhwum6uLOSh3gWc64CbjnRQElADBsDrr0NmJmvqNObuCp35atISTmtVh8eaZ9Hw2VmFt0zm77PXNsMiYdPKWDl6kpPJnT6d/3y6mBFZjYjfA4/1b8fvkvxafFEtk9pmWKTEFPRy1KzZto+7Fhlfpzfi9BPqMuLSdjSsUfnACUW1TGqbYZESU9BLxOXmOv791S+M+Hg5CXHGyP7tuTyp8eF31Kh3XqREFPQSUb9s28vdY+fw9bZsetZLZPi1px08ixeRiFPQS0Tk5jrGpa7msY+WkrBvHyNnvsLlP8zGkgvolglze2IRKRkFvZS6X7bt5a7xC/nm5+30zNrM8LH30HD3Fq+jRm2TIkedgl5Kzf9m8Z+s8GrxHaty+dV/9bYSBm/BlNomRY46Bb2UitBZfK/WdRl+aTsavPC0F+AAZt6CKbVNihx1Cno5Irm5jtdTVzPykxUkxBujLmvPZaf4HTX5Q3zAgENfQG2TIhGnoJcSW711L3ePX8g3q71Z/IhL23Nc9UoHTgg3xNU2KRJRCno5bLm5jn99uZqRU5eTGB/H45d3oH/nRtjLL8OECdC/Pwwc6J2sEBcpcwp6OSyrt+7lrvEL+Hb1Ds5oXZfhebP4MWPghhu8kz791LvNC3sRKVMKegmrjz33yy/516eLGZnViMTEhAOz+LzVrRMmHPyECROgXTvV3kXKAQV9rAujj/3naXO4e9xXfNvoJM74eT7DbziT405pfPDr9O9/YCYP0LGj+uNFygkFfawroo89x6/Fj5qxgwq1m/DElCe5dPnnWMcEOLP7wa+TV6bJq9Fv26b+eJFyQkEf6wrpY/95617uencBc3/ZwZn1KzL80b9Sf8emonvdBw48EPipqeqPFyknFPSxLl8LZE7Xbrw2+yce/3QFFeLjeOLyDlzauRHWbbx3Tu3a3m3ec8N8Xc3mRcqOgl7+1wL505ZfufulVOb+soPe9RN5dN8C6mceC9b4QFAfTt1drZUi5YKCXsjJdbz2xc+MmrqCiglxPNm5Kpf8ua+3R82wkEDXvjQiUSmurAcgZeunLb/y+5dS+ceHy+hxfB2m/a0nl/6Y6oV8aKDDgXp+Ydd1FZFySTP6WFBAn3zoLL5SYjxP/b4DF3dsVPAeNXmBrrq7SFRS0AddAX3yPx7fjrveXcD8NTs566R6PHpJO+odG+YeNaq7i0QdBX0Qhc7gQ+rqOVnZjJ26mMc/3EmlxHie/n1H+nVsWPC1WxXoIoGhoA+a/DP4p5+GChX4sWpd7up7G/PTG3LWSXV59JK2B8/iRSSwFPRBk68zJmfrNl59aQpPLNnrzeIv7VDwLF7XbRUJLAV90IR8kbqqfjPuiu/Ed0v2c9ZJDQqfxeu6rSKBpqAPmuRkcqZP59VPl/B4ZkOqZMTxzBXtuKhDIbV4UH+8SMAp6ANm1eZfuet7x3fpDTn75PoMu6Qt9Y4pphav67aKBJqCPiBych2vzvmJxz9ZThWXzTOnVOeiy04pfBYfSv3xIoGmoC/vwviSdNXmX7lr/AK+W7OTc378hn9MfZ562fug8WHU2tVOKRJYRxT0ZrYa2APkANnOuSQzqwW8DTQDVgO/c87tOLJhxqhiviTNyXW8Mvsnnpj2A1UqxPNM5TVc9N4wLCfH26ZAtXYRoXT2ujnDOdfROZfk3x8MzHDOtQJm+PelJAr6ktS3avMe+r/4JcM/Xk6vE+ry6R2n0++cjpj2ohGRfCJRuukH9PJ/fh1IAe6JwN8TfAV8SZqT63h59k88mTeLv6LjgY6a5GRvgVTeVZ40mxcRjjzoHfCpmTngJefcGKC+c24DgHNug5nVO9JBxqx8X5KuatmWQS9+yfdrd3Jum/o8cnG+jprUVLj9du8Xw+zZ3sW5FfYiMe9Ig767c269H+bTzGx5uE80s4HAQICmTZse4TACLDmZ7FzHy1OX8FTmdqomGP+s/AsXtKyJ5W+bVD+8iBTgiILeObfev91sZu8BXYBNZtbAn803ADYX8twxwBiApKQkdyTjiFphdNSsnDqHQW/OZcFxreizMpVHZr5M3d3b4LECVrCqH15EClDioDezqkCcc26P//M5wMPAZOBqYIR/O6k0Bho4xXTUZOfkMmb2Tzw9cwdVj63PPyc9xgUr5mAAzhU8Y1c/vIgU4Ehm9PWB9/wFOQnAf51zn5jZt8A7ZnYdsAa4/MiHGUBFlFlWbtrDoHcXsCBtF30aVOSRhwdSd9cWSEwEM8jOLnzGrn54EcmnxEHvnPsJ6FDA8W1A7yMZVEwooMzyv1n8tJVUrRjPP6/sxAXtG2BdJh6YpYNm7CJyWMy5si+PJyUlublz55b1MI6+kBr9yhZtGfTaFyzYmcN5DRJ55Lqe1KlWsaxHKCLlmJnNC1nDVChtgVCWkpPJ7tKVl2b9xDPPzKLa3t08N300F/z8LXTRVsEiUjoU9GXoB78WvzBtF30TdvLwa7dSZ892bV8gIqVKQV8GsnNyvVn89JVUq5TA83/ozPlffQDpeyAuTq2RIlKqFPSRUkiP/A9TZzNo1kYW5lShb7vjeLhfW+osmu+taM3J8YL+6ac1mxeRUqOgj4QCeuSzu3Tlpf+k8Mzi3VTLyOL5lCc4v98wqFbxQKtlbq7XPrltW1m/AxEJEAV9JOTrkV8xI5VBc3NYtG4/56/8ioc/fZHamXsP1OG1olVEIkhBHwl+cGdnZTM6+Xc8s7c1x2bu5vkPH+f8xSneOaGBrhWtIhJBCvpISE5mxXtTGTRrI4tyqnB+uwY8/PN0ai+b7T1uBtdee+j2BQp4EYkABX0py8rJZXTKjzw7aw/HVqrBCydXoO8PU6Fu7YPLMwMGlPVQRSRGKOhL0fKuelqUAAAIcUlEQVSNuxn07gIWr9vN+e0b8HDD/dTue/aBcH/6ae+LVpVnROQoUtAfqdRUsmamMLphV55dmc6xlRJ54Y+d6duuAQwffvDGZdu2wZAhZT1iEYkxpXHN2OBLTfVCOzX1kOPLf3cNl/xQhSeW76fPcYlM+1tPL+ThQDeNruEqImVIM/riFLJvfFZOLqM/WcqzV4ygevqvjJ40nD4DzoeqZx94rrppRKQcUNAXp4B945c1a8Nd4xewOOM4Llw1h4emj6ZWTgb0GnXo89VNIyJlTEFfnJDFTFmVKvNig67887k5VK+cyOirOtNnTy3ocqxm7CJSbinoi+OXX5bN+IpB8SexZPl+LuzQkIcuakOtqhWABt45eXV8Bb6IlDPBD/owLsBdlKycXF7YW4d/7mtNjSrxjL6qA33aNjj07yji+q8iImUp2EF/hAG8dL3XF790w24u8mfxNatWOPTEIq7/KiJS1oLdXllQAIchKyeXZ6av5KLn5rB5TwajrzqFZ5ulU/PZJw5tsQS1UYpIuRbsGf3h7Arpl3iWdurBoBWwdMNu+nVsyIMXtqHmwnlF/8tAbZQiUo4FO+jDDeDUVDLPPpcXOl3Ic9tPpkaVRF760ymc2+Y47/FwSjNqoxSRcirYQQ9hBfCSGV9x1+XDWFq/BRcvTeGBrnWp2abvgRO0X7yIRLHgB30RMrNzeX7mKp7fewI1qu1kzHvDOGft9zBkxsEnqjQjIlEsNoM+NZUl/3mfQVU6sCy+Ohd3bMiDDWpQo+GF0OuJgoNcpRkRiVIxF/SZX3zJ80NH83yX/tTYv4cxn43gnH4jIbk79Oxe1sMTESl1MRX0S9bvYtCH61mW/HsuWfwZD8wYQ42Mvep7F5FAi4mgz8zO5bmZq3hh5ipqJlbh5Xce5ezlX3gPVqyoL1dFJNACH/SL1+1i0LsLWL5xD5d0asQDF55MjT41Ydw474QBAzSbF5FAC2zQ/28W/9lKapLFy11qcvalHb0H875YzdsHJ++YiEgARSzozawP8AwQD7zinBsRqb8rv9BZ/KXLPuf+GWOokZt58IpWbUQmIjEiIkFvZvHA88DZQBrwrZlNds4tjcTflzczzzy9J8/trc3zKT9Su2oFXqnyM2d9+KS3ojU+/uAvXbURmYjEiEjN6LsAq5xzPwGY2VtAP6D0g96fmWdl53DJVY+zpF4LLu3ciAcuaEP1BVVhRCErWrXaVURiRKSCvhGwNuR+GtA1In+TPzNPzMnh4qUp3NHMOOt353uPFbWiVatdRSRGRCrorYBj7qATzAYCAwGaNm1a8r8pZGZ+/cKP4alBBz9e1IpWrXYVkRgQqaBPA5qE3G8MrA89wTk3BhgDkJSUdNAvgcOimbmISJEiFfTfAq3MrDmwDrgC+EOE/i7NzEVEihCRoHfOZZvZzcBUvPbKsc65JZH4u0REpGgR66N3zn0EfBSp1xcRkfAE+5qxIiKioBcRCToFvYhIwCnoRUQCTkEvIhJwCnoRkYAz50q+KLXUBmG2BfillF6uDrC1lF6rPIuV9wmx815j5X1C7LzXSL/P3zjn6hZ3UrkI+tJkZnOdc0llPY5Ii5X3CbHzXmPlfULsvNfy8j5VuhERCTgFvYhIwAUx6MeU9QCOklh5nxA77zVW3ifEznstF+8zcDV6ERE5WBBn9CIiEiIwQW9mfcxshZmtMrPBZT2e0mRmTcxsppktM7MlZnabf7yWmU0zs5X+bc2yHmtpMLN4M/vOzKb495ub2df++3zbzCqU9RhLg5nVMLPxZrbc/2yTg/iZmtkd/v+3i83sTTOrFJTP1MzGmtlmM1sccqzAz9A8z/oZtdDMOh+tcQYi6M0sHngeOA84GbjSzE4u21GVqmzgTufcSUA34Cb//Q0GZjjnWgEz/PtBcBuwLOT+Y8BT/vvcAVxXJqMqfc8AnzjnTgQ64L3nQH2mZtYIuBVIcs61xbs+xRUE5zP9F9An37HCPsPzgFb+n4HAi0dpjMEIeqALsMo595NzLhN4C+hXxmMqNc65Dc65+f7Pe/ACoRHee3zdP+114OKyGWHpMbPGwPnAK/59A84ExvunBOV9HgucDrwK4JzLdM7tJICfKd51LyqbWQJQBdhAQD5T59wsYHu+w4V9hv2Acc7zFVDDzBocjXEGJegbAWtD7qf5xwLHzJoBnYCvgfrOuQ3g/TIA6pXdyErN08DdQK5/vzaw0zmX7d8PymfbAtgCvOaXqV4xs6oE7DN1zq0DHgfW4AX8LmAewfxM8xT2GZZZTgUl6K2AY4FrJzKzasAE4Hbn3O6yHk9pM7MLgM3OuXmhhws4NQifbQLQGXjROdcJ2EuUl2kK4ten+wHNgYZAVbwSRn5B+EyLU2b/Lwcl6NOAJiH3GwPry2gsEWFmiXgh/4ZzbqJ/eFPeP/38281lNb5S0h24yMxW45XfzsSb4dfw/9kPwfls04A059zX/v3xeMEftM/0LOBn59wW51wWMBE4lWB+pnkK+wzLLKeCEvTfAq38b/Ir4H3ZM7mMx1Rq/Dr1q8Ay59yTIQ9NBq72f74amHS0x1aanHNDnHONnXPN8D7Dz5xzfwRmApf5p0X9+wRwzm0E1ppZa/9Qb2ApAftM8Uo23cysiv//cd77DNxnGqKwz3AyMMDvvukG7Mor8USccy4Qf4C+wA/Aj8C9ZT2eUn5vPfD+ibcQ+N7/0xevfj0DWOnf1irrsZbie+4FTPF/bgF8A6wC3gUqlvX4Suk9dgTm+p/r+0DNIH6mwEPAcmAx8G+gYlA+U+BNvO8esvBm7NcV9hnilW6e9zNqEV4n0lEZp1bGiogEXFBKNyIiUggFvYhIwCnoRUQCTkEvIhJwCnoRkYBT0IuIBJyCXkQk4BT0IiIB9/9bVuUpNvxxDwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import random\n", "import matplotlib.pyplot as plt\n", "# 在0-2*pi的区间上生成100个点作为输入数据\n", "X = np.linspace(0, 100, 100,endpoint=True)\n", "a, b = 2.5, 1.0\n", "Y = a*X+b\n", "\n", "# 对输入数据加入gauss噪声\n", "# 定义gauss噪声的均值和方差\n", "mu = 0\n", "sigma = 2\n", "Nx, Ny = X.copy(), Y.copy()\n", "for i in range(X.shape[0]):\n", " Nx[i] += random.gauss(mu,sigma)\n", " Ny[i] += random.gauss(mu,sigma)\n", "\n", "# 画出这些点\n", "plt.plot(X, Y)\n", "plt.scatter(Nx, Ny, marker='.', color='r')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来就是一系列计算然后更新 $a$ 和 $b$ 了。" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFlJJREFUeJzt3X+snNWd3/H3d+b6+hcktmPDsthZk413EzZSCGsRaFItDV1jUFXSKpGCqmBRJHdXRJutVmpJ+wdt0pV2pXaTIu2i0MUFqjQsm2SLlZJQ10WK0k0Il90UTIDYCQRuDNhg8yMYjK/vt3/Mufas78w89vW9zPU975c0embOnGfmnPuQfHye85x5IjORJKlba9gNkCTNP4aDJGkaw0GSNI3hIEmaxnCQJE1jOEiSpjEcJEnTGA6SpGkMB0nSNCPDbsBMrV69OtevXz/sZkjSGeXhhx9+MTPXNNU7Y8Nh/fr1jI2NDbsZknRGiYifnUw9TytJkqYxHCRJ0xgOkqRpDAdJ0jSGgyRpGsNBkjSN4SBJmqa6cLjj/z7F9v+3d9jNkKR5rbpw+MqDz/DtXc8NuxmSNK9VFw6tCI5O5rCbIUnzWn3h0ArMBkkarL5wCJg0HSRpoOrCod0KjqbhIEmDVBcOrfC0kiQ1qTAcPK0kSU2qC4d2K5j0tJIkDdQYDhGxLiIeiIjHI+KxiPhsKV8VETsiYnfZrizlERG3RMSeiHgkIi7u+qwtpf7uiNjSVf6bEfFo2eeWiIi56Gz5Li9llaQGJzNymAD+IDPfD1wK3BgRFwI3ATszcwOws7wGuArYUB5bgVuhEybAzcCHgUuAm6cCpdTZ2rXf5tPvWm/tCBw4SNJgjeGQmc9l5t+U568BjwPnA9cAd5ZqdwIfL8+vAe7Kju8DKyLiPOBKYEdmHsjMg8AOYHN57x2Z+b3MTOCurs+ada0WXq0kSQ1Oac4hItYDHwIeBM7NzOegEyDAOaXa+cCzXbuNl7JB5eM9ynt9/9aIGIuIsf37959K04/pXK1kOEjSICcdDhFxFvB14Pcz89VBVXuU5QzKpxdm3paZGzNz45o1a5qa3FMrwquVJKnBSYVDRCyiEwxfycxvlOIXyikhynZfKR8H1nXtvhbY21C+tkf5nHARnCQ1O5mrlQK4HXg8M/+k663twNQVR1uAe7vKrytXLV0KvFJOO90PbIqIlWUiehNwf3nvtYi4tHzXdV2fNes6I4e5+nRJWhhGTqLOR4BPA49GxA9L2b8B/gi4JyJuAJ4BPlneuw+4GtgDHAKuB8jMAxHxBeChUu/zmXmgPP9d4A5gKfCt8pgTrcA5B0lq0BgOmfldes8LAFzRo34CN/b5rG3Ath7lY8AHmtoyG1wEJ0nNqlsh7f0cJKlZfeHQchGcJDWpLxzCRXCS1KS6cGi7CE6SGlUXDuGlrJLUqLpwaLdwQlqSGlQYDp5WkqQm1YVDOOcgSY2qC4e295CWpEbVhUMrnHOQpCb1hYNzDpLUqL5w8H4OktSounDoXK007FZI0vxWXTiEP58hSY2qC4d2BGk4SNJA1YVDy0tZJalRheHgneAkqUl14UB4PwdJalJdOLTKDU+dd5Ck/ioMh046OO8gSf1VGA6drfMOktRfdeEQx0YOhoMk9VNdOEydVjIbJKm/CsOhs3XkIEn9VRgOTkhLUpPqwiEcOUhSowrDocw5TA65IZI0j1UXDs45SFKzCsOhjByG3A5Jms8qDIfO1pGDJPVXXTi4CE6SmlUXDi6Ck6RmFYZDZ+vIQZL6qzAcXAQnSU2qC4dji+BMB0nqq7pwcM5BkppVFw7+fIYkNWsMh4jYFhH7ImJXV9m/i4ifR8QPy+Pqrvc+FxF7IuLJiLiyq3xzKdsTETd1lV8QEQ9GxO6I+IuIGJ3NDp7IRXCS1OxkRg53AJt7lH8xMy8qj/sAIuJC4FPAb5R9/iwi2hHRBv4UuAq4ELi21AX44/JZG4CDwA2n06EmjhwkqVljOGTmd4ADJ/l51wB3Z+bhzHwK2ANcUh57MvOnmfkWcDdwTXRWpH0M+FrZ/07g46fYh1NyfM7BcJCkfk5nzuEzEfFIOe20spSdDzzbVWe8lPUrfxfwcmZOnFA+Z7yUVZKazTQcbgV+FbgIeA74T6U8etTNGZT3FBFbI2IsIsb2799/ai0uXAQnSc1mFA6Z+UJmHs3MSeC/0DltBJ1/+a/rqroW2Dug/EVgRUSMnFDe73tvy8yNmblxzZo1M2n68d9W8n4OktTXjMIhIs7revlPgKkrmbYDn4qIxRFxAbAB+AHwELChXJk0SmfSent2Tvw/AHyi7L8FuHcmbTpZjhwkqdlIU4WI+CpwObA6IsaBm4HLI+IiOqeAngb+BUBmPhYR9wA/AiaAGzPzaPmczwD3A21gW2Y+Vr7iXwN3R8R/AP4WuH3WeteDi+AkqVljOGTmtT2K+/4feGb+IfCHPcrvA+7rUf5Tjp+WmnOtMlZy5CBJ/dW3Qhrv5yBJTeoLhzLnYDRIUn/VhYOL4CSpWbXh4CI4SeqvwnDobL2fgyT1V104hCMHSWpUXThMjRycc5Ck/uoLh5YjB0lqUl84+PMZktSounA4PudgOEhSP/WFQ9maDZLUX3XhcPwe0qaDJPVTbTh4PwdJ6q+6cAgnpCWpUXXh4M9nSFKz+sKh9NhFcJLUX33h4MhBkhpVGA6drXMOktRfdeHgIjhJalZfOJSt2SBJ/VUXDi6Ck6Rm1YaDi+Akqb/qwsFFcJLUrLpwmLqfg9kgSf3VFw6OHCSpUYXh4CI4SWpSXTg45yBJzaoLh2OXshoOktRXteHgaSVJ6q+6cJhaIe1pJUnqr7pwOH5aacgNkaR5rLpwiNJjRw6S1F914eDIQZKaVRgOna0jB0nqr8Jw8GolSWpSXTi4CE6SmlUXDi6Ck6Rm1YaDp5Ukqb/GcIiIbRGxLyJ2dZWtiogdEbG7bFeW8oiIWyJiT0Q8EhEXd+2zpdTfHRFbusp/MyIeLfvcElM3eZ4jTkhLUrOTGTncAWw+oewmYGdmbgB2ltcAVwEbymMrcCt0wgS4GfgwcAlw81SglDpbu/Y78btmVThykKRGjeGQmd8BDpxQfA1wZ3l+J/DxrvK7suP7wIqIOA+4EtiRmQcy8yCwA9hc3ntHZn4vO5MAd3V91pyJwIUOkjTATOcczs3M5wDK9pxSfj7wbFe98VI2qHy8R/mcakU4cpCkAWZ7QrrXfEHOoLz3h0dsjYixiBjbv3//DJvYmXdwzkGS+ptpOLxQTglRtvtK+TiwrqveWmBvQ/naHuU9ZeZtmbkxMzeuWbNmhk3vzDs4cpCk/mYaDtuBqSuOtgD3dpVfV65auhR4pZx2uh/YFBEry0T0JuD+8t5rEXFpuUrpuq7PmjOtcJ2DJA0y0lQhIr4KXA6sjohxOlcd/RFwT0TcADwDfLJUvw+4GtgDHAKuB8jMAxHxBeChUu/zmTk1yf27dK6IWgp8qzzmVGfOwXCQpH4awyEzr+3z1hU96iZwY5/P2QZs61E+BnygqR2zyQlpSRqsuhXS0LmU1ZGDJPVXZzjgMgdJGqTKcGi1nHOQpEHqDAcnpCVpoGrD4ejksFshSfNXleEw0gomvVxJkvqqMhzarWDCcJCkvqoMh1bLS1klaZAqw6EdwVFHDpLUV5Xh0GoFRx05SFJfVYZDO5yQlqRB6gyHlqeVJGmQKsPBRXCSNFiV4eDIQZIGqzIcOhPSw26FJM1fVYZDO3BCWpIGqDMcPK0kSQNVGQ6tcJ2DJA1SZTi0/eE9SRqo2nDwh/ckqb9qw8F1DpLUX53h4A/vSdJAVYZDy6uVJGmgKsOh7c9nSNJAdYaDIwdJGqjKcGi1ArNBkvqrMhzagSMHSRqgynBwQlqSBqsyHEYMB0kaqM5waLeYmJwcdjMkad6qMhwWtYIj3tBBkvqqMhxG2i0mjjpykKR+Kg2H4IhzDpLUV5XhsKjlyEGSBqkyHEbanUVw3tNBknqrMhwWtTvdPuIVS5LUU5XhMNIKACa8YkmSeqoyHI6NHJx3kKSeTiscIuLpiHg0In4YEWOlbFVE7IiI3WW7spRHRNwSEXsi4pGIuLjrc7aU+rsjYsvpdanZonZn5OBaB0nqbTZGDv8gMy/KzI3l9U3AzszcAOwsrwGuAjaUx1bgVuiECXAz8GHgEuDmqUCZKyNl5OAqaUnqbS5OK10D3Fme3wl8vKv8ruz4PrAiIs4DrgR2ZOaBzDwI7AA2z0G7jnHOQZIGO91wSOB/RcTDEbG1lJ2bmc8BlO05pfx84NmufcdLWb/yaSJia0SMRcTY/v37Z9xo5xwkabCR09z/I5m5NyLOAXZExBMD6kaPshxQPr0w8zbgNoCNGzfO+J/9I2XOYcJ1DpLU02mNHDJzb9nuA/6KzpzBC+V0EWW7r1QfB9Z17b4W2DugfM5MjRzemnDkIEm9zDgcImJ5RJw99RzYBOwCtgNTVxxtAe4tz7cD15Wrli4FXimnne4HNkXEyjIRvamUzZkli9oAHDYcJKmn0zmtdC7wVxEx9Tn/PTO/HREPAfdExA3AM8AnS/37gKuBPcAh4HqAzDwQEV8AHir1Pp+ZB06jXY2WjHQy8c0jR+fyayTpjDXjcMjMnwIf7FH+EnBFj/IEbuzzWduAbTNty6laOtoZObzxluEgSb1UuUJ6aTmt9OaE4SBJvVQZDlNzDo4cJKm3qsPBOQdJ6q3KcDg252A4SFJPVYbD8auVvJRVknqpMhxG2i0WtcORgyT1UWU4QGfewQlpSeqt6nA47KWsktRTteGw1JGDJPVVdzg45yBJPVUbDktG27zh1UqS1FO94TDSchGcJPVRbTgsHXXOQZL6qTYc3rl0Ea+8cWTYzZCkeanacFi5bJSDh94adjMkaV6qNhxWLFvEa29OcOSok9KSdKJqw2HV8lEAXj7kqSVJOlG14bBi2VQ4eGpJkk5UbTisXLYIgIOOHCRpmorDoTNycFJakqarNxzKnMOB1w0HSTpRteGw5qzFtAKee+XNYTdFkuadasNhdKTFL71jCeMHDg27KZI071QbDgBrVy3j2YOGgySdqO5wWLmU8YNvDLsZkjTvVB0O7161jOdffdNfZ5WkE1QdDu/7pbPJhCeef23YTZGkeaXqcPiNX34nALt+/sqQWyJJ80vV4bB25VJWLFvEo+OGgyR1qzocIoJL1q/iu3teJDOH3RxJmjeqDgeAv/9ra/j5y2/wk/2vD7spkjRvVB8Omy48l3Yr+PrfjA+7KZI0b1QfDue+Ywkfe985/OXYs7w14Y1/JAkMBwCuu+xXePEXb3HX954edlMkaV4wHICPvnc1l//6Gr70v3fz1IvOPUiS4UDnqqUvXPMBFrWDG+58iL0v+5MakupmOBTrVi3jy5/eyP5XD/NP/+yv+c6P9w+7SZI0NPMmHCJic0Q8GRF7IuKmYbThkgtWcc/vXMay0TbXbfsBn779QR54Yp8T1ZKqE/Nh8VdEtIEfA78NjAMPAddm5o/67bNx48YcGxubk/a8eeQod/z109z+3afY/9phzl4ywkffu5qL1q3gg+tW8J41y1lz1mIiYk6+X5LmSkQ8nJkbm+qNvB2NOQmXAHsy86cAEXE3cA3QNxzm0pJFbX7nt36V6z+ynu/ufpFv73qe7z/1Et/a9fyxOstG26xbuYzVZ4+yavli3rV8lFXLRzl7yQjLRtssHR1h+WibpaNtlo2OsGRRi5FWi0XtYKTdYlGrsx1pB4tane1IKwwcSfPCfAmH84Fnu16PAx8eUluOWTzS5or3n8sV7z8XgBd/cZjH9r7Kz156nZ+9dIhnDhzipV8c5tGDL/PSL97itcMTp/2d7VYQQCsCAloBQXS2EZRiWl31Yuq9rtetEjL9smZQBnU+6VT36fc9/Xfq+84sf4+00PzP3/soi0fac/od8yUcev0ve9r5rojYCmwFePe73z3XbZpm9VmL+a1fWwOs6fn+4YmjvH74KIfemuCNt45y6NhjgjePTDIxOcnE0WRicpIjR5OJo5NMTObfeT4xOUlmp/OT5clkJpkwmZB0nmfmsdeTybGynKo/tX8vA84k9ntr0OnH/vu8Pd8zqD/SQtTvH3Czab6Ewziwruv1WmDviZUy8zbgNujMObw9TTt5i0faLB5ps2r56LCbIkmnZb5crfQQsCEiLoiIUeBTwPYht0mSqjUvRg6ZORERnwHuB9rAtsx8bMjNkqRqzYtwAMjM+4D7ht0OSdL8Oa0kSZpHDAdJ0jSGgyRpGsNBkjSN4SBJmmZe/PDeTETEfuBnM9x9NfDiLDbnTGCf62Cf63A6ff6VzOz9Mw9dzthwOB0RMXYyv0q4kNjnOtjnOrwdffa0kiRpGsNBkjRNreFw27AbMAT2uQ72uQ5z3ucq5xwkSYPVOnKQJA1QVThExOaIeDIi9kTETcNuz2yJiHUR8UBEPB4Rj0XEZ0v5qojYERG7y3ZlKY+IuKX8HR6JiIuH24OZi4h2RPxtRHyzvL4gIh4sff6L8hPwRMTi8npPeX/9MNs9UxGxIiK+FhFPlON92UI/zhHxL8t/17si4qsRsWShHeeI2BYR+yJiV1fZKR/XiNhS6u+OiC2n06ZqwiEi2sCfAlcBFwLXRsSFw23VrJkA/iAz3w9cCtxY+nYTsDMzNwA7y2vo/A02lMdW4Na3v8mz5rPA412v/xj4YunzQeCGUn4DcDAz3wt8sdQ7E/1n4NuZ+T7gg3T6vmCPc0ScD/wesDEzP0DnJ/0/xcI7zncAm08oO6XjGhGrgJvp3GL5EuDmqUCZkc6tJRf+A7gMuL/r9eeAzw27XXPU13uB3waeBM4rZecBT5bnXwau7ap/rN6Z9KBzx8CdwMeAb9K53eyLwMiJx5zOvUIuK89HSr0Ydh9Osb/vAJ46sd0L+Thz/P7yq8px+yZw5UI8zsB6YNdMjytwLfDlrvK/U+9UH9WMHDj+H9mU8VK2oJRh9IeAB4FzM/M5gLI9p1RbKH+LLwH/Cpgsr98FvJyZE+V1d7+O9bm8/0qpfyZ5D7Af+K/lVNqfR8RyFvBxzsyfA/8ReAZ4js5xe5iFfZynnOpxndXjXVM49Loj94K6VCsizgK+Dvx+Zr46qGqPsjPqbxER/wjYl5kPdxf3qJon8d6ZYgS4GLg1Mz8EvM7xUw29nPF9LqdFrgEuAH4ZWE7ntMqJFtJxbtKvj7Pa95rCYRxY1/V6LbB3SG2ZdRGxiE4wfCUzv1GKX4iI88r75wH7SvlC+Ft8BPjHEfE0cDedU0tfAlZExNQdDrv7dazP5f13AgfezgbPgnFgPDMfLK+/RicsFvJx/ofAU5m5PzOPAN8A/h4L+zhPOdXjOqvHu6ZweAjYUK5yGKUzqbV9yG2aFRERwO3A45n5J11vbQemrljYQmcuYqr8unLVw6XAK1PD1zNFZn4uM9dm5no6x/L/ZOY/Ax4APlGqndjnqb/FJ0r9M+pflJn5PPBsRPx6KboC+BEL+DjTOZ10aUQsK/+dT/V5wR7nLqd6XO8HNkXEyjLi2lTKZmbYkzBv84TP1cCPgZ8A/3bY7ZnFfn2UzvDxEeCH5XE1nXOtO4HdZbuq1A86V279BHiUzpUgQ+/HafT/cuCb5fl7gB8Ae4C/BBaX8iXl9Z7y/nuG3e4Z9vUiYKwc6/8BrFzoxxn498ATwC7gvwGLF9pxBr5KZ07lCJ0RwA0zOa7APy993wNcfzptcoW0JGmamk4rSZJOkuEgSZrGcJAkTWM4SJKmMRwkSdMYDpKkaQwHSdI0hoMkaZr/Dyer0MLTX4qsAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X = torch.from_numpy(Nx)\n", "Y = torch.from_numpy(Ny)\n", "X = X.float()\n", "Y = Y.float()\n", "lr = 1e-5\n", "iA, iB = torch.Tensor([0]), torch.Tensor([0])\n", "# 需要计算梯度\n", "iA.requires_grad = True\n", "iB.requires_grad = True\n", "\n", "# 记录最好的结果\n", "best_loss = float(\"inf\")\n", "best_a = 0.0\n", "best_b = 0.0\n", "loss_list = []\n", "max_epochs = 1000\n", "\n", "# 梯度下降\n", "for _ in range(max_epochs):\n", " # 计算 loss\n", " loss = torch.mean((iA*X-Y+iB)**2/(iA**2+1))\n", " # 反向传播\n", " loss.backward()\n", " # 更新参数\n", " with torch.no_grad():\n", " iA -= lr*iA.grad\n", " iB -= lr*iB.grad\n", " cur_loss = loss.item()\n", " if cur_loss" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "nX = np.linspace(0, 100, 100,endpoint=True)\n", "plt.plot(nX, best_a*nX+best_b)\n", "plt.scatter(Nx, Ny, marker='.', color='r')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后可以看到拟合效果已经很好了。调节 learning rate 和最大迭代次数可以获得更好的效果,loss 曲线最后又增大是因为可能在最小值点,loss 震荡,无法到达最小值点。这点可以参考我的梯度下降的理解博客。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "当 learning rate 为 1e-5 时,可以获得很好的效果,而如果不调大迭代次数,1e-6 的学习率效果不是很好。**总结就是梯度下降法初始学习率和应用 learning rate deacy 非常重要。**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 2 }