{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Programming Exercise 3 - Multi-class Classification and Neural Networks" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# %load ../../standard_import.txt\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "# load MATLAB files\n", "from scipy.io import loadmat\n", "from scipy.optimize import minimize\n", "\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "pd.set_option('display.notebook_repr_html', False)\n", "pd.set_option('display.max_columns', None)\n", "pd.set_option('display.max_rows', 150)\n", "pd.set_option('display.max_seq_items', None)\n", " \n", "#%config InlineBackend.figure_formats = {'pdf',}\n", "%matplotlib inline\n", "\n", "import seaborn as sns\n", "sns.set_context('notebook')\n", "sns.set_style('white')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load MATLAB datafiles" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "dict_keys(['__version__', '__globals__', '__header__', 'y', 'X'])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = loadmat('data/ex3data1.mat')\n", "data.keys()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "dict_keys(['__version__', '__globals__', '__header__', 'Theta1', 'Theta2'])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights = loadmat('data/ex3weights.mat')\n", "weights.keys()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: (5000, 401) (with intercept)\n", "y: (5000, 1)\n" ] } ], "source": [ "y = data['y']\n", "# Add constant for intercept\n", "X = np.c_[np.ones((data['X'].shape[0],1)), data['X']]\n", "\n", "print('X: {} (with intercept)'.format(X.shape))\n", "print('y: {}'.format(y.shape))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "theta1: (25, 401)\n", "theta2: (10, 26)\n" ] } ], "source": [ "theta1, theta2 = weights['Theta1'], weights['Theta2']\n", "\n", "print('theta1: {}'.format(theta1.shape))\n", "print('theta2: {}'.format(theta2.shape))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeAAAAA6CAYAAABoBopoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnVlwVFd6+H+9L9q6W2pJrV0CbWgHJDAgNol9tc3Y2PE4\ndsZ2yhnPQ5JK5SE1qUrNQ1JJpVKVqqQmGbsmMV4HG2NcGASYHSSBWC0hgaQWaN9barXU6lZ33/8D\ndW/AxtDdki2S//29UAh07rnnnnO+73zbUQiCICAjIyMjIyPzk6Kc7w7IyMjIyMj8/4gsgGVkZGRk\nZOYBWQDLyMjIyMjMA7IAlpGRkZGRmQdkASwjIyMjIzMPyAJYRkZGRkZmHpAFsIyMjIyMzDwgC2AZ\nGRkZGZl5QBbAMjIyMjIy84AsgGVkZGRkZOYB9U/5MLfbHfLvKBQKAOSKmXOPOLYi/7+OsdfrxeFw\n4PV6sdlsqNU/6bKQkflRGBgY4M6dO5hMJgoLC7+33p8mfqx9/mmRHwaD4ZE/fyp3GoVCQSAQIBAI\n4PV6EQQBtVqNWq1GqQzv0P7g5JvvjzEbFAoFCoUCQRBm9R4ej4eRkRHu3buH3W5neHiY9PR0qqur\niYiImNViFfsYCATCbkMQBAKBADMzMygUClQqFQBKpRKlUjknC0uhUOByubhy5Qp1dXVERkby4osv\nEhcX9796jjyNiHMCmPXcfRTivjCbOfdTI+5x4pyeC8R119/fzzfffMPFixeprq6mqKhoTtr/sXC7\n3fj9fjQaDRqNZk6UhUAgQFtbG21tbZjNZioqKqR9JBwEQcDn8yEIAlqtdtb9g6dQAAcCAcbHx+np\n6cHpdDI2Nsbk5CQ6nQ6bzUZ6ejrx8fFhtSkIAnq9Hq1Wi0qlQqlUztlG8FMIeEEQcLvduN1u9Ho9\nRqMxrHZGRka4efMm165do6Ojg5GREQKBAD6fj9WrVxMRERF2H/1+P5OTk3i9XiwWS9gbi7iJ1NbW\nIggC0dHRwH1NMi4uDovFQlRUFEajMazvKAgC4+PjXLp0iS+//JKrV69SVFTErl27JAVHZvaIm5bH\n48HlcgEQGRkpfbe5eoY4hyMjI9HpdGG186g1PNu5IK4r4KFNW6FQ4HQ6aWpqIjMzk8TExDkZD7fb\njd1up66ujsuXL6NQKELeL38qxLlht9ulvpaXl5OVlTVrK5RCoWB4eJhvvvmGpqYm1qxZQ0VFxaza\n9Pl8tLW1MTIywpIlS37wVBsKT5UAFgQBp9PJmTNnaGtrQ6FQ4Ha7GRwcZHJykvj4eCorK9m8eXNI\nk7W3t5f6+npmZmZISEjAZDKh1+sxGAyYzWaio6PD1rgEQcDr9eJ0OpmcnJTaDEeLEwQBv9/P9PQ0\ngUCAqKioh9pwu93cunWLtrY2bDYbRUVFmEymoMdCoVAwOTlJXV0dn332Ga2trcTHx5OdnU12dja5\nubno9fqQ+vxdJiYmuH79OhqNhvLy8rA1xUAgQF9fHx988AFDQ0Pk5+cTHR2N0WgkLi4Oq9WKzWYj\nIyODhIQEjEZjSOM9NTVFfX09n3zyCdevXyc6Opr09PQ5WVTB4vf78fv9qFSqoDTzBwXBT2lOFAQh\nrOfNzMwwMjJCV1cXw8PDdHZ2olAoSE1NJTU1laSkJGJiYmZ1KlEoFIyMjHD+/HmUSiVLliwhOTk5\naKEpCAIzMzP4/X5JSTAYDGi1Wrxer7QOtVptWAJyZmaGrq4uAoEAmZmZaDQaqd/j4+P84Q9/YOvW\nrVgsllmvPZfLRWNjI8ePH6e5uZnU1FR27tw5a8HzY+Hz+bh79y779++nrq5OOmBlZmbOum1BEGhu\nbuby5cuo1WoyMzNnPc8CgQCtra00NDRgNpspLCyctaL+VAlg0WTwz//8zyxevJgXXniBnJwcxsbG\nOHXqFCdOnMDn87Fhw4agF4Pb7ebQoUN88MEHjI+Pk5CQgNlsxmAwYLVa2bZtG+vWrQv747jdbu7e\nvUt9fT137twhLS2N6upq0tPTwxI+DoeDxsZGpqamqKysJDIyUnrXjo4OPv74Y44cOUJubi6vvvoq\n1dXVREZGBrVBBgIBbt++zYcffkhjYyObNm3i1VdfJSMj46HnhDup/H4/3d3dHDx4kA0bNsxqcqrV\namJjY0lPT+fatWtUV1ezfPlyHA4Ht27dora2Fp/PR2FhIevWraOkpISYmJig5oVSqaSzs5N9+/ZR\nX19PSUkJzz33HCtWrMBkMs2pGTMQCCAIwkNm80AggNvtxuFwMDk5idVqJTY29rHjJf7O1NQUcH98\nNBoNKpVK+lNkLoSzKJimpqaYmZlBq9USERER9MnE7/fT19dHTU0NX331FaOjo4yOjqJQKDCbzWRk\nZLBjxw4qKyuJj48P+8Tj9/u5fPky7777LjabjdTUVJKTk4P+fZ/PR1dXF0NDQ3R2dnLjxg2ys7Ox\n2Wz09fXh9XopLy8nNTUVs9n80HcMBrfbzblz55icnOSll156yL2hUqlob2+npaWFkpISbDZb2GvG\n4/HQ0NDAu+++S2dnJ5s3b+aVV14hPT39qbTmBAIBhoeHOXDgAAcOHODNN9+kurqa1NTUOYnBcLlc\n3Lx5E6PRSFVVFSUlJbNqTxAEdDod6enpNDY2cuzYMXJyctBqtbPb52bVqzlE1DA6OzuJjo7m5Zdf\npqysDK1WS2xsLCMjI1y9elXyCQfD9PQ0X375JZ9//jmDg4PMzMwwPj4O3DcHZWVlUVBQwMzMTMgC\nWDxNnjt3jiNHjmC1WikvL+fo0aPY7XbefPNNsrOzg16s4mn/6tWr/P3f/z3T09O8+eabPPvss1gs\nFlwuF9988w3nzp1jamqKW7du8eGHH7J06VIiIyODan9gYID/+q//oru7m+eff569e/eSnZ09a8Er\ntj8yMsK1a9dQqVRUVVWFbQqE+0LSYDAQGxtLfHw8zz33HLm5uSgUCnbs2MHg4CA3btzg1KlT/OY3\nv+GZZ57hV7/6FfHx8U8cc0EQaGpqorOzk5KSEl5//XXWrl2LwWCYs5OlQqFgenqazs5OHA4HeXl5\nqNVqXC4Xvb29nDlzhqNHj6LRaPiLv/gL1q1b99j+Dg0N8cEHH3DgwAG8Xi/Z2dlkZGSQkZHBwoUL\nSUxMBO5v6rGxscTGxob9LjMzM/T391NfX09NTQ1tbW2Ulpby+uuvs2jRoqA2yJaWFj799FO++uor\nfD4fGzdu5JlnnmFiYoJ9+/Zx8uRJrl+/TnV1NXv37mX58uUh91OlUtHW1sYf/vAHTCYTu3btIi8v\n7yHz8ZNiEZxOJx9//DGff/45k5OTuN1udDodWq1Wij9ISkpi8eLF7Nmzh9TUVBISEoJWrgOBAFNT\nU5Li9CCiabqjo4OhoSFsNlvIYyBy4MAB9u3bh8Fg4K233qK6upqEhIQfTfjOJhZFoVDgcDg4cuQI\nhw8f5le/+hU7duwgNjaWiYkJJicnMRqNYVujFAoFly9f5syZM6SlpZGfn49Op5u1Yq1QKEhLSyMt\nLY0jR47Q3Nw8a8H+1AhgEb/fT0xMDGazWTL7iAFD4+PjFBQUBCUsp6enOXHiBL///e/p7u7G7/eT\nnJxMWVkZCxcuJCYmhvT0dEmLCRXRhHn79m3WrFlDaWkpJpOJhIQE3nvvPXp6ekhNTQ16Eomb7I0b\nN+jp6UEQBGpra1m/fj0xMTFcuXKFhoYGHA4HSqWS6OhoKioqiImJCWqjVSgUNDQ0cP36dTIyMli7\ndi1ZWVlz6odzOByMjY1RVVWFwWCY1eL3+XwMDg5it9vJy8sjJSUFnU6HQqFAp9Oh1+sxm82kpKRw\n5MgRjh07RkJCAq+88gomk+mxY+L1erHb7RiNRtatW8fSpUvD9qc/ikAgQHd3NzU1NdTU1DA1NSWZ\n4+12O0NDQ0RHR1NQUEBKSgo5OTmPbW94eJgjR47wxRdfEBUVRVZWFrm5uVgsFvr7+/nggw/o7+/H\n7/ejVqspLy/nl7/8JXFxcSEL4a6uLs6fP09DQwNut5v09HSysrJobGxkdHRUesbjmJ6e5uLFi5w8\neZKYmBheeuklqquriYuLo7e3l5aWFlasWMHExIRk0aioqAh5Lvb09PDZZ5/R2dlJdXU1BQUFkhnX\n7/fjcDhwOp2kpKT84BpXq9VYrVZSU1NJS0uT5pgoXMbHx2lvb+fMmTNcvXqVzMxMtm7dKp3cg+mz\nSqVCp9NJ5meRiYkJvF4vHR0dDA8Ph/TuIh6Ph0uXLvH+++9jMpn42c9+xqpVq0JWwMSAV6/XK727\neNp/MDDK7/fjdrsZGxtjenqaxMTEkIM2vV4vd+/epaGhgS1btrB582YsFgudnZ189tln3L17l5Ur\nV7Jt2zaio6ND3kcCgQC9vb2kpaVRWVlJRkbGnCgigiAQFRVFSkoKarWa2tra/zsCWBwgjUZDX18f\n7e3tJCcn4/f7uXDhAl988QU+n49nnnkmKAGsUChQKpUMDg7i9XrR6/VUVlaya9cuyTwcERGBXq8P\neeErFAr6+vro7e3FarWycuVKrFYrCoUCm83G5OQkw8PDeDwejEZjUB/f5/PR0dHB5cuXJc3YYDBg\nMBjweDw0NTVx7949fD4fUVFRlJaWsm3btqADpiYnJ6mvr2dqaoqioiKys7O/pxXOJrLY6/UyNDTE\nxMQEBQUFswpeUSgUzMzMSGbB6urq751ONRoNFouFoqIiVCoVDoeDmpoaqqqqiIqK+kEhIZ7Uv/32\nWwoLC1m+fDmxsbFh9fNRBAIB2tvbOXLkCKdPn2ZgYACv10traysWi4XY2FgWLlzIwoULycrKwmQy\nYbVaHztWTqeTxsZGBgYG2LFjB6tWrSIxMRGtVsvQ0BBKpRK73U57ezsWi4Xs7Gz8fn9IvlsxaOX0\n6dNcuHCBrKwsCgsLiY+P5/Lly1gsFiIjI5+49hQKBUNDQzQ3N6PVatm9ezc7d+7EZrNJ0b579uwh\nJiaG3t5ezp07F/IpShAEybpVU1NDSkoK5eXlJCQkoFAo8Pl89Pb2cvjwYVQqFS+++OIPCmCDwUBl\nZeUjT7aCIDA1NUVnZyetra3U1tZy5coVRkZGGBwcZOvWrU/0VyoUCiYmJuju7qarqwuTySS97/j4\nOB6PR/I/h4rf76e/v599+/ahVquluSHuRaEwPj7O0aNH6evrQ6VSoVAoUKvVksJrNBqlbz8xMUFX\nVxd6vZ5du3aRnZ0dtAVR9H03NTUxNDTEL37xC2w2G1NTU9TU1NDY2IjdbsdgMLB27VpiYmJC3kf8\nfj+9vb2o1Wp0Oh0Oh4Px8XGMRiNRUVEhtfVd1Go1ZrOZmJgYOjs7/2/5gBUKBRaLhYmJCY4fP45G\no2FqaorDhw/T0tLCmjVrKCsrC2pzV6vV5OXlkZeXR0NDA3q9XtLmExISUKlUs0qH6OnpYWZmhoKC\nAhITEyWBUVdXh8PhIBAIhLQIBgcHuXbtGnfu3JF+z2QyoVKp6Ojo4NatW5IPLSoqioKCAnJzc4Oe\n+GLEZWxsLIsWLcJisTz07j6fj+npaZRKZchKiWjebm9vR6vVfs/05fP5JIUomDERo737+/ul0+N3\nTw8iBoOBRYsWsXPnTn7961/T2tpKamrqYxfa6OgoIyMjbNy4kaysrFkFZzyIKMTOnj3LkSNHMBgM\n7N69G61WS3R0NCaTiZSUFJKSkh4KunnSmHi9XiYmJlAoFBQUFLBkyRIp6CwuLo6Ojg4iIyNJT09n\n/fr1LFu2LOi4ABGPx8P169e5efMmsbGxVFdXk5iYSFNTE5cvX6a4uJjExMSgzM9utxun00lMTAyF\nhYWkpqZKCgGAXq9Hr9cjCAKRkZEhn9ZmZmb49ttv+fLLL+nr62Pbtm3k5uZiNBoJBAIMDg5y8uRJ\nampqWLdu3WNNj1qtltzcXLKzsx8ZOBkIBCgtLWVoaIjc3Fz27dtHY2MjWq2WwsJCMjIyHtt3lUqF\nx+OhpaWFuro6srKypG9nMBjQaDRh70Mej4fW1lZOnz7NO++8w/r168MSvuJ7Tk5O4nK5pAwRjUYj\nxR643W60Wq3Ud7vdjk6nY3p6OqxnTU9PS20qFAru3LnDuXPncDgc0vPDsc4pFAqmpqbo6+tjYGCA\nhoYG2tra8Pv9pKamUlVVhdlsDrndB9s3GAzExMTgcrmkNLJweaoEsFKpxGq1EhUVxenTpxkdHWVy\ncpLBwUEKCgrYtGnT9wTHD6FSqcjIyKC0tJQbN27g9/tpb2+ntraWrKwsUlJSpKCKcHA6nSgUCkmj\nFbXuI0eO4HQ6MZlM6HS6oE+/LS0tXL58WdpoAYaGhrh58yaNjY00NzfjdrsfMo8NDw9Lvr/HIQgC\nY2Nj9PX1SRuiXq/H6/XicrkYHx9ndHSU/v5+dDodaWlpxMfHhxR0c/fuXZqbm8nPz5dO/WKebU9P\nDzExMcTFxQUt7MbHx+no6ECj0VBSUvLYfuh0OnJyckhMTOTOnTtUVFT8oAAWx81kMpGWlvaQ/1zM\nO/Z6vWi1WmlzCAWHw0FHRwdjY2NkZmaybNkySktLH3l6fDAP9HHPeVB5EU8nolDp7e3l5s2bjI2N\nsWLFCt566y3S0tJC9neNjo5SW1vL2NgYGzZsIBAIcPz4cb766it6e3vZu3cvFoslqLYiIyMxm810\ndXVx+/ZtCgoKMJvNzMzMcPXqVU6dOoXVasXhcBAREUFqampI69Dj8XDx4kVu3bpFcnIy+fn5WCyW\nh/K6Dx06JFm9nhRd/GAU+nfXq7jhpqWlYbPZaG9vp62tjaGhIUZGRvB6vY+NddBoNMTFxeF2uzl7\n9iyLFy+mrKwMpVJJYmIikZGRuFyusAWw3W5nfHyctWvXYrVaw06tjI6OZsuWLXi9XgBJCPr9fmku\nabVayeXV1dWF0WjEYrGgVqtDijoX3ShqtZrTp0+j0+k4evQoHR0deL1e4uPjpcDQcN6lt7eXvr4+\nPB4P/f392O12ent7pXm5fv36WcV6qNVq9Hq9ZOWcTebEUyWARf+CUqmUTIpKpZLly5eze/duVq5c\nGVJ7SqVSSiEYHR3l0KFDnD9/nvz8fDZu3MjmzZsxm81hfQy1Wi35TQDJVCMuSNGH/STEYK6WlhZa\nW1ulnwFcvnyZe/fuSf5v8d9GR0c5ffo0qamp7Nmz54nPEHNqRZ9NZGQk09PT9PT0cOPGDRoaGrh9\n+zbd3d0YDAYWL17MunXrqKioCCqoCe4LzJGREUloiYu2ubmZI0eOsGzZMtasWRP0YnU4HLS1taHT\n6SS/3OPGUKVSYTQaGRgYeKw5LxAI0NXVRUxMjJSLKvZ3cnKSzs5Ouru7SU5OJjExkZiYGMkU9yQE\nQSA1NZUtW7YwNTXFpUuXGBgY4K/+6q+kTfdRfQ8VUbnxeDxcuXKF+vp6oqOj2bBhA6mpqSELX9Es\n39PTw9DQENeuXePzzz/n5s2beDwe3nrrLUpLS4PykwuCQHx8PIWFhTQ0NLB//378fj9r1qxhbGyM\nf/zHf2RwcJCEhARWr15NdXU1+fn5IfVX/FbiqUaMXBdPhF9//TU9PT38zd/8DYsXL5b6FS4PRi1H\nRkai0WhwOBz09fUxOTkpneYfhV6vp6SkhLq6Oq5evcr+/ftZsGDB905h381BDna+idalB5U7sS9i\nDrLP50OtVkv70aPa1mg0pKSkPPI54v8X18nY2Bjd3d1s374dk8n0xH5+F4PBwMKFCykoKOCLL76g\nubmZS5cuSRaRiIgIIiIiJCtlqFy7do2+vj62b9/Oc889h06n48KFC3z00Ud8+eWXLF++POS0xQcR\nlZyhoSHGx8fR6/Vht/XUCGCPx0NzczO//e1v6ezsxO12EwgEMBqN5OfnU1RUFJKmJbJ69Wqampqo\nra3F5XJJfq6Wlhba2tp4++23sVqtIbWpUCjIzMykr6+PxsZGkpKSUKlUtLS0MDg4KAm5YLR60Wfl\ndrsfEhyiKW14eFjSRuF//FJ9fX0MDQ0F1d9AIMDQ0BAej4fY2Fg8Hg+ff/45Bw8epK+vj4iICJRK\npZSTefz4cRoaGvjzP/9ztm3bFpQiIQqspqYm+vr6iI+PZ3h4mHfffZehoSEWL14ctF9YzK0OBALk\n5OQEHU0tCIJ0QvwhVCoV+fn57N+/n+HhYXw+HyqVitHRUY4fP857773H4OAgWq2W0tJS/uRP/oSS\nkpKgg7SMRiPLli0jNzeXmzdv8t577/Ev//Iv/N3f/R0LFy783kk+mIWrVqslq8KD73f37l3OnTvH\n0NAQW7ZsYcmSJWFtWIIgkJSUREVFBTdu3KCvr4+mpiYUCgWvvfYar7zySkh+crVazfr16/F6vRw6\ndIjf/e53vP/++/j9fkZGRnj++ef5oz/6I7KzszEYDGG5AMSxuHnzJv/5n/9JRUUFsbGxXLlyhYsX\nL5KVlUVJSclDyqD4p2hFChWXy8X58+dxuVxSHvOTTmlKpZL8/HxWr15Nc3MzNTU1REdH8/Of/5yJ\niQkp3kM0y/p8vpADAsWAKZVKRSAQkFK+hoeHsdvtdHV1kZmZSVVVlWQp+KF2HkUgEJBy1t1uN42N\njdy9exeHw8Hg4CDT09PodLqQLGYpKSm88cYbKBQKPvvsM/Lz83nttdfIzc3lq6++oqmpCbvdzqJF\ni0Ke0y6XS7I8WK1WYmJiKCsr4/r167S2tkpR1uEgugtmZmYYGBhgeHg46EPKo3gqBLDX6+X27dv8\n7ne/48yZMyQmJrJr1y70ej1NTU3MzMwwOTkZVmDPggUL+NM//VNWr16N2+1mZGSES5cucfPmTY4f\nPy5FzoaiEQmCQEpKCgsXLqS+vp5/+7d/w2QyMTk5SWVlJUqlMugPLAgCRqORtLQ0EhMTcTgcGAwG\nkpOTWbBgASMjI9jtdqamplAqlaSmplJRUUF5eTlLliwJ6hkqlYqkpCSMRiMtLS3MzMxw7do1fD4f\nL7/88kOpJU6nk3/4h39Ar9ej0+mCUiIEQSA7O5uqqiqOHz/Ob3/7WxITE+np6aG5uZn169ezYMGC\nkHJIx8bGGBkZIT09PagNWjQf5+fnPzEtSzQfnThxgrS0NPLy8vj222/55JNPuHv3LgqFgrGxMc6d\nO4fVasVqtZKdnR303FOr1VgsFun7/Ou//ivXr18nMTHxoRNDsPMtJiaGhQsX4vV66erqkjZtcZM1\nmUzk5+cTExMTVHuPIjo6mm3btpGZmUlNTQ3x8fGsX7+eV155RYqZCIXExESpvRMnTvDxxx9L/6ZS\nqYiOjg47IEan01FaWkp0dDROp5MrV65w584dNBoNLpcLn89HXFzcQwGQYnDY9PQ0JpMp5LHyeDyc\nOnWKzs5OZmZmSE9PJykp6QdjE0REk+uqVasYHR3liy++4PPPP6elpQWr1crg4CCCIHDq1Cm6urrI\nyMhg48aNQc0N0V1kMBjo7+/HYDBw8+ZN6urqmJycJC4uDqVSyd27d7lw4QJTU1Ps2bMn5HGfnp6m\nqamJs2fPSvWlBwYGOH36NNevX2dqaoply5axbds2bDZbUBYYtVr9UADi888/z4oVK4iOjiY1NZXB\nwUE6OjpCto7A/f3DZrNhtVqlw4MY6PZgLEK4TE1N0dvbi8PhoLu7m9zc3LBdmfMugH0+H3fu3OGT\nTz6htraWwsJCdu3aJQVbvf/++4yPjzM0NMSiRYtCbt9gMJCfn09ycjKBQACXy0VBQQEnT57k5MmT\nHDx4kNWrV5OdnR20gBAEAYPBQF5eHkqlku7ubrRaLRaLhW+//RaHwxGSMNfpdFLhkYaGBhQKBenp\n6WRnZ3P48GHu3r0L3BcyWq2WBQsWUFlZGfTJXRTcaWlpNDc3097ezsjICEuXLqWyspLc3FzgviJ0\n8OBBBEFg6dKlIVWPiYyMpLi4GI1GI20qtbW1WK1WnnnmmZD8fH6/H5fLhcPhCMoPOz09TUtLC+Pj\n4+Tk5DwxMtxsNrNgwQKuXr2K2Wxm8+bNdHd34/P52Lt3L1arlXv37nHy5Em6u7ulEoo/hGgKfLBO\ntZgqJlYra29vx+VyhWWyi4iIICMjg6ioKOrr69m6dSsGg4Guri4cDgcpKSnk5+eHZSESEc19ostj\nw4YNPPvss6Snp4el3Ws0Gmw2GwaDgeHhYT766COpTOSVK1f49NNP2b59O3l5eSHni2s0GsrKynjp\npZfo6uqis7OTe/fu4XQ6peCobdu2PXQ6FZUql8uFTqcLSgCL1qnu7m5OnDhBTU0NQ0NDFBUVsWnT\nJrKysoIaG5VKRVpaGtu3b5fiW1pbW2ltbcXpdCIIAhcvXqS/v18quRoMYuxDeXk5+/fvl/KXExIS\nKCsrw2q1otfrGRgY4Msvv6SxsZEtW7aELIBVKhVms5nc3FwSEhKA++Vsy8vLiY2NZXp6mqysrMea\n4r/LzMwMvb29NDY2snLlSpYtW0ZsbKwUFyHu+eFWYTOZTERHR0tKWVtbG3fu3CE9PX3Wte7j4uJY\nsmQJFy5coL+/H5/P90RF7IeYVwGsUNwvPtHc3Mz58+dJSUnhtddeY/Xq1ZjNZlwuF3FxcZIPI1xE\nc4RCcb8uamJioqQttrS0MDQ0xIIFC0Lue1xcHNHR0eTl5eH3+zEYDFy4cCHkSaNUKsnIyGDbtm0U\nFhbi8/kwm81ERkZy9uzZh9oSA6fE4iHBTnir1cq6dev49NNPsdvtwH1NbmxsDJVKhd/vp6mpicOH\nDxMdHc2KFStITU0N6T0sFgvLly+XomDPnTuHyWQiNTU1rIAKn8/H+Pj4Y7+/y+Xixo0bHD58mJSU\nFFJTUx9rMhcEAbPZzM6dO/n00085d+4cw8PDBAIBYmJiWLp0KVFRUcTExNDQ0CCZ3n4IMa3ixo0b\nxMXFkZmZKQVliKYq0SQWrpas1WrJyMigsrJSejfRrx8IBMjKyiItLS2stsV36Ovr48SJE1y6dAmb\nzcb27dulwifhIiok4lwtKSmhuLiY2tpajh49is/nY8+ePeTn54c0NkqlEpvNxosvvsjg4CA1NTWM\njo4yNTVFVlYWmzdvZs2aNd8LvhJ9/sEIfNEN0t3dzdGjR/n000/p7u6msLCQZ599ljVr1hAXFxd0\nn/V6PdkXDdbnAAAMyElEQVTZ2VgsFtLS0rh+/Tp9fX1cv36d3t5eAoEAaWlp5OTkBD0WKpWK5ORk\nXnrpJd577z1aW1ulUr3l5eWSItrf38+FCxeCjmX4LhqNRkrT8nq9WK1W7ty5w6pVqyQlRKVShVQV\nyul00tDQQEdHB7/85S8lN96DNR+e5E76IUSTueiua2pq4tSpUzidTsrLy2ed8x8XF0dZWRk6nY6x\nsbFZFfiY9xOweJyfnp5m586dbN26VQow6evrQxAEaYOZ7e06omAU88PEMPvZmCTEICG4bxLs7u6m\nrKws5LquKpUKm832UDWcycnJh2rlqlQqKUVI1JyDxWg0sm3bNux2O2fOnGF0dJSWlha+/vprKdDo\n/fff59atWzz//PMUFhYGVWHrUe8RFRUl1cQWTeeh9FWs5pSYmIjdbqevr4+0tLSHTuN+vx+n08m3\n337LoUOHuHXrFm+//TaxsbFP3MD0ej2bN2/G5/Px8ccfc/ToUWkzq6urY2pqikAggNPpfGI6hEJx\nv6rPwYMHSU5OlsyuarVaMo/29fWRn58ftslVrVaTmprKyy+/TCAQwGw24/f7GR8fJyYmhqysrLAK\nFohMT09TV1fHqVOniIuL49lnn52T6+sEQcDj8TA1NYXJZGLx4sVShbjf//73fPPNN1itVhITE0O6\ngUr04S5cuBCr1cqZM2fQaDRkZGSwadMmtm3bhtlsfqi9QCAglXp80vwQT75dXV0cO3aMzz77jJGR\nEfLz8/njP/5jVq9eHValMZVKRUJCAps3b2bFihUMDAzw7//+7wwPD7NgwQK2bt1KeXl5SOMQERHB\njh07sNvtUlR2T0+PdFmJz+ejs7OTkZERNm3aFNZFK2KAqZg5Ie4XBoNBih+B4APdxACm+vp6kpOT\nWblyJREREQiCwODgIL29vZJ5PRyMRiNOp5P29na8Xi+nTp3i2rVrlJSUsGrVqlkXHxJdKCkpKVKe\nf7gX2My7ABarqiiVShYsWMDU1BROpxOHw8GlS5ek/LukpKSQNxjx/4tXSInBTH19fVy8eJHW1tZZ\nXXH44DNEv+XY2BhLliwJy9T4qPcT0wBE5SEjI4OqqqqwIkczMzN544038Hg8nDlzhqGhIY4dO4bD\n4cBms/H1119TVlbG7t27fzAiMtj3EAQh7GvF1Go12dnZrF27ln379nH8+HF27Ngh5bb6fD7p0oeD\nBw/S0dHByy+/zKZNm4JWfKKioti+fTszMzPs37+f5uZmmpubuXHjhiR0xWpbT0ozEE+lFy9eRKVS\nSdXL2tra+PDDDyU/82w078jISFatWiX93el0MjIyQlRUFMnJyWGbwADsdjsnT54kJSWFF154gYKC\ngrDb+i4Wi4Xy8nJqa2sZGRnB7XazdetWhoaG+Oijjzh79iz5+fmPLcX5QwjC/ZKiFy5cIDIykj17\n9rB169YfLMEYihVtfHycs2fPSjXkd+7cye7duykpKZlViVVxHZvNZqKiosjOzub8+fPodDqp2Eyo\n+5xWq+XVV1/F4/Fw8uRJ/uM//oPk5GTi4+Px+Xy4XC4SExNZsWJF2Ckz4pp+MHXuwZ+HghjA1Nvb\nywsvvIDFYmF6epqpqSnOnDlDf38/RUVFYbkcAXJycjh+/DgHDhzA4/HgcDgoLS3l7bffZuHChXNS\nFctisbBq1Sq+/vprJiYmgq5D/13mXQBHRERIOYF/+7d/y6JFi+jv75c2l6qqKhYtWjSrykoNDQ0M\nDAxgs9mIjo6mvr6ew4cP4/V6ycnJkfLnZoNY7Uj0l8xFQfHR0VHa29ulZHdRiRCvUwzHIpCbm8uf\n/dmfkZSUxNGjR+nu7ubUqVPYbDaWLFnCX//1X5Obmzvr/ms0GhYtWsTVq1fD6mdCQgIbN26kpaWF\nf/qnf+Lrr78mNzcXg8FAb28vbW1tuN1uiouL+fWvf83y5ctDsjqIEfa7d++muLiYixcvcu3aNamA\nSmRkJPn5+axatYqMx5SyCwQCJCYm8s4773Ds2DEuXrzIkSNH8Hg8CIJAZmYmf/mXfxl0MFmwiMU5\nxBiCcBQdMW/2wIEDKJVK1q5dK8UDzBUajYasrCyqqqr4zW9+Q0tLCzt27JBySpuamrh06RKrV68O\neXwE4f6NNz6fj127drFz586gc5UfhxjtfOzYMZRKJc899xyvv/76nF0UAPfnjUajITs7G5PJJAmk\ncOrSC4KA1WrlnXfeoaqqiosXL2K325mZmSEpKYl169ZJ1oK5Kj0r+lfDTaPz+XxMTk5it9tpbGyk\nvr6eCxcuMDg4yKZNm3jhhRfCrlC3ePFiNm3axCeffML4+Djr1q3jpZdeemLJ11Dx+/3Y7XbcbnfY\nsmneBbDBYCAlJYX09HTJjGI2m1m3bh3Lly9n0aJFYZ1+4X9MSSdPnuTs2bP4fD60Wi1Op5OBgQHJ\nF5icnDzrzXF6eprbt28THx8/q7ywBxkbG5PC/P1+v1TGMNzxgPvmk4ULF7J3717Kysro6upCoVBI\naUR5eXlzdtl0Xl4edXV1Uq50KIhXiP3iF78gJiaG/v5+bt++jVarJTIykqVLl5Kdnc3ixYvJyckJ\nS7MXTXi5ubkkJiayceNGaVzF1J/o6OgnjocY5LZr1y5KS0ul/FCNRkNOTo5UZWkuEQThoXKT4Sqo\nYhGPyspKiouL50zAPNjPyMhIysvLKS8v59KlS/T09KBSqSTT64IFC8ISDJOTkzQ0NJCTk0NZWdkT\na4AH29/Lly/zxRdf4HK5+PnPf05VVZVU/3cuEQSB3NxcioqKcDqds5ojCoWC6OhoiouLSU9PZ2pq\nSlLORCvOXCmAgnD/2tiZmZmwgqS0Wi3p6ekUFxdz6NAhTp06hdFoJDMzk7Vr17Jy5UqprGg4fdPr\n9axfv16qUJiZmSm5hebi9AtIaZsVFRWzCuqaVwEsTpDi4mLefvttBgYG8Pl8mEwmaSIFU3/2cSiV\nSioqKhgfH+fmzZuMjIygUqkoLS2lqqqKDRs2zDoqDu5vBo2NjVLe6lwI4MjISNLS0qT7gZctW8by\n5cslf0m46HQ6MjIyiI+Pl25pMRqN6PX6sAM1HoXNZpM2LvHe21AwGo0UFRVhMBgYHByUXBXR0dHE\nxsaSkJBAXFzcrBQG0b+VkJAgRXg+SLDjLEb9WiwW8vLypKjoqKioORe+cH9sUlJS6O3tZWxsTMpn\nDgW/309dXR2JiYlStaofA5VKRWpqKm+88QZlZWXY7XZGR0cpKipi2bJlLFmyJOy8XLVazfLly2d9\n36uIWJWupaUFnU6H3++XqqLN1eYtEggEiI+PZ8+ePbjd7oduJgsHhUKB0Wj8nj9yLvutVCqlKnLh\nmuIVCgWJiYn87Gc/Y8GCBVLQaVZWFhkZGVit1lkrOzabTYp01+v1UsnPuSIyMpLVq1eTk5Pz2Nzq\nJzHvJ2AxwMRms0kavVi1Jdyyag+iVCqprKwkISGB5uZmSQAnJyezfPlyLBbLrM0yomk4Pj6e4uLi\nWV+sLWK1Wnnuuefo7e1FEAQKCgpYtGjRnJiRxEACcZKG48t5HIIgEBERwTPPPBNStOiDiBtKcXHx\nQ9HID9aKnas+z1U7er3+odP4XN4t/CAGg4HS0lK6u7u5cuUKixYtIjMzM6RAmLGxMQYGBli/fj1J\nSUlzZp78LqKlobKykqKiIu7du8fo6KhUv/pRis+TEAQBrVbLhg0b5lR5EP2zhYWFGI3GOb2o41Ho\n9XoqKioAwo76/S5zrSg8iBiBvmXLlllZHCIiIigvL6e4uPh7rrW56L+4v8GPMx56vZ6CggLy8/Nn\n9d0Uwo/5tb6D2+1+7L+LL/FjdCnU6i+hIAj3Lw+4ffs26enpmEymOdvMvtvvn/BzyTzFKBQKWltb\n+e///m+cTid79+5lxYoVQc8P8e7tkydPsmnTJpKSkubM8vE4HvWM2czp2cSGPApBEOjq6qK7uxu9\nXk9GRsZDmQgy95mrcf8x9/yniR9ykT1VAvh/O2Jg1P/1ySTzdOD1ehkcHMTj8WCxWEI6BQqCgMvl\neiitROY+YiEVQF7PMnPCUyGAZWRkZGRkZO7z4zh9ZGRkZGRkZB6LLIBlZGRkZGTmAVkAy8jIyMjI\nzAOyAJaRkZGRkZkHZAEsIyMjIyMzD8gCWEZGRkZGZh6QBbCMjIyMjMw8IAtgGRkZGRmZeUAWwDIy\nMjIyMvOALIBlZGRkZGTmAVkAy8jIyMjIzAOyAJaRkZGRkZkHZAEsIyMjIyMzD8gCWEZGRkZGZh6Q\nBbCMjIyMjMw8IAtgGRkZGRmZeUAWwDIyMjIyMvOALIBlZGRkZGTmAVkAy8jIyMjIzAOyAJaRkZGR\nkZkHZAEsIyMjIyMzD8gCWEZGRkZGZh6QBbCMjIyMjMw88P8AekW4kiHLlR8AAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sample = np.random.choice(X.shape[0], 20)\n", "plt.imshow(X[sample,1:].reshape(-1,20).T)\n", "plt.axis('off');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiclass Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Logistic regression hypothesis\n", "#### $$ h_{\\theta}(x) = g(\\theta^{T}x)$$\n", "#### $$ g(z)=\\frac{1}{1+e^{−z}} $$" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sigmoid(z):\n", " return(1 / (1 + np.exp(-z)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Regularized Cost Function \n", "#### $$ J(\\theta) = \\frac{1}{m}\\sum_{i=1}^{m}\\big[-y^{(i)}\\, log\\,( h_\\theta\\,(x^{(i)}))-(1-y^{(i)})\\,log\\,(1-h_\\theta(x^{(i)}))\\big] + \\frac{\\lambda}{2m}\\sum_{j=1}^{n}\\theta_{j}^{2}$$\n", "#### Vectorized Cost Function\n", "#### $$ J(\\theta) = \\frac{1}{m}\\big((\\,log\\,(g(X\\theta))^Ty+(\\,log\\,(1-g(X\\theta))^T(1-y)\\big) + \\frac{\\lambda}{2m}\\sum_{j=1}^{n}\\theta_{j}^{2}$$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def lrcostFunctionReg(theta, reg, X, y):\n", " m = y.size\n", " h = sigmoid(X.dot(theta))\n", " \n", " J = -1*(1/m)*(np.log(h).T.dot(y)+np.log(1-h).T.dot(1-y)) + (reg/(2*m))*np.sum(np.square(theta[1:]))\n", " \n", " if np.isnan(J[0]):\n", " return(np.inf)\n", " return(J[0]) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def lrgradientReg(theta, reg, X,y):\n", " m = y.size\n", " h = sigmoid(X.dot(theta.reshape(-1,1)))\n", " \n", " grad = (1/m)*X.T.dot(h-y) + (reg/m)*np.r_[[[0]],theta[1:].reshape(-1,1)]\n", " \n", " return(grad.flatten())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### One-vs-all Classification" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def oneVsAll(features, classes, n_labels, reg):\n", " initial_theta = np.zeros((X.shape[1],1)) # 401x1\n", " all_theta = np.zeros((n_labels, X.shape[1])) #10x401\n", "\n", " for c in np.arange(1, n_labels+1):\n", " res = minimize(lrcostFunctionReg, initial_theta, args=(reg, features, (classes == c)*1), method=None,\n", " jac=lrgradientReg, options={'maxiter':50})\n", " all_theta[c-1] = res.x\n", " return(all_theta)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta = oneVsAll(X, y, 10, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### One-vs-all Prediction" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def predictOneVsAll(all_theta, features):\n", " probs = sigmoid(X.dot(all_theta.T))\n", " \n", " # Adding one because Python uses zero based indexing for the 10 columns (0-9),\n", " # while the 10 classes are numbered from 1 to 10.\n", " return(np.argmax(probs, axis=1)+1)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set accuracy: 93.17999999999999 %\n" ] } ], "source": [ "pred = predictOneVsAll(theta, X)\n", "print('Training set accuracy: {} %'.format(np.mean(pred == y.ravel())*100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Multiclass Logistic Regression with scikit-learn" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "LogisticRegression(C=10, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = LogisticRegression(C=10, penalty='l2', solver='liblinear')\n", "# Scikit-learn fits intercept automatically, so we exclude first column with 'ones' from X when fitting.\n", "clf.fit(X[:,1:],y.ravel())" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set accuracy: 96.5 %\n" ] } ], "source": [ "pred2 = clf.predict(X[:,1:])\n", "print('Training set accuracy: {} %'.format(np.mean(pred2 == y.ravel())*100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Neural Networks" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def predict(theta_1, theta_2, features):\n", " z2 = theta_1.dot(features.T)\n", " a2 = np.c_[np.ones((data['X'].shape[0],1)), sigmoid(z2).T]\n", " \n", " z3 = a2.dot(theta_2.T)\n", " a3 = sigmoid(z3)\n", " \n", " return(np.argmax(a3, axis=1)+1) " ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set accuracy: 97.52 %\n" ] } ], "source": [ "pred = predict(theta1, theta2, X)\n", "print('Training set accuracy: {} %'.format(np.mean(pred == y.ravel())*100))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.3" } }, "nbformat": 4, "nbformat_minor": 0 }