{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "DJsUjs19_v63" }, "source": [ "# MNIST with SciKit-Learn and skorch\n", "\n", "This notebooks shows how to define and train a simple Neural-Network with PyTorch and use it via skorch with SciKit-Learn.\n", "\n", "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "metadata": { "id": "-zmIlvxI_v68" }, "source": [ "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb), we recommend you enable a free GPU by going:\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "8qYGNO2S_v6_" }, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import google.colab\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Gj0pvjxT_v7G" }, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "mPz6Bjqw_v7H" }, "source": [ "## Loading Data\n", "Using SciKit-Learns ```fetch_openml``` to load MNIST data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "mwpfASvc_v7J" }, "outputs": [], "source": [ "mnist = fetch_openml('mnist_784', as_frame=False, cache=False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "9Pt2JKyb_v7K", "outputId": "5a96aa80-e889-4553-c289-9534ed68d708", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(70000, 784)" ] }, "metadata": {}, "execution_count": 4 } ], "source": [ "mnist.data.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "sV0ehb52_v7L" }, "source": [ "## Preprocessing Data\n", "\n", "Each image of the MNIST dataset is encoded in a 784 dimensional vector, representing a 28 x 28 pixel image. Each pixel has a value between 0 and 255, corresponding to the grey-value of a pixel.
\n", "The above ```featch_mldata``` method to load MNIST returns ```data``` and ```target``` as ```uint8``` which we convert to ```float32``` and ```int64``` respectively." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "F2v_Fwne_v7M" }, "outputs": [], "source": [ "X = mnist.data.astype('float32')\n", "y = mnist.target.astype('int64')" ] }, { "cell_type": "markdown", "metadata": { "id": "C_yHJarZ_v7N" }, "source": [ "To avoid big weights that deal with the pixel values from between [0, 255], we scale `X` down. A commonly used range is [0, 1]." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "8RirCOTr_v7O" }, "outputs": [], "source": [ "X /= 255.0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "rohyp3d1_v7P", "outputId": "8f4d25e7-a175-4abb-a3e1-fed0fc4d3b83", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(0.0, 1.0)" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "X.min(), X.max()" ] }, { "cell_type": "markdown", "metadata": { "id": "tyUlsu0V_v7Q" }, "source": [ "Note: data is not normalized." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "gILlsHJS_v7R" }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "jGpA4v4u_v7R" }, "outputs": [], "source": [ "assert(X_train.shape[0] + X_test.shape[0] == mnist.data.shape[0])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "plXmcsp2_v7b", "outputId": "eb16e182-ac11-4a8e-b5da-c6395b73006e", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((52500, 784), (52500,))" ] }, "metadata": {}, "execution_count": 10 } ], "source": [ "X_train.shape, y_train.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "6EKEvbuP_v7c" }, "source": [ "### Print a selection of training images and their labels" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "C9muXJPC_v7d" }, "outputs": [], "source": [ "def plot_example(X, y):\n", " \"\"\"Plot the first 5 images and their labels in a row.\"\"\"\n", " for i, (img, y) in enumerate(zip(X[:5].reshape(5, 28, 28), y[:5])):\n", " plt.subplot(151 + i)\n", " plt.imshow(img)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.title(y)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "h2-R1-Df_v7e", "outputId": "619a14f4-7a23-4a09-a872-e646cf5c5900", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "plot_example(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "YQvC-rWf_v7f" }, "source": [ "## Build Neural Network with PyTorch\n", "Simple, fully connected neural network with one hidden layer. Input layer has 784 dimensions (28x28), hidden layer has 98 (= 784 / 8) and output layer 10 neurons, representing digits 0 - 9." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "5PG7R0W8_v7f" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "NUjWrGBP_v7g" }, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "gVCW8F3N_v7g" }, "outputs": [], "source": [ "mnist_dim = X.shape[1]\n", "hidden_dim = int(mnist_dim/8)\n", "output_dim = len(np.unique(mnist.target))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "ShYicHlv_v7h", "outputId": "95071ead-b292-4702-9093-4d6cc1f0f94a", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(784, 98, 10)" ] }, "metadata": {}, "execution_count": 16 } ], "source": [ "mnist_dim, hidden_dim, output_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "OeVnFhBS_v7i" }, "source": [ "A Neural network in PyTorch's framework." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "Xxtli0l__v7j" }, "outputs": [], "source": [ "class ClassifierModule(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim=mnist_dim,\n", " hidden_dim=hidden_dim,\n", " output_dim=output_dim,\n", " dropout=0.5,\n", " ):\n", " super(ClassifierModule, self).__init__()\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " self.hidden = nn.Linear(input_dim, hidden_dim)\n", " self.output = nn.Linear(hidden_dim, output_dim)\n", "\n", " def forward(self, X, **kwargs):\n", " X = F.relu(self.hidden(X))\n", " X = self.dropout(X)\n", " X = F.softmax(self.output(X), dim=-1)\n", " return X" ] }, { "cell_type": "markdown", "metadata": { "id": "LlEHSwjt_v7k" }, "source": [ "skorch allows to use PyTorch's networks in the SciKit-Learn setting:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "s0aDatqN_v7l" }, "outputs": [], "source": [ "from skorch import NeuralNetClassifier" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "dOrCbBjk_v7m" }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "\n", "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "8i_gnvPi_v7m", "outputId": "002bd91f-bc4f-4e24-af69-f2e2ef9a4771", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.8387\u001b[0m \u001b[32m0.8800\u001b[0m \u001b[35m0.4174\u001b[0m 3.8169\n", " 2 \u001b[36m0.4332\u001b[0m \u001b[32m0.9103\u001b[0m \u001b[35m0.3133\u001b[0m 0.8510\n", " 3 \u001b[36m0.3612\u001b[0m \u001b[32m0.9233\u001b[0m \u001b[35m0.2684\u001b[0m 0.8208\n", " 4 \u001b[36m0.3233\u001b[0m \u001b[32m0.9309\u001b[0m \u001b[35m0.2317\u001b[0m 0.8079\n", " 5 \u001b[36m0.2938\u001b[0m \u001b[32m0.9353\u001b[0m \u001b[35m0.2173\u001b[0m 0.8074\n", " 6 \u001b[36m0.2738\u001b[0m \u001b[32m0.9390\u001b[0m \u001b[35m0.2039\u001b[0m 0.8277\n", " 7 \u001b[36m0.2600\u001b[0m \u001b[32m0.9454\u001b[0m \u001b[35m0.1868\u001b[0m 0.8224\n", " 8 \u001b[36m0.2427\u001b[0m \u001b[32m0.9484\u001b[0m \u001b[35m0.1757\u001b[0m 0.8623\n", " 9 \u001b[36m0.2362\u001b[0m \u001b[32m0.9503\u001b[0m \u001b[35m0.1683\u001b[0m 0.8312\n", " 10 \u001b[36m0.2226\u001b[0m \u001b[32m0.9512\u001b[0m \u001b[35m0.1621\u001b[0m 0.8221\n", " 11 \u001b[36m0.2184\u001b[0m \u001b[32m0.9529\u001b[0m \u001b[35m0.1565\u001b[0m 0.8158\n", " 12 \u001b[36m0.2090\u001b[0m \u001b[32m0.9541\u001b[0m \u001b[35m0.1508\u001b[0m 0.7974\n", " 13 \u001b[36m0.2067\u001b[0m \u001b[32m0.9570\u001b[0m \u001b[35m0.1446\u001b[0m 0.8123\n", " 14 \u001b[36m0.1978\u001b[0m \u001b[32m0.9570\u001b[0m \u001b[35m0.1412\u001b[0m 0.8304\n", " 15 \u001b[36m0.1923\u001b[0m \u001b[32m0.9582\u001b[0m \u001b[35m0.1392\u001b[0m 0.8421\n", " 16 \u001b[36m0.1889\u001b[0m 0.9582 \u001b[35m0.1342\u001b[0m 0.8153\n", " 17 \u001b[36m0.1855\u001b[0m \u001b[32m0.9612\u001b[0m \u001b[35m0.1297\u001b[0m 0.8458\n", " 18 \u001b[36m0.1786\u001b[0m \u001b[32m0.9613\u001b[0m \u001b[35m0.1266\u001b[0m 0.8827\n", " 19 \u001b[36m0.1728\u001b[0m \u001b[32m0.9615\u001b[0m \u001b[35m0.1250\u001b[0m 0.8335\n", " 20 \u001b[36m0.1698\u001b[0m 0.9613 \u001b[35m0.1248\u001b[0m 0.8112\n" ] } ], "source": [ "net.fit(X_train, y_train);" ] }, { "cell_type": "markdown", "metadata": { "id": "5c3iyCKu_v7m" }, "source": [ "## Prediction" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "D7rdey0s_v7n" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "b9B8Zd6e_v7n" }, "outputs": [], "source": [ "y_pred = net.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "LCWachuM_v7o", "outputId": "c4785fc4-2ab1-4717-c024-c80f1aa34ef3", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.9631428571428572" ] }, "metadata": {}, "execution_count": 23 } ], "source": [ "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "0eRga6AV_v7o" }, "source": [ "An accuracy of about 96% for a network with only one hidden layer is not too bad.\n", "\n", "Let's take a look at some predictions that went wrong:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "XddXR1_a_v7p" }, "outputs": [], "source": [ "error_mask = y_pred != y_test" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "hXlNTrlt_v7q", "outputId": "11223953-1853-41ea-e43e-c384061351ca", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "plot_example(X_test[error_mask], y_pred[error_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "I2GsBaxH_v7r" }, "source": [ "# Convolutional Network\n", "PyTorch expects a 4 dimensional tensor as input for its 2D convolution layer. The dimensions represent:\n", "* Batch size\n", "* Number of channel\n", "* Height\n", "* Width\n", "\n", "As initial batch size the number of examples needs to be provided. MNIST data has only one channel. As stated above, each MNIST vector represents a 28x28 pixel image. Hence, the resulting shape for PyTorch tensor needs to be (x, 1, 28, 28). " ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "Bwsaz88X_v7r" }, "outputs": [], "source": [ "XCnn = X.reshape(-1, 1, 28, 28)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "Go0Yz8xl_v7s", "outputId": "e66d3f48-5f64-4a03-995e-b8b021efdf4f", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(70000, 1, 28, 28)" ] }, "metadata": {}, "execution_count": 27 } ], "source": [ "XCnn.shape" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "Rt8ETXGa_v7t" }, "outputs": [], "source": [ "XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "-jQm85il_v7t", "outputId": "bdfcf1d4-fd4b-4c5b-a887-edd7af625602", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((52500, 1, 28, 28), (52500,))" ] }, "metadata": {}, "execution_count": 29 } ], "source": [ "XCnn_train.shape, y_train.shape" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "YdQ-ISvb_v7u" }, "outputs": [], "source": [ "class Cnn(nn.Module):\n", " def __init__(self, dropout=0.5):\n", " super(Cnn, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n", " self.conv2_drop = nn.Dropout2d(p=dropout)\n", " self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height\n", " self.fc2 = nn.Linear(100, 10)\n", " self.fc1_drop = nn.Dropout(p=dropout)\n", "\n", " def forward(self, x):\n", " x = torch.relu(F.max_pool2d(self.conv1(x), 2))\n", " x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", " \n", " # flatten over channel, height and width = 1600\n", " x = x.view(-1, x.size(1) * x.size(2) * x.size(3))\n", " \n", " x = torch.relu(self.fc1_drop(self.fc1(x)))\n", " x = torch.softmax(self.fc2(x), dim=-1)\n", " return x" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "WJrWyFfb_v7u" }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "\n", "cnn = NeuralNetClassifier(\n", " Cnn,\n", " max_epochs=10,\n", " lr=0.002,\n", " optimizer=torch.optim.Adam,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "KRtHfgg8_v7u", "outputId": "071223dd-1d8a-4ad5-a748-7bac10d59ab3", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.4319\u001b[0m \u001b[32m0.9721\u001b[0m \u001b[35m0.0891\u001b[0m 5.8088\n", " 2 \u001b[36m0.1628\u001b[0m \u001b[32m0.9794\u001b[0m \u001b[35m0.0641\u001b[0m 2.1617\n", " 3 \u001b[36m0.1349\u001b[0m \u001b[32m0.9815\u001b[0m \u001b[35m0.0568\u001b[0m 1.8369\n", " 4 \u001b[36m0.1153\u001b[0m \u001b[32m0.9844\u001b[0m \u001b[35m0.0507\u001b[0m 1.4844\n", " 5 \u001b[36m0.1006\u001b[0m \u001b[32m0.9863\u001b[0m \u001b[35m0.0441\u001b[0m 1.4542\n", " 6 \u001b[36m0.0962\u001b[0m \u001b[32m0.9881\u001b[0m \u001b[35m0.0397\u001b[0m 1.4394\n", " 7 \u001b[36m0.0861\u001b[0m 0.9872 0.0423 1.4464\n", " 8 \u001b[36m0.0853\u001b[0m 0.9863 0.0410 1.4599\n", " 9 \u001b[36m0.0805\u001b[0m 0.9880 \u001b[35m0.0384\u001b[0m 1.4535\n", " 10 \u001b[36m0.0753\u001b[0m \u001b[32m0.9888\u001b[0m 0.0392 1.4857\n" ] } ], "source": [ "cnn.fit(XCnn_train, y_train);" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "p6V7wyzb_v7v" }, "outputs": [], "source": [ "y_pred_cnn = cnn.predict(XCnn_test)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "ZSyiZ3p6_v7v", "outputId": "124fdbe6-8747-4218-d1c1-a60908632eff", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.9883428571428572" ] }, "metadata": {}, "execution_count": 34 } ], "source": [ "accuracy_score(y_test, y_pred_cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "npZlem33_v7v" }, "source": [ "An accuracy of >98% should suffice for this example!\n", "\n", "Let's see how we fare on the examples that went wrong before:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "2oWL0xGC_v7w", "outputId": "449de735-c1a9-4301-d758-e07ac678d18f", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.7705426356589147" ] }, "metadata": {}, "execution_count": 35 } ], "source": [ "accuracy_score(y_test[error_mask], y_pred_cnn[error_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "8239U9fF_v7w" }, "source": [ "Over 70% of the previously misclassified images are now correctly identified." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "5vU1SXeV_v7x", "outputId": "332726a3-6c15-49c2-860d-b48eee5c85e2", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "plot_example(X_test[error_mask], y_pred_cnn[error_mask])" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "h-tIl3el_v7x" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13 (default, Mar 28 2022, 08:03:21) [MSC v.1916 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "bd97b8bffa4d3737e84826bc3d37be3046061822757ce35137ab82ad4c5a2016" } }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 0 }