{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Chapter 3 – Classification**\n", "\n", "_This notebook contains all the sample code and solutions to the exercises in chapter 3._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", " \n", "
 \n", " Run in Google Colab\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Python ≥3.5 is required\n", "import sys\n", "assert sys.version_info >= (3, 5)\n", "\n", "# Scikit-Learn ≥0.20 is required\n", "import sklearn\n", "assert sklearn.__version__ >= \"0.20\"\n", "\n", "# Common imports\n", "import numpy as np\n", "import os\n", "\n", "# to make this notebook's output stable across runs\n", "np.random.seed(42)\n", "\n", "# To plot pretty figures\n", "%matplotlib inline\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "mpl.rc('axes', labelsize=14)\n", "mpl.rc('xtick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n", "\n", "# Where to save the figures\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"classification\"\n", "IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n", "os.makedirs(IMAGES_PATH, exist_ok=True)\n", "\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", " path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " print(\"Saving figure\", fig_id)\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Warning:** since Scikit-Learn 0.24, fetch_openml() returns a Pandas DataFrame by default. To avoid this and keep the same code as in the book, we use as_frame=False." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import fetch_openml\n", "mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n", "mnist.keys()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70000, 784)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = mnist[\"data\"], mnist[\"target\"]\n", "X.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70000,)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "784" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "28 * 28" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving figure some_digit_plot\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEYCAYAAACHjumMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAHcElEQVR4nO3dPUjWawPHcU17FcvaLJoDl14oHIJeoSZrjYaoyaByUSJwaAxqK9uiKWqRHFyKhBoiCIeiF8hBiGioRUyooQif5SF4eKTrOujvVo+fz+j94zr/zolvfzgXd82zs7NNAAmrFvsBgH8vgQFiBAaIERggRmCAmNbC5/4XE1Cjea4feoMBYgQGiBEYIEZggBiBAWIEBogRGCBGYIAYgQFiBAaIERggRmCAGIEBYgQGiBEYIEZggBiBAWIEBogRGCBGYIAYgQFiBAaIERggRmCAGIEBYgQGiBEYIEZggBiBAWIEBogRGCBGYIAYgQFiBAaIERggRmCAGIEBYgQGiBEYIEZggBiBAWIEBogRGCBGYIAYgQFiBAaIERggpnWxH4C8379/V+2+ffsWfpL/NzQ0VNz8+PGj6qyJiYni5vbt21VnDQwMFDcPHjyoOmvdunXFzZUrV6rOunr1atVuqfAGA8QIDBAjMECMwAAxAgPECAwQIzBAjMAAMS7aLaBPnz5V7X7+/FncvHjxouqs58+fFzfT09NVZw0PD1ftlqrt27cXN5cuXao6a2RkpLhpb2+vOmvnzp3FzcGDB6vOWm68wQAxAgPECAwQIzBAjMAAMQIDxAgMECMwQIzAADHNs7Ozf/v8rx+uJK9evSpujhw5UnXWYnw15XLW0tJStbt7925x09bWNt/H+WPr1q1Vu82bNxc3O3bsmO/jLLbmuX7oDQaIERggRmCAGIEBYgQGiBEYIEZggBiBAWJctKs0NTVV3HR3d1edNTk5Od/HWTS1v8aay2VNTU1NT58+LW7WrFlTdZYLjIvKRTugsQQGiBEYIEZggBiBAWIEBogRGCBGYIAYgQFiWhf7AZaLLVu2FDc3btyoOmt0dLS42b17d9VZfX19Vbsau3btKm7Gxsaqzqr9asp3794VNzdv3qw6i6XHGwwQIzBAjMAAMQIDxAgMECMwQIzAADECA8T4ysxFMDMzU9y0t7dXndXb21vc3Llzp+qse/fuFTenT5+uOosVx1dmAo0lMECMwAAxAgPECAwQIzBAjMAAMQIDxAgMEOMrMxfBxo0bF+ysTZs2LdhZNTd+T506VXXWqlX+7MIbDBAkMECMwAAxAgPECAwQIzBAjMAAMQIDxPjKzGXu+/fvxU1PT0/VWc+ePStuHj16VHXWsWPHqnb8a/jKTKCxBAaIERggRmCAGIEBYgQGiBEYIEZggBiBAWLc5F0BJicnq3Z79uwpbjo6OqrOOnz4cNVu7969xc2FCxeqzmpunvMyKY3hJi/QWAIDxAgMECMwQIzAADECA8QIDBAjMECMi3b8MTIyUtycO3eu6qyZmZn5Ps4f165dq9qdOXOmuOns7Jzv4zA3F+2AxhIYIEZggBiBAWIEBogRGCBGYIAYgQFiBAaIcZOXf+Tt27dVu/7+/qrd2NjYfB7nf5w/f764GRwcrDpr27Zt832clcZNXqCxBAaIERggRmCAGIEBYgQGiBEYIEZggBiBAWLc5CVienq6ajc6OlrcnD17tuqswu/lpqampqajR49WnfXkyZOqHX+4yQs0lsAAMQIDxAgMECMwQIzAADECA8QIDBDjoh1L3tq1a6t2v379Km5Wr15dddbjx4+Lm0OHDlWdtUK4aAc0lsAAMQIDxAgMECMwQIzAADECA8QIDBAjMEBM62I/AMvLmzdvqnbDw8NVu/Hx8eKm5oZura6urqrdgQMHFuyfuZJ5gwFiBAaIERggRmCAGIEBYgQGiBEYIEZggBgX7VaAiYmJqt2tW7eKm4cPH1ad9eXLl6rdQmptLf927uzsrDpr1Sp/9i4E/xaBGIEBYgQGiBEYIEZggBiBAWIEBogRGCBGYIAYN3mXqNqbsPfv3y9uhoaGqs76+PFj1a7R9u3bV7UbHBwsbk6cODHfx+Ef8AYDxAgMECMwQIzAADECA8QIDBAjMECMwAAxLtotoK9fv1bt3r9/X9xcvHix6qwPHz5U7Rqtu7u7anf58uXi5uTJk1Vn+ZrLpcd/ESBGYIAYgQFiBAaIERggRmCAGIEBYgQGiBEYIGbF3+Sdmpqq2vX29hY3r1+/rjprcnKyatdo+/fvL276+/urzjp+/HjVbv369VU7lidvMECMwAAxAgPECAwQIzBAjMAAMQIDxAgMELMsL9q9fPmyanf9+vXiZnx8vOqsz58/V+0abcOGDcVNX19f1Vk1f7dzW1tb1VnQ1OQNBggSGCBGYIAYgQFiBAaIERggRmCAGIEBYgQGiFmWN3lHRkYWdLdQurq6qnY9PT3FTUtLS9VZAwMDxU1HR0fVWbDQvMEAMQIDxAgMECMwQIzAADECA8QIDBAjMEBM8+zs7N8+/+uHAP/VPNcPvcEAMQIDxAgMECMwQIzAADECA8QIDBAjMECMwAAxAgPECAwQIzBAjMAAMQIDxAgMECMwQIzAADECA8QIDBAjMECMwAAxAgPECAwQIzBAjMAAMQIDxAgMECMwQExr4fM5/0JrgBreYIAYgQFiBAaIERggRmCAGIEBYv4D1/YD6c25+gcAAAAASUVORK5CYII=\n", "text/plain": [ "