{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multivariate Logistic Regression Demo\n", "\n", "_Source: 🤖[Homemade Machine Learning](https://github.com/trekhleb/homemade-machine-learning) repository_\n", "\n", "> ☝Before moving on with this demo you might want to take a look at:\n", "> - 📗[Math behind the Logistic Regression](https://github.com/trekhleb/homemade-machine-learning/tree/master/homemade/logistic_regression)\n", "> - ⚙️[Logistic Regression Source Code](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/logistic_regression/logistic_regression.py)\n", "\n", "**Logistic regression** is the appropriate regression analysis to conduct when the dependent variable is dichotomous (binary). Like all regression analyses, the logistic regression is a predictive analysis. Logistic regression is used to describe data and to explain the relationship between one dependent binary variable and one or more nominal, ordinal, interval or ratio-level independent variables.\n", "\n", "Logistic Regression is used when the dependent variable (target) is categorical.\n", "\n", "For example:\n", "\n", "- To predict whether an email is spam (`1`) or (`0`).\n", "- Whether online transaction is fraudulent (`1`) or not (`0`).\n", "- Whether the tumor is malignant (`1`) or not (`0`).\n", "\n", "> **Demo Project:** In this example we will train clothes classifier that will recognize clothes types (10 categories) from `28x28` pixel images." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# To make debugging of logistic_regression module easier we enable imported modules autoreloading feature.\n", "# By doing this you may change the code of logistic_regression library and all these changes will be available here.\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Add project root folder to module loading paths.\n", "import sys\n", "sys.path.append('../..')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Dependencies\n", "\n", "- [pandas](https://pandas.pydata.org/) - library that we will use for loading and displaying the data in a table\n", "- [numpy](http://www.numpy.org/) - library that we will use for linear algebra operations\n", "- [matplotlib](https://matplotlib.org/) - library that we will use for plotting the data\n", "- [math](https://docs.python.org/3/library/math.html) - math library that we will use to calculate sqaure roots etc.\n", "- [logistic_regression](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/logistic_regression/logistic_regression.py) - custom implementation of logistic regression" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import 3rd party dependencies.\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import math\n", "\n", "# Import custom logistic regression implementation.\n", "from homemade.logistic_regression import LogisticRegression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the Data\n", "\n", "In this demo we will use a sample of [Fashion MNIST dataset in a CSV format](https://www.kaggle.com/zalando-research/fashionmnist).\n", "\n", "Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set. Each example is a 28x28 grayscale image, associated with a label from 10 classes. Zalando intends Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.\n", "\n", "Instead of using full dataset with 60000 training examples we will use cut dataset of just 5000 examples that we will also split into training and testing sets.\n", "\n", "Each row in the dataset consists of 785 values: the first value is the label (a category from 0 to 9) and the remaining 784 values (28x28 pixels image) are the pixel values (a number from 0 to 255).\n", "\n", "Each training and test example is assigned to one of the following labels:\n", "\n", "- 0 T-shirt/top\n", "- 1 Trouser\n", "- 2 Pullover\n", "- 3 Dress\n", "- 4 Coat\n", "- 5 Sandal\n", "- 6 Shirt\n", "- 7 Sneaker\n", "- 8 Bag\n", "- 9 Ankle boot" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | label | \n", "pixel1 | \n", "pixel2 | \n", "pixel3 | \n", "pixel4 | \n", "pixel5 | \n", "pixel6 | \n", "pixel7 | \n", "pixel8 | \n", "pixel9 | \n", "... | \n", "pixel775 | \n", "pixel776 | \n", "pixel777 | \n", "pixel778 | \n", "pixel779 | \n", "pixel780 | \n", "pixel781 | \n", "pixel782 | \n", "pixel783 | \n", "pixel784 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "9 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
2 | \n", "6 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "5 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "30 | \n", "43 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4 | \n", "3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "5 | \n", "4 | \n", "5 | \n", "5 | \n", "3 | \n", "5 | \n", "... | \n", "7 | \n", "8 | \n", "7 | \n", "4 | \n", "3 | \n", "7 | \n", "5 | \n", "0 | \n", "0 | \n", "0 | \n", "
6 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "14 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
7 | \n", "5 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
8 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "3 | \n", "2 | \n", "0 | \n", "... | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
9 | \n", "8 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "203 | \n", "214 | \n", "166 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
10 rows × 785 columns
\n", "\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "775 | \n", "776 | \n", "777 | \n", "778 | \n", "779 | \n", "780 | \n", "781 | \n", "782 | \n", "783 | \n", "784 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-6.188828 | \n", "-0.098811 | \n", "-0.087328 | \n", "-0.015862 | \n", "0.024045 | \n", "0.121492 | \n", "0.058710 | \n", "-0.042991 | \n", "0.026648 | \n", "-0.019745 | \n", "... | \n", "-0.051891 | \n", "-0.039698 | \n", "0.264672 | \n", "0.098661 | \n", "0.124377 | \n", "0.036166 | \n", "-0.053816 | \n", "-0.046484 | \n", "-0.026023 | \n", "-0.013350 | \n", "
1 | \n", "-6.474232 | \n", "-0.002895 | \n", "-0.002513 | \n", "-0.009509 | \n", "-0.011743 | \n", "0.117668 | \n", "-0.048227 | \n", "0.062439 | \n", "0.022960 | \n", "0.149320 | \n", "... | \n", "0.153883 | \n", "0.079365 | \n", "-0.033011 | \n", "-0.019675 | \n", "0.017837 | \n", "-0.025347 | \n", "-0.026005 | \n", "-0.003478 | \n", "-0.002177 | \n", "-0.009066 | \n", "
2 | \n", "-5.334694 | \n", "-0.003845 | \n", "-0.037747 | \n", "-0.110560 | \n", "-0.115551 | \n", "-0.028090 | \n", "-0.106846 | \n", "-0.041750 | \n", "-0.220657 | \n", "-0.049796 | \n", "... | \n", "-0.002940 | \n", "0.159790 | \n", "0.163778 | \n", "0.086334 | \n", "0.027404 | \n", "-0.098383 | \n", "-0.014745 | \n", "-0.087777 | \n", "0.167204 | \n", "0.146124 | \n", "
3 | \n", "-6.609060 | \n", "-0.010289 | \n", "-0.009757 | \n", "-0.029398 | \n", "-0.002841 | \n", "-0.013853 | \n", "0.065444 | \n", "0.152862 | \n", "0.039919 | \n", "-0.109553 | \n", "... | \n", "0.070574 | \n", "-0.061689 | \n", "0.028969 | \n", "-0.204451 | \n", "-0.200882 | \n", "-0.112483 | \n", "-0.078020 | \n", "-0.024253 | \n", "-0.008831 | \n", "-0.001139 | \n", "
4 | \n", "-5.965323 | \n", "0.000036 | \n", "-0.007320 | \n", "0.010527 | \n", "0.044382 | \n", "-0.019542 | \n", "-0.009529 | \n", "-0.023622 | \n", "0.027513 | \n", "-0.055903 | \n", "... | \n", "0.013763 | \n", "0.171588 | \n", "-0.147751 | \n", "-0.045985 | \n", "0.086505 | \n", "0.058849 | \n", "0.084423 | \n", "-0.070254 | \n", "-0.055217 | \n", "-0.002108 | \n", "
5 | \n", "-11.772900 | \n", "-0.000465 | \n", "-0.001262 | \n", "-0.037732 | \n", "-0.001665 | \n", "-0.004291 | \n", "-0.001232 | \n", "-0.010755 | \n", "-0.021838 | \n", "-0.059554 | \n", "... | \n", "-0.061065 | \n", "-0.070279 | \n", "-0.038202 | \n", "-0.073953 | \n", "-0.083174 | \n", "-0.073595 | \n", "-0.013555 | \n", "-0.012098 | \n", "-0.091695 | \n", "-0.005742 | \n", "
6 | \n", "-4.922826 | \n", "0.142705 | \n", "-0.033088 | \n", "-0.002464 | \n", "-0.113538 | \n", "-0.093484 | \n", "-0.035009 | \n", "-0.045016 | \n", "0.047347 | \n", "0.026366 | \n", "... | \n", "-0.096220 | \n", "0.025678 | \n", "-0.108178 | \n", "-0.043293 | \n", "-0.179027 | \n", "-0.050755 | \n", "0.083261 | \n", "0.115971 | \n", "-0.025531 | \n", "-0.170215 | \n", "
7 | \n", "-9.278711 | \n", "-0.000367 | \n", "-0.000285 | \n", "-0.000301 | \n", "-0.000410 | \n", "-0.000456 | \n", "-0.000550 | \n", "-0.001178 | \n", "-0.003780 | \n", "-0.007860 | \n", "... | \n", "-0.050218 | \n", "-0.051179 | \n", "-0.013316 | \n", "-0.013148 | \n", "-0.005308 | \n", "-0.005978 | \n", "-0.005408 | \n", "-0.006201 | \n", "-0.003152 | \n", "-0.000013 | \n", "
8 | \n", "-5.977974 | \n", "-0.000269 | \n", "-0.040986 | \n", "-0.061915 | \n", "-0.019404 | \n", "-0.065937 | \n", "0.122967 | \n", "0.033192 | \n", "0.107072 | \n", "0.096812 | \n", "... | \n", "0.065897 | \n", "0.037445 | \n", "-0.030971 | \n", "0.014127 | \n", "-0.048779 | \n", "-0.025388 | \n", "-0.033395 | \n", "-0.028944 | \n", "-0.036300 | \n", "0.002200 | \n", "
9 | \n", "-7.507192 | \n", "-0.000174 | \n", "-0.000325 | \n", "-0.001354 | \n", "-0.000743 | \n", "-0.001619 | \n", "-0.001899 | \n", "-0.005294 | \n", "-0.011630 | \n", "-0.013474 | \n", "... | \n", "-0.033911 | \n", "-0.029371 | \n", "-0.015877 | \n", "-0.003029 | \n", "-0.004274 | \n", "0.012986 | \n", "-0.001100 | \n", "0.016288 | \n", "0.117543 | \n", "-0.000238 | \n", "
10 rows × 785 columns
\n", "