{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multivariate Logistic Regression Demo\n", "\n", "_Source: 🤖[Homemade Machine Learning](https://github.com/trekhleb/homemade-machine-learning) repository_\n", "\n", "> ☝Before moving on with this demo you might want to take a look at:\n", "> - 📗[Math behind the Logistic Regression](https://github.com/trekhleb/homemade-machine-learning/tree/master/homemade/logistic_regression)\n", "> - ⚙️[Logistic Regression Source Code](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/logistic_regression/logistic_regression.py)\n", "\n", "**Logistic regression** is the appropriate regression analysis to conduct when the dependent variable is dichotomous (binary). Like all regression analyses, the logistic regression is a predictive analysis. Logistic regression is used to describe data and to explain the relationship between one dependent binary variable and one or more nominal, ordinal, interval or ratio-level independent variables.\n", "\n", "Logistic Regression is used when the dependent variable (target) is categorical.\n", "\n", "For example:\n", "\n", "- To predict whether an email is spam (`1`) or (`0`).\n", "- Whether online transaction is fraudulent (`1`) or not (`0`).\n", "- Whether the tumor is malignant (`1`) or not (`0`).\n", "\n", "> **Demo Project:** In this example we will train handwritten digits (0-9) classifier." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# To make debugging of logistic_regression module easier we enable imported modules autoreloading feature.\n", "# By doing this you may change the code of logistic_regression library and all these changes will be available here.\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Add project root folder to module loading paths.\n", "import sys\n", "sys.path.append('../..')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Dependencies\n", "\n", "- [pandas](https://pandas.pydata.org/) - library that we will use for loading and displaying the data in a table\n", "- [numpy](http://www.numpy.org/) - library that we will use for linear algebra operations\n", "- [matplotlib](https://matplotlib.org/) - library that we will use for plotting the data\n", "- [math](https://docs.python.org/3/library/math.html) - math library that we will use to calculate sqaure roots etc.\n", "- [logistic_regression](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/logistic_regression/logistic_regression.py) - custom implementation of logistic regression" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import 3rd party dependencies.\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import math\n", "\n", "# Import custom logistic regression implementation.\n", "from homemade.logistic_regression import LogisticRegression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the Data\n", "\n", "In this demo we will be using a sample of [MNIST dataset in a CSV format](https://www.kaggle.com/oddrationale/mnist-in-csv/home). Instead of using full dataset with 60000 training examples we will use cut dataset of just 10000 examples that we will also split into training and testing sets.\n", "\n", "Each row in the dataset consists of 785 values: the first value is the label (a number from 0 to 9) and the remaining 784 values (28x28 pixels image) are the pixel values (a number from 0 to 255)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | label | \n", "1x1 | \n", "1x2 | \n", "1x3 | \n", "1x4 | \n", "1x5 | \n", "1x6 | \n", "1x7 | \n", "1x8 | \n", "1x9 | \n", "... | \n", "28x19 | \n", "28x20 | \n", "28x21 | \n", "28x22 | \n", "28x23 | \n", "28x24 | \n", "28x25 | \n", "28x26 | \n", "28x27 | \n", "28x28 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
2 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
3 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4 | \n", "9 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 | \n", "2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
6 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
7 | \n", "3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
8 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
9 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
10 rows × 785 columns
\n", "\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "775 | \n", "776 | \n", "777 | \n", "778 | \n", "779 | \n", "780 | \n", "781 | \n", "782 | \n", "783 | \n", "784 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-9.772579 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.009417 | \n", "-0.007525 | \n", "-0.019028 | \n", "-0.015672 | \n", "-0.011255 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "-12.919264 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.007120 | \n", "-0.004079 | \n", "-0.000066 | \n", "-0.000032 | \n", "-0.000007 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "-7.262680 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "0.000385 | \n", "0.000191 | \n", "-0.000155 | \n", "0.000112 | \n", "0.000214 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "-7.540222 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.037224 | \n", "-0.019561 | \n", "-0.005837 | \n", "-0.001835 | \n", "0.000126 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "-9.932585 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.104504 | \n", "0.000216 | \n", "-0.001270 | \n", "-0.033175 | \n", "-0.032475 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
5 | \n", "-7.635457 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.020391 | \n", "-0.014374 | \n", "-0.005743 | \n", "-0.035149 | \n", "-0.032878 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
6 | \n", "-9.047295 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.003456 | \n", "0.000079 | \n", "-0.000917 | \n", "-0.022326 | \n", "-0.021854 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
7 | \n", "-10.491723 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "0.058497 | \n", "-0.042011 | \n", "0.086780 | \n", "0.161491 | \n", "0.156586 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
8 | \n", "-6.311099 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.005436 | \n", "-0.006730 | \n", "-0.001873 | \n", "-0.025549 | \n", "-0.038580 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
9 | \n", "-8.199128 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "-0.004838 | \n", "0.010311 | \n", "-0.007995 | \n", "-0.073160 | \n", "-0.056380 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
10 rows × 785 columns
\n", "