{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression with Synthetic Data\n", "\n", "For more explanation of logistic regression, see\n", "1. [Our course notes](https://jennselby.github.io/MachineLearningCourseNotes/#binomial-logistic-regression)\n", "1. [This scikit-learn explanation](http://scikit-learn.org/stable/modules/linear_model.html#logistic-regression)\n", "1. [The full scikit-learn documentation of the LogisticRegression model class](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)\n", "\n", "## Instructions\n", "0. If you haven't already, follow [the setup instructions here](https://jennselby.github.io/MachineLearningCourseNotes/#setting-up-python3) to get all necessary software installed.\n", "0. Read through the code in the following sections:\n", " * [Data Generation](#Data-Generation)\n", " * [Visualization](#Visualization)\n", " * [Model Training](#Model-Training)\n", " * [Prediction](#Prediction)\n", "0. Complete at least one of the exercise options:\n", " * [Exercise Option #1 - Standard Difficulty](#Exercise-Option-#1---Standard-Difficulty)\n", " * [Exercise Option #2 - Advanced Difficulty](#Exercise-Option-#2---Advanced-Difficulty)\n", " * [Exercise Option #3 - Advanced Difficulty](#Exercise-Option-#3---Advanced-Difficulty)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy.random # for generating our dataset\n", "from sklearn import linear_model # for fitting our model\n", "\n", "# force numpy not to use scientific notation, to make it easier to read the numbers the program prints out\n", "numpy.set_printoptions(suppress=True)\n", "\n", "# to display graphs in this notebook\n", "%matplotlib inline\n", "import matplotlib.pyplot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Generation\n", "\n", "As we did in the [linear regression notebook](https://nbviewer.jupyter.org/github/jennselby/MachineLearningCourseNotes/blob/master/assets/ipynb/LinearRegression.ipynb), we will be generating some fake data.\n", "\n", "In this fake dataset, we have two types of plants.\n", "* Plant A tends to be taller (average 60cm) and thinner (average 8cm).\n", "* Plant B tends to be shorter (average 58cm) and wider (average 10cm).\n", "* The heights and diameters of both plants are normally distributed (they follow a bell curve).\n", "\n", "* Class 0 will represent Plant A and Class 1 will represent Plant B" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "NUM_INPUTS = 50 # inputs per class\n", "PLANT_A_AVG_HEIGHT = 60.0\n", "PLANT_A_AVG_WIDTH = 8.0\n", "PLANT_B_AVG_HEIGHT = 58.0\n", "PLANT_B_AVG_WIDTH = 10.0\n", "\n", "# Pick numbers randomly with a normal distribution centered around the averages\n", "\n", "plant_a_heights = numpy.random.normal(loc=PLANT_A_AVG_HEIGHT, size=NUM_INPUTS)\n", "plant_a_widths = numpy.random.normal(loc=PLANT_A_AVG_WIDTH, size=NUM_INPUTS)\n", "\n", "plant_b_heights = numpy.random.normal(loc=PLANT_B_AVG_HEIGHT, size=NUM_INPUTS)\n", "plant_b_widths = numpy.random.normal(loc=PLANT_B_AVG_WIDTH, size=NUM_INPUTS)\n", "\n", "# this creates a 2-dimensional matrix, with heights in the first column and widths in the second\n", "# the first half of rows are all plants of type a and the second half are type b\n", "plant_inputs = list(zip(numpy.append(plant_a_heights, plant_b_heights),\n", " numpy.append(plant_a_widths, plant_b_widths)))\n", "\n", "# this is a list where the first half are 0s (representing plants of type a) and the second half are 1s (type b)\n", "classes = [0]*NUM_INPUTS + [1]*NUM_INPUTS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization\n", "\n", "Let's visualize our dataset, so that we can better understand what it looks like." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEjCAYAAAA/ugbCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+cXHV97/HXJ8GQbAwgSYSSkKxYJLmgkmS1oKLmar3+QFErVln6wB+9aZJbFVovF0wt1tzU1uL1akUlVhCblaooVW8vFpSCgPxw+RVDgGohhE0WXIOEX4YA+fSPM0NmJzM7Z2bOnO/37Hk/H4957O7ZM3M+O4T5fH98vt9j7o6IiJTXlNABiIhIWEoEIiIlp0QgIlJySgQiIiWnRCAiUnJKBCIiJadEIFEzs6vM7I9DxyEymSkRSHBmtsXMfmtmj5nZg2b2NTN7bsbXeK2ZjbQ452tmttvMHq08NpnZp8zswDaus8XMXt9FnEeb2eVm9pCZPWxmN5vZm/O4tpSXEoHE4q3u/lxgKTAA/EWgOD7t7rOAucD7geOA68xsZk7X/wFwBXAo8Hzgw8AjOV1bSkqJQKLi7tuAy4Bj6n9nZi80syvNbIeZ/drMhszsoJrfbzGzj5rZRjPbaWbfNLPplQ/xy4DDKr2Ox8zssBZx7HL3nwFvA2aTJIUJYzCzfwQWAD+oXOPMyvFvm9kDlZh+YmZHN7qmmc0BXgB8xd13Vx7Xufu1NeecaGa3VXoLPzWzl0x0bZE0lAgkKmZ2OPBm4NZGvwY+BRwGLAYOBz5Rd867gTeSfKC+BHifuz8OvAnY7u7PrTy2p4nH3R8laaGf0CoGd/8jYCuV3o27f7rynMuAI0la+LcAQ00utwP4JbDBzN5uZoeM++PNlgAXAH9CkpzOB75vZvtPcG2RlpQIJBb/bGYPA9cCVwN/XX+Cu//S3a9w9yfdfQz4P8Br6k77vLtvd/eHSIZZjs0gtu3AwW3EUB/3Be7+qLs/SZI0Xtpo3sGTjb+WA1uAzwCjlR7EkZVTVgDnu/uN7v6Mu18EPEkyfCXSsf1CByBS8XZ3/9FEJ1RayJ8jaZ3PImnI/KbutAdqvn+CpOXerXnAQ23EUBvzVGAdcDLJvMOeyq/mADvrz3f3EeBPK889HFgPfB04HlgInGZmH6p5yjSy+RulxNQjkCL5a8CBF7v7AcCpJEM1aXS0zW6leun1wDUpY6i/zinASZXXOBDor750y4Dd7wfOY+98yf3AOnc/qObR5+4XN7m2SCpKBFIks4DHgJ1mNg/4n20890FgdtpSUDPb38yWAf9M0uK/MGUMDwJH1MX8JMn4fx8Nhrxqrvk8M/srM/tdM5tSmTz+AHBD5ZSvACvN7PcsMdPM3mJms5pcWyQVJQIpkr8iKS/dCfwL8N20T3T3u4CLgXsqFTfNhlPONLNHST64vw7cDLyiMuGcJoZPAX9RucZHK69xH7AN2MzeD/VGdpP0GH5EUjK6iSSJvK/yNwwD/x34Akly+mX1d02uLZKK6cY0IiLlph6BiEjJKRGIiJScEoGISMkpEYiIlJwSgYhIySkRiIiUnBKBiEjJKRGIiJScEoGISMkpEYiIlJwSgYhIySkRiIiUnBKBiEjJKRGIiJScEoGISMkpEYiIlFzPEoGZXWBmvzKzTTXH/s7M7jKzjWZ2qZkd1Kvri4hIOj27Q5mZvZrk3q5fd/djKsfeAFzp7k+b2d8CuPv/avVac+bM8f7+/p7EKSIyWd18882/dve5rc7br1cBuPtPzKy/7tjlNT/eALwrzWv19/czPDycXXAiIiVgZvelOS/kHMEHgMua/dLMVpjZsJkNj42N5RiWiEi5BEkEZrYGeBoYanaOu6939wF3H5g7t2XPRkREOtSzoaFmzOx9wInA67xXExQiIpJaronAzN4InAm8xt2fyPPaIlJ8Tz31FCMjI+zatSt0KFGZPn068+fP5znPeU5Hz+9ZIjCzi4HXAnPMbAQ4Bzgb2B+4wswAbnD3lb2KQUQml5GREWbNmkV/fz+Vz5DSc3d27NjByMgIL3jBCzp6jV5WDb23weGv9up6IjL57dq1S0mgjpkxe/Zsuimq0cpiESkUJYF9dfueKBGIiJScEoHsY3QUXvhCeOCB0JGIFMdrX/vajhe+XnXVVfz0pz/NOKL0lAhkH2vXwpYtyVcR6T0lAonK6ChceCHs2ZN8Va9ACm1oCPr7YcqU5OtQ0zWsqWzZsoVFixYxODjI4sWLede73sUTT+xbCb9q1SoGBgY4+uijOeecc5493t/fzznnnMPSpUt58YtfzF133cWWLVv48pe/zGc/+1mOPfZYrrnmmnGvddNNN3H88cezZMkSXvGKV3D33Xd39Tc0okQg46xdmyQBgGeeUa9ACmxoCFasgPvuA/fk64oVXSeDu+++m9WrV3PnnXdywAEH8MUvfnGfc9atW8fw8DAbN27k6quvZuPGjc/+bs6cOdxyyy2sWrWKc889l/7+flauXMkZZ5zBbbfdxgknnDDutRYtWsQ111zDrbfeyic/+Uk+9rGPdRV/I0oE8qxqb2D37uTn3bvVK5ACW7MG6lvrTzyRHO/C4Ycfzitf+UoATj31VK699tp9zvnWt77F0qVLWbJkCXfccQebN29+9nfvfOc7AVi2bBlbtmxpeb2dO3dy8sknc8wxx3DGGWdwxx13dBV/I0oE8qza3kCVegVSWFu3tnc8pfpSzfqf7733Xs4991x+/OMfs3HjRt7ylreMWwm9//77AzB16lSefvrpltf7+Mc/zvLly9m0aRM/+MEPerKqWolAnvX97+/tDVTt3g3f+16YeHqpncooVVEV1IIF7R1PaevWrVx//fUAfOMb3+BVr3rVuN8/8sgjzJw5kwMPPJAHH3yQyy5rusnys2bNmsWjjz7a8Hc7d+5k3rx5AHzta1/rKvZmlAjkWSMjyVBq/WNkJHRk2WunMkpVVAW1bh309Y0/1teXHO/CUUcdxXnnncfixYv5zW9+w6pVq8b9/qUvfSlLlixh0aJFnHLKKc8OI03krW99K5deemnDyeIzzzyTs88+myVLlqTqQXSiZ3coy9LAwIDrxjSSldFROOII2LULZsyAe+6BQw/t/lzpvTvvvJPFixenf8LQUDInsHVr0hNYtw4GBzu+/pYtWzjxxBPZtGlT65Nz1ui9MbOb3X2g1XPVI5DSaacySlVUBTc4mHTn9uxJvnaRBCYzJQIplXYqo1RFJfX6+/uj7A10S4lASqWdyihVUUlZKBFIqbRTGVWmKqpaqpIqHyUCKZV2KqPKVEVVS1VS5aNEIFJgWbfetddUOSkRiBRY1q13VUl1rlfbUH/iE5/g3HPP7Sa0lpQIRAoq69a7qqTC0TbUItKRrFvvk7VKKsvhsxDbUAPcfvvtHH/88Rx55JF85Stf6f4PqaNEIMGpSqV9vWi9T9YqqayHz/Lehhpg48aNXHnllVx//fV88pOfZPv27dn8MRVKBBKcqlTa14vW+2SskurF5Hfe21ADnHTSScyYMYM5c+awfPlybrrppu7/kBpKBCUQc4tbVSqdmayt96z1YvI7722o01yzW0oEJRBri3t0FBYtSv4HhckxHp2Xydh6z1qvJr/z3oYa4Hvf+x67du1ix44dXHXVVbzsZS/r7o+o07NEYGYXmNmvzGxTzbGTzewOM9tjZi13xJPuxdziPusseOQReOqp5GdVqUiWejX5nfc21AAveclLWL58Occddxwf//jHOeyww7r7I+q5e08ewKuBpcCmmmOLgaOAq4CBtK+1bNkyl86sWuU+bVrSXpw2zX316tARJbZvd586dd82basYt293P+II99HR/GKVeGzevDn1ufPmNeozJcc7de+99/rRRx/d+Qv0UKP3Bhj2FJ+xPesRuPtPgIfqjt3p7nf36poyXsx14WvX7h0SqlUd5242rxHrMJfER8Nn6UU7R2BmK8xs2MyGx8bGQodTSLHWhVcTVK0ZM5Lj1f9RG33gxzzMJeWgbahz5u7r3X3A3Qfmzp0bOpxCirWypFWCavaBr+0PBKgOM0uNbt+TaBOBdC/WrnGrBNXoAz/mYS7Jz/Tp09mxY0ecyWDHDti4EYaHk687duRyWXdnx44dTJ8+vePX2C/DeKSHRkfhVa+C667L5565vbzeRImo2Qf+448370Wcd1628Um85s+fz8jICNENFz/+ePLBX5ugRkdh9myYObPnl58+fTrz58/v/AXSzCh38gAuBkaBp4AR4IPAOyrfPwk8CPxrmtdS1VBS/TNlSn5VP3lfr/a61Sqn2kqimTOzrwARyczChY3/gS5cGDQsUlYNmcfYxaozMDDgnW7vOhmMjsIRR8CuXcmk6j339LZXkPf1as2fD9u27Xt83rzwQ1oiTU2ZMr43UGW2b1c2R2Z2s7u3XLOlOYICyHuSNOSkbKzzGiITWrCgveORUSKIXN6TpJqUFenAunXQ1zf+WF9fcrwAlAgil/dagFjXHrQS88Z6UgKDg7B+PSxcmAwHLVyY/Dw4GDqyVDRHELm8x8yLOka/ejWcfz6sXKkqIpEqzRFMEr0eM69vSRdxjF4rjkW6o0RQcpNh7x6tOBbpjoaGSixkmWhWav+GqqL+LSJZ09CQtNRpSzqmidmiTm6LxESJoKS6KRONaTgp1o31RIpEiaCkOm1JxzYxW8TJbZHYKBGUVKct6bJNzMY0DCbSK0oEJdVJS7qMq45jGgYT6RUlAkmtbBOzsQ2DifSKEoGkVraJ2bINg0l5KRFIamWamC3jMJiUlxKBSANlGwaTclMiEGmgbMNgUm5KBCINjIzAqlXJjadWr57cw2AiSgQiDahiSMpEiUCkAVUMSZkoEYjUUcWQlI0SgUgdVQxJ2SgRiNQpUsWQ9kKSLPQsEZjZBWb2KzPbVHPsYDO7wsx+Ufn6vF5dX6RTRVo4p72QJAu97BF8DXhj3bGzgB+7+5HAjys/S0GpNdo7ad5bVTZJVnqWCNz9J8BDdYdPAi6qfH8R8PZeXV96L8bWaBbJKYYEl+a9ra9setGLkphjiF+KJe85gkPcfbTy/QPAIc1ONLMVZjZsZsNjY2P5RCepxdoazSI5hU5wad7bRpVNjz4KZ58dPn4pnmCTxe7ugE/w+/XuPuDuA3Pnzs0xMkkjxjr7LJJTDAkuzXvbqLIJ4OtfDx+/FE/eieBBM/sdgMrXX+V8fclArHX2WSSn0Aku7XvbqLIJktirx2NJ0BK/vBPB94HTKt+fBkRYkCetxFhnn0VyavYat9+e35h72ve2Wtm0fTtMnz7+d9Xnx5KgJX69LB+9GLgeOMrMRszsg8DfAL9vZr8AXl/5WQomxjr7LJJTs9cYHJx4zD3Lydl239tmQ0RVoRN0VIaGoL8/2Umwvz/5WQCwZKg+bgMDAz48PBw6DInY/Pmwbdu+x+fNS1//3+w1qmbMgHvugUMPHX989Wo4/3xYuRLOOy99zFloFTO09x5MWkNDsGIFPPHE3mN9fbB+fZLpJykzu9ndB1qdp5XF0pHYShSzWATW6DVWrYJp05LfN2pdh55cbvZ3B18IF1vre82a8UkAkp9PPTWO+AJTIpCOlKFEMc28Q+jJ5ShVW9/33ZdkovvuS34O+WG7dWvz38UQX2BKBCXVTYs+dCs4L63mHWKtngquWet7zZow8QAsWDDx70PHF5gSQUl106IvSyu41cRtjNVTUWjW+p6oVd5r69YlcwITCRlfYEoEJdRNi75MreBW8w4xVk91LYux/Wat71at8l4aHEwmhhcubH5OyPgCUyIooW5a9GoF71WkXUpTyWpsv1Hru68vOR5StQ54w4Y44wvJ3aN/LFu2zCUb27e7T58+/qNrxgz30dF0z583r3Ftyrx5yWsfcUT615LGgr2PCxc2/o+7cGH7r7VhQ/I8s+Trhg3Zxtqt2OPLCDDsKT5j1SMomW5b9BO1gstQSZSHYO9jlmP71db3nj3J19hq9WOPL2dKBCXTq3HtbucdYlqTEFLQiqwYx/YlF0oEJdOrce1u5x163QIuSrIJWpEV69i+9F6a8aPQD80RxK2beYfa57YzV9GuVavcp0xxX726N6+fhW7nbzKR9dh5ScbiY4XmCCQv3cw75NECLsoCuCgqsrIcO49xhbE0pEQgXet03iGvNQlFWQA36dYlxLjCWBpSIpCudTrvkEcLuEgL4Bq+jxuGGNmvP57N29oR4wpjaUiJQILJowWc53BL5hPSRR9aURVSYSgRtFCUapMiymNlbp7DLZlXPxV9aEVVSIWhRNBCmRdJTYYkmNc2EJ1OSE/4Hhd9aKV2fx+z5OskvxFMUSkRTKAo1Sa9UuYk2K5OJ6QnfI8nw9CKVvAWghLBBIpSbVKVZQu+7EmwHZ1OSI9+4Ttc+OVdyXv8pd/ywHnfGX+ChlYkJ0oETRSp2qQqyxZ8UZJgDMNXHU1IDw2x9vQd7KncMvwZN9Z+5NfjJ4I1tCJ5SbPqLPQjxMriVavcp00bP7I8bVq8K1OzXKEbxQrXlGJYMTzRjqzNbJ834NN5Yvx7zOM+On8gv8Bl0kMri7tTtMU9Wbbgo1jhmkIsw1edTEiv3fZ+9mDjjj3DFNaOvC+zuGLoLXUkthvfl0GabBH6ob2GJpZ1C76TFm4Itb22mHtrjcybur3xezx1e2bXiKG31LYNG9z7+sa/KX192qOoQ6hHUB5Zt+CLcOetIs7h1Bq56Eq8byaO7X30zWTkoiszef1Yektt62btROieROjrdyNNtgD2B04BPgb8ZfWR5rlNXu8jwCbgDuD0VuerRzCxLFrwRbu7WNHmcBrq4c6che0tmTX+x2w28fNC9yRCX78JUvYILDl3Ymb2Q2AncDPwTE0S+Uy7icfMjgH+CXg5sBv4IbDS3X/Z7DkDAwM+PDzc7qWkDatXw/nnw8qVcN55oaNpbf582LZt3+Pz5sXVcwlhdBSOOAJ27dp7bMYMuOceOPTQcHGl0t+fbKVRb+HCpCSuamgo6SVs3Zqsq3jsMdixo/XzeiVt3Dkzs5vdfaDVeWmHhua7+x+6+6fd/TPVR4exLQZudPcn3P1p4GrgnR2+lmSgiMMIRRi+CiXzyf48hzzSrJ1otAdToyQA+a3CLvgq8LSJ4Kdm9uKMrrkJOMHMZptZH/Bm4PD6k8xshZkNm9nw2NhYRpeWRoqyZkDSybTiLe+N79KsnWg0j9BMXquwC74KfMKhITP7OeDAfsCRwD3Ak4AB7u4v6eiiZh8EVgOPk8wTPOnupzc7X0NDvVPoYQTpvRiHPKZMSZJSK319+S3AqybM2gSV5/WbyGpo6ETgrcCbgN8F3lD5uXq8I+7+VXdf5u6vBn4D/HunrzVZ5VUDXpQ1A5KDRkNAMQ55NGtlz54dbhV20VeBp5lRBv4xzbG0D+D5la8LgLuAgyY6v4xVQ3nVgBdlzUBV0aqbCqNZ1cvs2Y3/gSxcGF+sE1XolPTeyaSsGkr7wX1L3c9Tgc1pntvk9a4BNgO3A69rdX7ZEkHtAjEz99tvDx1RPHqVINMmmEmbiBYubPyBP3t2lGWRbX2wR1ramYdMEgFwNvAo8DTwSOXxKLAD+FSaC2TxKFsiqK+RP/ro0BHFIcv9lOqlTTCFXK2bxkT1+zG0pruJoVmSC9mryUnWPYLcPvQbPcqUCBptFwHqFbj3bpFU2gTTy0QUXMwflt226DtdpDYJpE0EE04Wm9lSM1sKfLv6fe0jizkKGa/R5C3AKafkH0tMermlRNry2bV/9O/s2fVkct5vn2TtqXd3f/FYxHzvg25v2Vnw0s5cTJQlgH+rPK4HngKGSVYXPwVcnybTZPEoU4+g2eQtTLIWaJt6taVE2g37tv/9JY23jf7CJd0FEJMYhoAa6bZFrzmC7noE7r7c3ZcDo8BSdx9w92XAEqDBAn/p1sgIrFoF06aNPz5tWrlLOnu1LXiz8tmzzhpfvrv27Ccabxt91uPdBRCTkLeVnGj1crct+qKXduYhTbYA7khzrFePMvUI3ItX0llkzd7rmTPHTwrP4/7G/024P+wf4B5vSz6tVi32Erfou0XGm85dTLIKeEPl0CDwXHd/b2/S03haWSx5ql1t/ewq6+P641thC9GuaG1LmtXL9ZvMrVtXnL8voKw3nXs/yVYQH6k8NleOiUw6DSePY51M7XYiNYY99NOsXg45bFUCqXoEoalHIHmZcO+lH0fYKm22745Z4/KzWrH0JmLcz2iSyKRHYGbfqnz9uZltrH9kFaxILCbceynGVmk3E6nd9iayEmtvq0RaDQ3damYvB95Bsslc/UNkUulVdVLPdPMhGsuGcqrqCa5VIpgN/F/gJuDrwJ8ALwYedfcGfTkpmrx2OS2Kwt3wppsP0ZgWWsXQ24phviSQVusIPururwAOJdl36CGSSeJNZrY5h/ikx9auTf6/y3uNghJQhjr9ENWQzF5534AnMmmrhmYABwAHVh7bgRt7FZTkI+QtKkMlIKmhIZm9YpkvCaTVZPF6M7sO+CZwPPBT4GRPVhirfLTgQt2isoj3SM5MbMMPMQzJxCCW+ZJAWvUIFgD7Aw+QbCkxAjzc66Ck93q5iVsrpb1HcsmHHxqKJTHGNF8SQKs5gjcCLwPOrRz6c+BnZna5mf1Vr4OT3gl1i8qQCSi4kg8/7COmxFjy+ZKWcwSVLSs2Af8fuAy4DnghyQpjKahQZZKlvkdyyYcf9hFTYiz5fEmrOYIPm9k/mdlW4GqSm9bfBbwTODiH+KRHQpVJFq5OP0vdDD/EMoSSpdgSY4nnS1r1CPqBbwO/5+4vdPc/cvcvufvt7t5i/brELkQJZ+Hq9LPU6fBDTEMoWSr5uHxMWs0R/Jm7f8fdR/MKSPKjEs6ctTv8UO0FnHpqPEMoWSr5uHxM0q4jkEmm1CWcIaUdfqjtBTQTYgilnSGqVucODsJpp8HUqcnPU6cmP5doSCYWSgQlVdoSzqJoNJFaL+8hlHaGqNKcOzQEF12U/AOE5OtFFxV/yKuAtA11CU241fKh4eKSGs22l66KfbvoNOdq++mey/rGNJkyszPM7A4z22RmF5vZ9BBxlFWpSziLYqLWfqjSxnaqfNKcG1vVUInlngjMbB7wYWDA3Y8BpgLvyTuOMit1CWdRrFsH06bte/w5zwl3Q5x2qnzSnKuqoWiEmiPYD5hhZvsBfSSb2ElOSl3CWRSDgzBr1r7Hn3oqXLVQ2iqfoSF47LF9n19/rqqGopF7InD3bSRbVmwFRoGd7n55/XlmtsLMhs1seGxsLO8wRcJ76KHGx0MuuGpV/lqdJN6xY/xzZ8/e99ySr+aNSe6TxWb2POA7wB+SbGD3beASd9/Q7DmaLJZSKuJkahFjnsRinix+PXCvu4+5+1PAd4FXBIhDJG5FHDrRBHAhhUgEW4HjzKzPzAx4HXBngDhE4lbEoRNNABdSiDmCG4FLgFuAn1diWJ93HCKFULSN0IrYi5EwVUPufo67L3L3Yyob2T0ZIg4RyVgRezGiLSZEJGNF68XEKOdtx5UIRCRuk/FeDBMJsO24EoGI5KvdHUyLei+GThNYgDu3lSIRhLgBi0hp9PKDPabbWbajmwQWoAS3FIlAN2AR6ZFef7AXdV1CNwksQAnupE8EugGLSA81+8A77bTGyaDdD/airkvoJoEFKMGd9IlAN2AR6aFmH2zPPNO4Z9DuB3tR1yV0k8BClOC6e/SPZcuWeSe2b3efPn38HpszZriPjnb0ciJSb+HCRhvZ7n0sXDj+/A0b3Pv6xp/T15ccb2bDhuR1zJKvE50bi07+zh4Ahj3FZ+yk7hHoBiwFkHdpYNlKEXutUYu9Vn2PoZPWbhHXJRRsYd2kvlXl/Pmwbdu+x+fN0977UahONNaOMffyFox5X68shoaSOYHqvYdradfRoNLuPjqpE4FELu8ti7VFcu8oyUYp5m2oRRJ5lwYWtRSxCAo2FCLj7Rc6ACmxBQsat9B7VRqY9/XKZnBQH/wFpR6BhJN3aWBRSxFlX5r0z5QSgYST93CChi8mhyLvPxQpTRaLSLFo0j81TRaLyOSkSf/MKRGISLEUdf+hiCkRiEixaNI/c0oEIlIsmvTPnNYRiEjxaM1CptQjEBEpOSUCEZGSyz0RmNlRZnZbzeMRMzs97zhERCSR+xyBu98NHAtgZlOBbcClecchIiKJ0ENDrwP+w90bLBMUEWlCew1lKnTV0HuAiwPHICJFUn/vg+peQ6BKog4F6xGY2TTgbcC3m/x+hZkNm9nw2NhYvsGJSG9106Jfs2b8DXAg+XnNmiwjLJWQQ0NvAm5x9wcb/dLd17v7gLsPzJ07N+fQREqul0Mv3e4eqr2GMhcyEbwXDQuJxKfX2zx326LXXkOZC5IIzGwm8PvAd0NcX0Qm0Ouhl25b9NprKHNBEoG7P+7us919Z4jri8gEej300m2LXnsNZS50+aiINBKyPLLXQy9ZtOgHB5Ob0OzZk3xVEuiKEoFIbELfirHXQy9q0UdHt6oUiU0Mt2IcGkrmBLZuTXoC69bpg7qA0t6qUolAJDZTpiQ9gXpmyVCISEq6Z7FIUak8UnKmRCDSTKgJW5VHSs6UCEQaCTlhq8lUyZnmCEQaiWHCVqRLmiMQ6Yb2s5ESUSIQaUQTtlIiSgQijWjCVkpEiUCkEU3YhqE7jwUR+g5lIvEaHNQHf55057Fg1CMQkTjozmPBKBGISBxUqRWMEoGIxOHgg9s7LplRIhARKTklAhHpTNYVPg891N5xyYwSgYi0rxd7MWkRXzBKBCLSvl5U+GgRXzBKBCLSvl5U+GgRXzBaUCYi7VuwoPHurN0O42gRXxDqEYhI+2IbxtHWFF1RIhCR9sU0jBPyJkKTRJAb05jZQcA/AMcADnzA3a9vdr5uTCMiTekmQk3FfmOazwE/dPdFwEuBOwPFIRIvDXeko60pupZ7IjCzA4FXA18FcPfd7v5w3nGIRE3DHelp/UHXQvQIXgCMARea2a1m9g9mNrP+JDNbYWbDZjY8NjaWf5QiIWknzvRim7guoBCJYD9gKfAld18CPA6cVX+Su6939wF3H5g7d27eMYqEpeGO9GKauC6oEIlgBBhx9xsrP19CkhhEpEpMrTrlAAAHAklEQVTDHe0ZHEwmhvfsSb4qCbQl90Tg7g8A95vZUZVDrwM25x2HSNQ03CE5CrWy+EPAkJlNA+4B3h8oDpE4VVu0a9Ykw0ELFiRJQC1d6YEg6wjapXUEIiLti30dgYiIREKJQGQy0mI0aYN2HxWZbKqL0arrEKqL0UBzDNKQegQik40Wo0mblAhEJhstRpM2KRGITDZajCZtUiIQmWy0GE3apEQgMtlo7x1pk6qGRCYj3ftX2qAegYhIySkRiIiUnBKBiEjJKRGIiJScEoGISMkpEYiIlFwh7kdgZmPAfcAc4NeBw2lX0WIuWrygmPNQtHiheDH3It6F7t7ypu+FSARVZjac5iYLMSlazEWLFxRzHooWLxQv5pDxamhIRKTklAhEREquaIlgfegAOlC0mIsWLyjmPBQtXihezMHiLdQcgYiIZK9oPQIREclY1InAzLaY2c/N7DYzG64c+4SZbascu83M3hw6zqpG8VaOf8jM7jKzO8zs0yFjrNfkPf5mzfu7xcxuCx1nrSYxH2tmN1SPmdnLQ8dZ1STel5rZ9ZXjPzCzA0LHWcvMDjKzSyr/bu80s+PN7GAzu8LMflH5+rzQcVY1iffkyv9ze8wsuuqhJjH/XeXnjWZ2qZkdlEsw7h7tA9gCzKk79gngo6FjayPe5cCPgP0rPz8/dJytYq77/WeAvwwdZ4r3+XLgTZXv3wxcFTrOFvH+DHhN5fsPAGtDx1kX30XAH1e+nwYcBHwaOKty7Czgb0PH2SLexcBRwFXAQOgYU8b8BmC/yrG/zes9jrpHMEmsAv7G3Z8EcPdfBY4nNTMz4N3AxaFjScGBaqv6QGB7wFjSeBHwk8r3VwB/EDCWcczsQODVwFcB3H23uz8MnETy4UXl69vDRDhes3jd/U53vztsdI1NEPPl7v505bQbgPl5xBN7InDgcjO72cxW1Bz/00rX6YKYuqc0jvdFwAlmdqOZXW1mLwsYXyPN3mOAE4AH3f0XAeKaSKOYTwf+zszuB84Fzg4W3b4axXsHyQcrwMnA4UEia+wFwBhwoZndamb/YGYzgUPcfbRyzgPAIcEiHK9ZvDFLE/MHgMtyiSZ096hF12le5evzgdtJMughwFSSJLYOuCB0nC3i3QT8PWDAy4F7qVRrxfBoFHPN774E/HnoGFO+z58H/qBy/N3Aj0LH2SLeRSTDWTcD5wA7QsdZE+8A8DTwe5WfPwesBR6uO+83oWOdKN6a319FZENDKWJeA1ya12dF1D0Cd99W+forkjfl5e7+oLs/4+57gK+QfLhGoVG8wAjwXU/cBOwh2VMkCk1ixsz2A94JfDNcdI01ifk04LuVU75N5P8u3P0ud3+Duy8jGXr7j5Ax1hkBRtz9xsrPlwBLgQfN7HcAKl9jGeZsFm/MmsZsZu8DTgQGvZIVei3aRGBmM81sVvV7kkmUTdV/iBXvIGlxB9csXuCfSSaMMbMXkUwKRbER1gQxA7weuMvdR0LF18gEMW8HXlM57b8CUQxnTfDv+PmVY1OAvwC+HC7K8dz9AeB+Mzuqcuh1wGbg+yQJl8rX7wUIbx8TxButZjGb2RuBM4G3ufsTecUT883rDwEuTeYr2Q/4hrv/0Mz+0cyOJRl33QL8SbgQx2kW7zTgAjPbBOwGTssry6fQMObK795DnJPEzd7nx4DPVXoyu4D6+Y5QmsX7ETP7H5VzvgtcGCrAJj4EDFX+/d4DvJ+k4fgtM/sgyW7A7w4YX7194jWzd5AMy84F/sXMbnP3/xYyyDqN3uOfAfsDV1T+zdzg7it7HYhWFouIlFy0Q0MiIpIPJQIRkZJTIhARKTklAhGRklMiEBEpOSUCKb1K6Wntz+8zsy+0eM7bzOysFue81sz+X5PfnW5mfe1HK5I9JQKRDrj79939b7p4idMBJQKJghKByATMbK6ZfcfMflZ5vLJy/Nleg5m90JJ7IfzczP53XQ/juTV7zg9Z4sPAYcC/mdm/BfizRMaJeWWxSF5m2Pib7xxMsp0CJJuBfdbdrzWzBcC/kuxzX+tzwOfc/WIzq18FugQ4mmQLjOuAV7r7583sz4Dl7h7FdiNSbkoEIvBbdz+2+kNl06/qHa1eD/yXynJ/gAPM7Ll1zz+evXvzf4NkG+yqm6r7NVWSTT9wbZbBi3RLiUBkYlOA49x9V+3BmsTQypM13z+D/p+TCGmOQGRil5NsDgYk90ZucM4N7L3D2HtSvu6jwKzuQhPJhhKByMQ+DAxU7oi3GWi0E+TpwJ+Z2Ubgd4GdKV53PfBDTRZLDLT7qEiXKusBfuvubmbvAd7r7ie1ep5ILDReKdK9ZcAXLJk4eJjkXrMihaEegYhIyWmOQESk5JQIRERKTolARKTklAhEREpOiUBEpOSUCERESu4/Aa4CoewVxpXPAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# create a figure and label it\n", "fig = matplotlib.pyplot.figure()\n", "fig.suptitle('Plant Data Set')\n", "matplotlib.pyplot.xlabel('Height')\n", "matplotlib.pyplot.ylabel('Width')\n", "\n", "# put the generated points on the graph\n", "a_scatter = matplotlib.pyplot.scatter(plant_a_heights, plant_a_widths, c=\"red\", marker=\"o\", label='plant a')\n", "b_scatter = matplotlib.pyplot.scatter(plant_b_heights, plant_b_widths, c=\"blue\", marker=\"^\", label='plant b')\n", "\n", "# add a legend to explain which points are which\n", "matplotlib.pyplot.legend(handles=[a_scatter, b_scatter])\n", "\n", "# show the graph\n", "matplotlib.pyplot.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Training\n", "\n", "Next, we want to fit our logistic regression model to our dataset." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept: [0.4923611] Coefficients: [[-0.30765052 1.95764602]]\n" ] } ], "source": [ "model = linear_model.LogisticRegression()\n", "model.fit(plant_inputs, classes)\n", "\n", "print('Intercept: {0} Coefficients: {1}'.format(model.intercept_, model.coef_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction\n", "\n", "Now we can make some predictions using the trained model. Note that we are generating the new data exactly the same way that we generated the training data above." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Plant A: 59.93207269053181 7.030073699859045\n", "Plant B: 59.619487931898846 10.240893276828622\n", "Class predictions: [0 1]\n", "Probabilities:\n", "[[0.98498204 0.01501796]\n", " [0.09989079 0.90010921]]\n" ] } ], "source": [ "# Generate some new random values for two plants, one of each class\n", "new_a_height = numpy.random.normal(loc=PLANT_A_AVG_HEIGHT)\n", "new_a_width = numpy.random.normal(loc=PLANT_A_AVG_WIDTH)\n", "new_b_height = numpy.random.normal(loc=PLANT_B_AVG_HEIGHT)\n", "new_b_width = numpy.random.normal(loc=PLANT_B_AVG_WIDTH)\n", "\n", "# Pull the values into a matrix, because that is what the predict function wants\n", "inputs = [[new_a_height, new_a_width], [new_b_height, new_b_width]]\n", "\n", "# Print out the outputs for these new inputs\n", "print('Plant A: {0} {1}'.format(new_a_height, new_a_width))\n", "print('Plant B: {0} {1}'.format(new_b_height, new_b_width))\n", "print('Class predictions: {0}'.format(model.predict(inputs))) # guess which class\n", "print('Probabilities:\\n{0}'.format(model.predict_proba(inputs))) # give probability of each class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise Option #1 - Standard Difficulty\n", "\n", "Answer the following questions. You can also use the graph below, if seeing the data visually helps you understand the data.\n", "1. What should we be expecting as the output for class predictions in the above cell? If the model is not giving the expected output, what are some of the reasons it might not be?\n", "1. How do the probabilities output by the above cell relate to the class predictions? Why do you think the model might be more or less confident in its predictions?\n", "1. If you change the averages in the data generation code (like PLANT_A_AVG_HEIGHT) and re-run the code, how do the predictions change, and why?\n", "1. Looking at the intercept and coefficient output further above, if a coefficient is negative, what has the model learned about this feature? In other words, if you took a datapoint and you increased the value of a feature that has a negative coefficient, what would you expect to happen to the probabilities the model gives this datapoint?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise Option #2 - Advanced Difficulty\n", "\n", "The plot above is only showing the data, and not anything about what the model learned. Come up with some ideas for how to show the model fit and implement one of them in code. Remember, we are here to help if you are not sure how to write the code for your ideas!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise Option #3 - Advanced Difficulty\n", "\n", "If you have more than two classes, you can use multinomial logistic regression or the one vs. rest technique, where you use a binomial logistic regression for each class that you have and decide if it is or is not in that class. Try expanding the program with a third type and implementing your own one vs. rest models. To test if this is working, compare your output to running your expanded dataset through scikit-learn, which will automatically do one vs. rest if there are more than two classes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }