{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Here we review the optimizers used in machine learning. \n", "# Gradient Descent" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "Let's use a simple dataset of salaries from developers and machine learning engineers in five Chinese cities in 2019" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "# developer salary in Beijing, Shanghai, Hangzhou, Shenzhen and Guangzhou in 2019\n", "x = [13854,12213,11009,10655,9503] \n", "x = np.reshape(x,newshape=(5,1)) / 10000.0\n", "\n", "# Machine Learning Engineer in the five cities.\n", "y = [21332, 20162, 19138, 18621, 18016] \n", "y = np.reshape(y,newshape=(5,1)) / 10000.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Functions\n", "Objective Function:\n", "$$y=ax+b+ε$$\n", "Cost Function:\n", "$$J(a,b)=\\frac{1}{2n}\\sum_{i=0}^{n}(y_i−\\hat{y}_i )^2$$\n", "Optimization Function or optimizer:\n", "$$\\theta = \\theta - \\alpha \\frac{\\partial J}{\\partial \\theta}$$\n", "Here in the univariate linear regression:\n", "$$a = a - \\alpha \\frac{\\partial J}{\\partial a}$$\n", "$$b = b - \\alpha \\frac{\\partial J}{\\partial b}$$\n", "\n", "Here $\\frac{\\partial J}{\\partial a}$ and $\\frac{\\partial J}{\\partial b}$ are:\n", "\n", "$$ \\frac{\\partial J}{\\partial a} = \\frac{1}{n}\\sum_{i=0}^{n}x(\\hat{y}_i-y_i)$$\n", "\n", "\n", "$$ \\frac{\\partial J}{\\partial b} = \\frac{1}{n}\\sum_{i=0}^{n}(\\hat{y}_i-y_i)$$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def model(a, b, x):\n", " return a*x + b\n", "\n", "def cost_function(a, b, x, y):\n", " n = 5\n", " return 0.5/n * (np.square(y-a*x-b)).sum()\n", "\n", "def sgd(a,b,x,y):\n", " n = 5\n", " alpha = 1e-1\n", " y_hat = model(a,b,x)\n", " da = (1.0/n) * ((y_hat-y)*x).sum()\n", " db = (1.0/n) * ((y_hat-y).sum())\n", " a = a - alpha*da\n", " b = b - alpha*db\n", " return a, b\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def iterate_sgd(a,b,x,y,times):\n", " for i in range(times):\n", " a,b = sgd(a,b,x,y)\n", "\n", " y_hat=model(a,b,x)\n", " cost = cost_function(a, b, x, y)\n", " print(a,b,cost)\n", " plt.scatter(x,y)\n", " plt.plot(x,y_hat)\n", " plt.show()\n", " return a,b, cost" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.950768563083351 0.8552812669346652 0.00035532090622957674\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXhU5f3+8fdDFhJCIEDYEghhDUtI2BFc6o64ooKtVq2iYv3Wn7VCRNS6FJcqWpdqpbRFal1aEhZxpS4origIkwRI2LcECFsSCFlnnt8foCINEGBmziz367q4JDOHOff1kNx+OHPmHGOtRUREgl8jpwOIiIh3qNBFREKECl1EJESo0EVEQoQKXUQkREQ6tePExESbmprq1O5FRILSkiVLdlprW9f3nGOFnpqayuLFi53avYhIUDLGbDzSczrkIiISIlToIiIhQoUuIhIiVOgiIiFChS4iEiJU6CIiIUKFLiISIlToIiJeVuv28OePVrO8uMyv+3Xsg0UiIqFo1fa9jJ/pIq+ojBq3hz5Jzf22bxW6iIgXuD2Wv322jj/9dxVNYyL5yy8HcGHf9n7NoEIXETlJ63bsY0K2i+82lTKiT1sevbwviU0b+z2HCl1E5AR5PJYZX27gyfkFREc04tmf9+OyfkkYYxzJo0IXETkBm3fvZ0K2i0Xrd3NWWmv+eGUGbZvFOJpJhS4ichystby2aBOPvbuSRsbw5JUZjBnUwbGp/FAqdBGRBiourWTirFw+W72T07ol8sToDJITYp2O9QMVuojIMVhryVmyhT+8tQK3tUwelc61Q1MCYio/lApdROQoSsqrmDQ7j48KShiS2pIpYzLo1CrO6Vj1UqGLiNTDWss8VzEPzltOZY2b+y/qxdhTO9OoUWBN5YdSoYuIHGbXvmrun5vPe/nb6NcxgaevyqRr66ZOxzomFbqIyCHez9/KfXPy2VtVx90XpDHu9C5ERgTHZa9U6CIiQOn+Gh6at5y5y4rpk9SM12/pR1q7eKdjHRcVuoiEvQUFJUyclcvuihruPLc7vzmrG1FBMpUfSoUuImGrvKqWR95ewczFW0hrG8/0GwaTnuy/qyN6mwpdRMLS56t3cneOi23lVdx2ZlfuPLc7jSMjnI51UlToIhJWKqrrePy9lbz69Sa6tI4j57bhDEhp4dV9zF1axJT5hRSXVpKUEEvWiDRG9U/26j7qo0IXkbCxaN0usnJy2bxnPzed1pmsEWnERHl3Kp+7tIhJs/OorHUDUFRayaTZeQA+L3UVuoiEvKpaN1PmFzL9i/V0bNGE/4wbxpDOLX2yrynzC38o8+9VHty/Cl1E5CQs3bSH8dku1u2o4LpTOnHPyJ7ENfZd9RWXVh7X496kQheRkFRd5+bZD1fz10/X0q5ZDK/eNJTTuif6fL9JCbEU1VPeSX64KmPwnWgpInIM+UVlXPrnL3jpk7WMHtiB9393hl/KHCBrRBqxhx2Xj42KIGtEms/3rQldREJGrdvDiwvW8MLHa2gZF830GwZxds+2fs3w/XFyneUiInKCCrftZXz2MvKLyhnVL4mHLu1DQpNoR7KM6p/slwI/nApdRIJandvDXxeu47kPVxMfE8nUawdyQXo7p2M5QoUuIkFr7Y59jJ/pYtnmUkamt+ORUem0atrY6ViOUaGLSNDxeCzTv1jPlPmFxEZH8PzV/bkko33A3RLO31ToIhJUNu6qICs7l2827Oacnm14/Iq+tGkW43SsgKBCF5Gg4PFYXlu0kcffKyDCGKaMzmD0wA5hP5UfSoUuIgGvqLSSu3NcfLFmF6d3T+SJKzP88kGdYKNCF5GAZa0le/EW/vD2CjzW8ujl6VwzJEVT+RGo0EUkIG0vr+KeWbksKNzB0M4teWpMJh1bNnE6VkA75kf/jTEdjTELjDErjTHLjTG/rWebnsaYr4wx1caYCb6JKiLhwFrL3KVFnP/MQr5at4sHL+nNG7ecojJvgIZM6HXAeGvtd8aYeGCJMeYDa+2KQ7bZDdwBjPJFSBEJDzv3VXPfnDzmL9/OgJQEnhqTSZfWTZ2OFTSOWejW2q3A1oO/32uMWQkkAysO2aYEKDHGXOSroCIS2t7N28r9c/PZV1XHpJE9ufn0LkQ00rHy43Fcx9CNMalAf2DRiezMGDMOGAeQkpJyIi8hIiFmT0UND85bzjxXMX2Tm/P0VZn0aBvvdKyg1OBCN8Y0BWYBd1pry09kZ9baacA0gEGDBtkTeQ0RCR0frtjOpDl57Kmo4a7zenDbmV2JitBVvU9UgwrdGBPFgTJ/zVo727eRRCTUlVfV8oe3VpCzZAs928Uz48bB9Elq7nSsoHfMQjcHTvj8B7DSWvsn30cSkVC2cNUOJs7KpWRvNbef1Y07zulOdKSmcm9oyIR+KnAdkGeMWXbwsXuBFABr7VRjTDtgMdAM8Bhj7gR6n+ihGREJPfuq63js3ZW8vmgT3do0Zfa1A8nsmOB0rJDSkLNcPgeO+laztXYb0MFboUQktHy1dhdZOS6KSisZd0YX7jqvBzGH3aZNTp4+KSoiPlNZ4+bJ+QW8/MUGOrVqwsxbhzE4taXTsUKWCl1EfGLJxj1MyHaxfmcFvxrWiYkje9IkWpXjS1pdEfGqqlo3z3y4ir8tXEf75rG8fvNQhndLdDpWWFChi4jX5G0p466Zy1hdso+rh3Tk3gt7ER8T5XSssKFCF5GTVlPn4YWPV/PiJ2tp3bQxM24czJlpbZyOFXZU6CJyUlZuLWf8TBcrtpZzxYBkHry4D82baCp3ggpdRE5IndvD1E/X8txHq2keG8W06wZyfp92TscKayp0ETlua0r2Mn6mC9eWMi7KaM/ky9JpGRftdKywp0IXkQZzeyzTP1/PlP8WEhcdwQvX9OfijCSnY8lBKnQRaZANOyuYkO1i8cY9nNe7LY9enk6b+BinY8khVOgiclQej+VfX2/kj+8VEBlh+NNVmVzeP1k3ag5AKnQROaLNu/czcVYuX67dxc96tOaJKzNo11xTeaBSoYvI/7DW8p9vNzP57QN3mnz8ir78YnBHTeUBToUuIj+xrayKe2bn8knhDoZ1acWTozPo2LKJ07GkAVToIgIcmMrnLC3ioXnLqXVbHr60D9ed0olGulFz0FChiwg79lZz75w8PlixnUGdWvDUmExSE+OcjiXHSYUuEubezi3m93Pzqahxc9+FvRh7WmciNJUHJRW6SJjaXVHD79/M553crWR2aM7TV2XSrU2807HkJKjQRcLQf5dv4945+ZRV1pA1Io1bz+hCZIRu1BzsVOgiYaSsspaH31rO7O+K6NW+Ga+MHULvpGZOxxIvUaGLhIlPCku4Z1YeO/ZVc8fZ3bj97O5ER2oqDyUqdJEQt6+6jkffWcEb32yme5umTLt+IBkdEpyOJT6gQhcJYV+u3UlWdi5byyq59Wdd+N25PYiJinA6lviICl0kBO2vqeOJ9wr451cb6ZwYR/avhzOwUwunY4mPqdBFQsziDbuZkO1iw6793DA8lYkX9CQ2WlN5OFChi4SIqlo3f/pgFX/7bB0dWsTy73GncEqXVk7HEj9SoYuEANfmUsZnu1hTso9rhqZw74W9aNpYP97hRn/jIkGsps7D8x+t5qVP19ImvjGvjB3CGT1aOx1LHKJCFwlSy4vLGD/TRcG2vYwe2IHfX9yb5rFRTscSB6nQRYJMrdvDS5+s5fmPVtMiLpq/Xz+Ic3u3dTqWBAAVukgQWbV9L+NnusgrKuPSzCQevrQPLeKinY4lAUKFLhLA5i4tYsr8QopKK2kWE0lFjZvmsVH85ZcDuLBve6fjSYBRoYsEqLlLi5g0O4/KWjcA5VV1NDLwu3O7q8ylXroyj0iAevL9gh/K/HseC1M/XedQIgl0KnSRALR5936Ky6rqfa64tNLPaSRY6JCLSACx1vL6N5t47J2VGMDWs01SQqy/Y0mQ0IQuEiCKSyu5fvo33Dcnn/4pLXjg4t7EHnZlxNioCLJGpDmUUAKdJnQRh1lryVmyhT+8vQK3xzJ5VDrXDk3BGEOLuGimzC+kuLSSpIRYskakMap/stORJUAds9CNMR2BV4B2gAeYZq197rBtDPAccCGwH7jBWvud9+OKhJaSvVXcOzuPD1eWMCS1JVPGZNCpVdwPz4/qn6wClwZryIReB4y31n5njIkHlhhjPrDWrjhkm5FA94O/hgIvHfyviNTDWstbuVt54M18Kmvc3H9RL8ae2plGjYzT0SSIHbPQrbVbga0Hf7/XGLMSSAYOLfTLgFestRb42hiTYIxpf/DPisghdu2r5vdv5vNu3jb6dUzg6asy6dq6qdOxJAQc1zF0Y0wq0B9YdNhTycDmQ77ecvAxFbrIId7P38Z9c/LYW1XH3RekMe70LkRG6NwE8Y4GF7oxpikwC7jTWlt++NP1/JH/OePKGDMOGAeQkpJyHDFFglvZ/loenJfP3GXF9Elqxuu39COtXbzTsSTENKjQjTFRHCjz16y1s+vZZAvQ8ZCvOwDFh29krZ0GTAMYNGhQfafYioScBQUlTJyVy+6KGn57TnduP7sbUZrKxQcacpaLAf4BrLTW/ukIm80DbjfG/JsDb4aW6fi5hLu9VbU88vZK/rN4M2lt45l+w2DSk5s7HUtCWEMm9FOB64A8Y8yyg4/dC6QAWGunAu9y4JTFNRw4bfFG70cVCR5frNnJ3Tm5bC2r5LYzu3Lnud1pHKkbNYtvNeQsl8+p/xj5odtY4DfeCiUSrCqq6/jjewX86+uNdGkdR85twxmQ0sLpWBIm9ElRES/5Zv1uJmS72LxnPzed1pmsEWnERGkqF/9RoYucpKpaN1PmFzL9i/V0bNGEf99yCkO7tHI6loQhFbrISVi6aQ/js12s21HBdad04p6RPYlrrB8rcYa+80ROQHWdm+c+XM3UT9fSrlkMr940lNO6JzodS8KcCl3kOOUXlTEh20XBtr1cNagD91/cm2YxUU7HElGhizRUrdvDiwvW8MLHa2gZF830GwZxds+2TscS+YEKXaQBCrftZXz2MvKLyhnVL4mHLu1DQpNop2OJ/IQKXeQo6twepn22jmc/WE18TCRTrx3IBentnI4lUi8VusgRrN2xj/EzXSzbXMrI9HY8MiqdVk0bOx1L5IhU6CKH8XgsL3+5gSffLyA2OoLnr+7PJRntOXBZI5HApUIXOcSmXfuZkOPim/W7OadnGx6/oi9tmsU4HUukQVToIhy4Jdyrizbx+LsriTCGKaMzGD2wg6ZyCSoqdAl7RaWVTMzJ5fM1Ozm9eyJPXJlBUkKs07FEjpsKXcKWtZbsxVuY/PYK3Nby6OXpXDMkRVO5BC0VuoSl7eVVTJqdx8cFJQzt3JIpozNJadXE6VgiJ0WFLmHFWss8VzEPvLmc6jo3D1zcmxuGp9KokaZyCX4qdAkbO/dVc/+cfN5fvo0BKQk8NSaTLq2bOh1LxGtU6BIW3svbyn1z89lXVcc9I3tyy+ldiNBULiFGhS4hrXR/DQ+8uZx5rmL6Jjfn6asy6dE23ulYIj6hQpeQ9dHK7dwzO489FTXcdV4PbjuzK1ERjZyOJeIzKnQJOeVVtfzhrRXkLNlCz3bxvHzDYNKTmzsdS8TnVOgSUhau2sHEWbmU7K3m9rO6ccc53YmO1FQu4UGFLiGhorqOx95dyWuLNtG1dRyzbhtOv44JTscS8SsVugS9r9ftIivHxZY9ldxyemfGn59GTFSE07FE/E6FLkGrssbNk/MLePmLDXRq1YSZtw5jcGpLp2OJOEaFLkFpycY9TMh2sX5nBb8a1omJI3vSJFrfzhLe9BMgQaWq1s0zH67ibwvX0b55LK/fPJTh3RKdjiUSEFToEjTytpRx18xlrC7Zxy8Gd+S+i3oRHxPldCyRgKFCl4BXU+fhhQVreHHBGhKbRvPyjYM5K62N07FEAo4KXQJawbZy7vqPixVby7mifzIPXtKH5k00lYvUR4UuAanO7eGvC9fx7IeraB4bxbTrBnJ+n3ZOxxIJaCp0CThrSvYyPjsX1+ZSLspoz+TL0mkZF+10LJGAp0KXgOH2WKZ/vp4p/y0kLjqCF67pz8UZSU7HEgkaKnTxmblLi5gyv5Di0kqSEmLJGpHGqP7J9W67YWcFWTkuvt2wh3N7teWxK9JpEx/j58QiwU2FLj4xd2kRk2bnUVnrBqCotJJJs/MAflLqHo/l1UUbefzdAiIjDE+PyeSKAcm6UbPICVChi09MmV/4Q5l/r7LWzZT5hT8U+pY9+7k7J5cv1+7ijB6teeLKvrRvHutEXJGQoEIXnygurTzi49Za/vPtZh55ZyXWWh6/oi+/GNxRU7nISVKhi08kJcRSVE+pt20Ww40zvuWTwh0M69KKJ0dn0LFlEwcSioQeXflffCJrRBqxh13CNirCUFZZy9frdvHwpX147eahKnMRLzrmhG6MmQ5cDJRYa9Preb4FMB3oClQBY621+d4OKsHl++PkU+YXUlRaSUxkI6rqPGR0aMZTYzLpnBjncEKR0NOQCX0GcMFRnr8XWGatzQCuB57zQi4JAaP6J3Pvhb1o0SQKD3Dfhb2YeeswlbmIjxxzQrfWLjTGpB5lk97A4we3LTDGpBpj2lprt3snogSjPRU1/P7NfN7O3Upmh+Y8fVUm3drEOx1LJKR5401RF3AF8LkxZgjQCegA/E+hG2PGAeMAUlJSvLBrCUQfrNjOpNl5lFXWMOH8Hvz6Z12JjNDbNSK+5o1C/yPwnDFmGZAHLAXq6tvQWjsNmAYwaNAg64V9SwApq6zl4beWM/u7Inq1b8YrY4fQO6mZ07FEwsZJF7q1thy4EcAcOJF4/cFfEkY+XbWDiTm57NhXzR1nd+P2s7sTHampXMSfTrrQjTEJwH5rbQ1wM7DwYMlLGNhXXcej76zgjW82061NU6ZdP5CMDglOxxIJSw05bfEN4Ewg0RizBXgQiAKw1k4FegGvGGPcwArgJp+llYDy5dqdZGXnUlxWya1ndOF35/Ug5rBzz0XEfxpylsvVx3j+K6C71xJJwNtfU8eT7xcy48sNpLZqQs6vhzGwU0unY4mEPX30X47L4g27mZDtYsOu/dwwPJWJF/QkNlpTuUggUKFLg1TVuvnTB6v422frSE6I5Y1bTmFY11ZOxxKRQ6jQ5Zhcm0sZn+1iTck+rhmawr0X9qJpY33riAQa/VTKEdXUeXj+o9W89OlaWjdtzD/HDuFnPVo7HUtEjkCFLvVaUVzOXTOXUbBtL1cO6MADl/SmeWyU07FE5ChU6PITtW4PL32yluc/Wk1Ck2j+dv0gzuvd1ulYItIAKnT5wartexk/00VeURmXZCbxh0v70CIu2ulYItJAKnTB7bH8/bN1PP3fVTSNieTFawZwUUZ7p2OJyHFSoYe59TsrmJDtYsnGPYzo05ZHRvWldXxjp2OJyAlQoYcpj8fyz6828MT7BURHNOLZn/fjsn5JulGzSBBToYehzbv3k5Xj4ut1u2kc2YjyqjqmzC8Efrx1nIgEHxV6GLHW8sY3m3n0nRXUeSxREYbqOg8ARaWVTJqdB6jURYKVLlgdJraWVfKrl7/l3jl59EtJIKFJFLXun95jpLLW/cOkLiLBR4Ue4qy1ZC/ezPnPLOTb9buZfFkf/jV2KCXl1fVuX1xa6eeEIuItOuQSwkr2VnHv7Dw+XFnC4NQWTBmdSWpiHABJCbEU1VPeSQmx/o4pIl6iCT0EWWuZ5yrm/GcW8tnqndx/US/+PW7YD2UOkDUijdjDbkYRGxVB1og0f8cVES/RhB5idu2r5vdv5vNu3jYyOybw9JhMurVp+j/bff/G55T5hRSXVpKUEEvWiDS9ISoSxFToIeT9/G3cPzePsspaskakcesZXYiMOPI/wkb1T1aBi4QQFXoIKNtfy0NvLWfO0iL6JDXj1ZuH0rNdM6djiYifqdCD3ILCEu6ZlcuufTX89pzu3H52N6KOMpWLSOhSoQepvVW1PPL2Sv6zeDM92jbl79cPpm+H5k7HEhEHqdCD0BdrdnJ3Ti5byyr59c+68rvzutM4UjdqFgl3KvQgUlFdxx/fK+BfX2+kS2IcObcNZ0BKC6djiUiAUKEHiW/W72ZCtovNe/Yz9tTOB84jj9ZULiI/UqEHuKpaN0/NL+QfX6ynQ4tY/n3LKQzt0srpWCISgFToAWzppj2Mz3axbkcF156SwqSRvYhrrL8yEamf2iEAVde5ee7D1Uz9dC3tmsXwr5uGcHr31k7HEpEAp0IPMPlFZUzIdlGwbS9XDerA/Rf3pllMlNOxRCQIqNADRK3bw4sL1vDCx2toERfN9BsGcXbPtk7HEpEgokIPAIXb9jI+exn5ReVc1i+Jhy/tQ0KTaKdjiUiQUaE7yO2xTFu4jmc+WEV8TCRTrx3ABentnY4lIkFKhe6QtTv2MSHbxdJNpYxMb8fkUekkNm3sdCwRCWIqdD/zeCwvf7mBJ98vICYqgud+0Y9LM5MwxjgdTUSCnArdjzbt2s+EHBffrN/N2T3b8PgVfWnbLMbpWCISIlTofmCt5bVFm3js3ZVEGMOTozMYM7CDpnIR8SoVuo8Vl1YycVYun63eyendE3niygzdiFlEfEKF7iPWWrKXbGHyWytwW8sjo9L55dAUTeUi4jMqdB8oKa9i0uw8PiooYUjnljw1OpOUVk2cjiUiIS4kCn3u0qKAuHu9tZZ5rmIeeHM5VbVuHri4NzcMT6VRI03lIuJ7xyx0Y8x04GKgxFqbXs/zzYFXgZSDr/eUtfZlbwc9krlLi5g0O4/KWjcARaWVTJqdB+DXUt+5r5r75+Tz/vJt9E9J4KkxmXRt3dRv+xcRacjdhGcAFxzl+d8AK6y1mcCZwNPGGL99bn3K/MIfyvx7lbVupswv9FcE3svbyohnFvJxQQkTL+hJzq+Hq8xFxO+OOaFbaxcaY1KPtgkQbw6829cU2A3UeSVdAxSXVh7X495Uur+GB+ct581lxaQnN+P1Mf1Iaxfv8/2KiNTHG8fQXwDmAcVAPPBza62nvg2NMeOAcQApKSle2DUkJcRSVE95+/rUwI8LtnPPrDx2V9Twu3N78H9ndSUqoiH/4BER8Q1vNNAIYBmQBPQDXjDGNKtvQ2vtNGvtIGvtoNatvXPDhqwRacRG/fTemrFREWSNSPPK6x+uvKqWrGwXY2cspmVcNHN/cyq/Pbe7ylxEHOeNCf1G4I/WWgusMcasB3oC33jhtY/p+zc+/XGWy2erdzAxJ5dt5VX85qyu3HFOdxpH6kbNIhIYvFHom4BzgM+MMW2BNGCdF163wUb1T/bpGS0V1XU8/t5KXv16E11axzHrtuH0T2nhs/2JiJyIhpy2+AYHzl5JNMZsAR4EogCstVOBycAMY0weYICJ1tqdPkvsZ4vW7SIrJ5fNe/Zzy+mdGX9+GjFRmspFJPA05CyXq4/xfDFwvtcSBYjKmgOnPr785XpSWjZh5q3DGJza0ulYIiJHFBKfFPW27zbtYcJMF+t2VnD9sE7cM7InTaK1VCIS2NRSh6iuc/PMB6uZtnAt7ZvH8trNQzm1W6LTsUREGkSFflDeljLGZy9j1fZ9/GJwR+67qBfxMVFOxxIRabCwL/SaOg8vLFjDiwvWkNg0mpdvHMxZaW2cjiUictzCutALtpUzfqaL5cXlXN4/mYcu6UPzJprKRSQ4hWWh17k9/HXhOp79cBXNY6P463UDGdGnndOxREROStgV+pqSfYzPduHaXMpFfdszeVQ6LeP8dnFIERGfCZtCd3ssL3+xninzC4mNjuDPV/fnkswkp2OJiHhNWBT6xl0VTMh28e2GPZzbqy2PXZFOm/gYp2OJiHhVSBe6x2N5bdFGHnu3gMgIw1NjMrlyQLJu1CwiISlkC33Lnv1MnJXLF2t2cXr3RJ4cnUH75r69RrqIiJNCrtCttcxcvJnJb6/EWstjl/fl6iEdNZWLSMgLqULfXl7FPbNyWVC4g1O6tGTK6Ew6tmzidCwREb8ImUL/bPUOfvPad9S4PTx0SW+uH5ZKo0aaykUkfIRMoae2iqNfSgsevrQPnRPjnI4jIuJ3IVPoHVs24ZWxQ5yOISLiGN3ZWEQkRKjQRURChApdRCREqNBFREKECl1EJESo0EVEQoQKXUQkRKjQRURChLHWOrNjY3YAGx3Zue8lAjudDhEgtBY/0lr8SGvxU8ezHp2sta3re8KxQg9lxpjF1tpBTucIBFqLH2ktfqS1+ClvrYcOuYiIhAgVuohIiFCh+8Y0pwMEEK3Fj7QWP9Ja/JRX1kPH0EVEQoQmdBGREKFCFxEJESr0E2SMmW6MKTHG5B/heWOMed4Ys8YYk2uMGeDvjP7SgLXoaYz5yhhTbYyZ4O98/tSAtfjlwe+HXGPMl8aYTH9n9KcGrMdlB9dimTFmsTHmNH9n9JdjrcUh2w02xriNMaOPdx8q9BM3A7jgKM+PBLof/DUOeMkPmZwyg6OvxW7gDuApv6Rx1gyOvhbrgZ9ZazOAyYT+m4MzOPp6fARkWmv7AWOBv/sjlENmcPS1wBgTATwBzD+RHajQT5C1diEHiupILgNesQd8DSQYY9r7J51/HWstrLUl1tpvgVr/pXJGA9biS2vtnoNffg108EswhzRgPfbZH8/MiANC9iyNBnQGwP8DZgElJ7IPFbrvJAObD/l6y8HHRL53E/Ce0yGcZoy53BhTALzDgSk9LBljkoHLgakn+hoqdN8x9TwWstOHHB9jzFkcKPSJTmdxmrV2jrW2JzCKA4ehwtWzwERrrftEXyDSi2Hkp7YAHQ/5ugNQ7FAWCSDGmAwOHCseaa3d5XSeQGGtXWiM6WqMSbTWhuOFuwYB/zbGwIGLdV1ojKmz1s5t6AtoQvedecD1B892OQUos9ZudTqUOMsYkwLMBq6z1q5yOo/TjDHdzMEGO3gmWDQQlv+Ts9Z2ttamWmtTgRzg/46nzEET+gkzxrwBnAkkGmO2AA8CUQDW2qnAu8CFwBpgP3CjM0l971hrYYxpBywGmgEeY8ydQG9rbblDkX2mAd8XDwCtgL8c7LG6UL7qYAPW40oODD61QCXw80PeJA0pDViLk99HiK6diEjY0SEXEZEQoUIXEXwiXsEAAAAkSURBVAkRKnQRkRChQhcRCREqdBGREKFCFxEJESp0EZEQ8f8BY85ggZiYX+QAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a=0\n", "b=0\n", "_, _, sgd_cost = iterate_sgd(a,b,x,y,100)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.00035532090622957674" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sgd_cost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After 100 iterations, the regression is almost done. We record the cost such that in the following exploration of other optimizers, we will be able to compare iterations to reach the same cost." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def iterate(a, b, x, y, target_cost, func):\n", " i=0\n", " for i in range(1000):\n", " a,b = func(a,b,x,y)\n", " cost = cost_function(a, b, x, y)\n", " if cost