{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "
\n", "
Gradient Descent, Demystified
\n", "
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "`whoami` " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "`stu` \n", "Machine Learning Engineer @Opendoor \n", "@mstewart141 \n", "[comment]: <> (we are *so* hiring) " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Goal: Puzzle through the gradient descent algorithm towards a working prototype for linear and logistic regression." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Per Wikipedia:\n", "> Gradient descent is a first-order iterative optimization algorithm for finding the __minimum__ of a function. To find a local minimum of a function using gradient descent, one takes steps proportional to the negative of the gradient (or of the approximate gradient) of the function at the current point" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Ok, but what is a gradient?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Per Khan Academy:\n", "> The gradient stores all the partial derivative information of a multivariable function.\n", " \n", "The gradient is a vector-valued function: a vector of partial derivates." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "![title](gd.png)\n", "source: [Wikipedia](https://en.wikipedia.org/wiki/Gradient_descent)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "import sympy\n", "\n", "from numpy.linalg import inv\n", "from scipy.special import expit\n", "from sklearn.linear_model import LogisticRegression, LinearRegression\n", "from sklearn.datasets import make_classification\n", "from sympy import diff\n", "from sympy.solvers import solve\n", "from sympy.plotting import plot\n", "from toolz import compose, pipe\n", "\n", "from IPython.core.interactiveshell import InteractiveShell\n", "InteractiveShell.ast_node_interactivity = \"all\"\n", "\n", "np.random.seed = RS = 47" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Let's look at an example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEFCAYAAAAYKqc0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3XlcFXXf//HXsIPsyiYgi6ACoqgo\n7kummbmkmWabu2Z3V9vVYnVdXVdd3Ub1a18vy4zKNCvTMtNcMvcUcVcQFJRNRBZBdjjz+wPzJkNF\nPefMWT7Px4OHcpaZt8PxzffMmfmOoqoqQgghLIuN1gGEEELon5S7EEJYICl3IYSwQFLuQghhgaTc\nhRDCAkm5CyGEBZJyF0IICyTlLoQQFshO6wDCsimK4gv0A9oCVcAhIFlVVZ2mwYSwcIqcoSoMQVGU\nIcA8wBvYC5wBnIAOQHvgW+B1VVXLNAsphAWTchcGoSjKa8C7qqqeauY+O2AUYKuq6ndGDyeEFZBy\nF0IICyQfqAqDUhTlC0VRPJp8H6ooygYtMwlhDaTchaFtBX5XFGWkoiizgF+AtzTOJITFM+RuGdnf\nIwDYunUrQ4YMoU2bNuzduxd/f3+tIwlhyhR9LERG7sKgvvjiC6ZPn87nn3/O1KlTGTlyJPv379c6\nlhAWT0buwqBuv/12FixYgK+vLwC7du1izpw57N27V+NkQpgsvYzcpdyF0dXW1uLg4KB1DCFMleyW\nEabrpZdeori4uNn7HBwc2LhxI6tWrTJyKiGsh0w/IAwiNjaW0aNH4+TkRPfu3fHx8aG6upr09HT2\n7dvHzTffzLPPPqt1TCEsluyWEQZx33338cUXX/Dqq6/i6+tLfn4+zs7OREVFMXDgQJydnbWOKISp\n0stuGRm5C4PYs2cPeXl5LF68mF9//fVP91VVVV2x3KdPn86qVavw9fXl0KFDABQXFzNp0iSysrII\nDQ1l2bJleHl5oaoqjzzyCKtXr8bFxYXPPvuM7t27G/TfJoQ5kH3uwiAeeOABhg4dSmpqKvHx8Re/\nevToQXx8/BWfO3XqVNasWfOn2xITExk6dCjp6ekMHTqUxMREAH7++WfS09NJT09nwYIFzJ0712D/\nJiHMicF2y9Q36NQGVcXRztYgyxfmYe7cuXz44YfX/LysrCxGjRp1ceTesWNHNm3aREBAAPn5+Qwe\nPJi0tDTmzJnD4MGDmTx58l8eJ4S5qaytx8XBzrSPlhn46q8s+f0vEwIKK3M9xd6cgoKCi4Xt7+9P\nQUEBALm5uQQHB198XFBQELm5uc0uY8GCBRffQcTExOgllxD6NPvzPXpblsHK3c/DiaQdJ9Hp5HNV\noV+KoqAo1z64mT17NsnJySQnJ8sHusLkpBeUszXjrN6WZ7Byn9o3lMyzFfyWXmioVQgr4ufnR35+\nPgD5+fkXz3gNDAwkOzv74uNycnIIDAzUJKMQN2LR9iwc7fRXyQYr91s7B+Dr5shn27IMtQphRcaM\nGUNSUhIASUlJjB079uLtn3/+OaqqsnPnTjw8PGR/uzA75yrrWJ6Sw+1x+huYGKzcHexsuLd3CL8d\nK+R44XlDrUZYoMmTJ9OnTx/S0tIICgpi4cKFzJs3j3Xr1hEZGcn69euZN28eACNHjiQ8PJyIiAhm\nzZrFBx98oHF6Ia7d0t2nqK7TMbVfqN6WadCTmArLa+iXuJHJvYJ5YWxnQ61HiGsWHx9PcnKy1jGE\noL5Bx6DXNhHs7czS2X3AHOaW8XFzZFTXAL7dk0NZdZ0hVyWEEGZp/dECckurmNYvTK/LNfhJTNP6\nhlFR28C3yTmGXpUQQpidRduyCPJy5uYoP70u1+DlHhvkQY8QL5J2ZMlhkUII0cThvHP8nlnMlD6h\n2NroZW/MRUaZfmBq31BOFlWy6dgZY6xOCCHMQtL2LJztbZkYH3z1B18jo5T7iM7++Lk7skgOixRC\nCACKztewYl8ed/QIxMPFXu/LN0q529vaMLN/ODX1OtILyo2xSiGEMGlLd2dTW69jat9QgyzfaLNC\n3tEjiP3ZpXwqo3chhJWra9DxxY6TDIhsQ4Svm0HWYbRy927lwPjuQSxPyaG4otZYqxVCCJOz5tBp\nTpdVM02PJy1dyqjzuc/oH0pNvY7FO08ac7VCCGFSNh4tINLXlcEdfA22DqOWe4SvG4M7+pC04yQ1\n9Q3GXLUQQpiEPSdL+H5fHvf0DsFGz4c/NmX0KzHN6B/G2fM1/Lg/39irFkIIzX2y5QQezvZMjA8y\n6HqMXu79I9rQ0c+NT7acwIDz2gghhMk5VVTJ2sOnuTuhHS4Ohr2EtdHLXVEUZvQPI/V0OTuOFxl7\n9UIIoZlPt2Via6MY7PDHpjS5QPaYuLa0cXXgk62ZWqxeCCGM7lxlHcuSsxndtS1+7k4GX58m5e5k\nb8u9vUPYmHpG5noXQliFJbtPUVnbwMz+4UZZnyblDnBv7xAc7Gz4VEbvQggLV1uv47NtWfSLaE10\nW3ejrFOzcm/j6si4uEC+3ZNNUUWNVjGEEMLgfjqYx+myamYOMM6oHTQsd4AZA0IJbePK4p2ntIwh\nhBAGo6oqn2zJJMLXlUGRPkZbr6bl3sHPnQAPJz7fkUV1nZzUJISwPDtOFHE4r4yZ/cMMetLSpTQt\nd4DZA8M5e76W71LkSk1CCMvzyZZMWrdy4PZugUZdr+bl3ie8NV2CPPhkSyYNcqUmIYQFyThzno2p\nZ7ivTwhO9rZGXbfm5a4oCrMHhpN5toJ1Rwq0jiOEEHrzfUoOkb6u3Nc7xOjr1rzcAUbE+BPs7cx/\nNx+XKQmEEBbhTHk1H2/JpFeYN61dHY2+fpModztbG2YNCGfvqVKST5ZoHUcIIW7Yom1Z1Ot0zDLi\n4Y9NmUS5A9zZIxgvF3v++9sJraMIIcQNKa+u48udJ7m1cwChbVppksFkyt3ZwZb7+oSy/mgBGWdk\nSgIhhPn66vdTlFfX88Cg9pplMJlyB5jSJwRHOxs+2SKjdyGEeaqpb+DTbZn0i2hNbJCHZjlMqtxb\nuzpyZ3wQy1NyOVNWrXUcIYS4Ziv35lFQVqPpqB1MrNwBZvYPp3uIJ4t/lykJhBDmRadT+WjzcWLa\nutM/oo2mWUyu3EPbtKK1qyOfbs2krLpO6zhCCNFi644WcKKwgjmD2qMoxptqoDkmV+4Acwe1p7ym\nni92nNQ6ijAxb775JjExMXTu3JnJkydTXV1NZmYmCQkJREREMGnSJGpra7WOKayQqqp89Ntxgr2d\nGdnZX+s4plnunQM9GNTBh0XbMmVCMXFRbm4u77zzDsnJyRw6dIiGhgaWLl3K008/zWOPPUZGRgZe\nXl4sXLhQ66jCCu3OKmHvqVJmDwjHzlb7atU+wWU8OLg9Z8/Xsiw5W+sowoTU19dTVVVFfX09lZWV\nBAQEsHHjRiZMmADAlClTWLFihcYphTX6cudJBnXwYUKPYK2jACZc7r3CvOkR4sV/fztBXYNO6zjC\nBAQGBvLEE0/Qrl07AgIC8PDwoEePHnh6emJn13gl+aCgIHJzczVOKqzNodxz/LA/j15h3jg7GHeC\nsMsx2XJXFIUHB7cnt7SKH/fnaR1HmICSkhJWrlxJZmYmeXl5VFRUsGbNmhY/f8GCBcTHxxMfH09h\nYaEBkwpr88GmDNwc7bivj/EnCLscky13gJs6+dLJ340PNx1HJ9MBW73169cTFhaGj48P9vb2jB8/\nnm3btlFaWkp9fT0AOTk5BAY2P2/27NmzSU5OJjk5GR8f410RR1i2jDPl/HzoNPf3DcHdyV7rOBeZ\ndLkrisLcwe1JP3Oe9UdlOmBr165dO3bu3EllZSWqqrJhwwaio6MZMmQI3377LQBJSUmMHTtW46TC\nmnyw6ThOdrZM7xemdZQ/MelyB7gtNoB23i58uydbpgO2cgkJCUyYMIHu3bsTGxuLTqdj9uzZvPLK\nK7zxxhtERERQVFTEjBkztI4qrER2cSUr9+UxuVc7Tab1vRLFgIWptwV/uyebJ745wOfTezGwg7yd\nFjcuPj6e5ORkrWMIM/fc9wf5JjmHzU8Nwd/DSV+L1cvZTyY/cgcY3bUtAR5OvLMhXUbvQgiTUFBW\nzTfJOdzRI0ifxa43ZlHujna2PDCoPcknS9hxokjrOEIIwSdbTtCgqszVeIKwyzGLcgeY1DMYXzdH\n3t2QoXUUIYSVK6moZfHvpxjTtS3tWrtoHadZZlPuTva2zBnUnh0nitidVax1HCGEFUvakUVlbQMP\nDjbNUTuYUbkD3N2rHW1cHXhnQ7rWUYQQVupcVR1J27OY2jeESD83reNcllmVu7ODLbMGhLMl/Sx7\nT8mFtIUQxvfp1kxKKuu4M9405pC5HLMqd4B7e4fg5WLPuxtl37sQwrjOVdXx6bZMbonxI6atdpfQ\nawmzK/dWjnbMHBDOxtQzHMw5p3UcIYQVWbQtk/Lqeh4eGql1lKsyu3IHuL9PCO5Odry7Ufa9CyGM\n41xVHQu3ZjI82vRH7WCm5e7mZM/0/mH8cqSAo/llWscRQliBz7Zlmc2oHcy03AGm9Q3DzdGO92Tf\nuxDCwMqq61i49QTDov3oHGj6o3Yw43L3cLFnSt9QMs6cJ/W0jN6FEIbz2bYsyqrrecRMRu1gxuUO\nML1/KLmlVXLcuxDCYMqq6/hki3mN2sHMy927lSPT+4Wy+uBpDufJkTNCCP1LMsNRO5h5uQPMGBCO\nu5Mdb66T0bsQQr/Kq+v4ZGsmN0eZ16gdLKDcPZztmTUgnPVHC9ifXap1HCGEBfnq91Ocq6rj0ZvN\na9QOFlDuAFP7heLpYs8b645pHUUIYSHOVdbx3sZ07u3dzuxG7WAh5e7mZM+cge357Vghe07KjJFC\niBu3YMtxymsauLtXiNZRrotFlDvAlL4htHF1kNG7EOKGFZbXsGhbFqO7tiW6rbvWca6LxZS7i4Md\nDwxqz7aMInbK1ZqEEDfgg00Z1NTreMwM97X/wWLKHRpnjPR1c+SNX47JtVaFENclr7SKxTtPcUf3\nQMJ9XLWOc90sqtyd7G35nyER7MoqZkt6odZxhBBm6N2N6aioZjOHzOVYVLkD3NUrmGHRvny46YSM\n3oUQ1yTrbAXLknO4JyGEIC/TvDZqS1lcuTva2TIs2p8dJ4pYe/i01nGEEGbkzfXHsLdVeHCI6V4b\ntaUsrtwBxncLJMLXldfWplHfoNM6jhDCDKSdLueH/XlM7RuGr5uT1nFumEWWu52tDU8M78jxwgqW\np+RqHUcIYQZe/yUNVwc7HhgUrnUUvbDIcge4JcaPrsGevLX+GNV1DVrHEUKYsP3ZpfxypIBZA8Px\ndHHQOo5eWGy5K4rC0yM6kneumi93ntQ6jhDChC3ZdYpAT2em9w/TOoreWGy5A/Rt34YBkW14/9cM\nyqvrtI4jhDBBm48VsnR3NjMHhOHqaKd1HL2x6HIHePKWjpRU1vHxlkytowghTEyDTmX+6qMEeztz\nd0I7rePolcWXe5cgT26LDeCTLSc4e75G6zhCCBPy/d5cUk+X8+QtnXC0s9U6jl5ZfLkDPD68A+Ft\nWrFoW5bWUcQNKi0tZcKECXTq1ImoqCh27NhBcXExw4YNIzIykmHDhlFSUqJ1TGEGqusaeP2XNLoE\neTAqNkDrOHpnFeXe3seVLsGe/Pe342SerdA6jrgBjzzyCCNGjCA1NZX9+/cTFRVFYmIiQ4cOJT09\nnaFDh5KYmKh1TGEGFm3LIv9cNc/cGoWNjaJ1HL2zinIHePTmSBzsbHh1TarWUcR1OnfuHJs3b2bG\njBkAODg44OnpycqVK5kyZQoAU6ZMYcWKFVrGFGaguKKWD37NYGgnX/q0b611HIOwmnL3dXNizsD2\n/HzotFzQw0xlZmbi4+PDtGnT6NatGzNnzqSiooKCggICAhrfVvv7+1NQUNDs8xcsWEB8fDzx8fEU\nFsrEctbsvY0ZVNTW8/StnbSOYjBWU+4AswaG4evmyP/+dFQmFTND9fX1pKSkMHfuXPbu3UurVq3+\nsgtGURQUpfm32LNnzyY5OZnk5GR8fHyMEVmYoJNFFXyxM4uJ8cF08HPTOo7BWFW5uzjY8fiwDqSc\nKmXNIZlUzNwEBQURFBREQkICABMmTCAlJQU/Pz/y8/MByM/Px9fXV8uYwsS9tjYNWxuFx4Z10DqK\nQVlVuQPcGR9MBz9XXlmTSm29TCpmTvz9/QkODiYtLQ2ADRs2EB0dzZgxY0hKSgIgKSmJsWPHahlT\nmLB92aWsOpDPrAHh+Lmb/+RgV2I5p2O1kK2NwjO3RjHts9189ftJpvaznNONrcG7777LPffcQ21t\nLeHh4SxatAidTsfEiRNZuHAhISEhLFu2TOuYwgSpqsrLq4/SupUDcwaZ/5S+V2N15Q4wuKMPfdu3\n5qPfjjOuWyAeFjJRkDWIi4sjOTn5L7dv2LBBgzTCnGxMPcPvmcX8Z2yMRU0zcDlWt1sGGj90e25k\nFKDw/qbjWscRQhhYXX0Dr/9yjLA2rbirl2VNM3A5VlnuADGBHgyIbMOibZlyYpMQFm5pcg5nz9fw\nj9uisLe1jtqzjn/lZTx5S0ccbG2Yv/qo1lGEEAZSWlnL67+k0d7HlZs6Wc+RVFZd7r7uTjw4JIJ1\nRwrYlnFW6zhCCAN4a306ZVV1PD86+rLnQFgiqy53gBn9wwjycubFH4/I9VaFsDDHCsr5YudJ7k5o\nR1SAu9ZxjMrqy93J3pZnR0aRVlDO0t3ZWscRQuiJqqr8Z9URXB3t+PuwjlrHMTqrL3eAWzv70yvM\nmzfWHeNclVyxSQhLsO5IAVvSz/LYzZF4tbK+w52l3Gk8NPL5UdGUVNby7oZ0reMIIW5QTX0DL/10\nlEhfV+7pHaJ1HE1IuV/QOdCDiT2CWb43hxOF57WOI4S4AZ9uzeJUcSXPj462mkMfL2Wd/+rLeOKW\nDrg42PHCj0dk1kghzNSZsmre25jOsGg/BkRa7+yfUu5N+Lg5Ma1fGL8dK+SXI83PCS6EMG2vrEmj\nrkG9cBa69ZJyv8T9fULo4OfKiz8eoaq2Qes4QohrsC+7lO9ScpjeP4zQNq20jqMpKfdL2Nva8OLY\nzuSWVvHhpgyt4wghWkinU0nankWAhxMP3RShdRzNSbk3o3d4a8Z0bctHm09wskjmnRHCHHydnM33\ne3OZN6KTVcz6eDVS7pfx3G1R2NsovPjjEa2jCCGuouh8DYk/p9IrzJsxcW21jmMSpNwvw8/diUdu\njmRD6hk2HJUPV4UwZYk/p1JRU89Lt3e2qvljrkTK/Qqm9QsjwteVF348QnWdfLgqhCnanVXMN3ty\nmDkg3KIveH2tpNyvwN7WhhfGxODtYs+nWzO1jiOEuERdg45/fH+Ith5OPDxUPkRtSsr9KvpFtKFd\n61a8tT5dzlwVwsQkbc8iraCcf42JwcVBPkRtSsq9Bf4xKgpHexv+seKQnLkqhInIP1fFm+uOcVMn\nX4ZH+2kdx+RIubeAr5sTT43oxPbjRazYl6t1HCEE8J9VR6jXqfx7dIx8iNoMKfcWuqdXO+KCPXlp\n1VFKK2u1jiOEVduUeobVB0/zt5siaNfaRes4JknKvYVsbBTmj4ultKqOxJ9TtY4jhNWqqm0gcW0q\nt3b2Z9bAcK3jmCwp92sQ3dad6f1CWZaczZ6TxVrHEcIqvbEujdT8cqb0DcXRzlbrOCZLyv0aPXpz\nB3qHt+bp7w5SUy/HvgthTAdySlm4NZPJvdrRO7y11nFMmpT7NWrlaMfsgeFknDnP+xtlYjEhjKWu\nQcdT3x6gjasjz4zspHUckyflfh0Gd/RlXLdAPth0nNTTZVrHEcIqLNh8gtTT5fzn9s64O9lrHcfk\nSblfp3+Oisbd2Z6nvztIg06OfRfCkI4XnuftDemMjPXnlhh/reOYBSn36+TdyoF/j4lhf3Ypi7bJ\n1ARCGIpOp/LM8oM42dnw7zExWscxG1LuN2B0lwCGdvLl9V+OkV1cqXUcISzSkt2n2JVZzD9ui8bX\nzUnrOGZDyv0GKIrCS+M642in8P6vGTI1gZE0NDTQrVs3Ro0aBUBmZiYJCQlEREQwadIkamvlJDNL\nkV9aReLqVPq2b82d8UFaxzErUu43KMDDmWdvi2bp7my+2nVK6zhW4e233yYq6v8ufvz000/z2GOP\nkZGRgZeXFwsXLtQwndAXVVV5bsUhugZ78PL4WJli4BpJuevBnT2C6B/Rhv/96ajsnjGwnJwcfvrp\nJ2bOnAk0FsDGjRuZMGECAFOmTGHFihVaRhR68vXubDamnmFolB8hra37YtfXQ8pdDxRF4ZUJXbBV\nFJ78dj86OXrGYB599FFeffVVbGwaX7pFRUV4enpiZ9c43WtQUBC5uTK5m7nLLq7kP6uO0Ce8NVP6\nhGodxyxJuetJoKcz/xwVzc4TxXyx86TWcSzSqlWr8PX1pUePHtf1/AULFhAfH098fDyFhYV6Tif0\nRadTefLb/SiKwqsTumBjI7tjrofMbq9Hd8YHsfpQPok/pzKogw+hbeStpD5t27aNH374gdWrV1Nd\nXU1ZWRmPPPIIpaWl1NfXY2dnR05ODoGBgc0+f/bs2cyePRuA+Ph4Y0YX1+Cz7VnsPFHMK3fEEuwt\nMz5eLxm565GiKCSO74KdbePuGTm5Sb9efvllcnJyyMrKYunSpdx0000sXryYIUOG8O233wKQlJTE\n2LFjNU4qrtfxwvO8siaVmzr5MjE+WOs4Zk3KXc/8PZz49+gY9mWXslSOnjGKV155hTfeeIOIiAiK\nioqYMWOG1pHEdahv0PH3ZftxsrclUY6OuWGKAY/Nttphq6qqPPXtAVbuz+OHh/rRyd9d60jiEvHx\n8SQnJ2sdQzTx/q8ZvLY2jXcmd2NM17Zax9GSXn6rycjdABRFYd6tnXB3sufRpfuorpOpgYW4kiN5\nZby1/hi3dQmw9mLXGyl3A2nt6shrd3Yh9XQ5r61N0zqOECaruq6Bv3+zDw9nB/4ztrPWcSyGlLsB\nDenoy/19Qli4NZMt6XLonRDNef2XNGxtFF65IxbvVg5ax7EYUu4G9sytUUT4uvLEN/spqZA5T4Ro\nalPaGT7ekkm3YC+GRvlpHceiSLkbmLODLW9NiqO4opZnvz8ok4sJcUFheQ1PfLOfjn5uPHdb1NWf\nIK6JlLsRdA704O/DO7I1vZAV++TUeCF0OpW/f7Of8up63r27G072cqFrfZNyN5JZA8LpF+nDs8sP\nkXGmXOs4Qmhq4dZMNh8r5PnR0XTwc9M6jkWScjcSWxuFF8bE4OJgy0Nf7ZXDI4XVOpBTyqtrU7m1\nsz9392qndRyLJeVuRH7uTrw+sWvjRX5XHdE6jhBGd76mnoeX7MXH1ZHE8V3kLFQDknI3ssEdfZkz\nMJzFv5/ipwP5WscRwqieX3mIU8WVvD25Gx4u9lrHsWhS7hp44paOxAV7Mu+7A3JxD2E1Vu7LZVPa\nGR4Z2oGeod5ax7F4Uu4asLe14d3J3QjwdOLFH49QUy/734VlSztdztPfHaB7iBcP3RShdRyrIOWu\nkWBvF54Y3pF1Rwt48UfZ/y4sV3l1HXO/3IObkz3zx8ViKxffMAopdw0Nj/HngUHtWfz7KZan5Ggd\nRwi9U1WVed8d5GRxJe9N7oavm5PWkayGlLvGnhjegd7h3jz7/UFST5dpHUcIvVq0LYufDubz1C0d\nSQhvrXUcqyLlrjE7Wxvendwddyd75n6ZQll1ndaRhNCLPSeLmb/6KMOi/Zg9MFzrOFZHyt0E+Lg5\n8v493TlVXMm872T+GWH+zpRVk7g6lS5BHvy/O7vK8ewakHI3ET1DvXl+VDRpp8t4b2OG1nGEuG41\n9Q088OUeDuWV8b/jYvFwluPZtSDlbkLu7xNClyBPXl93jLWHT2sdR4hrpqoqz684TMqpUl6f2JWo\nALnEpFak3E2Ioii8PD6WrsGePP71PvmAVZidL3ae5OvkbB4aEsHI2ACt41g1KXcT42Rvy4L7etDK\n0Y5ZnydTLBf4EGZi54kiXvzxCDdH+fL4sA5ax7F6Uu4myM/diQX3x1NQVsODi/dQ16DTOpIQV5RT\nUsmDi1MIae3Cm5PisJETlTQn5W6i4oI9SRwfy84TxTKDpDBp52vqSfw5lQadysf3x+PmJB+gmgI7\nrQOIyxvfPYi00+X8d/MJOvi5cm/vUK0jCfEn9Q06/vZVCpvTz/LF9F6E+7hqHUlcIOVu4p4a0Yns\n4kq+3p1NSOtWDIj00TqSEEDjkTH/+uEwv6YVMn9cLH0j2mgdSTQhu2VMnK2NwisTulDXoDL3yxSO\n5ssRNMI0fLzlBIt/P8UDg9pzd4JcUcnUSLmbATcnexZN64mrox3TFu0m/1yV1pGElVt9MJ/5q1O5\nrUsAT93SUes4ohlS7mYiwMOZRdN6cr6mnmmLdsscNEIzKadKeOzrffQI8eL1O7vKkTEmSsrdjEQF\nuPPhvd3JOHOeB79MobZeDpEUxnWyqIJZScn4ezjx8f3xONnbah1JXIaUu5kZEOlD4h1d2JpxlnnL\nD8gkY8JoiitqeXb5QRpUlUVTe+LdykHrSOIKpNzN0IQeQTx2cwcOZJfyxrpjWscRVqC8uo6pi3Zx\n9HQ5C6f0lEMezYCUu5l6eGgE/SLa8O7GDD7cdFzrOEaRnZ3NkCFDiI6OJiYmhrfffhuA4uJihg0b\nRmRkJMOGDaOkpETjpJaluq6BWZ8nczivjNcmdKFHiJfWkUQLSLmbKUVReH50DGO6tuWVNaks/v2k\n1pEMzs7Ojtdff50jR46wc+dO3n//fY4cOUJiYiJDhw4lPT2doUOHkpiYqHVUi1HfoOOhr/ay80Qx\nr9/ZlaFRflpHEi0k5W7GbG0UXp/YlZs6+fKPFYdYuS9X60gGFRAQQPfu3QFwc3MjKiqK3NxcVq5c\nyZQpUwCYMmUKK1as0DKmxdDpVJ767gDrjxbwwpgYbu8WqHUkcQ2k3M2cva0NH9zTnV6h3jy+bD/r\njxRoHckosrKy2Lt3LwkJCRQUFBAQ0Di9rL+/PwUFzW+DBQsWEB8fT3x8PIWFhcaMa3ZUVeXFVUdY\nnpLL48M6MKVvqNaRxDVSDHi0hRzGYUTl1XXc88nvpJ4uJ2laT/q0t9xTwc+fP8+gQYN47rnnGD9+\nPJ6enpSWll6838vL66r73eMzJWbgAAASS0lEQVTj40lOTjZ0VLOkqir/b20aW4+fJT7Em3/cFiWX\nyTMuvWxsGblbCDcne5Km9WJIRx8e+DKFHceLtI5kEHV1ddxxxx3cc889jB8/HgA/Pz/y8/MByM/P\nx9fXV8uIZk1VVV7/5RjvbzpObKAHz42UYjdXUu4WxKuVAy/dHouvmyPTP9ttcQWvqiozZswgKiqK\nxx9//OLtY8aMISkpCYCkpCTGjh2rVUSzpqoqb6w7xnu/ZjC5VzAvjuksZ5+aMdktY4EKy2u4++Od\n5JRU8enUnvRp31rrSHqxdetWBgwYQGxsLDY2jeOS+fPnk5CQwMSJEzl16hQhISEsW7YMb2/vKy5L\ndsv81RvrjvHOhnTu6hnM/HGxUuza0cuGl3K3UH8UfHZJJZ9O7UlfC94Hfz2k3P+Pqqp89NtxXlmT\nxsT4IBLHd5Fi15bscxeX5+PmyJLZvWnn7cJLq47yW9oZrSMJE6SqKvNXH+W9jRnMGhAmxW5BpNwt\nWBtXR76a1Rs3Jztmfp7MTwfytY4kTEiDTuWZ5Qf5eEsmd8YH88ytUVLsFkTK3cK1cXVkwf3xxAV7\n8rclKSzddUrrSMIE1NbreHjpXpbuzuZvN0Xwr9HRUuwWRsrdCng42/P59AQGdvBh3vKDfPSbdcxF\nI5pXVdvA7C8a38k9O7ITfx/eUQ53tEBS7lbC2cGWBffFM7prWxJ/TuWVNakyXbAVKqmo5W9LUjiU\ne46Xx8cye2B7rSMJA5ELZFsRBzsb3poUh5eLPZuPFZJXWsWrE7rgaCcXXLAGJ4sqmLpoN7mlVbwz\nuRsjYvy1jiQMSEbuVsbWRuGFMTGMjA1g5b487vtkFyUVtVrHEgaWcqqEcR9sp7Sylq9mJkixWwEp\ndyukKAr/MySCt++KY192KeM/3E7W2QqtYwkDWXMon8kLduLmZMfyB/sRH3rlE7yEZZByt2Jj4wJZ\nPCuB0spaxn2wjd2ZljVdgbVTVZXPd2Qxd3EK0W3dWT63L2FtWmkdSxiJlLuV6xnqzfIH++HpYk/i\nmjSWyKGSFqGytp6Hl+7j+ZWHmdY3lCWzetPa1VHrWMKIpNwFYW1a8f2D/XBxsOWZ5QeZ990Bqusa\ntI4lrtPJogrGf7Cdnw7k8dSIjvxzVDRO9vKhubWRchcAeLo48Nm0XvzPkPYs3Z3NpP/uIK+0SutY\n4hr9mnaG0e9u5XRZNZ9N68WDgyPkGHYrJeUuLrK1UXjylk58dG8PjhdWMPrdrWw/flbrWKIFdDqV\ndzakM/2z3QR5ufDjQ/0Z2MFH61hCQ1Lu4i9GdPZnxf807oef/9NR3l5/jAadnPBkqooranl+5SHe\nWHeM2+MC+W5uX4K9XbSOJTQm5S6aFeHrysqH+hMb5Mmb69O5a8EOckoqtY4lLrEt4ywj3trM17uz\neW1CF96Y2BVnB9m/LqTcxRW4Otoxf1xn3pzUlaP55dz69hZ+2J+ndSxB48RfL/98lHsX/o67sz0r\nH+rPnfHBsn9dXCQX6xAtcqqokke/3kvKqVLGdwvkhbExuDnZax3rupnzxTqOFZTzzoZ0Vh3I556E\ndvzjtmgZrVsWvfyGlrllRIu0a+3Csjl9eHdjBgs2HyerqIKHh0YyuKNcjNpY6hp0fLTpOO9sTCfA\n3YkF9/VguEwjIC5DRu7imu07VcLfv9nP8cIKxnUL5J+jovFu5aB1rGtibiP3I3llPPntfg7nlTGq\nSwAvjImRk5Isl1xDVWinpr6B9zdm8MGm47g72/PPUVHcHhdoNvt8zaXcz9fU8/b6YxzMPUfGmQpe\nur0zIzrLaN3CSbkL7aWeLmPedwext1Wws7Hh+dHRRAW4ax3rqky93FVV5Yf9efzvT0c5U17DvQnt\n+PvwjniZ2TskcV2k3IVpaNCpLEvO5pU1qZRV1XFPQgiPD+tg0kVkyuV+NL+Mf/1wmF2ZxXQJ8uDF\nsZ2JC/bUOpYwHil3YVpKK2t5a306X+w8iaujHU/c0oE7ewSb5LwmpljuOSWVvLU+nYM55zhTXs2T\nt3RiUs9gbOXaptZGyl2YpmMF5Xy06TjL9+bS1sOJR26O5I7uQdjZms5pFaZU7kXna3j/1+N8ufMk\nKDClTwgPDmmPl4t8YGqlpNyFaduWcZbX1qaxL7uU3mHe3NEjiLFxgTjYaV/yplDuBWXVJG3PYlly\nNsUVtUzoEcSjN3egraezprmE5qTchelTVZX1R8+wcl8uqw7kE+DhxMwB4dzVM5hWjtqdZqFluR/J\nK+OTrSf4cX8eDTqVGf3DmdQziAhfN03yCJMj5S7Mh6qq/HaskA83Hef3zGJat3Lgrp7B3NWrnSaT\nXBm73FVVZdOxQj7ZcoJtGUW4ONgyMT6Y6f3CaNdaJvkSfyLlLsxTyqkSftyfR9L2LFRgQKQP9yS0\nY2gnX6PtlzdWuRdX1LL6YB5J20+SfuY8fu6OTO0bxt292uHhYr7TNwiDknIX5i3/XBVLd2Xz9e5s\nTpdV07d9a6IC3LmtSwDdgj0NekKUIcu9qraBX46cZuW+PDYfK6Rep3JzlC+3dQngtti2JvGZgzBp\nUu7CMtQ36NiYeoYt6Wf5enc2tQ062no4MTI2gFtj/ekW7IWNng8H1He5V9U2sCW9kJ8PneaXw6ep\nqG3A392JsXFtGRsXSFSAm9mcvSs0J+UuLE9ZdR3rjxTw04F8NqcX0t7HlbPnaxnYoQ03dfKlR4gX\nAR5/PZpkzZo1PPLIIzQ0NDBz5kzmzZt3xfXcaLnrdCqpp8vZfvwsR0+XsWp/PjX1OnqEeBHh48rt\n3QJJCPPW+y8lYRWk3IVlO1dVx/bjZ1lz6DSbjxXi5eLAibMVBHs7MzDShwhfV7oGedLex4XusdGs\nW7eOoKAgevbsyZIlS4iOjr7ssq+l3FVV5UxZDamny9mXXcre7BKq63TsPFEEwMhYf/zcnRjayY9e\nYd6y20XcKCl3YT0adCqp+WX8nllM8sli9p4qJf9cNd3beXKurJyc3Fx6xXUmpq07B37fQiubWmbf\nNxmvVg54ONvTysEWZwe7i7tG/ij3hgYdFTUNVNTVUVJRT0lVLQXnqskpqaK4opai8zVszThLsJcL\nR/LPYWdrQ4h3K/q09yY20JM+7VvLcelC36TchXU7fa6aI3nnWLlxOynHTuEZ0YPSqloaKsuoPl9O\nu7BwHO1sqKnXUV5dR1tPZ3Q6FSd7W77/1/0MfWYhOhVaOdhSr1Oprmugpk5HdkklZdX1DIhsg5O9\nLW1cHekW7EFI61ZEB7jj5ixHuQiDMu1yj4mJUZ2dTX9EU1hYiI+P6V8lXnJeXklJCWVlZYSEhDRm\nOFvE+cpqfPz80OlUGnQqDapKxfkKzldUAFBXcY7AsEgUBWwUBRsFbGwUbBUFO1sbHGxtMIXPP+Xn\nrj/mkBFgz549h1VV7XzDC1JV1SBfPXr0UM2B5NQvLXJu375dHT58+MXv58+fr86fP/+Kz5HtqV/m\nkNMcMqqqqgLJqh46WD75EWavZ8+epKenk5mZSW1tLUuXLmXMmDFaxxJCU3INVWH27OzseO+997jl\nlltoaGhg+vTpxMTEaB1LCE0ZrNxnz55tqEXrleTUL61yjhw5kpEjR7b48bI99csccppDxgsW6GMh\ncrSMEEKYFr18lC/73IUQwgLdULkrinKnoiiHFUXRKYoS3/S+l19+mYiICDp27MjatWubfX5mZiYJ\nCQlEREQwadIkamtrbyROi0yaNIm4uDji4uIIDQ0lLi6u2ceFhoYSGxtLXFwc8fHxzT7GkP79738T\nGBh4Mevq1aubfdyaNWvo2LEjERERJCYmGjklPPnkk3Tq1IkuXbowbtw4SktLm32cVtvzatunpqaG\nSZMmERERQUJCAllZWUbLBpCdnc2QIUOIjo4mJiaGt99++y+P2bRpEx4eHhdfCy+++KJRM/7haj9D\nVVV5+OGHiYiIoEuXLqSkpBg9Y1pa2sXtFBcXh7u7O2+99dafHqPV9pw+fTq+vr507vx/RzkWFxcz\nbNgwIiMjGTZsGCUlJc0+V1GUKYqipF/4mtKiFd7IoTZAFNAR2ATEN7k9ukuXLmp1dbV64sQJNTw8\nXK2vr//LIT933nmnumTJElVVVXXOnDnqBx98oM8jiq7q8ccfV1944YVm7wsJCVELCwuNmqepf/3r\nX+prr712xcfU19er4eHh6vHjx9Wamhq1S5cu6uHDh42UsNHatWvVuro6VVVV9amnnlKfeuqpZh+n\nxfZsyfZ5//331Tlz5qiqqqpLlixRJ06caNSMeXl56p49e1RVVdWysjI1MjLyLxl//fVX9bbbbjNq\nruZc7Wf4008/qSNGjFB1Op26Y8cOtVevXkZM91f19fWqn5+fmpWV9afbtdqev/32m7pnzx41Jibm\n4m1PPvmk+vLLL6uqqqovv/zyH/9/Lu1Zb+DEhT+9Lvzd69LHXfp1QyN3VVWPqqqa1sxdY++66y4c\nHR0JCwsjIiKCXbt2/eWXysaNG5kwYQIAU6ZMYcWKFTcS55qoqsqyZcuYPHmy0dapb7t27SIiIoLw\n8HAcHBy46667WLlypVEzDB8+HDu7xs/le/fuTU5OjlHXfyUt2T4rV65kypTGgdCECRPYsGHDH/+h\njCIgIIDu3bsD4ObmRlRUFLm5uUZbvz6tXLmS+++/H0VR6N27N6WlpeTn52uWZ8OGDbRv3/7iyW1a\nGzhwIN7e3n+6renr7wodeAuwTlXVYlVVS4B1wIirrvBq7d+SL/46cn8PuLfJ9wuBCZc8pw2Q0eT7\nYOCQPvK0MPNArnCyAJAJpAB7gNnGytVk/f8GsoADwKc085samAB80uT7+4D3jJ21yfp/bPpz13p7\ntmT7AIeAoCbfHwfaaLT9QoFTgPsltw8GioD9wM9AjEb5rvgzBFYB/Zt8v6FpL2iQ91PgoWZu12x7\nXvgZH2ryfWmTvytNv29y+xPAP5p8/0/giaut66qHQiqKsh7wb+au51RVNe4wsYVamHkysOQKi+mv\nqmquoii+wDpFUVJVVd1srJzAh8B/aDzq6D/A68B0fa6/pVqyPRVFeQ6oBxZfZjEG357mTFEUV+A7\n4FFVVcsuuTsFCFFV9byiKCOBFUCksTNiRj9DRVEcgDHAM83cbSrb809UVVUVRdHb28arlruqqjdf\nx3JzaRyJ/yHowm1NFQGeiqLYqapaf5nHXJerZVYUxQ4YD/S4wjJyL/x5RlGU74FegF5fyC3dtoqi\nfEzjqOhSLdnON6wF23MqMAoYql4YWjSzDINvz2a0ZPv88ZicC68LDxpfm0ajKIo9jcW+WFXV5Zfe\n37TsVVVdrSjKB4qitFFV9awxc7bgZ2iU12ML3QqkqKpacOkdprI9LyhQFCVAVdV8RVECgDPNPCaX\nxncbfwiicW/JFRnqUMgfgLsURXFUFCWMxt+Kf9rpfqEEfqXxrTPAFMBY7wRuBlJVVW12B7GiKK0U\nRXH74+/AcBrfvhvNhR/0H8ZdZv27gUhFUcIujFTuonHbG42iKCOAp4AxqqpWXuYxWm3PlmyfH2h8\n7UHja3Hj5X5BGYLSOAfxQuCoqqpvXOYx/hceh6IovWj8f2vsX0At+Rn+ANyvNOoNnFNVVaud7pd9\nZ24K27OJpq+/y3XgWmC4oiheiqJ40bjtmz8Esakb3H80DsgBaoACYG2T+56jcf9lGnBrk9tXA20v\n/D2cxtLPAL4BHI203+sz4IFLbmsLrG6Sa/+Fr8M07n4w9v7CL4CDNO5z/wEIuDTnhe9HAscubGst\ncmYA2cC+C18fmdL2bG77AC/S+MsIwOnCay/jwmsx3Mjbrz+Nu94ONNmGI4EH/niNAg9d2G77gZ1A\nXw1+zs3+DC/JqQDvX9jWB9FofzvQisay9mhym+bbk8ZfNvlA3YXenAG0pvGziXRgPeB94bHx/Pnz\noukXXqMZwLSWrM+QZ6gKIYTQiJyhKoQQFkjKXQghLJCUuxBCWCApdyGEsEBS7kIIYYGk3IUQwgJJ\nuQshhAWSchdCCBOgKEpPRVEOKIridOGM4MOKonS++jMvszw5iUkIIUyDoigv0XjWtDOQo6rqy9e9\nLCl3IYQwDRfmQNoNVNM4LULD9S5LdssIIYTpaA24Am40juCvm4zchRDCRCiK8gOwFAijcbLAh653\nWVedz10IIYThKYpyP1CnqupXiqLYAtsVRblJVdWN17U8GbkLIYTlkX3uQghhgaTchRDCAkm5CyGE\nBZJyF0IICyTlLoQQFkjKXQghLJCUuxBCWCApdyGEsED/H3ZRIJS9tKTmAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "e = 'x*x'\n", "plot(e);" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "2*x" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "diff(e)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "gradient = np.array(['2x'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "[0]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "solve(e)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### You told us to \"step proportional to the negative of the gradient\"?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "what if `x > 0`, `x < 0`?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Now, we'll do gradient descent live" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Starting with linear regression" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "X, Y = make_classification(n_samples=1000, n_features=3, n_informative=3,\n", " n_redundant=0, n_repeated=0, n_classes=2, \n", " random_state=RS)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### We need to introduce a `bias` term:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "bias_column = np.ones(X.shape[0])\n", "\n", "X = np.c_[bias_column, X]\n", "Y = Y.ravel()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Aside: do we even need gradient descent?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "# $$X_{m \\times n}, Y_{m \\times 1}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Meet the `'normal'` equation:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "# $$\\begin{equation}y_{mx1} = X_{m \\times n}\\space\\beta_{n \\times 1} + \\epsilon_{m \\times 1}\\end{equation}$$\n", "# $$\\begin{equation}\\beta_{n \\times 1} = (X^{T}X)^{-1}_{n \\times n}\\space X^{T}_{n \\times m}\\space Y_{m \\times 1}\\end{equation}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### This gives us a way to compute our `beta` vector:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "betas_normal_eq = inv(X.T @ X) @ X.T @ Y" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Sanity check: Scikit-learn" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "linreg = LinearRegression(fit_intercept=False)\n", "linreg.fit(X,Y);" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "betas_sklearn = linreg.coef_" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "array([ 0.45801758, 0.27764882, -0.10227603, -0.04232889])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "array([ 0.45801758, 0.27764882, -0.10227603, -0.04232889])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "betas_normal_eq\n", "betas_sklearn" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## But where are the gradients???" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "### Spoiler: linear and logistic regression aren't so different to optimize" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### To implement gradient descent for linear regression, we will use the identity function. \n", " \n", "#### For logistic regression, we will use the `sigmoid` function:" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "# $$\\begin{equation} F(z) = z \\end{equation}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "# $$\\begin{equation} F(z) = \\dfrac{1}{1+e^{-z}} \\end{equation}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### These two functions, let's write 'em up" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "identity = lambda z: z\n", "\n", "sigmoid = expit\n", "\n", "def sigmoid(z):\n", " return 1 / (1 + np.exp(-z))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Why do these $F(z)$ equations matter? What is our hypothesis?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### Our two hypotheses will be identical, except for the aforementioned functions!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### Recall the normal equation?" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "(1000,)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hypothesis = X @ betas_sklearn\n", "hypothesis_shape = hypothesis.shape\n", "\n", "hypothesis_shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Gradient Descent!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "> Gradient descent is a first-order iterative optimization algorithm" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "We must define an update step that moves us closer to the solution each iteration." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def gradient_descent(X=X, Y=Y, kind='linear', learning_rate=0.01, iterations=int(1e5)):\n", " betas = np.zeros(X.shape[1])\n", " assert kind in {'linear', 'logistic'}, f\"Whoops! Don't support kind: {kind}.\"\n", " fn = identity if kind == 'linear' else sigmoid\n", " \n", " def update_step(betas):\n", " hypothesis = fn(X @ betas)\n", " loss = hypothesis - Y\n", " gradient = X.T @ loss * (1 / X.shape[0])\n", " return betas - learning_rate * gradient\n", " \n", " return pipe(betas, *([update_step] * iterations))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# (tweetable version)\n", "def gradient_descent(X=X, Y=Y, kind='linear', lr=0.01, n=int(1e5)):\n", " return pipe(np.zeros(X.shape[1]), *([\n", " lambda betas: betas - lr * \n", " X.T @ (((identity if kind == 'linear' else sigmoid)(X @ betas)) - Y) * (1 / X.shape[0])\n", " ] * n))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Does it work for linear regression?" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "betas_gd = gradient_descent()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "array([ 0.45801758, 0.27764882, -0.10227603, -0.04232889])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "array([ 0.45801758, 0.27764882, -0.10227603, -0.04232889])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "array([ 0.45801758, 0.27764882, -0.10227603, -0.04232889])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "betas_gd\n", "betas_normal_eq\n", "betas_sklearn" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Great! How about for logistic regression?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "### What saith Scikit?" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "logr = LogisticRegression(fit_intercept=False, C=1e18)\n", "logr.fit(X,Y);" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "betas_logr = logr.coef_[0]\n", "betas_gdl = gradient_descent(kind='logistic')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "array([-0.13, 3.08, -1.04, -0.41])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "array([-0.13, 3.08, -1.04, -0.41])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.round(betas_logr, 2)\n", "np.round(betas_gdl, 2)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Bold claim: we're 90% of the way to a fully formed sklearn estimator!" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "class GradientDescentClassifierRegressor:\n", " def __init__(self, kind='linear', learning_rate=0.01, iterations=int(1e5)):\n", " self.kind = kind\n", " self.learning_rate = learning_rate\n", " self.iterations = iterations\n", " \n", " def fit(self, X, Y):\n", " self.coef_ = gradient_descent(\n", " X,\n", " Y,\n", " self.kind,\n", " self.learning_rate,\n", " self.iterations,\n", " )\n", " return self\n", " def predict(self, X):\n", " try:\n", " getattr(self, 'coef_') # illustrative\n", " fn = identity if self.kind == 'linear' else lambda arr: np.round(sigmoid(arr)).astype(int)\n", " return fn(X @ self.coef_)\n", " except AttributeError:\n", " raise RuntimeError('Must fit first!!!')" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": true }, "outputs": [], "source": [ "gdcr = GradientDescentClassifierRegressor(kind='linear')\n", "gdcr.fit(X,Y);" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "array([ 0.61911337, 0.3343834 , 0.49442232, 0.5050518 , -0.20501244,\n", " 0.00771493, -0.06387987, 1.02075149, 0.51484657, 0.77049524])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "array([ 0.61911337, 0.3343834 , 0.49442232, 0.5050518 , -0.20501244,\n", " 0.00771493, -0.06387987, 1.02075149, 0.51484657, 0.77049524])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gdcr.predict(X[:10])\n", "linreg.predict(X[:10])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "
\n", "
Questions?
\n", "
" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }