{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part of the lecture we explain Stochastic Gradient Descent (SGD) which is an **optimization** method commonly used in neural networks. We will illustrate the concepts with concrete examples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression problem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of linear regression is to fit a line to a set of points." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n=100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.1695, 1.0000],\n", " [-0.3731, 1.0000],\n", " [ 0.4746, 1.0000],\n", " [ 0.7718, 1.0000],\n", " [ 0.5793, 1.0000]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.ones(n,2) \n", "x[:,0].uniform_(-1.,1)\n", "x[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([3., 2.])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tensor(3.,2); a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = x@a + torch.rand(n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHHFJREFUeJzt3W+QHWd15/Hf0WhkjzAwdjxF7LGFRC2Rg6HQwJTXG6XI2iTIEGxPMIlNQgIsW8r/WhJWG3l5YbOVFEpcicNuqCQqLyEkLmNsgyJiEgWQKCoUNoyQhC1sgXFi8EDiIfaQxR7skXT2xe1r99zpvt19++l7b/f9fqpUmunp2/epntGZR6fPcx5zdwEAmmPdoAcAAAiLwA4ADUNgB4CGIbADQMMQ2AGgYQjsANAwBHYAaBgCOwA0TJDAbmaTZnanmT1oZg+Y2X8KcV0AQHHrA13n/ZL+3t3fbGYbJG3sdvK5557rmzdvDvTWADAaDh8+/F13n8o6r3RgN7MXSnqNpLdLkrs/I+mZbq/ZvHmz5ufny741AIwUM3skz3khUjFbJC1K+gszO2Jmt5jZ8wJcFwDQgxCBfb2kV0n6U3efkfSkpN2dJ5nZTjObN7P5xcXFAG8LAEgSIrA/KulRd783+vxOtQL9Ku6+191n3X12aiozRQQA6FHpwO7u/yLpW2a2NTr0WklfLXtdAEBvQlXF/KakW6OKmIclvSPQdQEABQUJ7O5+VNJsiGsBAMoJNWMHAET2HVnQTQdO6NtLyzp/ckK7dmzV3Mx0396fwA4AAe07sqDrP3aflldOSZIWlpZ1/cfuk6S+BXd6xQBAQDcdOPFsUG9bXjmlmw6c6NsYCOwAENC3l5YLHa8CgR0AAjp/cqLQ8SoQ2AEgoF07tmpifGzVsYnxMe3asTXlFeHx8BQAAmo/IKUqBgAaZG5muq+BvBOpGABoGAI7ADQMgR0AGoYcOwB0GHRLgLII7AAQMwwtAcoiFQMAMcPQEqAsAjsAxAxDS4CySMUAGBl5cufnT05oISGI97MlQFkEdgAjISt33g76C0vLMkkee22/WwKURWAHMBKycufxoO/Ss8F9mqoYABhO3XLnSUG/HdQ/v/vyPowuLAI7gEZJy6N3y5034YFpHFUxABqjnUdfWFqW67k8+r4jC4ntdCXpqWdOanLjeOL16vTANC5IYDezfzaz+8zsqJnNh7gmABTVLY8+NzOt973pFZqcWB3En3hqRd//wUmNj9mq43V7YBoXMhVzmbt/N+D1ACBRWrolK6UyNzOtmw6c0NLyyqqvr5x2TU6M63lnrK9tG4E4cuwAaqVb2WKeGvS04P+95RUdveF1lYy3331nQuXYXdI/mNlhM9sZ6JoAsEa3dEuebelC70m678iCtu85qC2779b2PQe178jCqq+l5fyrFCqw/7i7v0rS6yX9upm9pvMEM9tpZvNmNr+4uBjobQGMmm7plnYefXpyQqZWueL73vSKVTPkpOBvagXdzsCcJStwD6rvTJBUjLsvRH8/ZmYfl3SJpM91nLNX0l5Jmp2d9TUXAYAcstItWdvSxfck7VxlWrSTY9bD2kGVUZaesZvZ88zs+e2PJb1O0v1lrwsASfKkW7LMzUzr87sv1/TkhDpnmUVm1FmBO3TaJ68QqZgXSfpHMzsm6YuS7nb3vw9wXQBYI0+6Ja+yM+qswB3il1AvSqdi3P1hSa8MMBYAyCUr3ZJX0U6OnRUul100pbsOL6xKx8QDdzzt08+qGModAYysXTu2riqdlNJn1ElllncdXtA1r57WoQcXUwN3qF9CRRDYAYysIjPqtAelhx5cHLpGYQR2ACMt74y6To3CaAIGADkMqsKlFwR2AMhhUBUuvSAVA6DRQvVqGVSFSy8I7ACGRq9BOO11WfucFjWICpdeENgB9F1SIJaUOwjHX//CiXE9+cxJrZzyNa/LWvLfVAR2AH2VNos+c3xdriDc+frO3urx1+WtZBlEa90q8fAUQF+lzaKfeGptgJbWBuGk16e9Lk8ly6Ba61aJwA6gr4rWfXcG5yJ9XPJUsgyqtW6VCOwA+iptFj05MZ6rT3qeuvF28M7TMKxOC4/yIscOoK/S+rPceNXFkrL7pCe9fnyd6awz12vpqZU1OfKsSpaijcDqgMAOoK+y6sHnZqa1fc/BNcG2nR5p92UJ9bCzSCOwuiCwA+i7rFl0VnokZD15nRYe5UVgBzB0+p0eqcvCo7x4eApA+44saPueg9qy++7CGzpXoU59WYYRM3ZgxIVedh9CE9Mj/URgB0bcsC67b1p6pJ9IxQAjrol13KOOwA6MuDptIIF8COzAiONBZfMEy7Gb2ZikeUkL7v7GUNcFEEZaB0MeVDZPyIen/03SA5JeEPCaAALIqnzhQWWzBEnFmNkFkn5a0i0hrgcgrDwdDIetlh29CzVj/2NJ/0PS89NOMLOdknZK0qZNmwK9LYA80ipc2p0TuzXdYiZfP6Vn7Gb2RkmPufvhbue5+153n3X32ampqbJvC6CAtAqXdltc6bmg3hayJ3m3/w3wP4XwQszYt0u6yszeIOlMSS8ws79297cGuDaAAJI6GMZn6GlC1LJ3y+9L+fc5RX6lZ+zufr27X+DumyVdJ+kgQR0YLkkbTmQFdSlMLXtafv/dHz2md91+tHG7Fw0DWgoAI6Kz8iWp53lcWi170Y2f02b9pzz9VwurXssJukDJ3T9LDTswHLJy10kLkyz6O2kLufY1i2783Musn1Wv5TBjBxooT8fGXhYm9dIwLCm/3w2rXssjsAMNlDcAF12Y1EvDsLmZac0/8rhuveebmXn9aVa9BkFgBxqoqo6Nve5sdOjBxa5B3STdfO02AnogNAEDGqiqjo29Ngzr9gvFJP3CpZsI6gER2IEGqqpjY1LZZNJD1k5pv1DGzHTztdv0u3OvKDUurEZgBxqo1wBclbRfNH/4c69kpl4BcuxAQ1XRsbHX/VFpDdxfBHYAuZXZH5XWwP1DYAdGQNHVomnYH7UeyLEDDdfLatE07I9aDwR2oOHybLKRF/uj1gOpGKDhQqZPeAhaDwR2oGHi+fTJjeOp5/WaPuEh6PAjsAMN0lmO+MRTK4nnkT5pNgI7MKR6qWRJyqd3GjMb6GIlVI/ADgyhXhcC5cmbn3YnqDccVTHAEOq1kiVP3pzSxOZjxg70WZ4US6+VLFmbWsRz66EWLWH4ENiBgsoExLwpll77nneWI05uHJe79L3llVVj7TXVg3ow77KhbFVmZ2d9fn6+7+8LlNUZECVpfJ3prDPXa+mplcxAn7aB9PTkhD6/+/Ku79PuWx6ixW3ecWC4mNlhd5/NOo8cO1BAUu575bTriadWci3Xz5timZuZ1jWvnn52c2lJckl3HV7oqRVAr+NAPZUO7GZ2ppl90cyOmdlxM3tviIEBwyhP4Ov2kLNIr5Wk7eR6bQVQZhyonxAz9qclXe7ur5S0TdIVZnZpgOsCQydv4Ev7BVCk10qVs2p6vjRb6cDuLd+PPh2P/vQ/cQ+UtO/IgrbvOagtu+/W9j0HE1MeSQExSdovgCI7G1U5qx62HZYQVpCqGDMbk3RY0n+Q9AF3vzfEdYF+yVslMjczrflHHtet93wzdfaSNfPN22slqXQx5Kyani/NFeThqbufcvdtki6QdImZvbzzHDPbaWbzZja/uLgY4m2BYIosCErKfbeFnPkyq0avgtaxu/uSmR2SdIWk+zu+tlfSXqlV7hjyfYGysvLZ8dr1tB9ek4KXCjKrRi9CVMVMmdlk9PGEpJ+S9GDZ6wL91C2f3bkDUdFrAP0WIhVznqRDZvYVSV+S9Cl3/9sA1wX6Ju2h6JNPn9R7P3E8s2MiFSUYJqVTMe7+FUkzAcYCVCarDUD74/d+4viqHuZLy8n9zOPO3jiuG668eNX1QvZhoacLimLlKRov72bOczPT2rih+Fxn44b1a4J6qM2jQ14Lo4PAjsYrUvHSy+KfzteE3Dw65LUwOgjsaKz2gqOkZldSchBPewA6OTGuMbPEr3W+JuSKUXq6oBe07UWtpeWfk7ojdkoK4mmLgm686mJJyrVgKE/L3fi4XzgxLjMldofstX0vRhszdtRWt/xznr0/n3z6ZGKePW1RUN4FQ1l9WDrHvbS8ktodkp4u6AX92FFbaWmWMTOdyvlzPTE+Vslqzm6VLN3SQ23xvuhUxaAtbz92UjGorbQ88yl3mfJ1oms/iAwdKLutGM2TH4+fw+pTFEUqBrXVLc/skpIfda7V7weRbDiNqhHYUVtZLXRdejYfPjkxnnpeiCCap+VvW9a4yaGjLFIxqK12euLdHz2WmFOP56m37zmYuIrUpNJBtOjG0J0bTnerigF6QWBHrbUDYFYZYlq6xZUcfIvotogo7drkzVElUjGovTxliGnplukAaRgWEWHYMGNHI2TNgKvcjYhFRBg2zNgxEqrcjYhFRBg2zNjRCHkW8VSV1+58GMoDUAwagR1DK++Ky6JVKVXgYSiGCYEdfVNkaXyRYN1LVQrQZAR2VKazg+GTz5zUyqlWvXnWrLpIsKYqBViNh6eoRFIHw3ZQb+u2YUSRYN1tI2pgFBHYUYk8bXOl9ABeJFhTlQKsRmBHJfKmQdICeJFgXWUpI1BHpXPsZnahpA9LepFaK7T3uvv7y14X9Za2aCeu26y6aAkhVSnAc0pvtGFm50k6z92/bGbPl3RY0py7fzXtNWy00XxJW9ONrzOddeb6oWx2xWYWqIO8G22UTsW4+3fc/cvRx/9P0gOS+Bcx4pLSI9decqE2bhi+QqxuW+wBdRT0X5mZbZY0I+nekNdFPcXTI8OwiCgNdfBommAPT83sLEl3SXqXu/97wtd3mtm8mc0vLi6GelvURLfgOWjUwaNpggR2MxtXK6jf6u4fSzrH3fe6+6y7z05NTYV4W9TIMAdP6uDRNKUDu5mZpP8r6QF3/6PyQ0ITTW5M3pquyuCZd7s66uDRNCFy7Nsl/aKk+8zsaHTsf7r7JwNcGw2w78iCvv+Dk2uOj49ZZcGzSE6f7oxomtKB3d3/Ufk3hMcIuunACa2cXltW+7wN6ysLnkUfiFIHjyZh5Skql5ZHX1peqaykcJhz+kDVhq+oGI3TbRVqUnokxGIhtqvDKGPGjsolPZxs6yx5DLVYiAeiGGUEdlSuvQo1TTw9EqrencZgGGWkYtAXczPTuunAicz0SMjcOA9EMaqYsaOrvLXgeeRJj7BYCCiPwI5UoZtj5UmPkBsHyiMVg1QhmmMlVbh8fvflqeezWAgoj8COVGXz3b12dCQ3DpRDKgapyua7h7mjI9BkzNjxrM60yWUXTemuwwurgnORfDerP4HBILA3TK+rNpPSJncdXtA1r57WoQcXe8p3s/oTGAwCe4PsO7KgXXcce7bh1sLSsnbdcUxS9i5FaWmTQw8udn3Y2c2uHVvX7HtKhQtQPXLsDXLj/uNruiiunHbduP945murSJuw+hMYDGbsDbK0vJJ6fPueg13TKVWlTahwAfqPGfuIyFpktGvHVo2vW91Wf3xd940wQq5KBRAOgb1Bzk7Zfq5Taslh53YpXbZPCb0qFUA4BPYGueHKi3Ofu7C0vGqWfdOBE1o51ZGfP+WpNefUqAPDi8DeIHMz07ln7dLqWXbRh6fUqAPDi8DeMDdceXHqphZJ2rPsoqtM6cIIDC8Ce8MklRi+9dJNmu4ScL+9tFy4qyJdGIHhRbljA6WVGG7fczC1pLFoV0W6MALDy9w9+6ysi5h9UNIbJT3m7i/POn92dtbn5+dLvy+K6WwbILVm2SwaAurBzA67+2zWeaFm7B+S9CeSPhzoeqhA0Vl2r31nAAxWkMDu7p8zs80hrtU0VQfHvNfvPO/ma7d1HUevvdQBDF7fHp6a2U4zmzez+cXFxX697UBVvYgn7/V7GQd16kB99S2wu/ted59199mpqal+ve1AVR0c067/rtuPrll8VHQc1KkD9UW5Y4WqDo7drrOwtKzfuv2oNu++O7ESJuv11KkD9UVgr1DVwTHrOln1Tt1eT506UF9BAruZ3SbpC5K2mtmjZvbOENetu6qDY9L188oaB73UgfoKVRXzlhDXaZqqF/HEr5+WbulkUu5x0EsdqKcgC5SKYoFSeEmLjzpNT070vM0dgMHr9wIlDFjn7N20OsdOfhwYHQT2BomnTlg1CowuAntDkR8HRhfljgDQMMzY+4j0CIB+oCqmoF6Dc1LVSvsB5zRBHkAOVMVUoEzHw6R+Le1fqXROBBASOfYCyjT1yuoPQ+dEAKEwY89p35GFws204mmbdWY6lZH2onMigBBGPrDnyZm3UzBpkpppdaZtsoJ62nUAoKiRDux5c+ZJKZi28TFLXNGZ9pqxaObOylAAVRnpHHvenHnXFEnKRDztNafd9c97flo3X7uNzokAKjFyM/Z46iUtOdIZlM+fnEjNr6+cdt104MSaoJz2mna6hZWhAKoyUjP2zr0/03TmurP6nifNztN6sV920ZS27zmoLbvvXrV9HQCEMlIz9m658rakXHd7Zv3ujx5LfAia9NAzqRf7ZRdN6a7DCz3VwQNAXiMV2LvlyrM2oGgf61w92u2hZ2e6Zfueg6k5fQI7gFBGKrCn5b3zbkBRdkekqje3BgBpxAL7rh1bC824k5R56Jn1QBUAQhiph6eD3qC56s2tAUBq8Ix9GFvkVr25NQBIgdr2mtkVkt4vaUzSLe6+p9v5VbftTWqROzE+pmtePb2qKqV9nMVBAOogb9ve0qkYMxuT9AFJr5f0MklvMbOXlb1uGWkrSm+791s9d2cEgLoIkWO/RNJD7v6wuz8j6SOSrg5w3Z6lVZmkNeKiKgVAk4QI7NOSvhX7/NHo2CpmttPM5s1sfnFxMcDbpkurMhkzK3Q+ANRR36pi3H2vu8+6++zU1FSl75VWffKW/3ghVSkAGi9EVcyCpAtjn18QHatUt6qXbtUnsy8+h6oUAI1WuirGzNZL+pqk16oV0L8k6efd/Xjaa8pWxaRVvVDdAqDJ+raZtbufNLPfkHRArXLHD3YL6iF066OetPtRmRn6MNbDA0A3QRYoufsnJX0yxLXyyNtzJe8OSWnKvh4ABqGWLQXSqlg6j+fdISlN2dcDwCDUMrDn7blStpsi3RgB1FEtA3veZl55Z/Zpyr4eAAahtk3A8rTPLdumN0SbXwDot9oG9jzKdlOkGyOAOgrS3bGokN0dKUcEMCr6Vsc+SGnliPOPPK5DDy5qYWlZY2Y65a5pgj6AEVHrwJ5WjnjrPd9U+/8h7Y6O1KADGBW1rIppSys7TEsuUYMOYBTUOrD3UnZIDTqApqt1YE9aqJTccf051KADaLpaB/akhUq/cOmmNcG+jRp0AKOg1g9PpeSFSu2e61TFABhFtQvseerW86xKBYCmqlVg33dkQbvuOKaV08+VMO6645gkShgBoK1WOfYb9x9/Nqi3rZx23bi/0n09AKBWahXYl5ZXCh0HgFFUq8AOAMhWq8B+9sbxQscBYBTVKrDfcOXFGh9bvQRpfMx0w5UXD2hEADB8alUVQ390AMhWKrCb2c9KulHSj0q6xN3DNFnvghp1AOiubCrmfklvkvS5AGMBAARQasbu7g9IkllW6y0AQL/U6uEpACBb5ozdzD4t6YcTvvQed/+bvG9kZjsl7ZSkTZs25R4gAKCYzMDu7j8Z4o3cfa+kvVJrM+sQ1wQArEUqBgAaxtx7nzyb2c9I+j+SpiQtSTrq7jtyvG5R0iM53uJcSd/teYDVGuaxScM9vmEemzTc42NsvRvm8eUd24vdfSrrpFKBvWpmNu/us4MeR5JhHps03OMb5rFJwz0+xta7YR5f6LGRigGAhiGwA0DDDHtg3zvoAXQxzGOThnt8wzw2abjHx9h6N8zjCzq2oc6xAwCKG/YZOwCgoIEHdjP7WTM7bmanzSz1qbCZXWFmJ8zsITPbHTu+xczujY7fbmYbAo7tHDP7lJl9Pfr77IRzLjOzo7E/PzCzuehrHzKzf4p9bVuoseUdX3TeqdgY9seOD/rebTOzL0Tf/6+Y2bWxrwW/d2k/Q7GvnxHdh4ei+7I59rXro+MnzCyzpLei8f22mX01ulefMbMXx76W+D3u49jebmaLsTH819jX3hb9HHzdzN42gLHdHBvX18xsKfa1qu/bB83sMTO7P+XrZmb/Oxr7V8zsVbGv9X7f3H2gf9Rq+btV0mclzaacMybpG5JeImmDpGOSXhZ97aOSros+/jNJvxpwbH8gaXf08W5Jv59x/jmSHpe0Mfr8Q5LeXOG9yzU+Sd9POT7QeyfpRyS9NPr4fEnfkTRZxb3r9jMUO+fXJP1Z9PF1km6PPn5ZdP4ZkrZE1xkL/L3MM77LYj9bv9oeX7fvcR/H9nZJf5Lw2nMkPRz9fXb08dn9HFvH+b8p6YP9uG/R9V8j6VWS7k/5+hsk/Z0kk3SppHtD3LeBz9jd/QF3P5Fx2iWSHnL3h939GUkfkXS1mZmkyyXdGZ33l5LmAg7v6uiaea/9Zkl/5+5PBRxDN0XH96xhuHfu/jV3/3r08bclPabWYrcqJP4MdRnznZJeG92nqyV9xN2fdvd/kvRQdL2+js/dD8V+tu6RdEHgMfQ8ti52SPqUuz/u7k9I+pSkKwY4trdIui3g+3fl7p9Ta7KX5mpJH/aWeyRNmtl5KnnfBh7Yc5qW9K3Y549Gx35I0pK7n+w4HsqL3P070cf/IulFGedfp7U/NL8X/RfrZjM7I+DYiozvTDObN7N72mkiDdm9M7NL1JpxfSN2OOS9S/sZSjwnui/fU+s+5XltWUXf451qzfTakr7H/R7bNdH3604zu7Dga6sem6LU1RZJB2OHq7xveaSNv9R968vWeBaoQ2QVuo0t/om7u5mllhBFv2VfIelA7PD1agW1DWqVM/2OpP81gPG92N0XzOwlkg6a2X1qBa1SAt+7v5L0Nnc/HR0ufe+ayszeKmlW0k/EDq/5Hrv7N5KvUIlPSLrN3Z82s19W638+l/fx/fO4TtKd7n4qdmzQ960SfQnsXr5D5IKkC2OfXxAd+ze1/uuyPpphtY8HGZuZ/auZnefu34mCz2NdLvVzkj7u7iuxa7dnrE+b2V9I+u9FxhZqfO6+EP39sJl9VtKMpLs0BPfOzF4g6W61fsnfE7t26XvXIe1nKOmcR81svaQXqvUzlue1ZeV6DzP7SbV+cf6Euz/dPp7yPQ4VoDLH5u7/Fvv0FrWesbRf+587XvvZQOPKNbaY6yT9evxAxfctj7Txl7pvdUnFfEnSS61VxbFBrW/Qfm89ZTikVm5bkt4mKeT/APZH18xz7TW5uyigtfPZc2ptJRhS5vjM7Ox2GsPMzpW0XdJXh+HeRd/Lj6uVY7yz42uh713iz1CXMb9Z0sHoPu2XdJ21qma2SHqppC+WHE/h8ZnZjKQ/l3SVuz8WO574Pe7z2M6LfXqVpAeijw9Iel00xrMlvU6r/1db+dii8V2k1kPIL8SOVX3f8tgv6Zei6phLJX0vmtSUu29VPhHO80fSz6iVP3pa0r9KOhAdP1/SJ2PnvUHS19T6bfqe2PGXqPWP7CFJd0g6I+DYfkjSZyR9XdKnJZ0THZ+VdEvsvM1q/YZd1/H6g5LuUyso/bWkswLfu8zxSfqxaAzHor/fOSz3TtJbJa1IOhr7s62qe5f0M6RWeueq6OMzo/vwUHRfXhJ77Xui152Q9PqK/i1kje/T0b+R9r3an/U97uPY3ifpeDSGQ5Iuir32v0T39CFJ7+j32KLPb5S0p+N1/bhvt6lV7bWiVpx7p6RfkfQr0ddN0geisd+nWGVgmfvGylMAaJi6pGIAADkR2AGgYQjsANAwBHYAaBgCOwA0DIEdABqGwA4ADUNgB4CG+f8l8sVFpmTRMQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0], y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You want to find **parameters** (weights) `a` such that you minimize the *error* between the points and the line `x@a`. Note that here `a` is unknown. For a regression problem the most common *error function* or *loss function* is the **mean squared error**. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def mse(y_hat, y): return ((y_hat-y)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose we believe `a = (-1.0,1.0)` then we can compute `y_hat` which is our *prediction* and then compute our error." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = tensor(-1.,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(7.0485)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = x@a\n", "mse(y_hat, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+UXWV97/H3N5MJDKgZIlkWBkJCy4UWDURmcbk3LFtAATVAijZQtVeUa7R6axVLDctVCa7ba5BVqN7a26aKVPFXxBgDaFMksa66RJ3cEH5HEYxk1DIaJlfJmEwmz/1j753Z58z+dc559plz9vm81srKzPmxz8PJ8D3PfJ/v833MOYeIiFTHnNkegIiI+KXALiJSMQrsIiIVo8AuIlIxCuwiIhWjwC4iUjEK7CIiFaPALiJSMV4Cu5kNmtldZvaEmT1uZv/Fx3VFRKRxcz1d56PAvzjnXm9m84Bjsh58/PHHu8WLF3t6aRGR3rB9+/ZfOOcW5j2u5cBuZvOBVwDXADjnDgIHs56zePFiRkZGWn1pEZGeYma7izzORypmCTAGfMrMdpjZJ8zsWA/XFRGRJvgI7HOBlwP/xzm3DHgeWFP/IDNbbWYjZjYyNjbm4WVFRCSJj8C+B9jjnPtu+P1dBIG+hnNuvXNu2Dk3vHBhbopIRESa1HJgd879HHjGzE4Pb7oIeKzV64qISHN8VcX8GfDZsCLmKeAtnq4rIiIN8hLYnXMPAsM+riUiIq3xNWMXERFg045Rbtmyi5+OT3Di4ADXX3I6K5cNtXUMCuwiIp5s2jHKDRsfZmJyCoDR8Qlu2PgwQFuDu3rFiIh4csuWXUeCemRicopbtuxq6zgU2EVEPPnp+ERDt5dFgV1ExJMTBwcaur0sCuwiIp5cf8npDPT31dw20N/H9ZecnvKMcmjxVETEk2iBVFUxIiIVsnLZUNsDeT2lYkREKkaBXUSkYhTYRUQqRjl2EZFQJ7QD8EGBXUSEzmkH4INSMSIidE47AB8U2EVE6Jx2AD4oFSMiPSEvf37i4ACjCUG83e0AfFBgF5HKy8ufb9oxyv6Dh2Y8bzbaAfigwC4ilZeXP48H/cjgQD9rLz+z6xZOQTl2EekBWfnzpKAPcOxRc7syqINm7CJSIWl59Kz8eZUWTSOasYtIJUR59NHxCRzTefRNO0YT2+kC7D94iMFj+hOv142LphEvgd3MfmxmD5vZg2Y24uOaIiKNyMqjr1w2xIevfBmDA7VB/Ln9k/z6N4fo77Oa27t10TTiMxVzgXPuFx6vJyIyQ1q6JS+lsnLZELds2cX4xGTN/ZOHHYMD/Rx71NyubyUQUY5dRLpGVtlikTr0tOC/b2KSB2+8uJTxzkbvGV85dgf8q5ltN7PVnq4pIlIjK91S5Fg632eSbtoxyvJ1W1my5l6Wr9vKph2jNfel5fzL5iuwn++ceznwauBdZvaK+geY2WozGzGzkbGxMU8vKyK9JCvdEuXRhwYHMGBocIAPX/mymhlyUvA3gqBbH5jz5AXu2ew94yUV45wbDf9+1sy+ApwLfKvuMeuB9QDDw8POx+uKSG/JS7fkHUsXP5N0dHwCI0g3QOPdHPMWa2ezjLLlGbuZHWtmL4y+Bi4GHmn1uiIi9YqkW/KsXDbEt9dcyNDgAPUzzEZm1HmB23fapxE+UjEvAf7dzHYC3wPudc79i4friojUKJJuKarVGXVe4PbxIdSsllMxzrmngLM8jEVEJFdeuqWoRrs51le4XHDGQr68fbQmHRMP3PG0T7urYlTuKCI96fpLTp/R/CttRp1UZvnl7aO87pwhtj0xlhq4fX0INUqBXUR6UiMz6rSF0m1PjPHtNRe2ZbyNUGAXkZ5VdEbdbY3C1ARMRCTHbFa4NEOBXUQkx2xWuDRDqRgRqSxfvVpms8KlGQrsItIRmgnCWc/JO+e0UbNV4dIMBXYRaaukYAwUDsLR8/NaAuRt+a8yBXYRaZu0WfTR/XMKBeH656e1BGikV8tstdYtkxZPRaRt0mbRz+2fTHx8fRBOO3g66TlFKllms7VumRTYRaRtGq37rg/ORZ7fSK+W2WytWyYFdhFpm7RZ9OBAf6E+6Xl14/W9WvIahnXbxqOilGMXkbZJ68+y9vIzgfw+6UnPjx471ESvlkYbgXULBXYRaZu8evCVy4ZYvm7rjGAbpUeiviy+FjsbaQTWTRTYRaSt8mbReekRn/Xk3bbxqCgFdhHpKO1Oj3TTxqOitHgqImzaMcrydVtZsubehg919q3b+rJ0Is3YRXqc7633rapqeqSdFNhFelwnbr2vYnqknZSKEelxVa3l7mUK7CI9rtsOkZB8CuwiPU6LldXjLcduZn3ACDDqnFvh67oi4kdaF0MtVlaPz8XTPwceB17k8Zoi4kFe5YsWK6vFSyrGzE4CXgt8wsf1RMSvvC6GnVTHLq3zNWP/W+AvgRemPcDMVgOrARYtWuTpZUWkiLQKl9HxCRavuTfzJCLpPi3P2M1sBfCsc2571uOcc+udc8POueGFCxe2+rIi0oC8Cpe0k4h8yPptQL8plMPHjH05cLmZvQY4GniRmd3pnHuTh2uLiAdJXQzz+Khjz8rtQ/FzTqUxLQd259wNwA0AZvYHwF8oqIt0lvrKl/oZehIfdexpuf33bdjJlJs5itne8VoVaikg0iPilS9JPc/jsurYGzn8OW3WnxTU854jxXndoOSc+6Zq2EVmX17uOmlTkoV/Jx0hF79uI4c/NzPr147X1mnGLlIxRbo1NrspqdGGYY3m9rXj1Q8FdpGKKRp8m9mU1GjDsJXLhhjZvZfPPvCT3Lx+0pml0hwFdpGKKbNbYzOnG217YiwzqBtw21VnK6B7pCZgIhVTZrfGZhqGZX2gGPDG8xYpqHumwC5SMWV2a1y5bIgPX/kyhgYHMLIXWiNpHyh9Ztx21dn8z5Uva3lcUkuBXaRimgm+RTVS6hhJ+6D5m1VnaaZeEuXYRSqojG6NzZ6NqrbA7afALiKFtHI2qtoCt5cCu0gPaCaFUk9no3YP5dhFKq7R3aJpdDZq91BgF6m4vEM2itLZqN1DqRiRivOVQtEiaPdQYBepmCifPjo+QZ9Z6q7PZlIoWgTtDgrsIhVSX5KY1h5XKZRqU2AX6UDNVrEk5dPrqdlW9Smwi3SYZjcCQX7e3IBvr7nQyzilc6kqRqTDtFLFkpc3V2lib9CMXaTN8tIsrVSxZB1sEc+r+9iwJJ1LgV2kQa0ExSJplmZ6nkfiJYlRVcyUczV59VZSPdIdzGUcKluW4eFhNzIy0vbXFWlVfVAE6J9jvODouYzvn8wN9GmHSA8NDhzJfSe9BsBxx/Rz42Vnthx8i4xBOpOZbXfODec9Tjl2kQYk5b8nDzue2z9ZaLt+kTRL1HZ3cKC/5jHP7Z9sqhVAM2OQ7tZyYDezo83se2a208weNbObfAxMpBMVCX5ZC51F+62sXDbEsUfNzJQ20wqg2TFI9/IxYz8AXOicOws4G7jUzM7zcF2RjlM0+KV9ADTSb6WsmbV6vlRfy4HdBX4dftsf/ml/4l6kRZt2jLJ83VaWrLmX5eu2JqY8koJikrQPgEZONyprZl3mCUvSGbxUxZhZH7Ad+B3g48657/q4rki7FK0UWblsiJHde/nsAz9Jnb3kzX6L9ltJKl30eXapAnl1eVk8dc5NOefOBk4CzjWzl9Y/xsxWm9mImY2MjY35eFkRbxrZFLTtibHUoO5z9quZtTTLax27c27czLYBlwKP1N23HlgPQbmjz9cVaVVePjteu572w1vGdn3NrKUZPqpiFprZYPj1APAq4IlWryvSTln57PoTiBq9hki7+UjFnABsM7OHgO8D9znn7vFwXZG2SVsUff7AIW66+9HcjomqKpFO0nIqxjn3ELDMw1hESpPXBiD6+qa7H+W5/ZNHbh+fmJxxrXpJO0J99WJRTxdphnaeSuUVPcx55bIhjpnX+FznmHlzZwR1H4dH+7qO9B4Fdqm8Ripemtn8U/8cX4dH+7qO9B4FdqmsaMNRUsMrSA7iaQuggwP99Jkl3lf/HF87RtXTRZqltr3S1dJy0GkdEuOSgnjapqC1l58JUGjDUJG2u/Fxzx/ox4wZ3SFbad8rvU2BXbpW1m7RImd/Pn/gEJt2jCYuomYtWOYtZubtGK0fd3yBNv7fUObOU6k29WOXrpWWZokOlyhioL+vlN2cWdUsWemhSNQbXVUxEle0H7tm7NK10nLNU85hFOtEFy1G+g6WWTtGi+TIo8do56k0Q4un0rWycs2OYIt/Ee1ejCySI1ceXVqhwC5dK6+FroMjDbTqTyOK8xFEi7T8jeSNW3l0aZVSMdK1ohTF+zbsTMypx8/wXL5ua+IuUoOWg2ijh0PXL9CmVcWINEuBXbpaFADzqkfS0i2O5ODbiKyNRGnXVu5cyqRUjHS9In3L09ItQx7SMNpIJJ1GM3aphLwZcJk14dpIJJ1GM3bpCWWeRqTDoaXTaMYulVBkI09Zee0iu1VF2kmBXTpW0V2XjVallEGLodJJFNilrcoI1s1UpYhUmQK7lKq+i+HzBw8xORXUnPsK1qpKEamlxVMpTf0JQOMTk0eCeqTRAy8a6aGuqhTpVQrsUpoirXOh9WCtqhSRWgrsUpqiqZBWg3WZpYwi3ajlHLuZnQx8GngJwQ7t9c65j7Z6Xel+aRt34rKCNRQvIVRVisi0lg/aMLMTgBOcc//XzF4IbAdWOuceS3uODtroDUnH0/XPMV5w9NyOa3ilAy2kG7TtoA3n3M+An4Vf/8rMHgeGgNTA3rJ7roPtd4CbAuuDc66BFbeW9nLSnKRZ9wVnLGTbE2OM75/ZaXG2dEIdvIhPXo/GM7PFwLeAlzrn/l/a41qasd9zHYx8Mv3+gQXw6pth6armri+lSZrBl3U0XSPSjqqLt/0V6QRtPxrPzF4AfBl4T1JQN7PVwGqARYsWNf9C2+/Ivn9iL2x6J/zkAXj0K8H3EQX9WdWpG4lUBy9V4yWwm1k/QVD/rHNuY9JjnHPrgfUQzNibfjGXXz7H4cnkWf3EXtj4NthxJ/z84emgr4DfFp0aQNWdUaqm5XJHMzPgk8DjzrnyE92WfqRYYU//W+1MfmIvbHoH3LwE1g7CbS+Fhza0/jpyxKYdo8yx5FNIywqgRY+rUx28VI2POvblwJ8AF5rZg+Gf13i4brJzrinnuoenwmDvYN8zcPe7Fdw9iXLrScfXlRVA63e9RguiScFddfBSNV4XT4tqudwxXhVTpvknw3sfCQL8/R8KAr71TVfjuKngMRd9UGmcDGmLk31m/M2qs0oJoFoQlSoqunjanTtPV9wKN+6Ftfvgyn8KcuSRgQUwfC30zWv9dfbtCYL63e8OgjpMf5hEf+97BjauDj5sJFFaDn3KOW7Zsis1RVLGa852Pl+kHbq/u+PSVcmz5UXnwdffX5tLb9T8k4KZ+mReMHAwcnvwmpq5z5C1AzWpZtzHZiEtiEov684ZexFLV8H7n56e1c8/GbDg7+Fra2f5846FOf21z+8fCFIs+/YUfEEXfAjUe2hDuCg7P/jz1yfC/zpx+vubl1Q+l5+0OBkX7/DYSG680dfUgqj0iu7MsZfhSB59TzBTj/Lmt710Og2Ty2DteO01N70zKL/MU/GSy2gWnjZzN+Dpda/1mhtXmwCpmrZvUOp6aSmdiz4Y5Nhz0zEEHwhx93+oWFCHIGV097unx1IxUZOutMAdpUh85sbVGEx6VXVTMb4sXQWXfSxM5ZBeRx+lbuIKp3FCkxPJ6ZxZVLQWvKi8FIkOzRBpnWbsRSTN5tNSN3HzT2ogjRNq9MOg6FiaUEZzrLx2vNdfcnpiPxnlxkWKU469TI3k2CNR7Xwjr1EkVdREDt9HvruZPLdy4yLJlGPvBFEQjZdd9h8brBQefH7m45PSOXkKlWMy3Sdn49uC7wsE+lbz3c3O+JUbF2mNAnvZ0hZlwU8KpZnUDUwH+q+/PzXAt1oL3qndHEWqToF9NmUF/aKayePH1VXj/OhTb+eU3Rvoc4f5N5vD5/sv4q8m33Lk4Y3ku7X7U2R2KLB3u0bKMdOE1Tg/2n4/p/74C5gBBnM5zJvm3Mcbj74PczBlc9h9yip+e9mlhS6r3Z8is0Pljt2urhyz6aXwfXs4ZfcG6jvrmgU/JBYG+t/e/YXavjgPbQg2cSW0O9buT5HZocBeBUtXwXsfYdMVj/GZQ69kyhnOUfsn7xrzT6LPHS72etEpVjUN0ma2O1Y7XJHZoXLHCkkrTwS4fM6/86F5n2G++xVYUJhzRP8AXPYxDm18O3MpGNzX7ktvtxAv2UxqsaxWxyJNqXbbXkmUtSi5+fD5nP2bf2TJgc/x5wffyag7Hhc1RbvsY7B0FbtPWUWhz/lo921aRU50e3TweH3f/H3PBBU5UVO0ijdBE2k3BfYKKbooufnw+Sw/8DHOP3pjMLMOZ84Pn/1B7jz8Kg65OTgHh8M0zgzhKVb7B34r8fpHbs87eBxg8nn4yjtqg3tG3l5E8imwV0hee9x6o+MTNf1fbtmyi7+afAu/c+BOlhz4HKce+Byfnnolh6IfE+sLWh6vCI62/cjkVex3tQea7Hfz+MjkVcE3RU+4clPTPXKS8vYb39YT7Y1FfFG5Y4VEi5Lv27Az8XzRJPHdoEmpnBsPvZW1h97K0+teO+O+f/71ueydc5C/nLuBE+2X/NS9mI8cWsXdB85lLUwfH1hElL5J20lb8e6XIj4psFdMFNzrG2lliXaDNlp3fuLgAJvHz2fzwfNrbh+KHn/ONUGOvYio5XHWTtqo+2U8sMfPo41YX/Da4W8WIr1GqZgKSiozfNN5i6YDboKfjk80XHee+/gVtwapm7RWxxHrm+6RU9/Tvl488NefRxtxU8EHyt/9Z7hpQbBIe9MCnUsrPUPljj0mr2Njo50VG+7EeM91sP1TENXM9x8Ll/3t9Cw8r1tlvJSyodOtElT81CqpnqLljl4Cu5ndDqwAnnXOvTTv8Qrss6e+4yIEs+yO2jj00Ibkg8jDevsjgXjtIC3stQ3M6YeVfx98HX9NBX3pQO0O7K8Afg18WoG98zUyy57V3uh53S9bnbFHBhbAgV8l981XgJcO0tbAHr7gYuAeBfaZyg6ORa/fTJqlo2f3RQ8ZaVX/AJz1BvjhvwYfJDZnOpUECv7SNh0X2M1sNbAaYNGiRefs3r3by+t2urKDY9HrNzMOHycolS6tKubFp8EvnvD4QkbhtI8CvZSk4wJ7XC/N2MsOjln9YYZis/JmxrFkzb2Jocwgsa694yT1qak3px+OeuHMfH6r5vTBUfNh4jmv59BKb9PReB2i7MMmsq4zOj7Be7/4IO/54oNNPb/r+6mvuHW6lj1pQTaaWUPjZ9PmOTw1/VpR10sIgnv9WDTDF88U2EtWdnBMu34k7/exrHFcf8npiembruynXuS0qqRKHKChNEyaaHMVzPwQiZ9Hq86X4oGXDUpm9nngO8DpZrbHzK71cd0qKPuwiUb7wzQyjp7qp750Fbz/6aAd8ZX/FB5cEna/HH5rsIDaqn17guCe9ZtBNLu/57qwEdr86U1WaogmBWmDUhu0qyoma+YeZ9D+0sVuF1+kra+KKWr+yeHO2UK9kZMf1zcPrvi4ZvQ9qu2Lp43otcDeLkmVL/U6qqKlKupz5vOOhUMHa2fm0eaq+gqeZgwsCH67kJ6jxdMeFM2+o9l7/Zyva/PjnS4pf5+1uarVhdq8Cp6shWLN9HuCZuwVNqu7RiVdWsuEIwos1q7dl37tr74Lpg5mP1+BvispFSPSLepn96ddDDs/l76jNisV02ybhXnHQt9RqrvvcErFiHSLpFTOovOSZ/Vz+qdr75Nk9bPPcvB54PnwGs8E6aJobNJ11I+9zTbtGGX5uq0sWXNvzbF0IjWi8sv60suVf58dbPP62Rd1eDL4YKmn82i7glIxTWomf51UtRJlU4eUAxcfiubYi4rn8hObrllQ56/TqtqiaCpGM/YmRAF6dHwCx/S5oXmz71u27JpRihh9rBa9hkimpauCOveBBf6vnXgerYOR2zVz7zDKsTchKUBH54Zmzbjz+sMUuYZIrrTyy8xKnAT1Hw6p+Xs38yzarNdWRU7pFNgbkLfDMylwx1M2c8yYykl9+WoOJlKjPtjnBfq+eTMXaeeflF5xkxb0k1JD8d44EFTkHNyvahyPFNhDeTnzIrs66xtq1T8nL6gnXUOkFEmBPuu0Kghu27iaxBr7tEXb+z+Un+8/GKvG2bg6CPjRbwsqv2yKAjszA3CU74ba3ZxZQT1pV2fac/rCmbt2hkrHKNL9cukq+MkDQU49/pPbPxAE3iQNl1+G143/JlHf9lhyafGU7Jx5JCtFYsDrzhmakRtPe85h5/jxutdy21Vn90bnRKmOFbfCletrSzDjB4zX81V+GW97fM910x0vb1oQfC81enLGXp92KZIzz3qcA7Y9MTbj9rxe7CuXzfwwEOl4RWb3kYs+6K/8ct+eIIiPfHL6NjcVfD/yyeBD5rSLw7NpM1JKPaDnZuxJpYqW8th4vjuv73nS7DytF/sFZyzUJiXpDT7LL+efFBx1mGbfM0GA3/cM4IK/v/ouuHlJz22o6rkZe1oteV6+O5pZv2/DzsRF0KRFz3h+Pvrt4IIzFvLl7aOZ+XyRSonP8B/aAPe8Z3rBFMKqmOfJbH4W5fGjSpqipg7WHlEYLc5aXzDbr+iJVT238zTtgGYI8tx5O0mTqmMG+vsK58fLPtxapGvFK3MGjgtuq6+KuWlB9uHkPkQHqXRg0FcTsBRpee+igTVpFt5IK4CyD7cW6VpFcvfnXFObYy9DdDpWF1fj9Fxg93FAcyuLnmUfbi1SaVFPmu13lD9zh+lqnKWritX6d4ieWzyd7QOayz7cWqTyVtwKN+5NOXj82unvBxYEbY5btW/PdAO0+MJszaHjnbU4W+kce9Zu0tk8XUgnG4m0SfwQ8iInUyWZf3Lwd2I7hbpr9g/AWW+YLrlMWytoUltPUDKzS4GPAn3AJ5xz67Ie347AnrXICbS0ACoiXSge5KOqmDzRIeRprRQS5VT3ZG3oyrtyuwK7mfUBPwBeBewBvg/8sXPusbTntCOwZ1WfAKpMEZHkZmhJVTHNHjmYZP7J8N5HmnpqO6tizgWedM49Fb7wF4ArgNTA3g7NVJ+oMkWkxxTdRXvRB5MPGWkmtdPs8YUN8LF4OgTEP8r2hLfVMLPVZjZiZiNjYzO33/uWVmVy4uBA5n0iIjMsXRWkUGoWat8apFZqpO1jj/HVPydD26pinHPrnXPDzrnhhQsXlv56WdUnqkwRkYYtXRWkUNaOB3+vuLVgsI/J6oTpkY9UzChwcuz7k8Lb2iKtwqTIRiJVpohIS5JSOYvOy99BWzIfi6dzCRZPLyII6N8H3uCcezTtOb4WT1vd3i8i0k3atnjqnDtkZv8D2EJQ7nh7VlD3qejZo63WjavuXES6iZeWAs65rwFf83GtRhSpfClyOlKWVp8vItJuXd1SoEh1S5HTkbK0+nwRkXbr6sBepLql1W6K6sYoIt2mqwN7kYZerdasq+ZdRLpN17ftzWuh22qbXh9tfkVE2qnrA3ueVg/GaPX5IiLtVom2vSpHFJFe0DNH46WVI47s3su2J8YYHZ+gz4wp5xhS0BeRHtD1gT2tHPGzD/zkSN+1qfC3EtWgi0gv6OqqGEgvO0xLMKkGXUSqrusDezNlh6pBF5Eq6/rAnrRJKa8jsmrQRaTKuj6wJ21SeuN5i2YE+4hq0EWk6rp+8RSSNykNn7KAW7bsUlWMiPScrgzsRerW83akiohUVdcF9k07Rrn+SzuZPDxdwnj9l3YCKmEUEYEuzLGv3fzokaAemTzsWLu5LWd7iIh0vK4L7OMTkw3dLiLSa7ousIuISLauC+zHHdPf0O0iIr2m6wL7jZedSX9f7Rak/j7jxsvOnKURiYh0lq6rilF/dBGRbC0FdjP7I2At8LvAuc45f03WM6hGXUQkXaupmEeAK4FveRiLiIh40NKM3Tn3OIBZXtstERFpl65bPBURkWy5M3Yz+wbwWwl3fcA599WiL2Rmq4HVAIsWLSo8QBERaUxuYHfOvdLHCznn1gPrITjM2sc1RURkJqViREQqxpxrfvJsZn8I/G9gITAOPOicu6TA88aA3QVf5njgF00PslwaW/M6eXwaW3M0tuY0MrZTnHML8x7UUmBvBzMbcc4Nz/Y4kmhszevk8WlszdHYmlPG2JSKERGpGAV2EZGK6YbAvn62B5BBY2teJ49PY2uOxtYc72Pr+By7iIg0phtm7CIi0oCOCOxm9kdm9qiZHTaz1NVhM7vUzHaZ2ZNmtiZ2+xIz+254+xfNbJ7HsS0ws/vM7Ifh38clPOYCM3sw9uc3ZrYyvO8OM3s6dt/Z7Rxb+Lip2Otvjt0+2+/b2Wb2nfDf/iEzuyp2n/f3Le3nJ3b/UeH78GT4viyO3XdDePsuM8st6S1hbNeZ2WPh+3S/mZ0Suy/x37eNY7vGzMZiY/jvsfveHP4M/NDM3jwLY7stNq4fmNl47L6y37fbzexZM3sk5X4zs4+FY3/IzF4eu6+19805N+t/CNr+ng58ExhOeUwf8CPgVGAesBP4vfC+DcDV4df/APypx7F9BFgTfr0GuDnn8QuAvcAx4fd3AK8v6X0rNDbg1ym3z+r7Bvwn4LTw6xOBnwGDZbxvWT8/sce8E/iH8OurgS+GX/9e+PijgCXhdfraPLYLYj9TfxqNLevft41juwb4u4TnLgCeCv8+Lvz6uHaOre7xfwbc3o73Lbz+K4CXA4+k3P8a4OuAAecB3/X1vnXEjN0597hzblfOw84FnnTOPeWcOwh8AbjCzAy4ELgrfNw/Ays9Du+K8JpFr/164OvOuf0ex5Cm0bEd0Qnvm3PuB865H4Zf/xR4lmCzWxkSf34yxnwXcFH4Pl0BfME5d8A59zTwZHi9to3NObct9jP1AHCSx9dvaWwZLgHuc87tdc49B9wHXDqLY/tj4PMeXz+Tc+5bBJO8NFcAn3aBB4BBMzsBD+9bRwT2goaAZ2Lf7wlvezEw7pw7VHe7Ly9xzv0s/PrnwEtszIm/AAADT0lEQVRyHn81M394/jr8Ves2MztqFsZ2tJmNmNkDUYqIDnvfzOxcglnXj2I3+3zf0n5+Eh8Tvi/7CN6nIs8te2xx1xLM9CJJ/77tHtvrwn+ru8zs5AafW/bYCFNXS4CtsZvLfN+KSBt/y+9b247GM09dIsuQNbb4N845Z2apZUThp+3LgC2xm28gCGzzCMqa3g98qM1jO8U5N2pmpwJbzexhgqDVEs/v22eANzvnDoc3t/S+VZWZvQkYBn4/dvOMf1/n3I+Sr1CKu4HPO+cOmNnbCX7rubCNr1/E1cBdzrmp2G2z/b6Vpm2B3bXeJXIUODn2/Unhbb8k+BVmbjjLim73MjYz+w8zO8E597MwAD2bcalVwFecc5Oxa0ez1gNm9ingL9o9NufcaPj3U2b2TWAZ8GU64H0zsxcB9xJ8wD8Qu3ZL71uCtJ+fpMfsMbO5wHyCn68izy17bJjZKwk+NH/fOXcguj3l39dXgModm3Pul7FvP0GwvhI99w/qnvtNT+MqNLaYq4F3xW8o+X0rIm38Lb9v3ZSK+T5wmgWVHPMI/qE2u2C1YRtBbhvgzYDP3wA2h9cscu0ZObwwqEU57ZUExwm2bWxmdlyUxjCz44HlwGOd8L6F/45fIcgz3lV3n+/3LfHnJ2PMrwe2hu/TZuBqC6pmlgCnAd9rcTwNjc3MlgH/CFzunHs2dnviv2+bx3ZC7NvLgcfDr7cAF4djPA64mNrfZksfWzi+MwgWIb8Tu63s962IzcB/C6tjzgP2hROa1t+3MleFi/4B/pAgj3QA+A9gS3j7icDXYo97DfADgk/VD8RuP5Xgf7QngS8BR3kc24uB+4EfAt8AFoS3DwOfiD1uMcEn7Zy6528FHiYITHcCL2jn2ID/Gr7+zvDvazvlfQPeBEwCD8b+nF3W+5b080OQ3rk8/Pro8H14MnxfTo099wPh83YBry7h/4G8sX0j/H8jep825/37tnFsHwYeDcewDTgj9ty3hu/nk8Bb2j228Pu1wLq657Xjffs8QaXXJEF8uxZ4B/CO8H4DPh6O/WFiFYGtvm/aeSoiUjHdlIoREZECFNhFRCpGgV1EpGIU2EVEKkaBXUSkYhTYRUQqRoFdRKRiFNhFRCrm/wNkuJpRv23/fgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],y_hat);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have specified the *model* (linear regression) and the *evaluation criteria* (or *loss function*). Now we need to handle *optimization*; that is, how do we find the best values for `a`? How do we find the best *fitting* linear regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like to find the values of `a` that minimize `mse_loss`.\n", "\n", "**Gradient descent** is an algorithm that minimizes functions. Given a function defined by a set of parameters, gradient descent starts with an initial set of parameter values and iteratively moves toward a set of parameter values that minimize the function. This iterative minimization is achieved by taking steps in the negative direction of the function gradient.\n", "\n", "Here is gradient descent implemented in [PyTorch](http://pytorch.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([-1., 1.], requires_grad=True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(a); a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def update():\n", " y_hat = x@a\n", " loss = mse(y, y_hat)\n", " if t % 10 == 0: print(loss)\n", " loss.backward()\n", " with torch.no_grad():\n", " a.sub_(lr * a.grad)\n", " a.grad.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(7.0485, grad_fn=)\n", "tensor(1.5014, grad_fn=)\n", "tensor(0.4738, grad_fn=)\n", "tensor(0.1954, grad_fn=)\n", "tensor(0.1185, grad_fn=)\n", "tensor(0.0972, grad_fn=)\n", "tensor(0.0913, grad_fn=)\n", "tensor(0.0897, grad_fn=)\n", "tensor(0.0892, grad_fn=)\n", "tensor(0.0891, grad_fn=)\n" ] } ], "source": [ "lr = 1e-1\n", "for t in range(100): update()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+UHHWZ7/H3Mz09kwGWGUJycMkPiFwuLgJLZOB6NznrSkB0hRARA4uILGhw4RIRzRLWe0KIuzeBqPzYhbvEGJGVVQeEEGAxGxLUm6woEwPDzyy/xMygkhgzbMyQ+dHf+0d3T2p6qrqqu6t7uns+r3NyMlNdXfWlZni68tTzfb7mnENEROpHw1gPQERE4qXALiJSZxTYRUTqjAK7iEidUWAXEakzCuwiInVGgV1EpM4osIuI1JlYAruZtZnZ/Wb2kpm9aGb/M47jiohI4RpjOs5twA+dc+ebWRNwUL6dJ02a5I4++uiYTi0iMj5s3bp1l3Nucth+JQd2M2sF/hy4FMA51w/053vP0UcfTWdnZ6mnFhEZV8zsjSj7xZGKmQHsBL5lZtvMbLWZHRzDcUVEpAhxBPZG4H3A/3XOzQT+ACzO3cnMFphZp5l17ty5M4bTioiInzgCezfQ7Zz7Web7+0kH+hGcc6ucc+3OufbJk0NTRCIiUqSSA7tz7jfADjM7LrNpDvBCqccVEZHixFUVczVwb6Yi5jXgr2M6roiIFCiWwO6cexpoj+NYIiJSmrju2EVEBFi7rYeV67fz5p4+jmxrYdFZxzFv5pSKjkGBXUQkJmu39XD9A8/SNzAEQM+ePq5/4FmAigZ39YoREYnJyvXbh4N6Vt/AECvXb6/oOBTYRURi8uaevoK2l4sCu4hITI5sayloe7kosIuIxGTRWcfRkkyM2NaSTLDorOMC3lEeengqIhKT7ANSVcWIiNSReTOnVDyQ51IqRkSkziiwi4jUGQV2EZE6oxy7iEhG7O0Aujpg4zLo7YbWqTBnCZw0P74BB1BgFxGhDO0Aujrg4YUwkJmc1Lsj/T2UPbgrFSMiQoztALo64JYT4IHPHgjqWQN96Tv4MtMdu4gIMbUDyL1L99PbXeDICqfALiLjQlj+/Mi2Fnp8gnhB7QA2Lssf1CGday8zBXYRqXth+fO123rY1z846n2h7QByH4727sg/kGRL+gFqmSmwi0jdC8ufe4N+VltLkqVz3xv84NTv4SgGOP/9W6epKkZEJC758ud+QR/g4ObG/NUwvmkXx6jgnmyBc26vSEDPUmAXkboRlEfPlz+P/NA0ctrFpe/OK1y77qXALiJ1IV8efdFZx/mmW/b1D9J2UJLf7xsYdbwRD00LSbu0ToMvPBfHf1LRYqljN7NfmtmzZva0mXXGcUwRkULky6PPmzmF5eedSFtLcsTrv983wN53BkkmbMT2UQ9N86ZdPCr0cDRMnHfsH3TO7YrxeCIiowSlW8JSKvNmTmHl+u3s6Rt5dz6QcrS1JDm4uZH2tzdwfdN9HMEu7EdTIZFJowTWno992sWPUjEiUjPypVui1KEHBf/evgGe/tgeePhb/i0AgnLqIWmX2HvPRBRXSwEH/LuZbTWzBTEdU0RkhHzplijL0vlNNprbsJltE67I3wJgzpJ0msUr2cJTx1zNrBWbmLH4UWat2MTabT3DL2c/hHr29OE48CHk3adc4grss51z7wM+AlxlZn+eu4OZLTCzTjPr3LlzZ0ynFZHxJF+6JZtHn9LWggFT2lpYft6JI+6Qc4P/jY1ruDV5J238V/BJe7vT6ZVzbk/foWPQOo2nTryRS546KjBwx9Z7pgixpGKccz2Zv98ysweB04Cf5OyzClgF0N7eHlDBLyISLCzdErYsnXdN0lPe3sCnEo/TYIG7p2VbAJw0f0T+/JoVm+jLucP3PqyNpfdMkUq+Yzezg83sj7JfAx8CxrbWR0TqUpR0S5h5iS1saV7IbU13hgf1PFUuYYE7qMdMQb1nihRHKuYIYLOZPQP8HHjUOffDGI4rIjJClHRLoK4OuGlGOpfeuyO3UHE0S+SdMRoWuOP4ECpWyakY59xrwJ/GMBYRkVBh6ZYRhmeLhvRxyZVognPvGBHUcytcPvieyfxga8+IPLo3cHvTPpWuijHnKp/ubm9vd52dmsckImUUpTd6DudgsPEgkufeNiqo585cbUkm+PgpU3jipZ0VC9xmttU51x62n+rYRaQ+RemNnuEc/NYms+OURZw694pRrwdVuDzx0k62LD49luHGSYFdROqHt1FX1LRLsgU753beddJ83hWwy1hWuBRDa56KSH3Ipl56dxA5qLdMjNRSdywrXIqhO3YRqV3eO3RrADe6r/pImQeoBS564dcdslIVLsVQYBeR2pT7cDQgqDvAOeMtm8SO9/nn0MOMZYVLMRTYRaQqRG6Y9ci1sPXuCHfnaT2pSczuvx2AlqcSLJ/WU1RALqjMcowpsItIRfkFcCDvYtPDHrkW1/nN8MlFGftcEzcPHki3eKf81zMFdhGpmKC2uxOSDXkXychKbf1WaMXHIA004uhOHc7Ng/NZl5o94vXcSpaxaq1bTqqKEZGKCaoH91uaDkYHYXOpvMff55r4Yv/nYOkeLjjoG6OCOoysZBnL1rrlpMAuIhUTte57bsNmNjct5NUJn4RbTkg/KAWGnH/Icg66U5NYPPAZOg89E4jWq2UsW+uWk1IxIlIxQW1321qS7B9M0TcwxNyGzaxIruYg60+/2LuDwYeuphF4qOFDnJf6IeZJsjsH9wydwQ2Dl9GSTLC8gF4ttTbxKCoFdhGpmKB68KVz38uUHY8w7RcrOcLtHBG4ARqH3mHfY0tIzN3Ivz44xAVsJEGKIRr416HTuWHwMqb4BO6wSpYoy+nVIgV2EamYwLvoxBZ49gagj6CSlwl9v2HezCms5TY+kPP+Xxb5sLPWJh5Fpe6OIjJ2RrTVza87NYmpy16NfQi1VBWj7o4iUp0KnGAE6WqX1U0Xs7QMw6mliUdRKbCLSOXuWr89F17/ceTdnYMeN4lbuZDZH10Q/3jqlAK7yDgXNGkIiCe4d3XAY9dB3+6C3tZHM4sHLqfz0DOrOj1SjRTYRca5fLXcJQfTrg546CoY6i/sfa3TaJmzhNsidl+UkRTYRca5stZyb1xWWFBPtkTqjy75aeapyDhX1kUkeruj79s6TUE9JgrsIuNclKn3RWudGr7PjA/A0l74wnMK6jGJLRVjZgmgE+hxzp0d13FFJB5BlS9lXURizpLgHHvLRPjITQrmZRBnjv3zwIvAoTEeU0RiEFb5UrZa7mzQ9lbFKKCXXSyB3cymAh8F/gG4No5jikh8wipfylrHftJ8BfEKi+uO/Vbgb4E/CtrBzBYACwCmT58e02lFJIqgCpeePX0cvfjR7BLPw9sC69hza9J1912VSn54amZnA28557bm2885t8o51+6ca588eXKppxWRAoRVuOR2jPLtSZ6tSfdONOrbDWuvHO6X7mftth5mrdjEjMWPMmvFphGLWOR7TYoXR1XMLGCumf0S+B5wupl9J4bjikhM/Cpfwoy6yw+qSU8NpF/zkW+FonpdvagalJyKcc5dD1wPYGZ/AXzJOXdxqccVkfjkVr5E6ek66i4/X016wGtBuf0vdjzDkE9n2fGy2HS5aeapyDjhrXyZtWKT7wITWb517K1Tg9vrBtSrB+X2/YJ62HskulgnKDnnfqQadpGxF5a79kvNZNe3mNLWwvLzThx91zxnCUOWHHWufpfgqWOu9h1HMbNXa331omqgmacidSZK7nrezCksP+9EprS1YMDXWu7h1ZaL+eWEi9iy/3zm9Xxt9IFPms9XElex2x2Cc+mWur9LHcKXBq7gmheO9R1Lobn9eli9qBooFSNSZ6J2a5yX2MK85mUwYcfIshg3BJ3fTH999tdHHOfbe0/jbk4bdU4LSJ/MmzmFzjd2c++TvwrN6/utWSrFUWAXqTORujU+ci10rmF0oaPH1rtHBfZiFn9+4qWdeYO6AbdccLICeoyUihGpM6HdGrs6woM6+C5dV0zDsHwPQw345PunK6jHTIFdpM4EBd9bj38ZbjkBHvgsoUEdwEbnxnNz84EPWj2CPmgSZtxywcn8/bwTw8ciBVFgF6kzfsH3kXc/yKm/uC64XNHPKZeO2lRMT5mgD5qvzf9T3amXiXLsInVoXmIL8+w6mLAb3gHeKODNlkgH9Zz8erFro5a1LbD4UmAXqTePXHugqiUyg/bLRgVzr1LWRi1bW2DxpcAuUk+GH4wWoHVaekGMkA6NZV0bVWKlwC5STzYuI9KD0cxedt43IrfcLabUUcaGHp6K1JOIi0enHDxgHy6oj3pZ10aVWOmOXaSe5GnUle271eMmcfPgfB5OzebjBRxaD0FrhwK7SD2Zs4TBh66mceid4U3OwV7XzJcHL2ddavbw9ilFpFD0ELQ2KLCL1JG1Q7PYPPAZruF7HGm/4013ODcPzh8R0EEplHqnwC5ShbITgdrf3sD1TfdxBLuw1qmh1Ssr12+np//PuJ8/C9xHzbbqnwK7SJXJTgQ6c+jH3Jy8i2YyteO9O9Lri0JgcA8rPTRgy+LTYxytVCNVxYhUmZXrt3Pm0I+5JXknzZbTiCs1AI9dF/jesNJDlSaOD7pjF6mwsH4r7W9vYHlyNQkLOEDf7sBjLzrruBHT/r28efVier5I7VBgFylQKUExSr+V65vu4yD6ixqbtySxZ08fCTOGnBuRVy+254vUDnN5FpUtl/b2dtfZ2Vnx84qUKjcoAiQbjEMmNLJn30BooPdbRHpuw2b+ruk+3sUuaJ2K691B0M06AC0T4brXi/5vCFrIekpbi/LvVc7Mtjrn2sP2U45dpAB+jbAGUo7f7xsIXF/UK/fh5tyGzaxIruZd7AQc9O7A8oT1ftfIU3+yuKT/BvV8qX8lB3Yzm2BmPzezZ8zseTO7MY6BiVSjKMEv2/HQz5FtLcxt2MzWpgW83nwRtyXv5CDLTbs4yAnuzsFudwhfGlgQuHB0VKErLEnNi+OOfT9wunPuT4GTgQ+b2ftjOK5I1Yka/II+AG49/mW+mryLwxv2YgYWeHPu6E5NIuWM7tQkPj9wJe/bv4p1qdkl31mr50v9K/nhqUsn6fdmvk1m/lQ+cS9SoigPRfNVnXgFfQCc+uo/Qm4Jo5/WaVyw//aydFNUz5f6F0tVjJklgK3AfwPucM79LI7jilRK1EqReTOn0PnGbu598leBdy95736jdF9MtsCcJSwaGv0hEtedtXq+1LdYHp4654accycDU4HTzOyE3H3MbIGZdZpZ586dO+M4rUhs8q0OlOuJl3YGBvXhxZ0TW9ILRy9tS//d1ZHeoXVq/oG0ToNzboeT5he1cLQIlKHc0cyWAPucc18N2kfljlJtZix+1DdYG/D6io+OSNME/R+T3ZeuDnh4IQx40ijJlnTAhnRbgNTAyDcnmuDcOwrqjy7jT8XKHc1sspm1Zb5uAc4EXir1uCKVlK9SJJum6ckT1EccY+OykUEd0t9vXJYO3PPuTNeiZ7VMVFCXWJV8x25mJwHfBhKkPyg6nHPL8r1Hd+xSbfwmHgG0tSQxg9/vGwh4Z7oWfWnyHg6zvfknFmGwdE8s45XxKeodexxVMV3AzFKPI1JOYRUv2a9vfPj5EUF8T19wQId0UP9q0100EaXSJSS/XsS4RfyoV4zUvUIqXlau35737hw8d+iZKt/8d+lpfa6J5465mlPLMG6RXGopIHWvkIqXsMk/Nzau4bbknUy0sAlG6dmi2QlG1w18puAZo4WMW8RLd+xSt7JpDL9JPuAfxI9sa/Hd/68mPMkit4bD2Js3mHv1uEnM7r99+HsrcMaoerpIsRTYpaYF5aCDHoZ6+VXC+M0sPb/pP/hKYvWIBaLD9LtGbh4cWeXiPZ933K2ZB7S53SGDPmTU00XCKLBLzcqXg/ZLY+T6w/5B1m7r8X2I6v2wWGY/oLEvelDfn2zjy/svZl3qwLqjuYtceMftfUDr/W/w+5BRTxeJQoFdalZQDvqLHc8wFKGMd0/fgP9D1MQW5jUvgwnd0DwVen8dPpiGZLo+/aT5NAOzt/Xw04BqlrAPnWwePdsbXVUxUigFdqlZQbnmIecwonWiywbReYkt6bVEc5ed690BYUdLHgzn3DpiglG+XixRcuTZfdTTRYqhwC41KygHDQc6mkcJ7u1vb4C13xg9zT/saC0T4SM3FTxjNN+4vfuIFEvljlKz/PqKezkYbqDV1pIM3O/6pvvyBHXP0VqnAZb++7xvpJenywT1tdt6mLViEzMWP8qsFZsCV1CKMm7l0aVUumOXmpVNUQTl1L1reM5ascl3FqkBR7Ar/GSt0+ALz/m+VOhEotwHtEFVMSLFUmCXmpYNgGHVI2/u6Rs1Y3S3O4QbBy/BWqdmcukBMv3Rg+SbSBQUoJU7l3JSKkZqXpS+5V9tuWfUjNHDG/by1aa74NgPpata/LRMHO6PHkQTiaTa6I5d6kLeO+CuDs5zP/SdMdrEELz87+lSRW9VTAEPRjWRSKqNArvUr0euha13gxvK36irtzsdwIvsh66JRFJtFNilLuS2FrjniO9zzBvfi/bmItrpemlxaKk2CuxStaL2IverSjmqryNaP92GZN4Ho1HpYahUEwV2qahSgnVQCeHK9dtZ7L7BJ5s3kSDFEA0kSIUPxmfGqEg9UGCXssrtYviH/kEGhtI152HBOlIJYVcH6/uu5uDEO8MPRxtJkbdVTOu09F26ArrUKZU7StnkLgK9p29gOKhnFbrgxYjtXR3w8EIOsXdGVbyYBbQTaL88PdFIQV3qmAK7lE2U1rkQvOCFnxHbNy6DgZBacUsc+Lv9cjj766HjEal1SsVI2USdoBN1wYtRJYS93XmPa5aAG3bn3UekHpV8x25m08zsCTN7wcyeN7PPxzEwqX1RJugE1XtHmU0aWqZ4yqWFDVikTsSRihkEvuicOx54P3CVmR0fw3Glxvl1MUw2GIcdlAwO1h7zElvY0ryQ1yd8ki3NC9M9073mLEn3cRnFCk67FNKdUaTalZyKcc79Gvh15uv/MrMXgSnAC6UeW2qb38SdD75nMk+8tJM9+/K0ye3qGL3oRe8OeHhh+uvsg8/s3xuXpdMyrVOLqnYptDujSLUzF2EJscgHMzsa+AlwgnPu7aD92tvbXWdnZ2znldrgt8B0SzIx8q49U+kS+FA0T/vcYs1ascm314u37a9INTCzrc659rD9YquKMbNDgB8A1/gFdTNbYGadZta5c+fOuE4rNSRfbfqwsEqXkAemxVB3Rqk3sQR2M0uSDur3Ouce8NvHObfKOdfunGufPHlyHKeVGhMpgIYF7hL7uviJVFopUkPiqIox4JvAi845FQmLr7Xbemjw65tLTgDNF7hDFrzwO2eUB6J+D3nVnVFqWRx37LOATwGnm9nTmT9/GcNxpU5kc+vZ5evmNmxmc9NCXmu+iC3NC7n1+JcP7BxU6RJhwQu/c2ZnvWYfiPoF90illSI1JNaHp1Hp4en4MmvFJk55ewN/29jBkbYLg5EtAJItI4N2V0fJlS56ICr1KOrDU808lbJrf3sDy5OrOcj6/XcY6EsHcm8ZY4m9XPRAVMYzBXYpr0eu5damb4a2Rne93cP7RG3tm4+Wq5PxTE3ApHweuRY6w4M6wG+ZBBSWG89HD0RlPFNgl/LZenek3ZyD5f2fACLWukegB6IynikVI+Xjwlv2Ogf/L/VeOg89E4g3N67l6mS8UmCXvELz3fkqWCzhG9yzhVhDNHDv0OmssM+yPJMiUW5cpHRKxUig0Hx3tq9L7w7AHWjU1dWRfj2gbe5rR1/I7JYHOXb/d1h1yFUjUiTKjYuUTnfsEih03VG/vi7e0sWzv86rO//AUW90kHAphqyBN46azzF/fRc5DXiH+XWELKYqRmQ8U2CXQKH57qC+Lpnta7f1cP1rH6NvYO7wSy2vJVi+rSdvoFZuXKQ0SsVIoNDmWEF9XTLb46pwEZHC6I5dRvA+LG07KMlXkt/irxo2kiDFEA18383h4LNuS+88Z8no3umeRl2a/SkyNhTY61CxMzdzF8K4pv8uLk48PtzXpZEUF9kGrOdrMPProSsYqcJFZGwosNeZ/732We598ldkW7sVsszbyvXbWey+wSebN5EgBeQ064L0LNKtdx9YTzRPX5dFZx3nu2KSKlxEyks59jqydlvPiKCeFTWvfcXeO7gk8TiNlsJsdFAfFmHiEWj2p8hY0R17HVm5fvuooJ7Vs6ePWSs25U3PXNS4KVJfFywRvk+GKlxEKk+BvY6EPZTM5ruD0jPZ9Esoz8SjODoxiki8lIqpI1EfSs5t2MwGu4q5D70XbjlheKaoM/9fh+F/BVgC2i8fzq/H1YlRROKlwF5H/Kbje81t2MwvmhdwW/JOpjbsoiHTBmDwoauhq4MH+RC5C2o5Bw/Yh2FpL9yw+8BDU1SnLlKtFNjrSPZhZcLnqeeNjWu4NXknE23vqIeijUPvsO+xJXyp7xLuGTqDQdeAczDoGrhn6Ay+1HeJ7/lUpy5SnZRjrzPZ/La3zHBuw2Y+lXichjxPRif0/YYj21q4Yc9l3DB42YjXpuSZgao6dZHqozv2OpS9c7/0kJ+zuWkhtzXdmTeoA7yZOrzgzorqxChSnXTHXm+6OuCx65jXt5t5EOmje59rYnXTxSwtsLOiOjGKVCdzuU/LijmI2RrgbOAt59wJYfu3t7e7zs7Oks8rObo64MHPRZ5A5BzsdoewnEuZ/bErFZBFqpyZbXXOtYftF9cd+93APwH3xHQ8KcTwKkY7Ir8lBXxn6AzuOuSqvHfZqlMXqT2xBHbn3E/M7Og4jlWPyhocuzoYfOhqGofeibS7A3pSk1jddDEnn7uALXnGkdsUrJC+MyIydiqWYzezBcACgOnTp1fqtGOu3MFx32NLOChiUE85uGbgStalZkM/tISMI3QFJRGpShWrinHOrXLOtTvn2idPnlyp0465ck/imdD3m0j7OeBfhs5IB/WI41CdukhtUrljmZU7OL6ZOjzv687BXjeBz/dfOao+PWwcoSsoiUhVUmAvs3IHx9VNF7PPNY3YlnLpgN6dmsTnB67khP1rRtypRx2H6tRFalMsgd3Mvgv8FDjOzLrN7PI4jlsPyh0cT/7oApa4BXSnJpFyRndqEtcMXMmM/f/K7P7bAwN6lHGon7pIbYqljr1Q462Ovdwlg9nj+03v92Og0kWRGhS1jl2BvY7kVuD4mdLWwpbFp1dwVCISl0pPUJIq4J3i37OnD4MRKyopPy4yPiiwj6XhGaPd0DoV5iwJXBg6Ku9SdJo1KjI+KbCPla4OeHghDGTy4r070t9DycE9S+uNioxPKnestK4OuGkGPPDZA0E9a6AvfQcvIlIC3bFXUlcHQw9eScINBO/T21258YhIXVJgL1JB+etMLt317iB4RdKM1qlxD1VExhkF9iIU1NjrkWuhcw3gCFnEiD7XxHPHXM2p8Q9ZRMYR5diLELmxV1fHcFAPM+gauG7gM1zzwrExjlRExiPdsRcgbIbnqIZaG5cRJajvdwkWDVzButRsTJ0TRaRECuwZYTnzKLM6cxtqud7uwPRLdsLvbncINw5eMtzTRZ0TRaRUCuxEy5n7pV+8/GZ1/pZJvIudo/bNLnjxcGq2ZoaKSOyUYydazjw3zTK3YTObmxbyWvNFbG5ayI0znh/14HR5/yd8W+p+Z+gMbv8/y7nlgpPVOVFEYjcu79hz0y5Rcube/W5sXMOnEo/TkMmzTLVdnPOrm6Br2ohZo52Hnsnit+FvGzs40n7Hm+5wbh6cz9ZDz+QSNDNURMpj3AV2v7RLbrOsLG++e9FZx7H5wTv5O7ubw9iL5STPW9iffljqCeyLzjqO6x/oZ13/gZ7oLckEH3/PZGat2KQeLiJSFuMusPulXRzk74TY1cG8f7+Ocxt2569Fz5k16s3PZ4P4B98zmR9s7Snb4tYiIuMusAet8elI57lH3UV7mnWFTTDymzWam26ZtWJTYD5fgV1E4jDuAntQTj1wAYqNy0Y36/Jl6ba7Icq9uLWIyLiriil4DdJITbkM2i+L1G633Itbi4iMu8AeuEBzYgvccgIsbUv/3dWRfkNYU66WiXDeKjj765HOX+7FrUVE6joVEzSbNPeh5tOPruJs9880Dr2TfqN30Ys5S0YuiJHVMhE+clPBi2L4PVBVVYyIxCmWwG5mHwZuAxLAaufcijiOW4p8s0mBEa99pv87NDa8M/IA2UUvvvBc+vsYl7BT/bqIlFPJgd3MEsAdwJlAN/CUma1zzr1Q6rFLkW826ZmDP2azrWZi8978B8nm10+aH9tydSIi5RZHjv004BXn3GvOuX7ge8C5MRy3JEFVJu1vb+D6gX/k8Ib0JKPsH19a9EJEalAcqZgpwA7P993A/8jdycwWAAsApk+fHsNp8/OWNc5t2JyZ1r+LlDXQSCr8AMmWSOWLIiLVpmJVMc65Vc65dudc++TJk8t+vkVnHcf5Tf/Bc82XcVvyTqY27KLByBvU0zNPDVqnwTm3K/0iIjUpjjv2HmCa5/upmW0VEVj50vM1zm34ZvhsUQ9rnXbgYamISI2KI7A/BRxrZjNIB/QLgYtiOG6ooMqXKTse4dRfrCkoqJNoUupFROpCyYHdOTdoZv8LWE+63HGNc+75kkcWQVDly7RfrCR8STpP26+QmvSw1ZVERKpJLHXszrl/A/4tjmMV4k2fh6Nvukkc4XaR73Y95WDrKTdx6twrQs8RZXUlEZFqUtMtBY5sa+HGxjXc6nk4OrVhFy4kqP/L0Blc88Kxkc4RZXUlEZFqUpstBR65FrbezWY3BInRdejpT6uRHdadg72umS8PXs661GwsYjdFdWMUkVpTe4H9kWuh85tAJtsSeHfuoHUaqd5u3kyll6RblzqwklHUbopBbX7VjVFEqlXtpWK23h1tv0zp4rpzn+dMd8eIoF5IN0V1YxSRWlN7d+xuKHwfz6IXpXZTVDdGEak1tRfYLZE3uKeA14+6gGM8pYuldlNUN0YRqSW1F9hPuXQ4x57lMs9Ie9wkbh6cz4bXPsDTS0jXAAAHyUlEQVTH1z7LEy/tpGdPHwkzhpxjiu62RWQcqL3Anl2paOvd4IYYpIF7h07nhsHLDuyTGuLeJ381XBMzlIn8qkEXkfHAnAuboRm/9vZ219nZGcuxZix+NHSOaa7AhatFRKqYmW11zrWH7Vd7VTE5iik7VA26iNSzmg/sfuWIYc2/VIMuIvWs5gP7vJlTWH7eiUxpa8FIp1k++f7po4J9lmrQRaTe1d7DUx9+5YjtR01k5frtqooRkXGnJgN7lDa6qj0XkfGq5gL72m09LLrvGQZSB0oYF933DKASRhERqMEc+9J1zw8H9ayBlGPpuoqs7SEiUvVqLrDv6RsoaLuIyHhTc4FdRETyq7nAfthByYK2i4iMNzUX2G84570kEyOnICUTxg3nvHeMRiQiUl1qripG/dFFRPIrKbCb2SeApcCfAKc55+Lp7BVCNeoiIsFKTcU8B5wH/CSGsYiISAxKumN3zr0IYBbWdktERCql5h6eiohIfqF37Gb2OPAun5e+7Jx7KOqJzGwBsABg+vTpkQcoIiKFCQ3szrkz4jiRc24VsArSKyjFcUwRERlNqRgRkTpT0pqnZvYx4B+BycAe4Gnn3FkR3rcTeCPiaSYBu4oeZHlpbMWr5vFpbMXR2IpTyNiOcs5NDttpTBazLoSZdUZZvHUsaGzFq+bxaWzF0diKU46xKRUjIlJnFNhFROpMLQT2VWM9gDw0tuJV8/g0tuJobMWJfWxVn2MXEZHC1MIdu4iIFKAqAruZfcLMnjezlJkFPh02sw+b2XYze8XMFnu2zzCzn2W2f9/MmmIc20Qz22BmL2f+Psxnnw+a2dOeP++Y2bzMa3eb2eue106u5Ngy+w15zr/Os32sr9vJZvbTzM++y8wu8LwW+3UL+v3xvN6cuQ6vZK7L0Z7Xrs9s325moSW9ZRjbtWb2QuY6bTSzozyv+f58Kzi2S81sp2cMn/G89unM78DLZvbpMRjbLZ5x/aeZ7fG8Vu7rtsbM3jKz5wJeNzO7PTP2LjN7n+e10q6bc27M/5Bu+3sc8COgPWCfBPAq8G6gCXgGOD7zWgdwYebrfwb+Jsax3Qwszny9GLgpZP+JwG7goMz3dwPnl+m6RRobsDdg+5heN+C/A8dmvj4S+DXQVo7rlu/3x7PPlcA/Z76+EPh+5uvjM/s3AzMyx0lUeGwf9PxO/U12bPl+vhUc26XAP/m8dyLwWubvwzJfH1bJseXsfzWwphLXLXP8PwfeBzwX8PpfAo8BBrwf+Flc160q7tidcy8657aH7HYa8Ipz7jXnXD/wPeBcMzPgdOD+zH7fBubFOLxzM8eMeuzzgcecc/tiHEOQQsc2rBqum3PuP51zL2e+fhN4i/Rkt3Lw/f3JM+b7gTmZ63Qu8D3n3H7n3OvAK5njVWxszrknPL9TTwJTYzx/SWPL4yxgg3Nut3Pu98AG4MNjOLa/Ar4b4/nzcs79hPRNXpBzgXtc2pNAm5n9MTFct6oI7BFNAXZ4vu/ObDsc2OOcG8zZHpcjnHO/znz9G+CIkP0vZPQvzz9k/ql1i5k1j8HYJphZp5k9mU0RUWXXzcxOI33X9apnc5zXLej3x3efzHXpJX2dory33GPzupz0nV6W38+30mP7eOZndb+ZTSvwveUeG5nU1Qxgk2dzOa9bFEHjL/m6VWxpPIupS2Q55Bub9xvnnDOzwDKizKfticB6z+brSQe2JtJlTdcByyo8tqOccz1m9m5gk5k9SzpolSTm6/YvwKedc6nM5pKuW70ys4uBduADns2jfr7OuVf9j1AWDwPfdc7tN7MrSP+r5/QKnj+KC4H7nXNDnm1jfd3KpmKB3ZXeJbIHmOb5fmpm2+9I/xOmMXOXld0ey9jM7Ldm9sfOuV9nAtBbeQ41H3jQOTfgOXb2rnW/mX0L+FKlx+ac68n8/ZqZ/QiYCfyAKrhuZnYo8CjpD/gnPccu6br5CPr98dun28wagVbSv19R3lvusWFmZ5D+0PyAc25/dnvAzzeuABU6Nufc7zzfrib9fCX73r/Iee+PYhpXpLF5XAhc5d1Q5usWRdD4S75utZSKeQo41tKVHE2kf1DrXPppwxOkc9sAnwbi/BfAuswxoxx7VA4vE9SyOe15pJcTrNjYzOywbBrDzCYBs4AXquG6ZX6OD5LOM96f81rc18339yfPmM8HNmWu0zrgQktXzcwAjgV+XuJ4Chqbmc0E7gLmOufe8mz3/flWeGx/7Pl2LvBi5uv1wIcyYzwM+BAj/zVb9rFlxvce0g8hf+rZVu7rFsU64JJMdcz7gd7MDU3p162cT4Wj/gE+RjqPtB/4LbA+s/1I4N88+/0l8J+kP1W/7Nn+btL/o70C3Ac0xzi2w4GNwMvA48DEzPZ2YLVnv6NJf9I25Lx/E/As6cD0HeCQSo4N+LPM+Z/J/H15tVw34GJgAHja8+fkcl03v98f0umduZmvJ2SuwyuZ6/Juz3u/nHnfduAjZfh/IGxsj2f+38hep3VhP98Kjm058HxmDE8A7/G897LM9XwF+OtKjy3z/VJgRc77KnHdvku60muAdHy7HPgc8LnM6wbckRn7s3gqAku9bpp5KiJSZ2opFSMiIhEosIuI1BkFdhGROqPALiJSZxTYRUTqjAK7iEidUWAXEakzCuwiInXm/wM8wsJN1oadhAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:,0],y)\n", "plt.scatter(x[:,0],x@a);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate it!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "rc('animation', html='jshtml')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Once \n", " Loop \n", " Reflect \n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = nn.Parameter(tensor(-1.,1))\n", "\n", "fig = plt.figure()\n", "plt.scatter(x[:,0], y, c='orange')\n", "line, = plt.plot(x[:,0], x@a)\n", "plt.close()\n", "\n", "def animate(i):\n", " update()\n", " line.set_ydata(x@a)\n", " return line,\n", "\n", "animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In practice, we don't calculate on the whole file at once, but we use *mini-batches*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Learning rate\n", "- Epoch\n", "- Minibatch\n", "- SGD\n", "- Model / Architecture\n", "- Parameters\n", "- Loss function\n", "\n", "For classification problems, we use *cross entropy loss*, also known as *negative log likelihood loss*. This penalizes incorrect confident predictions, and correct unconfident predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }