{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient descent" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# From cost_function.ipynb\n", "def costFunction(X, y, theta):\n", " m = len(y)\n", " hypothesis = X @ theta\n", " err = hypothesis - y\n", " return ((1 / (2 * m)) * (np.transpose(err) @ err).item((0, 0)))\n", "\n", "def gradientDescent(X, y, theta, alpha, num_iters):\n", " m = len(y)\n", " J_history = np.zeros((num_iters, 1))\n", " \n", " for iter in range(0, num_iters):\n", " hypothesis = X @ theta\n", " err = hypothesis - y\n", " theta = theta - alpha * (1 / m) * np.transpose(np.transpose(err) @ X)\n", " J_history[iter, 0] = costFunction(X, y, theta)\n", " \n", " return (theta, J_history)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's run it against some data:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Actual theta_0: 100 Gradient descent theta_0: 99.99811396616462\n", "Actual theta_1: 40 Gradient descent theta_1: 40.00108261853554\n" ] } ], "source": [ "actual_theta = np.array([\n", " [100],\n", " [40],\n", "])\n", "\n", "X = np.array([\n", " [1, 0.8],\n", " [1, 2.3],\n", " [1, 1.6],\n", "])\n", "\n", "y = X @ actual_theta\n", "\n", "theta = np.array([\n", " [0],\n", " [0],\n", "])\n", "\n", "alpha = 0.5 # learning rate of 0.5\n", "num_iters = 200 # 200 iterations\n", "\n", "(theta, history) = gradientDescent(X, y, theta, alpha, num_iters)\n", "\n", "print(\"Actual theta_0:\", actual_theta.item(0, 0), \" Gradient descent theta_0:\", theta.item(0, 0))\n", "print(\"Actual theta_1:\", actual_theta.item(1, 0), \" Gradient descent theta_1:\", theta.item(1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just to ensure that the error really is decreasing with each run, let's plot the cost history across runs." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD4CAYAAADsKpHdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXL0lEQVR4nO3dfbBc9X3f8feXqyeuQDyIiyIkgQTIsXlIa6OhFAjplNjIOEa0rht5GqNpmdGUwS207rQwZHD+UcbuQ2LzB8pQ4yJSCqY2DkonpFaUuAkZDL48hQch6xoI3CAjGT9IBiGQ+PaP87vy6mpXD3fv3bPivF8zZ87Z355z9quzq/3c3+/s7onMRJKkY+ouQJLUHwwESRJgIEiSCgNBkgQYCJKkYlrdBUzUKaeckosXL667DEk6qjz++OM/ysyhdvcdtYGwePFihoeH6y5Dko4qEfG3ne5zyEiSBBgIkqTCQJAkAQaCJKkwECRJgIEgSSoMBEkS0MRAePhh+O3fhj176q5EkvrKIQMhIr4WEdsi4tmWtpMjYkNEbCnzk1ruuzkiRiJic0Rc0dJ+QUQ8U+67LSKitM+MiK+X9kcjYvHk/hPH+e53Yc0a2LVrSh9Gko42h9NDuAtYPq7tJmBjZi4FNpbbRMQ5wErg3LLN7RExULZZC6wGlpZpbJ/XAj/JzLOB3we+NNF/zGGZNauav/32lD6MJB1tDhkImfmXwI/HNa8A1pXldcDVLe33ZebuzHwJGAEujIj5wJzMfCSrS7TdPW6bsX19A7h8rPcwJQwESWproucQ5mXmVoAyP7W0LwBebVlvtLQtKMvj2/fbJjP3AD8D5rZ70IhYHRHDETG8ffv2iVU+Fgi7d09se0l6n5rsk8rt/rLPg7QfbJsDGzPvyMxlmblsaKjtj/Udmj0ESWprooHwehkGosy3lfZRYFHLeguB10r7wjbt+20TEdOAEzhwiGryGAiS1NZEA2E9sKosrwIebGlfWT45tITq5PFjZVhpZ0RcVM4PXDNum7F9/TPgz8t5hqlhIEhSW4e8HkJE3Av8I+CUiBgFvgB8Ebg/Iq4FXgE+DZCZz0XE/cDzwB7g+szcW3Z1HdUnlo4FHioTwJ3AH0bECFXPYOWk/Ms6MRAkqa1DBkJmfqbDXZd3WH8NsKZN+zBwXpv2tymB0hMGgiS11bxvKhsIktSWgSBJAgwESVJhIEiSAANBklQYCJIkoImBMG0aDAwYCJI0TvMCAapegoEgSfsxECRJQFMDYeZMA0GSxmlmIMya5fUQJGmc5gaCPQRJ2o+BIEkCDARJUmEgSJIAA0GSVBgIkiTAQJAkFQaCJAkwECRJhYEgSQIMBElS0dxA2LOnmiRJQJMDAfyBO0lq0exAcNhIkvYxECRJQNMDwSEjSdqn2YFgD0GS9jEQJEmAgVBvHZLUR7oKhIj4dxHxXEQ8GxH3RsSsiDg5IjZExJYyP6ll/ZsjYiQiNkfEFS3tF0TEM+W+2yIiuqnrkAwESTrAhAMhIhYA/xZYlpnnAQPASuAmYGNmLgU2lttExDnl/nOB5cDtETFQdrcWWA0sLdPyidZ1WGbOrOYGgiTt0+2Q0TTg2IiYBgwCrwErgHXl/nXA1WV5BXBfZu7OzJeAEeDCiJgPzMnMRzIzgbtbtpka9hAk6QATDoTM/DvgvwKvAFuBn2Xmt4F5mbm1rLMVOLVssgB4tWUXo6VtQVke336AiFgdEcMRMbx9+/aJlm4gSFIb3QwZnUT1V/8S4DRgdkT81sE2adOWB2k/sDHzjsxclpnLhoaGjrTkXzAQJOkA3QwZ/TrwUmZuz8x3gQeAi4HXyzAQZb6trD8KLGrZfiHVENNoWR7fPnUMBEk6QDeB8ApwUUQMlk8FXQ5sAtYDq8o6q4AHy/J6YGVEzIyIJVQnjx8rw0o7I+Kisp9rWraZGgaCJB1g2kQ3zMxHI+IbwBPAHuBJ4A7gOOD+iLiWKjQ+XdZ/LiLuB54v61+fmXvL7q4D7gKOBR4q09QxECTpABMOBIDM/ALwhXHNu6l6C+3WXwOsadM+DJzXTS1HZPp0iDAQJKlFM7+pHOFV0yRpnGYGAlSBsGtX3VVIUt9obiAMDhoIktSiuYEweza8+WbdVUhS32huIAwOGgiS1KK5gTB7Nrz1Vt1VSFLfaHYg2EOQpH2aGwgOGUnSfpobCA4ZSdJ+mhsI9hAkaT/NDQR7CJK0n2YHwptvQra99IIkNU5zA2FwEPbuhXfeqbsSSeoLzQ2E2bOrucNGkgQ0ORAGB6u5J5YlCWhyIIz1EAwESQIMBIeMJKlobiA4ZCRJ+2luINhDkKT9NDcQ7CFI0n6aGwieVJak/RgIDhlJEtDkQHDISJL209xAcMhIkvbT3ECYPh2mTXPISJKK5gYCeBlNSWrR7EAYHLSHIElFswPBHoIk7WMgGAiSBDQ9EBwykqR9ugqEiDgxIr4RES9ExKaI+IcRcXJEbIiILWV+Usv6N0fESERsjogrWtoviIhnyn23RUR0U9dhs4cgSft020P4CvCnmflB4O8Bm4CbgI2ZuRTYWG4TEecAK4FzgeXA7RExUPazFlgNLC3T8i7rOjyDgwaCJBUTDoSImANcBtwJkJnvZOZPgRXAurLaOuDqsrwCuC8zd2fmS8AIcGFEzAfmZOYjmZnA3S3bTK3Zsx0ykqSimx7CmcB24H9ExJMR8dWImA3My8ytAGV+all/AfBqy/ajpW1BWR7ffoCIWB0RwxExvH379i5KLxwykqR9ugmEacBHgLWZ+WHgTcrwUAftzgvkQdoPbMy8IzOXZeayoaGhI633QJ5UlqR9ugmEUWA0Mx8tt79BFRCvl2Egynxby/qLWrZfCLxW2he2aZ969hAkaZ8JB0Jm/hB4NSJ+uTRdDjwPrAdWlbZVwINleT2wMiJmRsQSqpPHj5VhpZ0RcVH5dNE1LdtMrcFB2LMH3nmnJw8nSf1sWpfb/xvgnoiYAbwI/EuqkLk/Iq4FXgE+DZCZz0XE/VShsQe4PjP3lv1cB9wFHAs8VKap13pNhBkzevKQktSvugqEzHwKWNbmrss7rL8GWNOmfRg4r5taJmQsEH7+czjxxJ4/vCT1k2Z/U3nOnGq+c2e9dUhSHzAQAHbsqLcOSeoDBgLAz35Wbx2S1AeaHQgnnFDN7SFIUsMDwSEjSdrHQAADQZJoeiAcf3w1NxAkqeGBMG1a9W1lA0GSGh4IUA0bGQiSZCBwwgl+7FSSMBDsIUhSYSAYCJIEGAgGgiQVBoKBIEmAgVCdVDYQJMlA2NdDyLaXcZakxjAQ5syB997z2sqSGs9A8PeMJAkwEAwESSoMBANBkgADwYvkSFJhINhDkCTAQPC6ypJUGAj2ECQJMBC8apokFQbC9Olw7LEGgqTGMxDA3zOSJAyEygknwE9+UncVklQrAwFg7lx44426q5CkWhkIYCBIEpMQCBExEBFPRsT/KbdPjogNEbGlzE9qWffmiBiJiM0RcUVL+wUR8Uy577aIiG7rOiIGgiRNSg/hBmBTy+2bgI2ZuRTYWG4TEecAK4FzgeXA7RExULZZC6wGlpZp+STUdfgMBEnqLhAiYiHwCeCrLc0rgHVleR1wdUv7fZm5OzNfAkaACyNiPjAnMx/JzATubtmmN+bOhV27qkmSGqrbHsKXgf8IvNfSNi8ztwKU+amlfQHwast6o6VtQVke336AiFgdEcMRMbx9+/YuS28xd241t5cgqcEmHAgR8RvAtsx8/HA3adOWB2k/sDHzjsxclpnLhoaGDvNhD4OBIElM62LbS4CrIuJKYBYwJyL+J/B6RMzPzK1lOGhbWX8UWNSy/ULgtdK+sE177xgIkjTxHkJm3pyZCzNzMdXJ4j/PzN8C1gOrymqrgAfL8npgZUTMjIglVCePHyvDSjsj4qLy6aJrWrbpjZNPruY//nFPH1aS+kk3PYROvgjcHxHXAq8AnwbIzOci4n7geWAPcH1m7i3bXAfcBRwLPFSm3rGHIEmTEwiZ+R3gO2X5DeDyDuutAda0aR8GzpuMWibEQJAkv6kMwKxZMDhoIEhqNANhjF9Ok9RwBsIYA0FSwxkIYwwESQ1nIIwxECQ1nIEwxkCQ1HAGwpi5c6urpr333qHXlaT3IQNhzNy5VRj89Kd1VyJJtTAQxvjlNEkNZyCMmTevmv/wh/XWIUk1MRDGzJ9fzbdurbcOSaqJgTDGQJDUcAbCmLlzYfp0A0FSYxkIYyLgl37JQJDUWAZCq9NOg9d6e7E2SeoXBkKr+fPtIUhqLAOhlYEgqcEMhFbz51fXVd69u+5KJKnnDIRWYx899ctpkhrIQGh12mnV3BPLkhrIQGjll9MkNZiB0MpAkNRgBkKroSE45hgDQVIjGQitBgaqXz01ECQ1kIEw3oIFMDpadxWS1HMGwnhLlsBLL9VdhST1nIEw3plnwssvw969dVciST1lIIy3ZAm8+67fRZDUOAbCeGeeWc1ffLHeOiSpxwyE8ZYsqeYGgqSGmXAgRMSiiPiLiNgUEc9FxA2l/eSI2BARW8r8pJZtbo6IkYjYHBFXtLRfEBHPlPtui4jo7p/VhdNPr76L4IllSQ3TTQ9hD/D5zPwQcBFwfUScA9wEbMzMpcDGcpty30rgXGA5cHtEDJR9rQVWA0vLtLyLurozYwYsWmQPQVLjTDgQMnNrZj5RlncCm4AFwApgXVltHXB1WV4B3JeZuzPzJWAEuDAi5gNzMvORzEzg7pZt6uFHTyU10KScQ4iIxcCHgUeBeZm5FarQAE4tqy0AXm3ZbLS0LSjL49vbPc7qiBiOiOHt27dPRuntnXmmPQRJjdN1IETEccA3gRszc8fBVm3TlgdpP7Ax847MXJaZy4aGho682MO1ZEl1TYS33pq6x5CkPtNVIETEdKowuCczHyjNr5dhIMp8W2kfBRa1bL4QeK20L2zTXp+xj546bCSpQbr5lFEAdwKbMvP3Wu5aD6wqy6uAB1vaV0bEzIhYQnXy+LEyrLQzIi4q+7ymZZt6fOAD1fyFF2otQ5J6qZsewiXAZ4F/HBFPlelK4IvARyNiC/DRcpvMfA64H3ge+FPg+swc+32I64CvUp1o/gHwUBd1de+ccyACnnmm1jIkqZemTXTDzHyY9uP/AJd32GYNsKZN+zBw3kRrmXSDg3D22fDss3VXIkk94zeVOzn/fHsIkhrFQOjkvPNgZAR27aq7EknqCQOhk/PPh/feg+efr7sSSeoJA6GT88+v5g4bSWoIA6GTs8+GWbM8sSypMQyETgYGqo+fPv103ZVIUk8YCAdz4YXw6KNeTlNSIxgIB3PppbBzp8NGkhrBQDiYSy6p5g8/XG8dktQDBsLBnHEGLFgAf/3XdVciSVPOQDiYiKqXYA9BUgMYCIdy6aXw6qvwyit1VyJJU8pAOJTLLqvmf/Zn9dYhSVPMQDiUX/kVOP10eLDeSzRI0lQzEA4lAq66CjZs8JKakt7XDITDcfXV1a+ebthQdyWSNGUMhMNx2WVw4onwR39UdyWSNGUMhMMxfXo1bPTAA/Dzn9ddjSRNCQPhcK1eDTt2wL331l2JJE0JA+FwXXxxdY2EtWshs+5qJGnSGQiHKwKuuw6efBIeeaTuaiRp0hkIR+Kzn4WhIbjlFnsJkt53DIQjcdxxcOut8J3vwEMP1V2NJE0qA+FIrV4NZ50Fn/989d0ESXqfMBCO1IwZcPvt8MILVShI0vuEgTARH/tYFQZr18I999RdjSRNCgNhon73d+HXfg1WrfKH7yS9LxgIEzVjBvzxH8MFF8CnPgVf+YqfPJJ0VDMQunH88dUP3n3yk3DjjfCJT8DISN1VSdKEGAjdmjMHvvlN+PKXq0ttfuhDcM011ZfX7DFIOor0TSBExPKI2BwRIxFxU931HJFjjoEbbqg+efS5z1U/gnfxxbB4cXWO4c47YdMmePfduiuVpI4i++Cv2IgYAL4PfBQYBb4HfCYzn++0zbJly3J4eLhHFR6hHTvgW9+qTjb/1V/Bj35UtQ8MVCFx1lkwb171redTTqnmJ54Ig4PVdOyxv1ieNQumTTtwGhiopoha/6mSji4R8XhmLmt337ReF9PBhcBIZr4IEBH3ASuAjoHQ1+bMqXoGq1ZVw0abN8Njj8H3vw9btsCLL1Zt27d3fxW2gYFfhMQxx1QBMRYSY8sTnTrtoxuTEWD9sI/3Sw06Ot16K/zmb076bvslEBYAr7bcHgX+wfiVImI1sBrg9NNP701l3YqAD36wmtp5660qGHbsqJZ37armY9Pbb8PevbBnT/tp7L5334X33vvFeYvM7qZO++jGZPRG+2Ef75cadPQ66aQp2W2/BEK7P3UOeMVn5h3AHVANGU11UT0xOAhnnFF3FZLUNyeVR4FFLbcXAq/VVIskNVK/BML3gKURsSQiZgArgfU11yRJjdIXQ0aZuSciPgf8X2AA+FpmPldzWZLUKH0RCACZ+SfAn9RdhyQ1Vb8MGUmSamYgSJIAA0GSVBgIkiSgT37LaCIiYjvwtxPc/BTgR5NYzmTq19qs68hY15Hr19reb3WdkZlD7e44agOhGxEx3OnHnerWr7VZ15GxriPXr7U1qS6HjCRJgIEgSSqaGgh31F3AQfRrbdZ1ZKzryPVrbY2pq5HnECRJB2pqD0GSNI6BIEkCGhgIEbE8IjZHxEhE3FRjHYsi4i8iYlNEPBcRN5T234mIv4uIp8p0ZQ21vRwRz5THHy5tJ0fEhojYUuZTc8mmzjX9cssxeSoidkTEjXUdr4j4WkRsi4hnW9o6HqOIuLm85jZHxBU9ruu/RMQLEfE3EfGtiDixtC+OiF0tx+4PelxXx+euV8frILV9vaWulyPiqdLek2N2kPeHqX2NZWZjJqqf1v4BcCYwA3gaOKemWuYDHynLxwPfB84Bfgf4DzUfp5eBU8a1/WfgprJ8E/Clmp/HHwJn1HW8gMuAjwDPHuoYlef1aWAmsKS8Bgd6WNfHgGll+UstdS1uXa+G49X2uevl8epU27j7/xtway+P2UHeH6b0Nda0HsKFwEhmvpiZ7wD3ASvqKCQzt2bmE2V5J7CJ6trS/WoFsK4srwOurrGWy4EfZOZEv6netcz8S+DH45o7HaMVwH2ZuTszXwJGqF6LPakrM7+dmXvKze9SXZGwpzocr056drwOVVtEBPDPgXun6vE71NTp/WFKX2NNC4QFwKstt0fpgzfhiFgMfBh4tDR9rnTvv9broZkigW9HxOMRsbq0zcvMrVC9WIFTa6hrzEr2/w9a9/Ea0+kY9dPr7l8BD7XcXhIRT0bE/4uIX62hnnbPXT8dr18FXs/MLS1tPT1m494fpvQ11rRAiDZttX7uNiKOA74J3JiZO4C1wFnA3we2UnVXe+2SzPwI8HHg+oi4rIYa2orqEqtXAf+7NPXD8TqUvnjdRcQtwB7gntK0FTg9Mz8M/Hvgf0XEnB6W1Om564vjVXyG/f/46Okxa/P+0HHVNm1HfMyaFgijwKKW2wuB12qqhYiYTvVk35OZDwBk5uuZuTcz3wP+O1PYVe4kM18r823At0oNr0fE/FL3fGBbr+sqPg48kZmvlxprP14tOh2j2l93EbEK+A3gX2QZdC7DC2+U5cepxp0/0KuaDvLc1X68ACJiGvBPga+PtfXymLV7f2CKX2NNC4TvAUsjYkn5S3MlsL6OQsrY5J3Apsz8vZb2+S2r/RPg2fHbTnFdsyPi+LFlqhOSz1Idp1VltVXAg72sq8V+f7HVfbzG6XSM1gMrI2JmRCwBlgKP9aqoiFgO/Cfgqsx8q6V9KCIGyvKZpa4Xe1hXp+eu1uPV4teBFzJzdKyhV8es0/sDU/0am+qz5f02AVdSnbH/AXBLjXVcStWl+xvgqTJdCfwh8ExpXw/M73FdZ1J9WuFp4LmxYwTMBTYCW8r85BqO2SDwBnBCS1stx4sqlLYC71L9dXbtwY4RcEt5zW0GPt7jukaoxpfHXmd/UNb9VHmOnwaeAD7Z47o6Pne9Ol6daivtdwH/ety6PTlmB3l/mNLXmD9dIUkCmjdkJEnqwECQJAEGgiSpMBAkSYCBIEkqDARJEmAgSJKK/w/jBMS5gJDFsgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "fig = plt.figure()\n", "x = np.linspace(0, num_iters - 1, num_iters)\n", "plt.plot(x, history, 'r')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }