{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Handwritten digits recognition (using Convolutional Neural Network)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> - 🤖 See [full list of Machine Learning Experiments](https://github.com/trekhleb/machine-learning-experiments) on **GitHub**
\n",
"> - ▶️ **Interactive Demo**: [try this model and other machine learning experiments in action](https://trekhleb.github.io/machine-learning-experiments/)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experiment overview"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this experiment we will build a [Convolutional Neural Network](https://en.wikipedia.org/wiki/Convolutional_neural_network) (CNN) model using [Tensorflow](https://www.tensorflow.org/) to recognize handwritten digits.\n",
"\n",
"A **convolutional neural network** (CNN, or ConvNet) is a Deep Learning algorithm which can take in an input image, assign importance (learnable weights and biases) to various aspects/objects in the image and be able to differentiate one from the other.\n",
"\n",
"![digits_recognition_cnn.png](../../demos/src/images/digits_recognition_cnn.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import dependencies\n",
"\n",
"- [tensorflow](https://www.tensorflow.org/) - for developing and training ML models.\n",
"- [matplotlib](https://matplotlib.org/) - for plotting the data.\n",
"- [seaborn](https://seaborn.pydata.org/index.html) - for plotting confusion matrix.\n",
"- [numpy](https://numpy.org/) - for linear algebra operations.\n",
"- [pandas](https://pandas.pydata.org/) - for displaying training/test data in a table.\n",
"- [math](https://docs.python.org/3/library/math.html) - for calculating square roots etc.\n",
"- [datetime](https://docs.python.org/3.8/library/datetime.html) - for generating a logs folder names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Selecting Tensorflow version v2 (the command is relevant for Colab only).\n",
"%tensorflow_version 2.x"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python version: 3.7.6\n",
"Tensorflow version: 2.1.0\n",
"Keras version: 2.2.4-tf\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sn\n",
"import numpy as np\n",
"import pandas as pd\n",
"import math\n",
"import datetime\n",
"import platform\n",
"\n",
"print('Python version:', platform.python_version())\n",
"print('Tensorflow version:', tf.__version__)\n",
"print('Keras version:', tf.keras.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuring Tensorboard\n",
"\n",
"We will use [Tensorboard](https://www.tensorflow.org/tensorboard) to debug the model later."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load the TensorBoard notebook extension.\n",
"# %reload_ext tensorboard\n",
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Clear any logs from previous runs.\n",
"!rm -rf ./.logs/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the data\n",
"\n",
"The **training** dataset consists of 60000 28x28px images of hand-written digits from `0` to `9`.\n",
"\n",
"The **test** dataset consists of 10000 28x28px images."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"mnist_dataset = tf.keras.datasets.mnist\n",
"(x_train, y_train), (x_test, y_test) = mnist_dataset.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train: (60000, 28, 28)\n",
"y_train: (60000,)\n",
"x_test: (10000, 28, 28)\n",
"y_test: (10000,)\n"
]
}
],
"source": [
"print('x_train:', x_train.shape)\n",
"print('y_train:', y_train.shape)\n",
"print('x_test:', x_test.shape)\n",
"print('y_test:', y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IMAGE_WIDTH: 28\n",
"IMAGE_HEIGHT: 28\n",
"IMAGE_CHANNELS: 1\n"
]
}
],
"source": [
"# Save image parameters to the constants that we will use later for data re-shaping and for model traning.\n",
"(_, IMAGE_WIDTH, IMAGE_HEIGHT) = x_train.shape\n",
"IMAGE_CHANNELS = 1\n",
"\n",
"print('IMAGE_WIDTH:', IMAGE_WIDTH);\n",
"print('IMAGE_HEIGHT:', IMAGE_HEIGHT);\n",
"print('IMAGE_CHANNELS:', IMAGE_CHANNELS);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explore the data\n",
"\n",
"Here is how each image in the dataset looks like. It is a 28x28 matrix of integers (from `0` to `255`). Each integer represents a color of a pixel."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "18 | \n", "19 | \n", "20 | \n", "21 | \n", "22 | \n", "23 | \n", "24 | \n", "25 | \n", "26 | \n", "27 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "175 | \n", "26 | \n", "166 | \n", "255 | \n", "247 | \n", "127 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
6 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "30 | \n", "36 | \n", "... | \n", "225 | \n", "172 | \n", "253 | \n", "242 | \n", "195 | \n", "64 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
7 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "49 | \n", "238 | \n", "253 | \n", "... | \n", "93 | \n", "82 | \n", "82 | \n", "56 | \n", "39 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
8 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "18 | \n", "219 | \n", "253 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
9 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "80 | \n", "156 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
10 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "14 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
11 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
12 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
13 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
14 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "25 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
15 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "150 | \n", "27 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
16 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "253 | \n", "187 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
17 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "253 | \n", "249 | \n", "64 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
18 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "253 | \n", "207 | \n", "2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
19 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "250 | \n", "182 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
20 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "78 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
21 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "23 | \n", "66 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
22 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "18 | \n", "171 | \n", "219 | \n", "253 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
23 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "55 | \n", "172 | \n", "226 | \n", "253 | \n", "253 | \n", "253 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
24 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "136 | \n", "253 | \n", "253 | \n", "253 | \n", "212 | \n", "135 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
25 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
26 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
27 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
28 rows × 28 columns
\n", "\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2.174955e-15 | \n", "2.053761e-08 | \n", "1.115043e-10 | \n", "1.815501e-11 | \n", "4.382045e-09 | \n", "1.084139e-11 | \n", "7.955132e-17 | \n", "1.000000e+00 | \n", "2.871653e-11 | \n", "5.481966e-11 | \n", "
1 | \n", "8.536813e-10 | \n", "1.098458e-06 | \n", "9.999989e-01 | \n", "2.133420e-10 | \n", "1.870166e-11 | \n", "3.222684e-15 | \n", "2.280592e-08 | \n", "2.037868e-10 | \n", "9.249878e-11 | \n", "2.877182e-15 | \n", "
2 | \n", "7.979921e-12 | \n", "9.999999e-01 | \n", "1.698135e-09 | \n", "1.035064e-13 | \n", "3.937335e-08 | \n", "1.830633e-08 | \n", "1.814277e-09 | \n", "3.928236e-08 | \n", "8.205532e-08 | \n", "1.508753e-11 | \n", "
3 | \n", "9.997569e-01 | \n", "1.507927e-10 | \n", "3.036238e-08 | \n", "7.496398e-11 | \n", "1.072911e-10 | \n", "3.716224e-07 | \n", "2.426345e-04 | \n", "4.577839e-09 | \n", "1.174887e-08 | \n", "9.708089e-08 | \n", "
4 | \n", "2.499753e-09 | \n", "1.776901e-11 | \n", "7.716882e-12 | \n", "1.750144e-14 | \n", "9.988386e-01 | \n", "1.832140e-09 | \n", "1.151763e-08 | \n", "1.740052e-10 | \n", "1.392591e-09 | \n", "1.161428e-03 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
9995 | \n", "1.009779e-14 | \n", "1.738570e-09 | \n", "9.999994e-01 | \n", "1.041259e-10 | \n", "5.272622e-16 | \n", "2.269359e-19 | \n", "2.778265e-16 | \n", "5.792643e-07 | \n", "4.054746e-12 | \n", "5.626435e-18 | \n", "
9996 | \n", "2.263675e-11 | \n", "6.903200e-08 | \n", "3.381161e-08 | \n", "9.999995e-01 | \n", "8.395995e-13 | \n", "3.699991e-07 | \n", "2.905099e-14 | \n", "5.455035e-10 | \n", "4.476401e-10 | \n", "1.262823e-09 | \n", "
9997 | \n", "9.689668e-24 | \n", "5.208208e-12 | \n", "3.422634e-17 | \n", "4.813990e-21 | \n", "1.000000e+00 | \n", "6.983922e-16 | \n", "6.201705e-17 | \n", "2.069550e-11 | \n", "2.440485e-11 | \n", "1.458596e-10 | \n", "
9998 | \n", "1.358465e-08 | \n", "9.598834e-11 | \n", "2.364747e-11 | \n", "5.815844e-09 | \n", "2.983105e-12 | \n", "9.995547e-01 | \n", "2.545367e-05 | \n", "1.387848e-11 | \n", "4.197330e-04 | \n", "1.053049e-08 | \n", "
9999 | \n", "9.229651e-12 | \n", "5.644470e-12 | \n", "4.068150e-12 | \n", "1.636682e-17 | \n", "2.020714e-12 | \n", "9.393594e-12 | \n", "1.000000e+00 | \n", "3.027714e-21 | \n", "7.427696e-13 | \n", "6.800262e-16 | \n", "
10000 rows × 10 columns
\n", "\n", " | 0 | \n", "
---|---|
0 | \n", "7 | \n", "
1 | \n", "2 | \n", "
2 | \n", "1 | \n", "
3 | \n", "0 | \n", "
4 | \n", "4 | \n", "
... | \n", "... | \n", "
9995 | \n", "2 | \n", "
9996 | \n", "3 | \n", "
9997 | \n", "4 | \n", "
9998 | \n", "5 | \n", "
9999 | \n", "6 | \n", "
10000 rows × 1 columns
\n", "