{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Practical Deep Learning for Coders, v3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Lesson 2_sgd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.basics import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part of the lecture we will explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks We will ilustrate the concepts with concrete examples.
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在这部分,我们将会解释随机梯度下降算法(SGD),它在神经网络应用中是常用的**优化**算法。我们将通过实例来解释其原理和概念。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression problem 线性回归问题" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of linear regression is to fit a line to a set of points.
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "线性回归的目标是将一条直线拟合到一组点。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n=100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.1957, 1.0000],\n", " [ 0.1826, 1.0000],\n", " [-0.1008, 1.0000],\n", " [-0.1449, 1.0000],\n", " [ 0.7091, 1.0000]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.ones(n,2) \n", "x[:,0].uniform_(-1.,1)\n", "x[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3., 2.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor(3.,2); a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = x@a + torch.rand(n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHF5JREFUeJzt3X+QHHd55/HPo/XYrAx45WSTs9deLFdx8uGo7DVbPieqSs4KhQxJ7D1jwBy+kBwpFUkqBYQoJxfcxU5BWTlVDnKVVCXKHSE5OCMjmz0bQxQnMpWKg52sbiWEMALZJOCxE4uzFw60mPHquT+mR+6d6Z7pnzM9M+9XlUq7PT0zj3vXz3z19PP9fs3dBQAYHhsGHQAAIB0SNwAMGRI3AAwZEjcADBkSNwAMGRI3AAyZRInbzKbM7ICZfcXMHjezHy87MABAtHMSnvd7kv7c3W8xs3MlbSwxJgBAF9ZrAo6ZvVLSUUmXO7N1AGDgkoy4L5d0StKfmNlVkg5Lere7fy98kpntlLRTks4///zXXnHFFUXHCgAj6/Dhw99y9+kk5yYZcc9LelTSNnd/zMx+T9J33P0/xT1nfn7el5aW0sQMAGPNzA67+3ySc5PcnHxK0lPu/ljw/QFJ12QNDgCQT8/E7e7/JOmbZrYlOPTTkr5calQAgFhJu0p+TdIngo6SJyX9YnkhAQC6SZS43f2IpES1FwBAuZg5CQBDhsQNAEMmaY0bAMbO4nJdew+e0NMrq7p4alK7dmzRwtzMoMMicQMYb3HJeXG5rtvvO6bVxpokqb6yqtvvOyZJA0/eJG4AY6tbct578MTZ4y2rjTXtPXiCxA0AgxKXnO984LieP92IfM7TK6v9CK0rbk4CGFtxSTguaUvSxVOTZYWTGIkbwNjKkoRP/+BFLS7XS4gmORI3gLG1a8cWTdYmUj3n+dMN3X7fsYEmbxI3gLG1MDeju27eqpmpSZmkmalJTU3Wej6vdZNyULg5CWCsLczNrOsSWVyu6737j6jXrjGDvEnJiBsAQhbmZvT262ZlPc4b5E1KRtwARkKSWY5JZ0J+cGGr5l91ofYePKH6yqpMWjcCn6xNaNeOLR3P6xcSN4Chl2SWY9qZkOESStWmvpO4AQy9JLMc88yEbK+DDxo1bgBDL+5GYfh4knOGBYkbwNCLu1EYPp7knGFB4gYw9KIm0rTfQExyTlqLy3Vt23NIm3c/qG17DvVtUg41bgBDL1zHjruBmOScNAa57Ku592ozT29+ft6XlpYKf10AqIptew6pHlEfn5ma1CO7t6d+PTM77O6J9valVAIAGQzyZieJGwAyiLupObWx91oneZG4ASCDXTu2qDbROTH+u98vf9nXRInbzP7BzI6Z2REzo3gNoHSD6thIamFuRrUNnYm7ccZLXzkwTVfJ9e7+rdIiAYBAlTfqbVlcrut040zkY2XXuSmVAKicbtPTq6JbLGVP6kk64nZJf2FmLumP3H1f+wlmtlPSTkmanZ0tLkIAY6esjo0iF4vqFkvZKwcmHXFvc/drJL1B0q+a2U+2n+Du+9x93t3np6enCw0SwHgpY3p6q/xSX1mV66XyS9baeWxXyWSt9HJOosTt7k8Hfz8r6dOSri0zKADjrYzp6UWXX+JivOPGKzPHmFTPxG1m55vZK1pfS3q9pC+VHRiA8RW1F+RdN2/NNZKNK23UV1Yzda6UEWNSSWrcPyrp02bWOv9/ufuflxoVgLGXZQ3sbjXsi6cmI6eoS1pXOmm9d1kxFqHniNvdn3T3q4I/V7r7h/oRGACk0auGHVXaaFe1zpU4tAMCGAm9atjtpY04w7CxAsu6AhgJSVoIw6WNuNX9hmFjBUbcAEZC2hbCMjpX+oXEDWAkpE3Eg+wKyYtSCYBCFDkrMYssO9xUbff2pEjcAHKryqJQw5qI0yJxA4iVdBTdraMjz56OgxzBVxmJG0CkNKPooheFqsoIvqq4OQkgUpq1PYpeFGoYlnUdJBI3gLPCu87ETQ+PGkUX3Vo3yI14hwGlEgCSOssTcaJG0Vk6Onq9x7BOjukHEjcASdHliXa9+qKTJupeNx537djS8SEyLJNj+oHEDUBS9zKESYV1diS58Vj0CH7UkLgBSIovT8xMTeqR3dsLe5+krYPj0pOdBTcnAUhqlidqE+vXzatNWOHlCW485kfiBvAS7/F9AcrYT3LckLgBSGqWMBpn1mfqxhkvvHd6mFflqwpq3AAkpS9hZJ2Szo3H/BhxA5CUroTRa5uwbliDJD8SNwBJ6UoYWaekRyX89+4/og8sHssd/zghcQOQlG5jgaydIVEJ3yV94tFvJBqto4kaNzBGepUpkvZOZ52SHpfYXcq1BOy4STziNrMJM1s2s8+UGRCAcuSpS7eLK6tcf8X02UWqtu051PHa3RI7fdzJpRlxv1vS45JeWVIsAAoQNaqWpPfdc1Rrvr7dL+tmB1GdIddfMa17D9fXTWV/z/4juvOB4/qtn7tSknT6By/GviZ93MklStxmdomkn5H0IUm/XmpEADKLWgdk14GjkqsjabdkHem2l1W27TkUuUjV86cbZ2No7xNvoY87naQj7o9I+k1Jr4g7wcx2StopSbOzs/kjA5Ba1M2/xlr36Y9FjXS7fQB0i2GGlsDUeta4zexnJT3r7oe7nefu+9x93t3np6enCwsQQHJpR89FjnSzfACYpEd2bydpp5Tk5uQ2STea2T9I+qSk7Wb28VKjApBJmuQ5YRbb7pdF1A3LXqhrZ9Mzcbv77e5+ibtfJulWSYfc/bbSIwPGVHj7sKjOjG6ikmdtwlTbsH7Vv8nahH73LVcVOtJt9YFPTdY6HouLgbp2NvRxAxWSd3fzuHVAoo6VUZ5o3bCM62xhqnsxzGPuNOcxPz/vS0tLhb8uMOq27TlU6mYGrBNSXWZ22N3nk5zLiBuokDI3Gcg7mkd1sFYJUCFlbjKQdWEoVA8jbqBCitrdPKokwpZho4PEDVRIEZsMxJVELpisaWW10XE+LXnDh5uTwIiJu8G5aWNN32+cWTear20wvfxl52jldIOblQOW5uYkNW5gxMSVPlZON9attz01WZOsuZZI3tUC0V8kbmDEdLvBuTA3o0d2b9fX9/yMzj/vnI41RLhZORyocQMl6WfPdPi9LpisqTZh65Jy1A1OblYOLxI3UIJ+9ky3v9fKakO1DaZNG2tda9dZd7HB4JG4gRziRtXdeqaLSNzh991g1rHWduOMa+O552j5P78+9jWKaj1E/5G4gYy6jar7OQMy6wYJRbQeYjBI3EBG3UbVZZYhot43SpL3Sro5MKqFrhIgo26j6rjNdIsoQyQZtVPyGG2MuIGMuo2qyyxDxL1vy4SZ3vTazpE0KwOODhI3kFHUzT2puZP54nK9tDJE3Pu2rLnr3sN1zb/qwrPvz8qAo4VSCZBR3I4vz59udJ2BmGeHm/D7tmZATph1nNM+kYaVAUcLiRvIYWFuRuef1/kP19XGmt53z9GO5Nwa+dZXVnNNMw/PgDyToKuEyTajhcQN5BSX/NbcO5JzGSPfJGt4l7nON/qPxA3klCT5tZJz3pFvVJklSQdLmV0u6D8SN5BTVFKM0urmiJIk+ceVWSStq3nPTE3qrpu3rrvp2F4XjzoHw4P1uIEC9JqCLjWTZdw08yRJtOyNhDFYbBYMFCRp73O49a+99U56qSyRp7+bG4xoIXEDMbL2PvdKzln7u1nNDy09a9xm9jIz+zszO2pmx83szn4EBgxang6QVrveh996tSTpvfuPZOrZDuMGI1qSjLhfkLTd3b9rZjVJf2Nmn3P3R0uODRiouGnl9ZVVbd79YM8yx+JyXbsOHD27oUF9ZVW7DhyVlG22Iqv5oaVn4vbm3cvvBt/Wgj/F39EEKmRxuS5T/C96e1dHVPK884HjHVuDNdZcdz5wPHOyZTU/SAnbAc1swsyOSHpW0kPu/ljEOTvNbMnMlk6dOlV0nEBf7T14ItHopFvp5PnTjdjjecsmGG+JEre7r7n71ZIukXStmf1YxDn73H3e3eenp6eLjhPoqzSdGlm6OthRHXmkmoDj7iuSPi/phlKiASoiTadG3Lnti0+1Y5EnZJWkq2TazKaCryclvU7SV8oODBikqA6O2gZTbWL9SnzdujruuPFK1TZ0rtwXRg82skjSVXKRpD81swk1E/097v6ZcsMCBiuugyPqWNzNwvBrxHWobDBL1KEChDHlHShZe1tgnKRT3zGa0kx5Z5EpoGR7D57ombQlat5IjsQNlKzsDhWMHxI3ULIiOlSAMBI3ULLIDpUJ6+g4Yd0RJMXqgEDJiuhQAcLoKgGACqCrBABGGKUSoIeku+AA/ULiBrrIugsOUCZKJUAXeXbBAcrCiBt9N0ylBzboRRWRuNFX3UoPUvXa49igF1VE4kZfxZUe7nzguL7fONOzltzv0fquHVvWfdBITJTB4JG40VdxJYaobb5aCV16aWnU8D6Q/bhRyAa9qCIm4KCvtu05FLs2dZzaBlPjTPzv6czUpB7ZvT1vaMBAMQEHlRW1bsdkbaLrNl/dkrbEjUKMH0ol6Ktu63a8Z/+RTK/JjUKMGxI3+m5hbiayRnzH/ce1stpZ6+6GG4UYR5RKUBl33HhlRxmlG5P0ptdGfwgAo4wRNyojqozyvRdejB2Fu6SHv3KqjxEC1UDiRqW0l1HaJ+y048YkxhGJG5XWSuLvu+eo1iJaV8M3JodpKj2QR88at5ldamYPm9njZnbczN7dj8CAloW5Gf3uW66KbCNs3ZhsjczrK6tyvTQ5Z3G5PoCIgXIlGXG/KOl97v5/zOwVkg6b2UPu/uWSY8MYihs195rB2G0qPaNujJqeidvdn5H0TPD1/zOzxyXNSCJxo1C91r6OayOUuk+lX1yuk7wxUlLVuM3sMklzkh6LeGynpJ2SNDs7W0BoqLq8NeX253/vhRdj177u9bpxq/hJSvR8YJgk7uM2s5dLulfSe9z9O+2Pu/s+d5939/np6ekiY0QF5a0pRz0/ru3v6ZVVLS7XtW3PIW3e/aC27TnU8T7dJuHQeYJRkyhxm1lNzaT9CXe/r9yQMAzy7gwT9fw4F0zWen5ILMzNxK53wpR4jJqepRIzM0n/Q9Lj7v5fyw8JwyDNzjDtJZHrr5hOvELgZG1CZkpUQrnjxitZOxtjIcmIe5ukfy9pu5kdCf68seS4UHFxo9j241ElkY8/+o3Y1920saaZqUmZmsu13nXzVq1ErNUtdX5ILMzN6K6bt3Y8n/o2Rk2SrpK/UXNZCIy4NDcbk+4Mk6YkMlmb0G/93JUd79naRKFd1IdHt84TYFSwyBQkpb/ZmHR0m+bGYNzoOG4Nb0ogGFdMeYek7jcb40awSUa3UxtrkduStZuZmuz6Pq0Ymc4OkLgRiLtZ2D5iTtu7nWRnvCSjZ0ogwEtI3NDicn3dJrxhrTry4nK9Y6OD+sqqdh04qjvuP65vrzYiE3mvjREmzHTN7AXae/CE3rv/CKNpIAESN7T34InIpG1q1pc/sHhMn3j0G5HnNNb8bHJun6Le7QOhZc1djzzx3Nnv+7FzOzDsuDmJ2BuIrYQbl7SjhCfhxH0gpHkNAJ1I3IjtyZ6ZmsyUfFsfBHmmmjNNHYhH4kbXdrssCbT1QZBnqjnT1IF4JO4x1+oSWW2sacKa86zCPdndEuj5506otmH93Kxwh0jUB0Jtwjqe044ebaA7EvcYC0+6kZo3CltJs3VjMCr5mqTbrpvV8d++QXvffFXsJJyoSTp7b7mq4zm3XTfLNHUgBfMkjbYpzc/P+9LSUuGvi97S9Flv23Mosn97ZmpSj+zenuk1AWRjZofdfT7JubQDjpD2tr1erXVJV/hj8gtQLZRKRsTicj2yba9ba12vFf56bV4AYDBI3COiW9te3Mi6WzcJu6YD1UXiHhHd2vbiRtbdVvjLu8MNgPJQ4x4RcZvltqatx4mrX6fZ4QZAfzHiHhFxbXtvv242043FpDvcAOg/EveIiCp7fPitV+uDC1szvR6bFwDVRalkhBTZtsfmBUB1kbgRi/5toJoolQDAkGHEPSBMIweQFYl7AFqTW1p90ml3fcmb9PnQAIZbz1KJmX3UzJ41sy/1I6BxkGdyS94ZjcyIBIZfkhr3xyTdUHIcYyXP5Ja8MxqZEQkMv56J293/WtJzvc5Dcnkmt+Sd0ciMSGD4FdZVYmY7zWzJzJZOnTpV1MuOpDyTW/LOaGRGJDD8Ckvc7r7P3efdfX56erqolx1J3RZ36iXvjEZmRALDj66SAck6uSXvjEZmRALDL9HWZWZ2maTPuPuPJXlRti4DgHTSbF2WpB3wbklfkLTFzJ4ys3fmDRAAkF3PUom7v60fgQAAkmGtEgAYMiRuABgydJUUhPU/APQLibsAeReNAoA0KJUUgPU/APQTibsArP8BoJ9I3AVg/Q8A/UTiLgDrfwDoJ25OFoD1PwD0U6K1StIa57VKaAsEkEWatUoYcXeRNglnbQuMeh+JETyAaCTuGFmScLe2wLjnRL3Prk8dlUxqrHni9wYwPrg5GSNLb3aWtsCo92mc8bNJO+l7AxgfJO4YWZJwlrbANL3e9IUDkEjcsbIk4SxtgWl6vekLByCRuGNlScJZ9pKMep/aBlNtwlK9N4Dxwc3JCK0uj9XGmibMtOaumYSdHWn3kozrAY86xo1JABJ93B3auzyk5mg36S7sAJBFoXtOjhtW+gNQdSTuNqz0B6DqSNxtWOkPQNWRuNtEdXmYpOuvmJbUrIFv23NIm3c/qG17DmlxuT6AKAGMs0SJ28xuMLMTZnbSzHaXHdQgLczN6JrZC9Ydc0n3Hq7rA4vHdPt9x1RfWZWrORX9PfuP6Oo7/4IEDqBveiZuM5uQ9AeS3iDpNZLeZmavKTuwQVlcrutvn3iu4/hqY013P/bNjhuXkrSy2tDt9x0jeQPoiyQj7mslnXT3J939B5I+KemmcsPKL2tJY+/BE4prkFzr0jpJ5wmAfkmSuGckfTP0/VPBsXXMbKeZLZnZ0qlTp4qKL5NWL3a4pJF0RNyte2TCLPaxXs8FgKIkSdxR2apj6Onu+9x93t3np6en80eWQ55e7LjuEZP0tn99aceNyyTPBYAiJUncT0m6NPT9JZKeLiecYuTpxY7rKnn7dbP64MJW3XXzVm3aWOt4HmuJAOiXJGuV/L2kV5vZZkl1SbdK+nelRpXTxVOTqkck6SQj4l77R7bWImGLMgCDkmitEjN7o6SPSJqQ9FF3/1C38we9VgnrjQAYNoXvOenun5X02VxR9RG7rgMYZSO7rGvU8qqUNwCMgpFN3GHN0skXtdo4c/YYG/ACGFYjv1bJ4nJduz51dF3SbmHSDIBhNPKJe+/BE2qcib8By6QZAMNm5BN3r8TMpBkAw6ZSNe4ybh7G9XRLzYk1TJoBMGwqM+LOs75IN7t2bFFtQ/QaI2+/bpYbkwCGTmUSd1l7PS7MzWjvm6/S1ORL09Q3bazpI2+9Wh9c2JrrtQFgECpTKilzr8eonm4AGFaVSdx51hcJY5INgFFXmVJJ1Kp8aVfcK6tODgBVUpnEvTA3o7tu3qqZqUmZpJmpydSLQpVVJweAKqlMqUTKX4sus04OAFVRmRF3EeLq4UyyATBKRipxF1EnB4Cqq1SpJC/W4QYwDkYqcUv0bAMYfSNVKgGAcUDiBoAhQ+IGgCFD4gaAIUPiBoAhQ+IGgCFj7vH7MWZ+UbNTkv4x49N/WNK3CgynKMSVDnGlV9XYiCudrHG9yt2nk5xYSuLOw8yW3H1+0HG0I650iCu9qsZGXOn0Iy5KJQAwZEjcADBkqpi49w06gBjElQ5xpVfV2IgrndLjqlyNGwDQXRVH3ACALkjcADBkBpK4zezNZnbczM6YWWzbjJndYGYnzOykme0OHd9sZo+Z2dfMbL+ZnVtQXBea2UPB6z5kZpsizrnezI6E/nzfzBaCxz5mZl8PPXZ1v+IKzlsLvff9oeODvF5Xm9kXgp/3F83sraHHCr1ecb8vocfPC/77TwbX47LQY7cHx0+Y2Y48cWSI69fN7MvB9fkrM3tV6LHIn2mf4voFMzsVev9fCj32juDn/jUze0ef4/pwKKavmtlK6LEyr9dHzexZM/tSzONmZv8tiPuLZnZN6LFir5e79/2PpH8laYukz0uajzlnQtITki6XdK6ko5JeEzx2j6Rbg6//UNIvFxTXf5G0O/h6t6Tf6XH+hZKek7Qx+P5jkm4p4XolikvSd2OOD+x6SfqXkl4dfH2xpGckTRV9vbr9voTO+RVJfxh8fauk/cHXrwnOP0/S5uB1JvoY1/Wh36FfbsXV7Wfap7h+QdLvRzz3QklPBn9vCr7e1K+42s7/NUkfLft6Ba/9k5KukfSlmMffKOlzkkzSdZIeK+t6DWTE7e6Pu3uvrdevlXTS3Z909x9I+qSkm8zMJG2XdCA4708lLRQU2k3B6yV93Vskfc7dTxf0/nHSxnXWoK+Xu3/V3b8WfP20pGclJZodllLk70uXeA9I+ung+twk6ZPu/oK7f13SyeD1+hKXuz8c+h16VNIlBb13rri62CHpIXd/zt2fl/SQpBsGFNfbJN1d0Ht35e5/reZALc5Nkv7Mmx6VNGVmF6mE61XlGveMpG+Gvn8qOPZDklbc/cW240X4UXd/RpKCv3+kx/m3qvOX5kPBP5M+bGbn9Tmul5nZkpk92irfqELXy8yuVXMU9UTocFHXK+73JfKc4Hp8W83rk+S5ZcYV9k41R20tUT/Tfsb1puDnc8DMLk353DLjUlBS2izpUOhwWdcribjYC79epW1dZmZ/KelfRDz0fnf/30leIuKYdzmeO66krxG8zkWStko6GDp8u6R/UjM57ZP0HyX9dh/jmnX3p83sckmHzOyYpO9EnDeo6/U/Jb3D3c8EhzNfr6i3iDjW/t9Zyu9UD4lf28xukzQv6adChzt+pu7+RNTzS4jrAUl3u/sLZvYuNf+1sj3hc8uMq+VWSQfcfS10rKzrlUTffr9KS9zu/rqcL/GUpEtD318i6Wk1F2+ZMrNzglFT63juuMzsn83sInd/Jkg0z3Z5qbdI+rS7N0Kv/Uzw5Qtm9ieSfqOfcQWlCLn7k2b2eUlzku7VgK+Xmb1S0oOSPhD8E7L12pmvV4S435eoc54ys3MkXaDmP32TPLfMuGRmr1Pzw/Cn3P2F1vGYn2kRiahnXO7+f0Pf/rGk3wk999+0PffzBcSUKK6QWyX9avhAidcribjYC79eVS6V/L2kV1uzI+JcNX9I93uz2v+wmvVlSXqHpCQj+CTuD14vyet21NaC5NWqKy9Iirz7XEZcZrapVWowsx+WtE3Slwd9vYKf3afVrP19qu2xIq9X5O9Ll3hvkXQouD73S7rVml0nmyW9WtLf5YglVVxmNifpjyTd6O7Pho5H/kz7GNdFoW9vlPR48PVBSa8P4tsk6fVa/y/PUuMKYtui5o2+L4SOlXm9krhf0s8H3SXXSfp2MDgp/nqVdQe22x9J/1bNT6EXJP2zpIPB8YslfTZ03hslfVXNT8z3h45frub/WCclfUrSeQXF9UOS/krS14K/LwyOz0v676HzLpNUl7Sh7fmHJB1TMwF9XNLL+xWXpJ8I3vto8Pc7q3C9JN0mqSHpSOjP1WVcr6jfFzVLLzcGX78s+O8/GVyPy0PPfX/wvBOS3lDw73uvuP4y+P+gdX3u7/Uz7VNcd0k6Hrz/w5KuCD33PwTX8aSkX+xnXMH3d0ja0/a8sq/X3Wp2RTXUzF/vlPQuSe8KHjdJfxDEfUyhjrmirxdT3gFgyFS5VAIAiEDiBoAhQ+IGgCFD4gaAIUPiBoAhQ+IGgCFD4gaAIfP/AXe6PY4gqGXRAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0], y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You want to find **parameters** (weights) `a` such that you minimize the *error* between the points and the line `x@a`. Note that here `a` is unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**.
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "你希望找到这样的 **参数**(权重) `a`,使得数据点和直线`x@a`之间的 *误差* 尽可能小。需要注意的是这里`a`是未知的。对于回归问题最常用的 *误差函数* 或者说 *损失函数* 是 **均方误差(MSE)** 。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat-y)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we believe `a = (-1.0,1.0)` then we can compute `y_hat` which is our *prediction* and then compute our error.
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "假设我们取`a = (-1.0,1.0)`,那么我们就可以计算 *预测值* `y_hat` ,随后我们可以算出误差来。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = tensor(-1.,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(7.9356)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = x@a\n", "mse(y_hat, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X+QHOV5J/Dvs6MRjES8Kzl7Z1hpg3Bx0gUkI9gj5NYVR5KDsANiLRyBz1ww5iwTUzkDjoIoCAjKlESoQ47rnLLlCz8SY0ABsRbCnEyQnFQoC3v3Vj8sQEZAQFrhIFvaNUZraVg990d37/b0dPf0z5nume+naku7Pd09r3pXz7563ud9X1FVEBFRfrQ1ugFERBQOAzcRUc4wcBMR5QwDNxFRzjBwExHlDAM3EVHOBArcItIhIk+IyCsi8rKI/H7aDSMiIndTAp73NwD+r6p+WkSmApiWYpuIiMiH1JqAIyIfALALwFnK2TpERA0XpMd9FoDDAB4UkY8AGATwZVV9z36SiKwEsBIApk+ffsG8efOSbisRUdMaHBz8hap2Bjk3SI+7B8AOAL2q+qKI/A2AX6nqX3ld09PTowMDA2HaTETU0kRkUFV7gpwbZHDyIICDqvqi+fUTAM6P2jgiIoqnZuBW1Z8DOCAic81DSwC8lGqriIjIU9Cqkj8H8IhZUfI6gGvTaxIREfkJFLhVdSeAQLkXIiJKF2dOEhHlDAM3EVHOBM1xExG1jP6hYdy3dR8OjYzhjI4SVi2di76FXY1u1gQGbiJqSV7BuX9oGLdu2oOx8jgAYHhkDLdu2gMAmQneDNxE1HL8gvN9W/dNHLeMlcdx39Z9DNxERI3iFZzvenovjh4ru15zaGSsHk0LhIOTRNRyvIKwV9AGgDM6Smk1JzQGbiJqOVGC8PDIGHrXbUP/0HAKLQqHgZuIWs6qpXNRKhZCX2flwhsdvBm4iajl9C3swtrl89HVUYIA6OoooaNUDHStNVDZSBycJKKW1Lewq6JKxFlp4qfRA5XscRMRYbIXHqTn3eiBSva4iSjXgs5yDHKe1Qu3zh0eGYMAsG83UyoWsGrpXDQSAzcR5VbQWY5hZ0Pa0yhZnP7OwE1EuRV0lmOc2ZDOXHgWMMdNRLnlNUjoPB70vLxg4Cai3PIaJHQeD3peXjBwE1FuuU2kcRs8DHpeGP1Dw+hdtw1zVj9T9xmVzHETUW5Zuecg1SJBzguq0Uu/iqrWPiuknp4eHRgYSPy+RERZ0LtuG4Zd8uNdHSW8sHpxpHuKyKCqBtrbl6kSIqKQGj3YycBNRBRSowc7GbiJiEJaNK8z1PGkBRqcFJF/A/AugHEA7wfNwxARhZXFmYpOz+x+2/X49lcO1+X9w1SVLFLVX6TWEiJqeY2u1giif2i44dubMVVCRJnhNzU9K/zaUq8cd9AetwL4gYgogG+p6gbnCSKyEsBKAOju7k6uhUTUMtKq1kgy/eLXlnqtGhi0x92rqucD+ASAG0TkD5wnqOoGVe1R1Z7Ozvok6ImouaRRrWGlX4ZHxqCIv/2YV1s6SsW6pXMCBW5VPWT++Q6ApwBcmGajiKg1pTE1Pen0i1cb1yw7J3Ibw6qZKhGR6QDaVPVd8/OLAdydesuIqOUkPTUd8E5tWLu2h32fNNoYVs0p7yJyFoxeNmAE+u+q6j1+13DKOxHVk18O22t6utvONmuXz29Y9UqiU95V9XVV/Yj5cU6toE1EVE+1cthuqQ1n0AayV73ih+WARJRrtXLY1ibAXR0lCIyFoLzyDHnZWIHLuhJRrgUpIXRuP+aVPsnLxgrscRNRrkUpIUyjeqWeGLiJKNeiBGG39EkjBybDYqqEiGJr5MJQUcvzsrh7e1AM3EQUSxYWhspzEI6CgZuIXAXtRftVdUQJpnlY1rXRGLiJqEqYXnSSC0NlofeeBxycJKIqYdb3SHJhqDws65oFDNxEBMDo7fau24Y5q59xrXEG3HvRSZbWNXoT3rxgqoSIqlIUXtx60UkuunRGRynXE2PqhYGbiFxTFE5+vegwVR1+g4+rls6t+gWSp4kx9cLATUS+qQgBEqvuqDX4mIUlU/OAgZuIPFMUXR0lvLB6cWLvE6R0sNVqsqPg4CQReS59umhestsQcvAxGQzcRIS+hV244oIuiO2YAnhycDjy3oxu0thTshUxcBMRAGD7K4dT31wg76vyZQVz3EQEIFwaI+q0dA4+JoOBm4gABK+hjjot3Rns1195HgN2REyVEBGA4GmMKNPS3faFvPHxnVh49w8SzaG3Cva4iQhA8DRGlMoQrwk+R4+VuYhUBAzcRC2kVm46SA11lGnpfkE9zhKwrSpwqkRECiIyJCJb0mwQEaXDLV1x66Y9oVMVXimVRfM6Jxap6l23reK+tcr9WMcdTpge95cBvAzgAym1hYhicutRA0aqwq2XHKW365ZSWTSvE08ODlcMWN74+E7c9fRe3HnZOVg0rxOP7HirqtzQwjrucAIFbhGZBeCPAdwD4OZUW0REkbhVe6x6YhegQPmkV8iM1tt1plR6123zzGFbbfBqAeu4wwva4/4agL8E8FteJ4jISgArAaC7uzt+y4goFLcBwPK4d8C2JNHb9Qv+fm3oYh13JDVz3CJyKYB3VHXQ7zxV3aCqPara09mZ7PoGRFRblJ5zUr3dKMFfALywejGDdgRBBid7ASwTkX8D8BiAxSLynVRbRUShhQ2eXR0lrF0+P5HA6TZgWQvz2tHVTJWo6q0AbgUAEflDAH+hqlen3C6ilhV1OrnbJgTFglTluEvFQmIB22Lda83mvRgZK1e85tUG5rWjYx03UYbE2eXcawKN27E00hPWgKVfZQvXJ0mGqNYevAirp6dHBwYGEr8vUbPrXbcttQ0NovbkqT5EZFBVe4Kcyx43UYaktdFAnJ48ZQ8XmSLKkLQ2GoiyMBRlF3vcRBmS1C7nzrSIW/oF4FTzvGLgJsqQJDYacEuLCNxnLrIkL584OEnUZLwGOJ3Bu9gmOO3UKRg5VuZgZQaEGZxkjpuoyXilPxRGdYoA6CgVATHWEomzUiA1BgM3UZPxSn9YJYVvrPtjTD9lStUaIhyszA/muIlSUs+6aft7tZeKKBakIjA7BzjTKjuk+mDgJkpBPeumne81MlZGsU0wY1rRM38dZRcbyg4GbqIYvHrVfnXTSQRu+/u2iWDcUWRQPqmYNnUKhu642PX6pMoOqTEYuIki8utVp5mKcL6vM2gHea8kyg6pcRi4iSLy61WnmYrw2jE97HsF2RiYsolVJUQR+fWqvTbUTSIVEaTXzrRHc2OPmygiv151mqkIvynsADBjWhF3XnZO1XtxdcDmwcBNFJHbAB8AHDvxPvqHhlNLRXi9r+U35ZNVx7g6YHNhqoQoor6FXVi7fL4xC9Hm6LGy7yzE/qFh9K7bhjmrn0Hvum2hZyta79vlkcN2m0jD1QGbCwM3UQx9C7sw/ZTq/7iOlcfxlY27qoKz1fMdHhmLNdW8b2EXXli9GOLxujMPzgk3zYWBmygmr+A3rloVnJPu+QZdvzutdb6pMRi4iWIKEvys4By35+tMsyya1xmoeiXNKheqPwZuopjcgqIbq5rDTZDg75ZmeXJwGFdc0DWx6l9XR8l1B3d7XtzvPMoHVpUQxeQs/XObgg5gogQv6lRzrzTL9lcOB9pImBNumgcDN5GHMHXP9qDoLL0DJoNznPpuDjCShYGbyEWcuudawTlqz5cr+pGlZo5bRE4VkR+LyC4R2Ssid9WjYUSNFLf6wyrXW3/leQCAmx7fGalm244DjGQJ0uM+DmCxqv5aRIoA/lVEnlXVHSm3jahhvNIPwyNj6F23LVCao39oGKue2DWxocHwyBhWPbELQLTZilzRjyw1A7cauwn/2vyyaH4kv8MwUYZ0TCvi6LGy62tWuqJW+uSup/dWbQ9WHlfc9fTeyMGWA4wEBCwHFJGCiOwE8A6A51T1RZdzVorIgIgMHD58OOl2EtWVxxLXVfzSJ16B/+ixcuy0CbW2QIFbVcdV9TwAswBcKCLnupyzQVV7VLWns7Mz6XYS1dXomHvQdROlqoO7qlMcoSbgqOoIgB8CuCSV1hBlRJhKDa9znYtPOXGRJ4oqSFVJp4h0mJ+XAHwcwCtpN4yokdwqOIptgmKhclknv6qONcvOQbHNaxkoA2uwKYogVSWnA3hYRAowAv1GVd2SbrOIGsurgsPtmN+kHOt8r40P2kQwZ/UzrBChUESDjsKE0NPTowMDA4nflyiPnGWBXkrFAtcPaWEiMqiqPUHO5SJTRCm7b+u+mkEbYM6bgmPgJkpZmDw2c94UBAM3UcqSqFAhsmPgJkqZa4VKQaoqTrjuCAXF1QGJUpZEhQqRHatKiIgygFUlRERNjKkSohrC7IRDVA8M3EQ+4uyEQ5QWpkqIfMTdCYcoDexxU0PkJf3ADXopixi4qe680g8Dbx7B9lcOZyqYc4NeyiKmSqjuvNIPj+x4C8MjY1B4bzTQPzSM3nXbMGf1M3XZRYYb9FIWscdNdeeVZnDOKBgrj0/sz9g/NIw1m/dixLYzTT0GCrlBL2URAzfVnVf6wc3RY2Xc3r8HTw4OV/XSgcmBwjQDKTfopaxhqoTqzi394LdPzCMvvuUatC0cKKRWw8BNdde3sAtrl89HV0cJAqCro4TPXtTteX6tVRk4UEithqkSagi39MOWXW9X5LCD4EAhtSL2uCkz1iw7pyqFUssVFzD/TK2HPW7KDLcKjveOv+/bC9/+yuF6NY8oMxi4KVOcKRTnZB0nDkxSK2LgpkyzgvhXNu7CuMsopX1gMi/T6IniqpnjFpHZIrJdRF4Wkb0i8uV6NIzI0rewC/9rxUd8ZzBaPfNaMy+JmkGQHvf7AL6iqv9PRH4LwKCIPKeqL6XWqt0bgefvBkYPAu2zgCV3AAtWpPZ2lB1eveZaMxi9ptFbMy+JmknNwK2qbwN42/z8XRF5GUAXgHQC9+6NwNP/EyibucvRA8bXAIN3k6u19rXfDEavXPfRY2X0Dw0zeFNTCZXjFpEzASwE8KLLaysBrASA7m7vyRQ1PX/3ZNC2lMeM4wzcmRI3p+y8/r3j73uufV3rvn7T6NOeEk9Ub4EDt4icBuBJADeq6q+cr6vqBgAbAGOz4MgtGj0Y7riX3RuBZ28Bxo4YX5dmAp+4l8E/IXF3hnG73suhkbGavyRWLZ2LGx/f6Xk9UTMJNAFHRIowgvYjqrop1Ra1zwp33M3ujUD/lyaDNmB8vmklsKYDWNNufNxzhnEuhRZ3Zxi36720l4o1Bx77Fnaho1R0vZ5T4qnZBKkqEQB/B+BlVb0/9RYtuQMoOv6hFUvG8aCevxs46TZpQ1GxeGj5PeCp6xm8IwizM4xzDe3b+/cEXh2wVCxABIF+SbjNvOSUeGpGQXrcvQD+O4DFIrLT/Phkai1asAK47OtA+2wAYvx52dfDpTjCpFV03Aj0Trs3AuvPNXro689lcHfw6sU6j7uV6X1nx1ue950xrVix+NTa5fMxcsx95qTzl4Tb4lVrl89nfpuaTpCqkn+F/6qbyVuwIl4uun2WUY0SlDPQ794IfO8GYPyE+foBYNMXjI/22U1dnhh0wHHV0rlVMxrderdhUiKlYgF3XnZO1fvdt3Vf4O3DuHY2tYLmXGRqyR1Am3u+05Uzf/7sLZNB28kqT9xyM3DvnMl8+b1zct8rDzOJJWjvNszAoFfvmNuHEVVqzinvVm/YXlVSmOoejKVQnT+3D2q6KY8BA39Xfc1T15vveTSXE4f8BhzdAmqQ3m3HtCKOeqQ67Lo6Sp734vZhRJWaM3AD7umWLTcDgw8CetL4ujgduOxryQVXHZ8M+qMHjCqWnKRX+oeGPQcM7b3msLXbtTZBAIL1npkCIZrUvIHbzaX3Gx+1lGbW7nUHYkatjM/+tFIkXqxc8u39e/DIjrcm6nKsVMrAm0ew/ZXDrsF8tMbGCAURnN/djvu27sNNj+9kb5oogNYK3EF94l4j7aFeg2qC6j3Ja3Cb/Wlfk6U0wzjWgDSL3wBiqVjAonmdWHj3D1xTHmPlcddgDhi95PZS0Xc97XFVvPDa5C/JeuzcTpR3zTk4GdeCFcCnvmn0vCeYhTXts4Gezxs587Ds1SvWmiyjBwCo0cMfO2J8bh8ArUNJot8A4hUXdOHJwWHfPLXzV5i9xloi1COFmchD1IrY4/ZSqySx+6LKwc/idGPw03Xij8leveK2JotdeQwYeAD1SLd4rfPR1VHC9lcOBy7ns7N+GXjVYAe9noiqsccd1YIVwC1vAGtGjY/bDgF9f2tOHAKqSt+dsz8DTRJy9GWtdAuQ6AQhv3K7qAHUyotHnW7OaepE3hi4k7RgBXDTT41AvnyD/+zPMGuv2I0erE6zWL3x3RuNjwj15adMmfxRmDGtOFFT7RdABUDvh2f61li7/VKohTXaRP6YKklLrVTLkjsq1x2v4jEA2j7Le+nbZ28Bjr9bma4ZO2LMArXa5OC2p+NvyicnPnebIQkAHaUi1iwzZjn6lQi61WAvmtdZUYXi/JpVJUT+RIMU2obU09OjAwMDid+36fhVlZx9MbDru5UBulgyeu6bViJ0VUv7bON/A473Pjl6EIdOfhB//f4KbD750YmXuzpKeGH1YgDcy5GoHkRkUFV7Ap3LwJ1hXlu4rT833FosABQCWTMyeV9Hb/+kTmblT0LQBoV0ZH/iEFGzYOBudi6BF8USjmMqTimPul7yc3TiQ2v2G1+ECPzjUsS7eio+oO/iHenEgfNX4b8s+2LcvwEROYQJ3ByczCOPpW/X6rU4rtUDgSd0Ctae+JPJAyGWvS1oGR14F20CfAiH0TP4l9AmWliLKI84OJlXLoOfD3/3GRxp+yLunPL3mCm/BgAcxWlYU/5TDH7gjyZPDLvsrU3FhJqxI5PL3UrBmGmag3VZiPKOgbuJnNFRwuaRj2LziY9WHBcA6+3ldTUrWiKwlgfI+LosRM2AqZIm4lYzLQA+e1F3ZRVIRaolBfaJQgDw8LLJuvI17cbXRBQZe9xNJNS61fZUy+6NldP3pQ3Qkzhe7EDxxAjaoux/ZOXRH14GvPHPla+98c9GALcrzTQW92IvnagmVpWQr9ce/CLmvPlY+P+aWXXjzgAdihgLegVZipco58JUlbDHTb4+fO23gN1LzHryAxO9cV/OdVkiU2OnoV/uBw4OAOX3jMPSBlxwLQM6tSz2uCm6iQlCB7yrSmL1uGsoTAUu/wbTK9QU2OPOidxPJa+1HgsAzPlYdY47KeMnKtdhcebqmTenJsXA3SDOxZ3C7vwSN+jX7ZfGNZvdByiTMn5isoLlezdUbgg9dgTo/5LxOYM3NZGaqRIReQDApQDeUdVzg9yUqZLaetdt89y8wFrcyYvbin6lYmFiKdZa4l4f25abKzeJiE38JxU5F9giyqCkp7w/BOCSWC2iKl4bFATZuMBtj8gw233FvT62S+8H1oxMbkKx/NseG1CIkWqRGut5t8/yn8YfYoo/UR7UTJWo6r+IyJnpN6W1eG0XFmTnlzhBP4nrE1crV757Y3UaxFKYagyGWoOkbqJuWrHlZmDwIWPQVQrABZ9jJQtlQmIzJ0VkpYgMiMjA4cOHk7pt0/LbLqwWr+AedLuvuNfX3YIVwF8dNnrm9g2cSzMnq0qW3OG+gXNbMVpp4pabjVJEayq/jhtf391p7jCU7gbORH4ClQOaPe4tzHEnK+oAYe5z3GlJsqrkrpmTQdtPYSow9bTJDTC4wBZFlPh63Azc2ZObqpK8ilN/XpoJnPMp4NUfVG+CQeSBgZsorqA97ii49C25SLSqREQeBfAjAHNF5KCIXBe3gUSZd8Hn0rv36AFjDfM17cYviDXtzJdTKEGqSj5Tj4YQZYpVPTL4YO21WeLgOuYUAdfjJvJy6f3AnUdtdeZi5K9LMyc/bysm937WOua7Nxo9cFaukAcuMkUUh7OSJQnFUtVG0Ljs65PrsdRa2Ityibu8NwgrNVrYREA1q0hmnhVtfRYrGDtZwTnIlnNcXCuXGLgboGlroym63RuBLTcCJ94Ldr6zp12hxnosrpeYa6ezR54LDNwNEGfRKGoRbmkOZ7rDa+p++2xzzZWo/16l+tridOCyrzGgZwTX426AzK3/QdkTZP1yoDodYu0o5LceS00uAb/8nlGWuOkL7JXnDAN3QuIsGkU0wQqc9ny5PaAGyXFHMXrAWLvcntphrjyzmCpJCHPcVBdu6ZZ6YI88dUyVNIAVnFlVQqlyS7ekUZLo5Jwg5KyiYVCvK/a4U8LSQGoI1x65y8BkVF5liYWpxmSkMtMsUbGqJGFhg3CUtInbewDswVMC7ME8tpBliQzggTFwJyhKEA5bGuj2HsU2AQQoj09+f5gzp0TZ0x2lGcDxUeBkjZx5lLLEYgn4yH8D9j5Vnc7hzkITkt5zsqVF2Z8xbGmg23uUT2pF0A7yvkShLFhhbKK8ZgS45Q2g75uVOww5WWWJYbeCK48Zuwe55eCtnYW23Bzuni2Og5M1RKnPDlsaGKbWm3XhlBrnwKffAGTSZYmDD032urnXZ00M3DVEqc9etXSua3rFaz9Jr/fwOpeoLrwmDDlrzUszgLERADGWv7XKGq29Pu3Hra8ZvCcwVVJDlE19+xZ2Ye3y+ejqKEFg5Lb9ctNu71FsExQLEup9ierGmWZZ/i2fNIt4HLefYv78Dz7k/rrX8RbFHrcHe5VHx7QiTpnShtGxcuDqjr6FXYEHEb1qwN2OcWCSMsneO3emWM6+GBj6B2D8hPf11o5DXhOK/CYaJblJdE6wqsQFZ0ESJcxrkpAzh+2116cUgDtdBjd3bzSm6p8sVx4vTAUu/0blL5OMB3eWA8bElf6IGsSZ47b0XOee415/rndNeftsI52zeyPwvRtcevwCiBhL32ZgEJRT3mPiSn9EDTKx1+dDwapKRg9638t67fm7PdI0ClgdV2sQ9Jf7gSOvZ34qPwO3C670R9RAl94fvOfrN4vTqjf3C+5O9l2LRg+Yy96uBIrTgPKxzARzVpW4cKvyEACL5nVOfN0/NIzeddswZ/Uz6F23Df1Dw3VuJRFhyR3uGzYXphqvAeEnDFVRcw0WnVxsa8vNDd3QOVDgFpFLRGSfiOwXkdVpN6rR+hZ24fzu9opjCuDJwWH0Dw1PDF4Oj4xBAQyPjOGmx3fi9v49DWkvUctasALo+9vKUsTSzMqBySV3GIE8KeUxYOABs6evkz3zh5cl9x411BycFJECgJ8B+CMABwH8BMBnVPUlr2vyPjjZPzSMmx7f6boaQ5eZLnFLpQiA9Veex8oToqxxVpW0TQVO+pQnRuU1iBpA0oOTFwLYr6qvmzd/DMDlADwDd1ZEXVr1vq37PJfQ8RugVPNaBm6ijHGbBeqcWn/mR4GDP443ld8+dT9FQQJ3FwB79v8ggN9zniQiKwGsBIDu7u5EGheHsxZ7eGQMt24yUhm1AmutdUis+4W9logyxG0QdPdG4OkbJ9cVd+WzvnmddiQKkuN2m69a1WpV3aCqPara09nZ6XJJfUVZ1c/iVT0iMAYuVy2d6zmJl5UnRDm2YAVw2yEj5WFNw5c2oDgdxlrks4Gez3tfLwXv1xIUpMd9EMBs29ezABxKpznJiVOL7bZIlAD47EXdE731gTeP4JEdb1X8BuNaIkRNolZJ4i/3V5YOWqyp+ykL0uP+CYCzRWSOiEwFcBWAzek2Kz6vnm+QHrHbIlHrrzwPX+2bP3HOV/vmY/2V5wVeSIqImsg1mx298kKsgcmwAk15F5FPAvgagAKAB1T1Hr/zs1BVwvVGiChPEp/yrqrfB/D9WK2qM+66TkTNqqmnvLstrcrd14ko75o6cDt99ts/wguvTS4NGaZEkIgoK1pmrZLb+/dUBG0LN+AlorxpmcD96IseK4iBk2aIKF9aJnCP+1TPcNIMEeVJJnPcaQwgFkQ8gzcnzRBRnmSux+22ZOqtm/bEXu/6M7832/V474dncmCSiHIlc4E7zhojfr7aNx9XX9SNghirjBREcPVF3XjkC78f675ERPWWuVRJmvs9frVvfsW0dSKiPMpc4E5qv0dOtCGiZpW5VInbfo9hV91LK09ORJQFmQvcbivzhV0YKq08ORFRFmQuVQK4rzESRpp5ciKiRstcjzsJcdbiJiLKuqYM3EnkyYmIsiqTqZK4uBY3ETWzpgzcQPw8ORFRVjVlqoSIqJkxcBMR5QwDNxFRzjBwExHlDAM3EVHOMHATEeWMqM+WXpFvKnIYwJsxbvHbAH6RUHOSxHaFk8V2ZbFNANsVVhbbFbdNv6OqnUFOTCVwxyUiA6ra0+h2OLFd4WSxXVlsE8B2hZXFdtWzTUyVEBHlDAM3EVHOZDVwb2h0AzywXeFksV1ZbBPAdoWVxXbVrU2ZzHETEZG3rPa4iYjIAwM3EVHONCxwi8ifiMheETkpIp4lNCJyiYjsE5H9IrLadnyOiLwoIq+KyOMiMjWhds0UkefM+z4nIjNczlkkIjttH78RkT7ztYdE5A3ba+fVq13meeO2995sO5748wr4rM4TkR+Z3+vdInKl7bVEn5XXz4rt9VPMv/t+81mcaXvtVvP4PhFZGqcdEdp1s4i8ZD6f50Xkd2yvuX4/69Cmz4nIYdt7/w/ba9eY3/NXReSapNoUsF3rbW36mYiM2F5L61k9ICLviMhPPV4XEfm62ebdInK+7bV0npWqNuQDwH8GMBfADwH0eJxTAPAagLMATAWwC8Dvmq9tBHCV+fk3AfxZQu36awCrzc9XA7i3xvkzARwBMM38+iEAn07heQVqF4BfexxP/HkFaROA/wTgbPPzMwC8DaAj6Wfl97NiO+dLAL5pfn4VgMfNz3/XPP8UAHPM+xTq2K5Ftp+fP7Pa5ff9rEObPgfgf3v8vL9u/jnD/HxGvdrlOP/PATyQ5rMy7/sHAM4H8FOP1z8J4FkAAuAiAC+m/awa1uNW1ZdVtda26xcC2K+qr6vqCQCPAbhcRATAYgBPmOc9DKAvoaZdbt4v6H0/DeBZVT2W0Pt7CduuCSk+r5ptUtWfqeqr5ueHALwDINDssJBcf1Z82vsEgCXms7kcwGOqelw/8qSCAAAD4ElEQVRV3wCw37xfXdqlqtttPz87AMxK6L0jt8nHUgDPqeoRVT0K4DkAlzSoXZ8B8GhC7+1JVf8FRufMy+UA/l4NOwB0iMjpSPFZZT3H3QXggO3rg+axDwIYUdX3HceT8B9V9W0AMP/8DzXOvwrVPzz3mP9lWi8ip9S5XaeKyICI7LDSN0jveYV6ViJyIYye1Gu2w0k9K6+fFddzzGcxCuPZBLk2zXbZXQej92Zx+37Wq01XmN+bJ0Rkdshr02wXzHTSHADbbIfTeFZBeLU7tWeV6tZlIvJPAD7k8tJtqvq9ILdwOaY+x2O3K+g9zPucDmA+gK22w7cC+DmMALUBwC0A7q5ju7pV9ZCInAVgm4jsAfArl/MCPa+En9U/ALhGVU+ahyM/K7e3cDnm/Dum8vNUQ+B7i8jVAHoAfMx2uOr7qaqvuV2fcJueBvCoqh4Xketh/E9lccBr02yX5SoAT6jquO1YGs8qiLr/XKUauFX14zFvcRDAbNvXswAcgrGQS4eITDF7Ttbx2O0SkX8XkdNV9W0z2Lzjc6sVAJ5S1bLt3m+bnx4XkQcB/EU922WmI6Cqr4vIDwEsBPAkIj6vJNokIh8A8AyA283/Slr3jvysXHj9rLidc1BEpgBoh/Ff4CDXptkuiMjHYfwy/JiqHreOe3w/4wajmm1S1V/avvw2gHtt1/6h49ofxmxP4HbZXAXgBvuBlJ5VEF7tTu1ZZT1V8hMAZ4tRETEVxjdrsxqZ/+0w8ssAcA2AID34IDab9wty36ocmxnArLxyHwDXkeg02iUiM6x0g4j8NoBeAC+l+LyCtGkqgKdg5AD/0fFaks/K9WfFp72fBrDNfDabAVwlRtXJHABnA/hxjLaEapeILATwLQDLVPUd23HX72ed2nS67ctlAF42P98K4GKzbTMAXIzK/3Gm2i6zbXNhDPb9yHYsrWcVxGYAf2pWl1wEYNTslKT3rNIYhQ3yAeBTMH4jHQfw7wC2msfPAPB923mfBPAzGL85b7MdPwvGP679AP4RwCkJteuDAJ4H8Kr550zzeA+A/2M770wAwwDaHNdvA7AHRhD6DoDT6tUuAP/VfO9d5p/Xpfm8ArbpagBlADttH+el8azcflZgpF6WmZ+fav7d95vP4izbtbeZ1+0D8ImEf9ZrteufzH8D1vPZXOv7WYc2rQWw13zv7QDm2a79vPkM9wO4tp7Pyvx6DYB1juvSfFaPwqiGKsOIWdcBuB7A9ebrAuAbZpv3wFYll9az4pR3IqKcyXqqhIiIHBi4iYhyhoGbiChnGLiJiHKGgZuIKGcYuImIcoaBm4goZ/4/Hxs2UrsTgccAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],y_hat);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for `a`? How do we find the best *fitting* linear regression.
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "到现在我们已经指定了 *模型* 的类型(线性回归),以及 *评估标准* (或者说 *损失函数* ),接下来我们需要处理 *优化* 过程;即,我们如何才能找到最优的`a`呢?我们如何才能找到 *拟合* 最好的线性回归模型呢?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent 梯度下降" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like to find the values of `a` that minimize `mse_loss`.
\n", "\n", "我们希望找到最小化`mse_loss`值的`a`的值。\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.
\n", "\n", "**梯度下降** 是一个用于优化函数的算法。给定一个由一组参数决定的函数,梯度下降从一组初始的参数值开始,不断向能够最小化函数值的参数值迭代。这个迭代式最小化的结果是,通过向函数梯度的负方向不断递进而得到的。\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/).
\n", "\n", "这里是 [PyTorch](http://pytorch.org/)中梯度下降算法实施的细节。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([-1., 1.], requires_grad=True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(a); a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def update():\n", " y_hat = x@a\n", " loss = mse(y, y_hat)\n", " if t % 10 == 0: print(loss)\n", " loss.backward()\n", " with torch.no_grad():\n", " a.sub_(lr * a.grad)\n", " a.grad.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(7.9356, grad_fn=)\n", "tensor(1.4609, grad_fn=)\n", "tensor(0.4824, grad_fn=)\n", "tensor(0.1995, grad_fn=)\n", "tensor(0.1147, grad_fn=)\n", "tensor(0.0893, grad_fn=)\n", "tensor(0.0816, grad_fn=)\n", "tensor(0.0793, grad_fn=)\n", "tensor(0.0786, grad_fn=)\n", "tensor(0.0784, grad_fn=)\n" ] } ], "source": [ "lr = 1e-1\n", "for t in range(100): update()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X98XGWZ9/HPlcmkTKkmLS1i0xaKD09ZsdVCRPZpd10oUFEpoWpFQVDYLSus/KhbKCvbX6vbYl1c2EdXakVgRSUi1FLEWqnilmf5kVIIVOzKD7VNhbZAykJDm0zu548zM53MnDNzJvMjM8n3/Xr1leTMmZkrJ+nVu9e57vs25xwiIlI76gY7ABERKYwSt4hIjVHiFhGpMUrcIiI1RolbRKTGKHGLiNSYUInbzJrM7G4z+62ZPWtmf17uwERExF99yPNuAn7mnPu4mTUAI8sYk4iI5GD5JuCY2duBp4BjnWbriIgMujAj7mOBPcB3zey9wBbgSufcm+knmdl8YD7A4YcfftLxxx9f6lhFRIasLVu27HXOjQtzbpgRdwvwCDDDOfeomd0EvO6c+8eg57S0tLj29vZCYhYRGdbMbItzriXMuWFuTu4EdjrnHk18fTdw4kCDExGR4uRN3M65l4AdZjYlcWgW8JuyRiUiIoHCdpV8Abgz0VHyAvC58oUkIiK5hErczrkngVC1FxERKS/NnBQRqTFK3CIiNSZsjVtEZNhYu7WTVRu2s6urm/FNMRbOnkLr9ObBDitFiVtEhqWg5Lx2ayfX3fM03T1xADq7urnunqcBqiZ5K3GLyLCTKzmv2rA9dTypuyfOqg3blbhFRAZLUHJedt82Xtvf4/ucXV3dlQgtFN2cFJFhJygJByVtgPFNsXKFUzAlbhEZdgaShDu7upmxchNrt3Z6B9YvgGVjYGmj93H9ghJHGUyJW0SGnYWzpxCLRgp+XrIW/vK/zYb274BLlFtc3Pu6QslbiVtEhp3W6c2smDuV5qYYBjQ3xWiKRUM994z4Qxz5yiP+D265rWQx5qKbkyIyLLVOb+7XJZLZaRLkmvo2LOhBl/u5paLELSLCoR7tpeu20dV96CblnLrNXFPfxnjbyy43lvG2N/hFrPDyy0AocYtITQs7yzHMeclRePLck17fyMroGkbaQQAm2F76cu09c9JnS/idBVPiFpGaFXaWY6GzIVs7/4XWA7dBQ3bpo84ADMjI4JM/CB+9sejvKQzdnBSRmpVrluNAzqOjDb4yvn/HiC8HjRMB8z7O/TZctK6I76QwGnGLSM0KmkiTeTzUeR1tcN8V0BNihmTjRLj6mdBxlppG3CJSs4Im0mQeD3Xeg8vDJe1oDGYtDh1jOShxi0jN8ptIE4tGWDh7St7zPt7w/9hol8HSJvj6e2Dfjvxv2DgRzr4Zps1j7dZOZqzcxORF9/efUVkBKpWISM1K3lgM0y2Sft5Fox7jH+O3EEm2/YVJ2i2XpG4+DvbSr+Zcrt6WgWlpaXHt7e0lf10RkZK4YTJ0vxruXKuDkz7Xr2NkxspNdPrUzZubYjy86LQBhWRmW5xzofb21YhbRIa+jjZ44NpwybpxIuzbCY0TvFr2tHlZp4S9KVouStwiMrR1tMFPLof4wXDnh+gWaRoZ9V0CtmlkuPVOihUqcZvZ74H/AeJAb9jhvIjIoHtwefikHRsT6rSgCnMZKs++Chlxn+qcyzFJX0SkeCXfqHffznDn1UXhrBtCnZq+lkm/two4XmpqBxSRqpHs1ujs6sZxqFujqFa7xgnBj1mE1OzH1m/61rP9YgxaHbBSu+SETdwO+LmZbTGz+eUMSESGr9BT0wsxazFEGrKP10Xh3G/B0i6vrh0iaSdj9KuIGGT1j5dL2FLJDOfcLjM7EthoZr91zv06/YREQp8PMGnSpBKHKSLDwYC6NTI7RmJjvJJHMhFPm8fjv3+N4574Jxrd/4DBwWgjI87+WuhkHSYWR2V6uCFk4nbO7Up83G1m9wInA7/OOGc1sBq8Pu4Sxykiw8D4pphvf3RgCaKjDdZeBn1pteXuV70uEkjNcLzu8aPp7rkldUosHmFFfCqtJYyxuYKbCectlZjZ4Wb2tuTnwJnA4K2uIiJDVtgp7CkPLu+ftJPiB73HKH35peAYyyDMiPsdwL1mljz/+865n5U1KhEZlsJOYU/J1TGSeCyotJHctb3Q7pWCYyyDvInbOfcC8N4KxCIikrUXZE6NE4LXGUl0kwSVNgxSxwtda6SgGMtA7YAiUrtmLSZu2ePPg66ex9/1BcC/tOGzf03x3SsVpMQtIrVr2jz+KfJ3vNI3Cue8mYuvulH8fc98rvrNcYA3Ol4xdyrNTTEM7yZiUPdEpdYaKZbWKhGRmnb7GydzGydnHbe0JJxZ2gha3a9SE2iKpRG3iFSX9Qtg2RhY2uh9XL8g5+lhd8FJVw2dIcVQ4haR6rF+Qf+Nel3c+zpH8h5IEvYrn6yYO3VQbzgWQhspiEjRSrYw1LIx/rurWwSWBK+lXfKFqQaBNlIQkYop6TZefkk71/GEwW7PqzQlbhHxFXYUm2tmYsHJ1CLBI25JUY1bRLIUsrxqwQtD5bj5+PykT2RtRuCcd1wOUeIWkSyFrO8Ruqujow2+Mj7nzccLX/4kd8RPp9fV4Rz0ujruiJ/OhS9/svhvaghRqUREgP6lkUImqCycPaVfjRt8ujrWL4D2W8mer5iw5Tb46I3s6upmCRezpPfifg9bjUyMqRQlbhHJusEYxG90nXfRpY623EkbUiPwgpd1HaaUuEXEtzSSKVdvdM6ujgeXkzNpA31WRx0hR++ixC0iudfoMCisN7qjzUvW+3bmXr0vwTm4187kY1THkqm1QIlbRAJLFKNHRtm6+MxwL5K5hRgkkrbfWnxewu7D+F58FksPXMjHEseHW0/2QKirRERYOHsK0Uj23uVvvNUbbof19Qvgnvn9k3aKg4x90fsc3BE/nXcduJMlvRerhl0gJW4RoXV6M4c3ZP8HvKfP5V+jOszNRxw0TsRhdLqxXNVzWapzRDXswqlUIiIA7Ov22buREGtUh7j5SONEuPoZDHh8aydbNmzHVMMeMCVuEQEKa8V7fN0tTHxiFUe6PZhlFkIyRGMwa3HqS9Wwi6dSiYgA4ZdHfXzdLbxny/UcxR7q8iXt2Bg4+2aYNo+1WzuZsXITkxfdz4yVm8LVzsWXEreIAOHXqJ74xCpidjDPqxm0XALXvphK2plrn1x115NMX/5zJfABUKlERFLClDGOdHt8h9nOgZl5vduzFsO0eanHgib4vLa/Z+BLwA5jStwiw0hBGw5kTqRJJOPdNo6j2JN1+ss2jqOWPuf7UrlucA54CdhhLHTiNrMI0A50Ouc+Wr6QRKSUksm6s6u731SYnBsedLTBfVdATyLh7tvhfQ3sOHEhjVuu71cu6XYN/Proz3PTyk2+/ygE3fhMqpXd1atFISPuK4FngbeXKRYRKVLmiPrU48fx4y2dqTJFZtNev9Gu38zHdD3d8OBy3n/1MzwOia6Svey2sfz66M+z5MUT6E4k+mQNe9l921hy9gmcevw47nzkj4FNg5qAU5hQe06a2QTgduArwIJ8I27tOSlSeX4r/PlPNu/PgBc//Sb85HKIh7jpuLQr6+iMlZsCR9TRiIHzJvP4iUUjNbVRb7mUY8/JfwWuAd6W403nA/MBJk2aFPJlRaRU/G4AhtkKfHxTDB68NkTSxqt1+8hV6uiJB0fRrAk4A5I3cZvZR4HdzrktZvZXQec551YDq8EbcZcsQhEJpdA68Zy6zVwbbWP8W6/AWyH+ymZMpEmXr4btx4CHF51W0HPEE6aPewYwx8x+D/wQOM3MvlfWqESGsYFOVAmqE2d27hle0r6h4Ts0214szLi8cWJqIo0fv8k7A41X8subuJ1z1znnJjjnjgHOAzY55y4oe2Qiw1Ahm/RmCpr5eP4pk1KTaj476jG2jf4iNzd8kxgH8gdUF4W534arnwlM2nBo8k5TLJr1WDRiROv6//OhhaWKoz5ukSqSa5PefHXgvJsQ3D4HXnwIekMGExsDZ92QM2Fnvn/r9GbfXvGccUnBQnWVFEpdJSIDM3nR/b6FCwNeXPmRgb1oRxvcdxWu583c64pAahU/qbxCukq0VolIFQmq+w64HpyaSBMiaee4+SjVRYlbpIqEXaEvr442+Pp74J6/OTT7MZc8Nx+luqjGLVJFit4sNzH70XW/mn+ETWJhqI99Wwm7xihxi1SZAW00kDFdPWzSfsym8QEl7ZqjxC1S69YvCLHn4yHJ3dW/33c6N0bn07XofnV61BglbpFaFmqjXo9z3tKrK3o+wUMjTuXN3l569nv7TOZcKVCqjm5OitSi9JuPIZL2ftfAsuhVHLX0OW765xUcPqI+aw2RZL+4VD+NuEXKpKBNC8LKt/RqBufgNUbxz+6zzPzI/NTxoHVNtC52bVDiFimDzCVWS1KKyNzcIAfnoMvextKez9D+9jOy/tEoZEd3qT5K3CJFCBpVFzN1PdCDy0Ml7T4H/xE/ndWjLufhpf6r7y2cPSVr7W6tH1I7lLhFBijXqLospYh9O3M+7Bx0urF8tXce6/pmYjneq+h+cRlUStwiA5RrVF1UKSJgk14aJ3h7P/rY7xpY1PPXrOubGfq9BtQvLlVBXSUiA5RrVD3gqevJOva+HYA7tElvR5uXwKP9k7Fz8ErfqKykrbLH0KYRt8gABY2q68ybt7hi7tTCSxF+dezEJr3JVfteuucfONLtZZc7IlUWSTd6ZJQlZ5+Q9V5l6XKRQaHELTJAfjf4AOLOcd09T7Ni7tTCt+YKqmMnj0+bxyPxGb7vm/RWT1/WsbJ0ucigUalEZICSu75ELHtlkO6eOF9se8p3+7Hk1mRX/sN1vLT0f+GWNnmTaTraAjfjTT+efN/mgBq230SaXPV4qT1K3CJFaJ3eTF/AZiRx57K2H0uOfE96fSMroms4ij3eno/JWvZxZ2bVsf3WyW6d3szDi04LXEwqs/6uCTdDi0olIkUKs8N5+uj2jPhD3Bj9FvWWUdLo6Ybf/dxbF9uvq4TsOnVjLEpXd49vTGFi1ISb2qTELVKkoFp3ppbXN7Ikegejo2/gU13x7NvpJWmfpVb96tTJjXh7+g6N+v06SjThZmhR4hYpUuZkljoz4hnlk2X1t/KZ+l/kr00G1bjxr1P3xB2jR0YZ2VCfs1tEE26GFiVukRJIn8ySPjJeVn8rF0S8hF3sno9B9eiu/T1sXXxmQTFKbVPiFgkw0L7n1unNNO9Yz9QnrmOEiweXRdJZJO+ej6pTS1Le/7mZ2WFm9piZPWVm28xsWSUCExlMyVFzZ1d3VmdIXh1tvP/JL3EYIZN2NAbnfivvvo8l20hYal6YEfcB4DTn3BtmFgU2m9kDzrlHyhybyKAJ6ntedt82/1F4+voiVgcu943KlNgYOOuGUJv1qk4tSXkTt3POAW8kvowm/oTb3E6kRgW19722v4fXMrb7at6xnvc/veTQVPVQSdug5WL46I0FxaU6tUDICThmFjGzJ4HdwEbn3KM+58w3s3Yza9+zZ0+p4xSpKL/ZkH7OiD/E9CcWhVonG7xFod7kMJi7uuCkLZIUKnE75+LOufcBE4CTzew9Puesds61OOdaxo0bV+o4RSoqs53Pz5y6zayMrqGe7LVB/DgH/9l3Aie8dSszfjo2XL1cxEdBXSXOuS4z+xXwIeCZskQkUgWaAzo45tRt5pr6NsbbXvqoy579mCbuDo2MXmMUS3suTK3kp0WepBhhukrGmVlT4vMYcDrw23IHJjKY/Do4zq1/mJXRNUyo20udkTNpE41xvV3B5APfZ/KB73PigdVZy69qkScZqDAj7ncCt5tZBC/Rtznn1pc3LJHBld7B0fL6Rq5r+BHvYE/+STSQ6sn+QHwGd//oqX7T0TN1dnUzedH96hCRgoTpKukAplcgFpGq0hp5mFa7FhpeDf+kaCw1kaY1cWjVhu05F6FK7xMHlU4kPy3rKuKnow1+cjl0h0jaFgEMGidmzX5snd7MwtlTiEbyj9VVOpGwNOVdxM8D10L8YP7z0kbYQVZt2E5PPNzUB62PLWFoxC3iJ8xI22eE7aeQZKx1RyQMjbhFChVilJ0uaHEoo/8UZK07ImFpxC3iJzYm4AErKGlD8OJQ558yieamGIbXN75i7lTdmJRQNOIW8XPWDbD2MuhL2xasLgqt3ywoaYMWh5LSU+IW8ZNMzgF7PxZKi0NJKSlxiwQJ2PtRZLApcYvkMdCdcETKRYlbJAe/ndU1w1EGmxK3DIpaGcUG7YSzasP2qoxXhgclbqm4WhrFBk2e0QxHGUxK3FJxBe/nmCbnSL2jzZuqnpz1WMB+jkG0s7pUI03AkYoLGq2+tr+n367qV9/1JNevfTr1eM6d1/0Whep+1evF7mgbcKzaWV2qkUbcUnFBo9hMDrjzkT/ScrQ3i/GLbU9lbSmWqjePWO6/KFRfj9eLXUT/NWjyjFQXcyH21itUS0uLa29vL/nrytCQWePOZ/TIKG/19AWeb8CLh51P/5U/Ms5Y2jWgWEUqxcy2OOdawpyrEbdUnN8o9s0DvXR19/ie/9r+Q8fT93zc5cby1d55bHn7GTBiAuzb4f+GjRNK/j2IDCYlbhkUmVPA127t5Oq7ngwcMwMsq7+Vz0R+QV1iT4IJtpcbomt45t3HwDGLvRp3ZrmkLupNVRcZQnRzUqpC6/Rmzj9lUtaejoY3yv7tiAu5MC1pJ8XsIO9//t+8GvY53+i/ql9szIAWhRKpdhpxS9X4cutUWo4e06+EctLrG7kx+u/UW46x+L6d3ketLSLDhBK3VJXMEspLS+dTn7OAgmrYMuwocUt1Wb8AttwGLg4W4R3k7jzpA+oSNexamUYvUqy8NW4zm2hmvzSzZ81sm5ldWYnAZBi6fQ60f8dL2gAunlXzTuccvHj0eTBtXu7JOSJDTJibk73AF51zfwacAlxuZu8ub1gyrHS0wQ2T4cWHQj/FAbvHnsK7PncLkHsavchQkzdxO+f+5Jx7IvH5/wDPAvr/pxQvmbDv+Zv8u6pHDz/0udVhLZfwji9sSB3KNY1eo24ZagqqcZvZMcB04NFyBCO1paia8voFuPbv5CyFJPVZHXVf2pXznFzT6LUEqww1ofu4zWwU8GPgKufc6z6PzzezdjNr37NnTyljlCpUVE359jmhk7ZzcC9n5j0v16JPWoJVhppQidvMonhJ+07n3D1+5zjnVjvnWpxzLePGjStljFKFcm0wkNP6BfDiQ6GT9n/2ncDfd1/I2q2dzFi5icmL7mfGyk1Z/0C0Tm+mKRb1fR0twSpDTd5SiZkZ8B3gWefcjeUPSWpBIRsMPP/dSzn6D21EXB8YeZO2c/Aao1jacyHr+mbSFIuG2nhh6ZwTshav0hKsMhSFGXHPAD4DnGZmTyb+fLjMcUmVCxrFZh5//ruXcuzvf0g9fViepO0cvNI3iit7LuPEA6tZ1zeTWDSCGaFG963Tm1kxdyrNTTEMaG6KsWLuVNW3ZcjJO+J2zm0m/yBJhpmFs6eEGt0e84e7sBC/Pc7B990ZPHviErb8dg+WdsPz6rue9H2O3+g+c+alyFCkmZPST9hOkbAbDNQ5l/ef/WQt+/C5N/Fln/datWG7tg8TSaPELSmFbuIbanQbkLST+3fEqePO+GmsHnU5Dwe8VtjRvchwocQtgJe0c24NNsDyw5uMYBQHfI+/58B3AS8Jr8iRhLV9mEh/StySGmlnJu2kZC157dZOlq7bltqpZvTIKEvOPoHmHeuZ+MQqjnR72G3j2HHiQt4/51IAvnTwEv4lY1nWXmf8Q88lGIROwqpdixyixC2+PdnpmkZGmb785/22EAP4i7d+yV+u/RyjecO7AWlwFHto3HI9jwPvn3MpD404lQUHSGw39gq73BF8tXce9/XNTL1O+x9e1WhapABK3JJzZmE0YrzxVi89ff1H43PqNrMyuoaRlr2zeswOMvGJVTDnUsxgXd9M1h2cmXUeeHX07z3yx35f56qri4i2LhOCuzMiZhzeUN8vac+p28zmhiu4KfpN36SddKTbC0DXfv8NgHMJNQNTZBhT4hYWzp5CLBrpdywWjfAv897Lvu6eVLJ+ccSnuSn6TSbU7c3bm73bxgIDb9nT+iIiwZS4BYAR9Yd+FUaPjKZmHH4tdgf/mpasw0ym6XYN7DhxIeD/j0IY6tEWCaYa9zCX2bsN8FZPn/dJRxtz3c9CJWvwerP32dv43Un/mOoqyWzlaxoZ9a2Zp1OPtkhu5gJawIrR0tLi2tvbS/66Ek4h62TPWLnJd1Zic1OMh0dcAft2hHvTxokwa3GoXdYz4zv1+HH88rd71FUiw5qZbXHOtYQ5VyPuIeb6tU9z5yN/TO2Lnq9LI7OWPKdus9e6170X3grzjgZzV4dK2EnqyRYpjmrcQ8jarZ39knZSri6NZC15Tt1mtjTMT918rAtVHjFoubigpC0ixdOIewhZtWF7VtJOCurSWDh7Cm/eeyWfso2hkrUjsfxIAaURESktJe4hJFcLXVCXRmvkYVzdxrzrZDuMXe4I/pXzmHnuZSp1iAwiJe4hJGjDXMNnT8aONnhwOezbkXex9U43lpkHb059/V/afFdkUKnGPYT49UwbcP4pk/on2o42uC9cx8h+18BXe/uXQzQ5RmRwacQ9hIRe/vTB5dCTP/l28TYW93yGdX391xnR5BiRwaXEPcSEarXbtzPPq3jdIr9q/iIb73ka+rSBgUg1UeIejhonBJdJ0rpFWhOHtOSqSHVR4h6KUjced3pJOrNtb9Zir8adXi6JxuDsm7Pa+zRZRqT6KHEPokKmpoe2fgG03wrJju59O7wkDYeScvJjruQuIlVLiXuQFLoxr9/zs5J+5OH+STupp9tL0umJedo8JWqRGpW3HdDMbjWz3Wb2TCUCGi78tgsLu4FAMul3dnXjOJT09z+wmKyknZRxQ3Lt1k5mrNzE5EX3M2PlJtZu7RzgdyIilRamj/s24ENljmPYCeqFDtMjHZT0D+t+KfhJjRNSnwYlfiVvkdqQN3E7534NvFqBWIaVoF7oMD3Su7q6U7vSvDDi02xuuII5dZvZ1XdEwDPMq2EnFDPaF5HBV7KZk2Y238zazax9z549pXrZIStou7AwPdIXjXqMldE1qVX8JtTtZWV0DY9GWrzukH6yV/ArZrQvIoOvZInbObfaOdfinGsZN25cqV52yGqd3syKuVNpbopheBsXJLcLy+ea6F1ZG/WOtIOcddhTXktf40TAvI9zV8NHb+x3bjGjfREZfOoqGUQD7ZEeGVDLHtn9UqhukYWzp2RtV6YZkSK1Q4tMVbOONvj6e2Bpk/exo807nnajsZ+g4xmKGe2LyODLO+I2sx8AfwWMNbOdwBLn3HfKHdiwl1zBLzm7MX0iTdDMx7QbkPloRqRI7cqbuJ1zn6pEIJLBbwW/5ESaq585dI5mPooMO6pxV4vM9UWCFoFKTqTRzEeRYUuJuxr4lUUwfGdBhqxji8jQpcRdQgUvGpW2fVi25La8acm7wDq2iAxNStwlUvCiUZmjbF/O68VWHVtE0ihxl0iuaeS+iTvM9mGNEw/diBQRSVAfd4kUPI083/ZhKouISAAl7hIpeBp5rpuMjRN9d6MREQEl7pIpeNGoWYuzF4SKxmDut73yiJK2iARQ4i6RzGnknx31GFtGXUXrT07oP109adq87AWhNMoWkRDMuYAdU4rQ0tLi2tvbS/66NaOjjd6ffIH6+FupQ72Rw6g/59+UmEXEl5ltcc61hDlXXSUhhO7PTvRlu307si5sffwt9j+wmJEBidvvPYDSbyYsIjVPiTuP0P3ZaX3ZFvBaQVuL+b3Hwh89BQY9cZf7fUVk2FHiziNnf3bk4UPri1gduHjAq3h29R2BXy+J33v09GWXsHL2hYvIsKGbk3kE9WFf+sY34J75ienqLm/S3u8aWNNwQUHvUey5IjI0KXHn4deHPaduMxfU/wLfRaAyOAc7+8ay2M3nfR+ZH/o9ColHRIYXJe48/Pqzr422hbpw3Yzgqp7L+OTIbzPz3MsCSxx+7xGtM6KR/tVybS8mIqAad6D0Lo/zDnuEL0S+z1FuL7ttHO9gb/ATLQKuDxonEJu1mJtCtP8lE7q6SkQkDPVx+0h2eZwRf4il0TsYzRtYv8FvwFrZmLerunq1RaRA6uMu0qoN27mF5fxFdFtGwk7yWSsbg5aLlbRFpOyUuH1c+sY3+ItIUNJO0lrZIjI4lLh9fLp+U+AkmhStlS0ig0RdJT4i9OU+QWtli8ggCjXiNrMPATcBEWCNc25lWaMaZGYR3wk1zsHBhkZGnP011sZnsGz5z3ltfw8ATbEoS+ecoK4PESm7vCNuM4sA3wDOAt4NfMrM3l3uwAbT85M+QWazjXPwn30ncFrkNtbGZ7Dw7qdSSRugq7uHhT96irVbOyscrYgMN2FG3CcDzznnXgAwsx8C5wC/KWdgpeC7ql/6+iIBNxUvfPmTzI/v5fzIJiL0EaeOO+OnsaT3Yqyrm1UbtqcWf0rX0+e0loiIlF2YxN0M7Ej7eifwgcyTzGw+MB9g0qRJJQmuGH4r7o29dx6u7plDNx737fBW9IN+yXtXVzdLuJglvRdnve74pljO9UK0loiIlFuYm5N+DRZZw03n3GrnXItzrmXcuHHFR1akzBX3ltXfygx7Jvub6en2RuBpgtYDMbzp6bnWC9FaIiJSbmES905gYtrXE4Bd5QmndHZ1dTOnbjObG67ghRGf5sLIL4L7sjN2XPdbO8SA80+ZROv0ZhbOnpK1jgh464toLRERKbcwpZLHgePMbDLQCZwHfLqsUZXA12J3cG7fz6jL25BN1o7rQWuHJI8nPy67b5u6SkSk4vImbudcr5n9HbABrx3wVufctrJHVoyONua6n+WZ+ZjGpye7dXpzziSc73ERkXIJ1cftnPsp8NMyx1I6Dy7PP/MxafIHNVVdRGrK0Jw5mVGzTucSf3pdHT+2D7F22r9XLCwRkVIYmmuVNE5IbCnWnwOuPHgZ6/pmpo7FtAGviNSYoTninrXYW08kjQPu6D2K2tUxAAAH40lEQVS9X9KGQxvwiojUiqE54k7WrNNmSF69dw5r4zN8T9ekGRGpJUMzcYOXvNNuOq5ddH/gqZo0IyK1pCoTt+8aI0XWoCNmxAO2adOkGRGpJVVX406uMdLZ1Y3DW2PkunueLnrVvU99YKLv8RnvGqMbkyJSU6oucWeuMQKluYH45dapXHDKJCKJWTkRMy44ZRJ3/s2fF/W6IiKVVnWlkuQaI9fUtzHe9rLLjeWrvfO4r2tm/ifn8eXWqXy5dWoJohQRGTxVl7gvGvUY1/SsYaQdBGCC7WVldA1jog3AR0K/Tjnq5CIi1aDqSiXXRO9KJe2kkXaQa6J3hX6NctXJRUSqQdUl7pHdLxV03E+56uQiItWg6hJ35hKreY/7CJpQo4k2IjIUVF/i9pmuTjTmu/RqkKAJNZpoIyJDQfUl7mnz4OyboXEiYN7Hs28uaOlVvx1sYtGIJtqIyJBQdV0lQNZ09ULl28FGRKSWVWfiLgHtUCMiQ1X1lUpERCQnJW4RkRqjxC0iUmOUuEVEaowSt4hIjVHiFhGpMUrcIiI1xlzAdl5FvajZHuAPRbzEWGBvicIpJcVVmGqMqxpjAsVVqGqMq9iYjnbOjQtzYlkSd7HMrN051zLYcWRSXIWpxriqMSZQXIWqxrgqGZNKJSIiNUaJW0SkxlRr4l492AEEUFyFqca4qjEmUFyFqsa4KhZTVda4RUQkWLWOuEVEJIASt4hIjRm0xG1mnzCzbWbWZ2aBLTRm9iEz225mz5nZorTjk83sUTP7nZndZWYNJYprjJltTLzuRjMb7XPOqWb2ZNqft8ysNfHYbWb2Ytpj76tUXInz4mnvvS7teMmvV8hr9T4z+6/Ez7rDzD6Z9lhJr1XQ70ra4yMS3/tziWtxTNpj1yWObzez2cXEMYC4FpjZbxLX50EzOzrtMd+fZwVi+qyZ7Ul7779Oe+yixM/8d2Z2UaliChnX19Ni+m8z60p7rFzX6lYz221mzwQ8bmZ2cyLmDjM7Me2x8lwr59yg/AH+DJgC/ApoCTgnAjwPHAs0AE8B70481gacl/j8W8DnSxTXV4FFic8XATfkOX8M8CowMvH1bcDHy3C9QsUFvBFwvOTXK0xMwP8Gjkt8Ph74E9BU6muV63cl7ZzLgG8lPj8PuCvx+bsT548AJideJ1LBuE5N+/35fDKuXD/PCsT0WeD/Bvy+v5D4ODrx+ehKxZVx/heAW8t5rRKv+5fAicAzAY9/GHgAMOAU4NFyX6tBG3E75551zm3Pc9rJwHPOuReccweBHwLnmJkBpwF3J867HWgtUWjnJF4v7Ot+HHjAObe/RO8fpNC4Usp4vfLG5Jz7b+fc7xKf7wJ2A6FmhxXI93clR7x3A7MS1+Yc4IfOuQPOuReB5xKvV5G4nHO/TPv9eQSYUKL3HnBMOcwGNjrnXnXOvQZsBD40SHF9CvhBid47kHPu13iDsyDnAHc4zyNAk5m9kzJeq2qvcTcDO9K+3pk4dgTQ5ZzrzTheCu9wzv0JIPHxyDznn0f2L89XEv9l+rqZjahwXIeZWbuZPZIs31C+61XQtTKzk/FGUs+nHS7VtQr6XfE9J3Et9uFdmzDPLWdc6S7BG70l+f08KxXTxxI/m7vNbGKBzy1nXCTKSZOBTWmHy3GtwgiKu2zXqqx7TprZL4CjfB76knPuJ2FewueYy3G86LjCvkbidd4JTAU2pB2+DngJL0GtBq4FllcwrknOuV1mdiywycyeBl73OS/U9SrxtfoP4CLnXF/i8ICvld9b+BzL/B7L8vuUR+jXNrMLgBbgg2mHs36ezrnn/Z5f4pjuA37gnDtgZn+L9z+V00I+t5xxJZ0H3O2ci6cdK8e1CqPiv1dlTdzOudOLfImdwMS0rycAu/AWcmkys/rEyCl5vOi4zOxlM3unc+5PiWSzO8dLzQPudc71pL32nxKfHjCz7wJ/X8m4EuUInHMvmNmvgOnAjxng9SpFTGb2duB+4PrEfyWTrz3ga+Uj6HfF75ydZlYPNOL9FzjMc8sZF2Z2Ot4/hh90zh1IHg/4eRabjPLG5Jx7Je3LbwM3pD33rzKe+6si4wkdV5rzgMvTD5TpWoURFHfZrlW1l0oeB44zryOiAe+Htc55lf9f4tWXAS4Cwozgw1iXeL0wr5tVY0sksGRduRXwvRNdjrjMbHSy3GBmY4EZwG/KeL3CxNQA3ItXA/xRxmOlvFa+vys54v04sClxbdYB55nXdTIZOA54rIhYCorLzKYDtwBznHO70477/jwrFNM7076cAzyb+HwDcGYittHAmfT/H2dZ40rENgXvZt9/pR0r17UKYx1wYaK75BRgX2JQUr5rVY67sGH+AOfi/Yt0AHgZ2JA4Ph74adp5Hwb+G+9fzi+lHT8W7y/Xc8CPgBEliusI4EHgd4mPYxLHW4A1aecdA3QCdRnP3wQ8jZeEvgeMqlRcwP9JvPdTiY+XlPN6hYzpAqAHeDLtz/vKca38flfwSi9zEp8flvjen0tci2PTnvulxPO2A2eV+Hc9X1y/SPwdSF6fdfl+nhWIaQWwLfHevwSOT3vuxYlr+BzwuUpeq8TXS4GVGc8r57X6AV43VA9ezroE+FvgbxOPG/CNRMxPk9YlV65rpSnvIiI1ptpLJSIikkGJW0Skxihxi4jUGCVuEZEao8QtIlJjlLhFRGqMEreISI35/yJ+LE/zYkLXAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],x@a);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate it! 过程动画化" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "rc('animation', html='jshtml')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Once \n", " Loop \n", " Reflect \n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(tensor(-1.,1))\n", "\n", "fig = plt.figure()\n", "plt.scatter(x[:,0], y, c='orange')\n", "line, = plt.plot(x[:,0], x@a)\n", "plt.close()\n", "\n", "def animate(i):\n", " update()\n", " line.set_ydata(x@a)\n", " return line,\n", "\n", "animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In practice, we don't calculate on the whole file at once, but we use *mini-batches*.
\n", "\n", "实际上,我们并没有立刻计算整个数据集,相反,我们采用 *mini-batches(小批次)* 的策略。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vocab 术语" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Learning rate 学习率\n", "- Epoch 轮次\n", "- Minibatch 小批次\n", "- SGD 随机梯度下降法\n", "- Model / Architecture 模型/架构\n", "- Parameters 参数\n", "- Loss function 损失函数\n", "\n", "For classification problems, we use *cross entropy loss*, also known as *negative log likelihood loss*. This penalizes incorrect confident predictions, and correct unconfident predictions.
\n", "\n", "对于分类问题,我们使用 *交叉熵损失* ,也被称为 *负对数似然损失* 。该损失函数将惩罚那些置信高的错误预测和置信低的正确预测。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }