{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "DJsUjs19_v63" }, "source": [ "# MNIST with SciKit-Learn and skorch\n", "\n", "This notebooks shows how to define and train a simple Neural-Network with PyTorch and use it via skorch with SciKit-Learn.\n", "\n", "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "metadata": { "id": "-zmIlvxI_v68" }, "source": [ "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb), we recommend you enable a free GPU by going:\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "8qYGNO2S_v6_" }, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import google.colab\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Gj0pvjxT_v7G" }, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "mPz6Bjqw_v7H" }, "source": [ "## Loading Data\n", "Using SciKit-Learns ```fetch_openml``` to load MNIST data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "mwpfASvc_v7J" }, "outputs": [], "source": [ "mnist = fetch_openml('mnist_784', as_frame=False, cache=False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "9Pt2JKyb_v7K", "outputId": "5a96aa80-e889-4553-c289-9534ed68d708", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(70000, 784)" ] }, "metadata": {}, "execution_count": 4 } ], "source": [ "mnist.data.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "sV0ehb52_v7L" }, "source": [ "## Preprocessing Data\n", "\n", "Each image of the MNIST dataset is encoded in a 784 dimensional vector, representing a 28 x 28 pixel image. Each pixel has a value between 0 and 255, corresponding to the grey-value of a pixel.
\n", "The above ```featch_mldata``` method to load MNIST returns ```data``` and ```target``` as ```uint8``` which we convert to ```float32``` and ```int64``` respectively." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "F2v_Fwne_v7M" }, "outputs": [], "source": [ "X = mnist.data.astype('float32')\n", "y = mnist.target.astype('int64')" ] }, { "cell_type": "markdown", "metadata": { "id": "C_yHJarZ_v7N" }, "source": [ "To avoid big weights that deal with the pixel values from between [0, 255], we scale `X` down. A commonly used range is [0, 1]." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "8RirCOTr_v7O" }, "outputs": [], "source": [ "X /= 255.0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "rohyp3d1_v7P", "outputId": "8f4d25e7-a175-4abb-a3e1-fed0fc4d3b83", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(0.0, 1.0)" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "X.min(), X.max()" ] }, { "cell_type": "markdown", "metadata": { "id": "tyUlsu0V_v7Q" }, "source": [ "Note: data is not normalized." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "gILlsHJS_v7R" }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "jGpA4v4u_v7R" }, "outputs": [], "source": [ "assert(X_train.shape[0] + X_test.shape[0] == mnist.data.shape[0])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "plXmcsp2_v7b", "outputId": "eb16e182-ac11-4a8e-b5da-c6395b73006e", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((52500, 784), (52500,))" ] }, "metadata": {}, "execution_count": 10 } ], "source": [ "X_train.shape, y_train.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "6EKEvbuP_v7c" }, "source": [ "### Print a selection of training images and their labels" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "C9muXJPC_v7d" }, "outputs": [], "source": [ "def plot_example(X, y):\n", " \"\"\"Plot the first 5 images and their labels in a row.\"\"\"\n", " for i, (img, y) in enumerate(zip(X[:5].reshape(5, 28, 28), y[:5])):\n", " plt.subplot(151 + i)\n", " plt.imshow(img)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.title(y)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "h2-R1-Df_v7e", "outputId": "619a14f4-7a23-4a09-a872-e646cf5c5900", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAABbCAYAAABNq1+WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO29eZRc133f+blvq/dq33rvru7GvhIgwH0TJcqWSS22ZNmW7FiZWF4mtnNmPD4eZxyP7WQ8x5M5iU8Sa5xEsWRrbMcZSZYjUwtlUZRIiaQgAsRG7ECjN/TeXfv2tjt/VGMjAAokm6zu5vucgwOgq+rVfbfv+957f/e3CCklAQEBAQFvP0q7GxAQEBDwTiUQ4ICAgIA2EQhwQEBAQJsIBDggICCgTQQCHBAQENAmAgEOCAgIaBOBAAcEBAS0iVUrwEKIvxJCTAshSkKIs0KIX2x3m9qJECIkhPiMEGJMCFEWQhwRQjze7na1GyHEkBDia0KIvBBiRgjxKSGE1u52tYtgnNya1agpq1aAgT8ChqSUceBDwB8KIfa3uU3tRAMmgHcBCeB3gc8LIYba2KbVwJ8Cc0APsJdW//xqW1vUXoJxcmtWnaasWgGWUp6QUjYv/3f5z8Y2NqmtSCmrUso/kFKOSil9KeVXgIvAO3lSAhgGPi+lbEgpZ4CngJ1tblPbCMbJrVmNmrJqBRhACPGnQogacBqYBr7W5iatGoQQXcAW4ES729Jm/h3wMSFEWAjRBzxOS4QDCMbJq1ltmrKqBVhK+atADHgY+BLQfO1PvDMQQujAXwOfk1Kebnd72sxztFa8JWASOAj897a2aJUQjJMbWW2asqoFGEBK6Ukpvwf0A/+03e1pN0IIBfhLwAZ+vc3NaSvLffEUrQcpAmSBFPCv29mu1UAwTm7NatKUVS/A16DxDrYBAwghBPAZoAv4SSml0+YmtZs0kAM+JaVsSikXgT8Hnmhvs9pLME5um7ZryqoUYCFEpxDiY0KIqBBCFUK8D/g48K12t63N/EdgO/BBKWW93Y1pN1LKBVoHTP9UCKEJIZLAPwaOtbdlbScYJ69itWqKWI35gIUQHcAXgT20Jokx4D9IKf9LWxvWRoQQg8AoLZuVe81LvyKl/Ou2NGoVIITYS+sgbg/gAc8A/0xKOdvWhrWJYJzcnNWqKatSgAMCAgLeCaxKE0RAQEDAO4FAgAMCAgLaRCDAAQEBAW0iEOCAgICANhEIcEBAQECbeF1p+wwRkiaRt6otq4IGVWzZFLf7/ndCnwCUyS9IKTtu571Bn9ycd0K/BM/PzbnVWHldAmwS4V7x2Mq1ahVyQL4+v+x3Qp8APC2/OHa77w365Oa8E/oleH5uzq3GSmCCCAgICGgTgQAHBAQEtIl3bOmWNYMQCE0HRSAMg1aelWWU1r/9ah3peeB7bWpkQMAqQyw/L6qKCIVaz4qigvTBdpCeh7QdkD7SdX/49d4iAgFerSwLr5JO4mzupZk1WNil4YVaoeNSBSfpodgKA9/wCI/kkZPT+NVqmxseENB+tN4eynf1U+lRKdzfxIo26U2WKDZMCsezhJYEmeMOocUG6ukxvFKpPe1sy7feDCFAKAhVba32NA0UBXwfpGz9DVzJXeF5SF+uv1Xf5Zlb0xDhMKQTVAZMal0K7q4KEcvGkwJd9diemaPkmFw6OYxei6Ev5iEQ4JujqIjL40pVW2PrWhwHf3lFxHrKj7K8YxKajlCVG+/98nPVbLZ1JfimEKK10tWuypmfilMa1CgP+/zinc+zy5rgXdYiky78ovKPmJ1NYhQM/JBFbDqOYts3XFZKCb58S3eXq0KAlUgE0duFm42xuDNMIyto3FFD0zyaBRNhK2gVBcUGoyhQ65A+08RYrCMuzeOXSkjHXRdirOzcysyjaeod4G2uEY002J49Q9KosyU8gy8VztS6cH2VrlAJLLA/pjJZSNLx6Q1YLzj4tdrafZjeDEKgxmLLIiPAl/jlMtKXKHdspdEVZuZ+A3tjHTNsEzFbD52UgvLLGfqfaWJMl/DOnG/zjawMQtNQu7vwkzHm709Rzwqau2uk4rUr78mXwnhVnf6vKcS+N4Isl/EbjTa2+vWhJhOIdIrKjk7m9mlXTrUa3S7vvfMom8JzPB47TlJx0THoUh1+e9M3GBno5POd+5gpRph+YACtlrv+whKsWYlZkMTPlVFGppD1+or3TXsF+PKqNxrB7YxT7TPJ75RovVX+7K6/IqnUeaa6nVknzqlSN4WGxaX5JH5ZR7UNoiGFcLWB4tj4NJAOa34F42TDFHa5ZPsL/IutX6NXy7NZc/CQzHuCJd+k4etUvBAhxSWsNvmtwaco9Ef4190/RzhsIWz7HSnAQlURkTAYemtsSYlwXYTjUO+JUBrUSN47y+9u/iq7jQVyWhQAT/o8ZnyE0oUeEm4U5axY02MIaK0KQyH8TJxGd4SlXZLwQIk/2/M3PGy2xoaP5LmGwclGP58+837iRyPg2LCGBFiYJl4qQmlII3bvPKrSWtHfkZni/+59hrAwAB0fFYCYMHg8nKdszpIYrjFuZzjQNUSxaV53XSkFcxczhGZV9GqE6EK0pS3rSYCVnVuZfk+aeqdE3VYmai2yP7FENlRlxk2wKKLowqPfyLO/5yIqkqn+FGXP5OVdAyzUo5wa7UBb6iI6LgjPe8RPFZHnLq7ZFbFWsbEmTUoLGX7z7D8CQEgQnkBpgtoUhGclwgUvBHZCcPYjR7kvfoHKgMC6e5DokSn8ick238nbyOWVb1eWiz/dTaPLQ5o+KBKl3IfiCESuSkdykZ/PHWC3sUBCUfGkf+US/+Pgs/y//8P9XHh+kE2XcshyBW9hsY039QZRVNRUgubeYao9OnPvdchkS3ys7zQbzTk26SUgfOXtW/UiGaXGf9j3biZkL90vJhAvHls7E1DIwI2FqPZKfmfjs+ii9cz3aXlMcWt5M4XKA9YIu80J7gyPUfVDN7znfF8X83aMb+7cijfeT/cBn/jTp5GN5oqthNsqwM3eKJV762zqmeN3hr6Kis+Mm6QhdRa9KL5s7Sdiap2HzUtkVQuf1kOhZE7iSI/P9G3m5XKO75zcSmPcwChFscZD+IBsrj0BVmo25qJEsSG86KM0JXrFRbg+as1G1G38sUmkbbdEp7+bVx7t5o7oJM2MT2lAI3LeavdtvK0IVUWELZxsDP3uPO/PnWazNUtEaXKq3kvFC7E3Mk6vlmebkadHbfWPz1WR+VBklvdt+iKPFH4RLxNDlRIWl9aOEMFVW2gsSmGTQXkIfu++J3nYGqFHNdCFCljX3XeXapFWPPYPjvMDZ4jaqElUVVt2z7Vw76qKZ6o4SY8PR8duEF0fH+8m96GgsEFvrYp36UtXL3eNl5ESnQPgr2MXeXpgB4eXdpB4MQK+XLGVcFsF2LUUerOLND2N3zjx0xSKEUKnLVQbFBsujxNfhz/s8fEsHxF1CVkOv7D9Be4NX2CHOcmG0Cy5/Utc3J7hu/1biOzdReehJsZ3X1l7K+G5RTpfMhCuj6g1Ea4HtgNSttxmXLd1T1Li1xuo5RozEz18SduLWrvtCNB1hZJKMff4Bio5wYcHD/Oj8VdIqzV0fIb0eRpSJ6PUCAmPI81OvuakOFzJMVZJ88HuY/xM7DS6UAgJjbt6JvjBj+4meT5KfG4Bv9FcM+NH6+1h6ZEclX6F+GMz7EvNcrc5RlZVrxOWV6MLlZ/t/D7bozN8/vSjJPp7kYUiXqH4Nrb+zaE0FA404ow6HTyX38J8PcrYYgq7oSOWDMRNfoVSAAJ8ywfdZ3Bggf5ogR3RaXqNPA9YFxnWTO42x0h3VHihZyteTxpVUWCFvCbaKsBeSGFjfJELxSyVIxmSE9D99xfwS2X8ev3KDCx0A2XzEF7cpNpn0UiF+GpiN4mBOo+GzzGsmfyYdRyAP45d4qu53eSrvXQfMIC1tRL2FhZhYZHbWXtIx0bWahizGmOhDGb9nSnAIhpmaY8kPpznJ5KHuNNQAH35VQnYOFLQlPCVSg8vLm3g6Hg/yiWTb9zn8UTkFDHFJ6wY3Bsf4Qf7cxRkgkQohHBc5BoRYD8TZ/4uMIeL/Pn2v2SjZgFXt9b+LUaVguD94QqPh4/zub6H8bJxVMeFNSTAqi042ezj+4UNHDi5EW1RI3kWzIJP7EwenFuciagKTjaKE9OYvr+Hse4ORgfSbEgsMKAvMqzZbNENNulF1GyTZtbCqq5cJfu2CnD8TJEjX9iFVpP0jLmE8k38cgVp29dtf6TnwXwerWQQL0eJRAzmQr38cedP8EcdLiLq8sCmEd6dOk1CrfOx/pf44wffy5S6m8yJJsZ3jq6dLdXrxfOw5gV+KES9x6XRLcgei7wjQhzVZILmvk3kcwadW+e4t3OMjNIErppgpr06S57OZxcf4ni+l4mjPcRGFRKeRPhwcUua8qCGKVsP6ICxyPbOWY6k46BpCF1DOje6KK0ahEDdNMzcu7qo5GDvPefYm5gkfa2nGZJpr868Z/C5xQcZqWQZy6dwHJXf2v1NPhq7iCk0FBSIO1RzEWJ1By6177ZuF5kvYo2o9Kpp/kvxCfQq9Mz6GBUXc66BUrNhbunWuxhFxag30Q2dXj9NI6UyfU8n+ZzFu5IdPGpOXX7jW9L+tgqwf/wsvae1lq/d8iD3b/pGD29+vvXv5bOl7hMRhGFAbydu0uLFj26jcKfFJ3pf4MORJdjzNF/o2M+c0U//91q3uR49A6TjEp7z8QwFb0ONvlQRO9WD+cM/uuYRyQSzd4eoDXj8+sBh7g2fJ6uqV1738ZlyLS44nXzlxG6sMyYbny7DD46jbtmI3Zfgwt1hyr5BTCwLsFbgofR5DqWGEIaObKi3+vr2s2zzbQyn8T6Y58GuSf6w9ynSagjlmhHgSI8p1+JEs48nj+4hNGmQOuOj13z+7n++kyci59EVFUVANF6n0mtizVqshf2Ul89DPk/o7AX6v3r9a5JWldYfyrK0GGcvENINnOh+8iLG5KY0qpi57rB2pWmvG5r0rwZUvN6P2g54HspiAb3epPMH3VycG+Y/vscgt/FLpLUKH+w5xv/T0YeIxWCd+sYKXaPcr1Ad9FBtjfGFFH31tbFlfqMoponS00Vtayf+3SX2dc2wzxqlW62h0zI71aTNkufxn+ce59h8D7EjJqkzDtpsAVcICvs6mL0H7txygV61RkQRONLjtN3F12d2oc/pyFq9tRtbpWj9fVTv6GF+j84Hcqe5IzxBWFFbK1nAxaPo25xzLH712M9RmomROagSnncJT9XB9xlZyPBCby/7QlP0axZJq8F8BtyIdsWIs94RuoEwdMRQP042zNJeny3bJ7krMoInfaa9Gguejls0CC3UEJX6in13mwVYvmFRlI6NdMCfaZ1GJi5NkzQMLiR28UzXDh6KnuEnI3k+1deARBQh/fUZJRYKUR326N04z9RMCr9gotVWboCsRkQ0Qn1jlqXtOv9u71/wgFkmJHQumx58fIq+x4QX5TsntxI5azDwTB7/6KlWnXYhmN8n+P0nvsBec5J+zcKRHg4ex2o5zl7oIXlJ4Feqq9r8YA9mmXhMJb1lgd/KPk9CMWF5AoLWynfC1Xm2sh39ySRbTlRbYbeFIhJQwmEaU7v59uB2ujMF+jVJd6TERFcHdkx9ZwiwEAgzhBKNkN+dojyg8NhdR/nfe75BQlHxMRhzwxxpDGIsqKjTS8hyZcW+flVEwq0EQtMQIQMJOFJddmHzEUKCoV8Xpriu8CWiKSg3QkhHQfHgtk7w1iBKLIYY6KHRG2P2nhC1IYdOtbLsXtUS3nmvyaxn8M9HPs7F2QzJQwaxSRdRa6LEYrh7NlLtNzE2lhgyFpZNDzrnXZ+XG4M8ObqL9EGNxKjdOjdYxTgxHXOwzPbMDLpQUJaNBpdXvhOuzn+aezcHZwZITLtoc6WWV8ctUBAoYnnwrAX7w5tBUdF6upDxCEv7M9SzCqXtDpGOMvfFLxBe9hqpSZsv5R/kG2PbiU6ArFZXdFe0blRJmCFENNJKUiNVPAQ+EkX18S0dNWT88IusRXwPrSqoVFoh28IViPV42Ago2TTzd2coDwq2PnaBPclJ+jX3ir3TkR4jbpQXq5vJ/00/mw8WEWNjeMUSZNIo2TTj77PoumeGXx04yDa9SnjZb/Sl+hB/PvYA7oEUPZ87dMNB8GqkkVb5+OZD3Bkevc7/1ZEeo67BgdomvvXSLiLjKpGTk7ij47e8lorkiuqud/EFFDOEvaGLSn8I8XPz/Gz/MT4QO8agJpYndIOa71CWPk+e3U38W2Eyxyor7pq3ZgX4clIVJR5HmCEqe3sp96lYm4rssiZpSJ3nGxI7b6KUF5GvMfOvRYSmoQ704XQnsfscBjvzTM50Y84LlJp988PMNYrQDZSIhZ1Ls7hHovbUuD89wqbQLLq4ejrdkB5H64McKuYwyrIVtLKpH9/UKOZCNFMK/sYamxPz9Ol5QkJBQcHH50Kjk0uX0qQX5RU/61WPAAW5LJ5XV77H7Tj/x4UPMrmQJHlCITLrISuvbX7z1rnqKqaJGOjFj1uUN0SxowrlYbBTPh/uvMgua4K04qGLltueJyV/XtzFc4ubUU9HSFxsoi5Vbu9Q73WwdgXYslpx4Bt7aGRNxj/i8dE9L/FjiWM8ZDb4/8o9/OXiHVgTGszM4zfXlwAr4TDFfd2UcyqP736ZR+Jn+Fff/TiZUw7KUnldCbASsaCvm4WdFr/2o0+xy5zgXrO07Dp1dWdT9iVfn9vJmUtdDM07iKbN9IOdVIZ8tu4Z56HMBbr0Ikm1xg5jlrAwW7Zf6fHS0iCJIwbxMbt1OLwWkOAj8BB4UlKTDiftGP91/n6af97NhrE62smT+JUq3q3OWqTAv0Z8fSlaJqw1MP+8HpRkgoX7uqj0C3Z/4DT7E2PcbV0krdboVT3CQkcVV32mHTw+dehRMt8JMXikiDx6+i3xhlhbAqyoaAO9yIhFbShBM6FSGVBopiTbhybZHxll3o3zd5U4nx1/iPHTXXSOLidcXuX2vNdNKERpSKUy6GEoLjNuAqMA5lwDWV9nh3ChEG7CxInBhtAcA1oRU+hXTvsvYwrYmZjG8xXGH85h7uindIdNR1eRe9Oj7AmPERE2pnDwpGDWq3Ow2c3B6jDnL3QzMOESmquvOe2Zd+M815BcsDv5wuR+xsazbJpqos2XW8n6bya+itpys9N9YloDXXj4KDQ8HbWuoLhrZBK6FZfzacejyP4uqn1RFvdK6KlzX3KEHaFLDGglYorAFCo+PhcdyaJv8cWluzlT7CJ80iQ+3kTNl3HXczrK20WxTPL39VHpVxAP57mvd5QPpQ+z21ggthxK+htT7+Lps9tIP2Oy7ctnkPUGfq32wy++xhCxCDyc5ydzp2n6Gs8tbiZ53oHDp/DW2WQjwhbVPpNGh88eY4ZeLXSD+AJkVYvf7XwBv0NS2OTjA2EBumg9ZArKlZDcM47Cy80s//LkB7APpMkdc7G+eXTtmB+u4VS9l7+Yf4Cx8SyDfyvYOlNDnLmI9xph1IplIiIRRLiV7CoiXMBgsR7GWFLQqs7bexMrjNB0lEQMb0Mv40/EaA41+ezDn2GrXiKmaOhCRVn2mqlJm7Lv8Xelu3i5MMDEX2yi47kZcoun8CtVXPet64vVKcCK2hoghgHZFDJkYHdFsOMai3cI7E6H9/aMc0/sIt1qCQV4pt7N+WYXz5zfSui0RWzSxssX18528jYRmoba1Yndl6I3PkdfKM+Xp/YwOZ9iuOSsP19nIfBjFtVuFS9pYwpuKr6XCQsDBERfI3DJx+fpyg6+PruT8rkk2TEfa7aOXGNmKnPJ47+evQvXUXHnLCLTCtZUHmWpjGfbNxdfRUUxdPydG6j2WGSyBQaMRUzRMsXUmgZaHVRnjTw3lxMQGQYibCEiYbxsHC9sUM0aVHpUnM11NnYvsEEvkVWvRkkqCFw8XrFDTDgZnp7ZxsRsioFZFxaWWruHt9gNcdUJsNC0ln13oAcnE2HmPot6p+Thh1/hrvgod1sjdChNkoqCIgRHbYtnar38/rMfJnVYY/h4He3kqVby5DUSw/96UBJxlh7JURpS+GTnKTaHZph+8XE6Tkj08UnWlfwubyMbvTHK99TZnZvGFG8uJNSRHjXp8CcvPsbgl2DLeB45OtkK7FljhF88y9BYF7georaAbNr4hSLurSo4CIESCaMkE5z7mQi9u2b5Fxu+yXutBRpSsOTblEoWnTM+WrG5Js4RhGGgWCZ0Zmn2JyluNFi8xyWcrvGu3CmGrXk+ED1OTPFJK9d7QvlIir7Nv5/6ECdmu9G+m6B33CN8dh65XDfuraZ9Anw5h2tnBgwdP2y0Vju6gmdpVPoMmglBZaOLla3x7uQp9pqTbNBAFyFebIQYd9I8W9jGSDlDeFQnPuagTxdw8/m23dZbjdB1ap0KzayPI1UuOWlCi4LIzPqx/QpNaz1Y3Z00hjIsbdfJdc+xPT5zndfD66EpHRzp82Q1x9FqDmtUx5pYgoX8mq2j51frKNPz4Hl4zWYrqvRmOyAhUMJhhGXibu6n2hlCz1W5t2OUAW2JkNAZcWHUSSPzBuaii6iujaTswjAQsRiNgSQLu0NU+yQDuQWG4os8Ej9Dr55nQFPQRSusxF+eVi7vohQgZdRIRurMd8VBqgg/S6grjlZqIOo2FMpvWdWdtghwK2gihLNnI+PvM7EzHt1Di5iai6U5dJsVPpJ9mbRaIak0MIVHl6pcSRiy4NX5hed/BeuUSfqMhzXTYOjSBP7CUmsgrmNkLEJpb5O+njxHS/3M1mNkjzfRDpzCW4OruJuhpFLIngyjH0jziY9/k02hWfaFpogogpB4Y1kuJlyfCTfB//mFn2LoyQrDU2N4s/P4a9heLh0br7D8O38Nu7UwDOS2IWo9EWZ+vsEjQyf4V9kX2WGUiQodH5/PLT7Aty9tpuOggvnd4/irOAT7WpRMisaGDsZ+zOBTH/4MabVCRmkSEhBTVFTElUCdmxFTDP6g+5vUumBxa4iCF+bv8/s4X85y+mIP+pxO58EOEi/PIAulVu6JFaR9K2BFwYlq2F0usY4Kj3RfIKY2iKoN0mqFPcYMpriaTKPmezSFT0Ix8ADpKGgNMBdt9EtL+AtLa3YlczsITUPJpHG64iTTVXKxPLP1GFP5BLmKvabqeN0SRUWoKrInQ3FHktqgwwdix0gr3nW2uzeCKiSq8FE8gXB8pOOs6jDj2+YWwis07cqYkfEI+a0xqt0Kd/ZP8iOpV9ikl8goYZrSpeY7nCp1k59KkFvy1tahtaLgGwpe1ON+s7AckBLCkR4F36UpoeArN/g5G7ikVQdTCBKKQVoI+gFPVmkkjzNo9dD0NCatFKX5MEYxizmqwLoRYKCZVLlz6wX2JSf4RPIg5vIJ9aIn+EZ1K2XfpOi2yqeYikNUbfCR2CuEhWDT0CznRRfxMQNjtI038Tahdncx9aFBKkOSf77lKWJqnd/+zk8THtVR8zMr7iDeDtREHJFKcPFDaf7xz3yTPdY4w5qK+hqlZW6XflWnQynT//AEZzO95L4eJvS1uRVo9epDaBpqNoPflWb0iRT1Po+ffeh57o+eu1KOKSxaE9q0Z3PJi3L2UI6hb7qEz86vi3OEERf+W/4BRmsZjs704rrXr4JDIYcfGTjDZmuWJ6Jn6FJbPsCqEDxsLnB3aI4PRI9T26Txnza+mxceHCLyZDep8xdXtJ3tE2DPQ3Uks7UYF4wOXjJ7MRUbTyrMuEmemt9F1TWoOTq+FIRUj3iowd3WCINajVwkz2I2jBtKte0W3k6kFaKak8i+BqbiUPYs9EUNa04iGmt8JbdcnJVMkuZAinq/y0fjh0koAv1VJofLSXPKvnfdIZEjwZYKRT+EhyCpNJdNVxohoaOLVlWIoegSYx0pnIjFjVXA1jiKimKGELEobq6Teq9FbdAl3Vfgw4lD7DU0LteDK/kNqtLnpcYApxq9WHMK5qXKiiaaeVvwPNSmh14I8eXKAKbSMsmcqvfxzPQW8uUw3mQY4V6/Aq6GJM9pmzgT7aIhdXr1PKZwMIRHr1YkJny6VQgrgocTZ3ClwoH+3WQ3DUOxgre4tCIFgNsiwNJ1kZ5H/LsX8Sa7mTCT/PvkduTyCli1fYy8jer4xPzWY+YbGsVsJ3/0m+/nl/ue5Rc6n6Octfjtzk+SVhRQ1ncKcqcrzkff9zzbrCk+O/Ego9MZNjzVwDg7jbew9MMvsIpRLAthmUz/WA/K+xb4RO4EPapxQxmdmrSZciXnnCyfn7+Hmts6WHGlynwtQqlm4p6KozYF9vYa3ekS/3Lzl3nEvDpBdRhl0vEaXijMekNNJ2nuGaKUM6i+v8xQZpJ/0/c9hvQFtupXp6u6tPlcaQdHywM8/+1dJM5C/+ECjIyvGdvvZbypWYx8kU0TGf7yqQ9e+blwJcmaQ8qzEfXKjUKpKMhwCFfP8tXQo0hdodZpYMcESw832TE4zcd7fsB7rDEetka5zxrj8z89z9OPbGPhW5vJ/dnpFYkxaN8KWEq8uXlEoYhuGBhhC7H8wEnXxS+WW24gy6eOSihEtK+HsXyKic4Mu405TK2GZ4HUNYS6TgV42Se6Gdd5MHaWPrXIpaUEypSJPjWLOz3T7ha+aUTYgmScWo/kY7kT3B85f93BiY9PzXeY9yXH7V6O1XK8PNWP6yxnQfMVvLKOWlHInAGt6TOTMpkFyr4FXBUVR6o4noKxFnysfhjLPr3oOsI0kV0ZSoMGlZzgscFz3BW7yHusGeJKaxfhI2lKhyXf5WBxiKOzvcQvQPpEBTG9gLeWbL/LSMfGc2wolVDOv+o1bi+iWgFQVFKD/XiZGNXeOKfUbo7Ec2zU5xjU6vSrJh+IHWU4NM/vdf9UK0bhVmWOXgftzwds20jHRVybLEf6N5QQkraNrNaoX+jmc8Z9bNs2xR6jRDMtqW9IE27aa6qG1e2i9XSRfzhHfmtrgjnW7EM9HCN90YdCuc2tWxlqd29g4Q6d7P4ZPpE8QEIRcE1FhzOOx1dKd/LC0gZOvTyINavQ90IdtbH8AEgP4TYRjocolJERi4XdnXiAJ69OzJ8S5cIAABOPSURBVJ6UPDu9icqRDAPTa99bRuvpor6zl0qvzuJeieho8uPbDjBkLvKeyGkSikdUuXp42ZQOX6n2cLg2yLEv7qDzSBNzdBq5sIRfXweHuG8G38OfnkVZzDNc7sRLWPz9h+7jW3ds4ZObXuCXE6MMah5ZdQw/6SAjFrhv/uSl/YEYUoL0bq/woeeh1gWlqknVb1nwfEPihhVQV3HpmDeKEMiIRXlAodnpsehGudDswpqXhGdsWA8ud0LQyKjUBjzuT82S0270dphxYxws5Dgz1UXinCA65aK9dOo6z4/LU7UIhVDJAKAoPqq4utT18SlWLMwFgbqGQ21FKIQSDuN1pigN6FRygoEd0+xKTfPr2edIq2orInAZF48lr8mSr/JSZZiXlwZInXMJvXwe722I9npDKCpCEVeyHl7O5/JW1nb0G41WuflSCRSVyL57yHfHmM4l8fEJKzphQDX8lt4obz6DXPsF+HYQAjWZhM4Mdp/DvX2X6NaK+IBWFYSWbGiuwkH0JhChEGo6RWVLhsi750ipHn/w/R9HnTHY9FIBMbWwttyFXoNqj8LG7RPcHb/+hLkpHQq+y5/N/Chjf72J7jmf2Ik5RLWOexOfZyUWo/rYdio9KvrOEvt6JhjQloBWshVH+tilEOk5H7VqrznPEREKoVgmlUe3MvG4JNJR4z25g/SFCtwVHiGj1OhQtSvmG0d6jLkuZ5xO/teDP4k3bZE5KgjPuUSPTeJVqqsz6byiog324yci5HfFqWcUEqOtMkraxBzuzGy7W7hirAkBFqqKiIRxoyGseINNkXlM4WFLiWILtOr6y4GgWCZ+JkmtQ+XR7gss2FHmn+8hOikRl+Za5evXCa4FW+Nz9OnXHyY60qfsK1zIZ8gerqDNl1pJxW+yAhKahhIJU8ppVPsk2zOLbIvMElMcQMWTEgeJaCoYZX9NTtjCMBDRKKUBjcf3vcz+6CgfjV1ER10WXQ0fH09KGrQSzIy4WQ5Wh9FORkiNSrLfm8KbmsFdxQnnhSLwklGanRbFDQqNXhchNRAW0UocsZi/aqaEFb+Py4FiUhMITV63i/KkbH3dCn3n2hBgy6J6Ry+lnMajg4f5UOJlRpwsz9hZIlMSZXwWWSy1u5krgtA0lHCY+r2bGf95j57sLLnQEieKPXQfsLEu5vHXmqvQG2TMVXmmup2lmQTdExP4pfKNA18I1ESc+n1bKPdpWI/P8mjnGD+WOM6AVqBr+XB2zHWZcJOEJ1SiR8fx1+B5gb9zmKl7o1TvqfOJzPN0qHVMcTUz3IJX59n6AOebXXxlchcLhSjmsTDmomTgaBl1qYI/t7Dqq30ITWNxT5zSsGDgoQne332c43v7uVRLcPbwAMlTWaJTLtZ4GWWxsKIH0ULTqPzEfgobVYwHF/nFoSO8J3oSgDHXZsRJIwsGotZYkeRNa0OADZ1qt0atF/ZHR9mhe7xcT3C83E+o4OPNL66brGdC0xCRMJVenf9l3zfo1ooUvDClpkni7DzuxbF2N3FlEUqrsoO48fdX8C3O1LpQSyp+vnB9tN+y77DQNUQsRnFYozII/yx3kCeiJ5b9f1t2UB+feT/MqNOBuSRxJybfrrtbUepdJqXtLncOTLI/BMqyT69LK5PZvK9xqDrE8UIvCyezWHMK/f9QQJkv4M0ttLwFXs3l6EPPWxG/1hVBVal3Cpq5Jj/Z8zL/JDFKOX6CmpR8Un6cEdGHZ+oIL4olJWIpD75srYp9+cbyNVweT5ZFYZOKvbfCzw4d4eOJQ8QUARjMeBFONvpQqwqy2XIeeLOsbgFWVNRUAtnfxcIDDluGZxjQFxlzJf/2yI+gnQ0zeLGEXC0DZwVQujsp3dlDcQvsN0c51Bji3zz7ONGLGvHymXY3b0VRYjFE2MJO+OwIT9GhluGaVJMnmn08N7YJc05BStnaHUQjiEScxqZO6h068/sEbsrlzq3n2Rhd4D2R02RV9To3Nkd6/NnsY/xgYpCumVVo8/whqKkUIhlnYZfGLz3wLfaFR1EQqELBkz5frmb5nZc+gpwLkXpFYFQkQ7M2WsVGTM7i1+rIV+W0FXori9jMz+2kuFmSOCuITXpETs/jrXC01xtBcYCmiiNb+V/CQkcXHr+W+zYHM8OM3pfhUjXBbCNEpbYVZ8EiPK4SveSTPjCHqDfx8wXwvNcO01dUtL4eZDzC7ENpqr2CwYfG+XDPYe6zRkirKguex4gL/9Pxj+E9l6b/hIMsl1fE7LmqBVioKiIawU5bbN84xSd6X6RTrTDvRdDOhuk86KLO5HHXifgC+PEwpUEVp7vJBr3Bs1WT9GGV+LiDbNqtmfr1skr7R4QMRCSMH/YZMBZJKjbXup/NOXEa8xapYqv9wjAQkQheNk5ho0ElB+9/7CX2RUZ5IjJ2Q1n2yzjS5+hsH/75KEZ+7WWMu5zjtj7g8luZk8uTS2uiUoXCoeow8e+ZxMddrG+fwF/Oiidp2SxvvKBAGDoiEiZ/t8Mn7/oen4k9jBvRCS3GEedv/MjbipQIF4QrruRw0IWKjsr7w0XeHz5ypQI0tPyb/7rcw5+ce5TFUxlio8uZzJo2ODbYzvU75Mt9IgRC1/DTMZpdEZbuctm5eZLfGPgHHjUdoBXoc9F3GLE7qZxOseVLl5ClMt4K5V5ZUQG+7DLyZt1FhKahJBOQTjL3UCeVAcEvdZxmszHLF4t3caTQT+KcT+T0PP46sf2qqRR0Z5m7O0Xi8Wl+tGOUsFC5P3KOv/uJPUyVwzgf2AGeQLxGt0rBldeFI1Bc6H3OxXr2RMuXehUdVgpdR1ohMHwyShXzVTf27ugpFu+O8p3uTVyK7seNQCNno4cdhjsn2Bsp8MHkYbrVCuHldIPX0pQOX6zkOFLN4R5I0feyjTG2sOZyHfjJGNWBMEq0ZULwpI+6nJbTkz6PxE7zzce3MrYQI75lL4oNwpMoLoRKEuFd369SgWq3ip2AB7ed5JHoaf4ifD/IG/uwHUjbpvNQlei0yae7H6K2JcT7Yse5w7g2OEeiLFc+B9hnjvNLG5/ncFeO5zdtoFYJo8ymUBsQygsuW7gUB6xFHyTUMwpuGKrDHkrS5se3Hee+6AU2aEV8LKa9Okuezh/P/AjfHxsicRZkvrCiBX5XXICFYeA3mwh4wyIsNA0Rj2H3xFnaLTFzJR6KnGFAcziwOMT5sS42j9RXxVZppRDxKPWBBMXN8Nmtf0O36hESJnuMOp/Z8Zc0pEpDanjLKx9f3hj5d60d1ZcKl9wUS26UT+ffz8APTPD9VSXAaBpSV1F0j5hiX0nGdJm7QjV2dD/D/fHz/OfII2xOzPM7PU+RUMTyavcyN656obXyfSa/jUNTA3QcdQg9c+wtLS/zVuHFQtSyClb4qg332gKR+4wF/q8dX+JoI8d/69hP09GxbRXX1lBmQyjO9f0qVdA3luhLFflE5/PsMxroxuoZF9J1UY+cI34hyvydG/ladCcbhufYZVxNntSqZt3SFh+frbrK1sQo/yQxCn3PMeI4fKWym4lGmkMLAzheS7xrTYPSaAzhQ2i4SCZa45P9h9kWmmaPsUhaDQEhfHxmvBAjdiffHxvCOBIlcbG5OsvSX65ikf/QTsqDCl0HbaxTM/iFIn75h0drXT75F5kUlV1d1NMq+e3gpl0e3HWanLXEyWYfz1ZjjL/YT8c50Kfn1txK5rVwetPM3G2gbiiTVlzCyxnAdFQ6VAdPejjLIbUNKXCkQsEP4cjW+zwE43aWsm9RdMM0pUZUbRVbtGMSf6gHdWoRfxWFLstiCcV1MU9u4Fc6fo6P517ilxOjV17XhUoY2B26xM8MHKRPz5NYLqJ4K5rSoSY9nqoOcqLez4vP7iRxHsIXF1q18lapOea1UBwPrQEV5+b3HVM0Nup5IkoTdVjS8HWaUqPh61zalMT2r/+cIiS7YlN06iU26nlWoyVS2jZUqnT9wKM0083/dsdH+dTAIo92n+PdsZMMacWbBu1cJqNKHgifo2CG2WzN4chlAfYNTnd3A7AtOkNKq7LfHCWtNggrLXfF5xoxzjW7+dOTj+COREmcheRIg9DY0oprzsoIsGGgRMLMvsvjo3cd4Ove/fQvJlAd57YFWMRj2Lk00w+qeH0Nfmv/P7DDvMQeo05T+vzbhQd5cW6Y3u+5WC9dwFsnpofL1HtM5N4yDw5cJKm0MnhBS4RSQr3O5lWRTZrSpyq9qwIsFcbtLDPNBFP1ODXXYGdimgFzCTcmqQ1EiNYdWEUC7JVKUCqRfSXHvNbN181d1wmwgkJIKGzXYfuVn986IbuPT0N6zHsK/33uTk5M9zDwtE3oxdP4zVsXqFztCMdDa0g8T8HHR+FVqRWFTk7TyWmSe0LnrvxceVUOXP+mmREsmnL17Qqk6yJdl8jTJ4haJvGJjRSHuvnC/VHUbT5q9BQ57dbtTikm94QkUAVr5EolDAAyx2+oK+gv76Ic6fHt8nZenBsm8o0oXd+aQuYLeIXiW7Lge/MCrKiI3i6cjjixzgoPxs7xpd17uSQTWHNxrMVBoGV3erWnkR1TqGcVPBPsmMRJ+WzZMU5fuEhMbbRmoekdTJSTzB3rwpwXDEwsIuuNlrvJOkIvudiXIpyLd+D1S6a9Oi/UB5hw0ry4tAHb1zBVh4anc6mYoN7UsefCKM3lh0yCVlVQHFBtwIex8CC+Iel6RRIZKSGWVqfvqznfIDYa4ex0J88P6Axopddc3bwaR3pMeg4jTprfO/MhFhZiWGdMoosSc2Ie/22q7/VWoSwUiekqS2fi/Obmh7grdpEnIhcJCeW6kOP1yOVafZGREno5wpId5W8uPMJf5e5hc+8cWbNKj3l1XG+zpvlg9MJt901TOnyvkeCSk2LSTrPoRHjypTuJjGr0nKsjS+UVtfm+mjctwEJVcfqSlAdC7O8e4fFwnlfueJFvd29hYj7FfD7EdRPxNboZ6S7zUxsPk9XLbDZmyKhVtutQkQ7fb3RwpJrjyNPbiE5KNj87B3ML+JXq6rJjrhBGvkH0YohL2SS+lIy6Uf5q+j7Oz2XRD8YQLvgGqE2Ij3pkCg7GkbN4hcLNLyjlVY8JKfFh1RZZ1CYWSNcd8juSfH3bHTwYPUtOu/1EQ03pctLu4h8Ku1D+JsOWE2WUC+fxSqU1F258M9ypacTsHJ09+/haZi+vbOth35bxlqlqHaZAuRbp2K1cFcdOowHd348gLJPG3iFmtg8ymoVmt3NFYwaH5rl7y+ht901Devx9/k5O5ruZmEvhl3QGn5SED55DllfO2+FWvGkBlp6HPlchBrwwNsynY5to+Do7k9P0hEuUeltbRgWJ/6otUY9VpMco4EiVH9Q2UnQtJuop5utRRiY6UJd0ul7xsBYcKLZmorW8knkt1MUy6TNhFNfiLn4Nv6phTeqESpAccRGeRKoC4UrMuRpKpdkqwvlaNs01Yu+U1SqKopA5nuBvtfv52947+Wz/FGHNJmnU6Q0VuCd8gZoMMWFnrrgm5d0Ih/I58g2LqekU6oLB8Ggddam0vmoDSon0PMITFbKHEswu9vLThV+iJ1Xisa4z9BtLPGBdJCbkbZdumvbqlH2VS26cea8TezJCZtxDK9RW9aQlbQekxJwqk1IF9oxKc/KqjM1f7OXHJ34dofkoektxbvUUCMD3BOqEiV4WxEug1SXWpTyyVntbFnpv3gThe3hnRtDGTfTDe/gT/1Ge2HKCH0+9zJBWpF9rZS1rnVpevwab95qcd+IcaeR4amoH0wsJoi9ZWAs+2w/MQrGCXygiXefm/ozrCHd0HGNsgk6h0PVp9YrfolyO8LmWW/p3rlG8QhEKRRJ/O0/qKyHcncNM7tiAExXYCWj0OZze3c1SM8y56U6kbAmwXzDo+L6CteSy/dwSolzFW1jCXY3Zvd4sUuIfPUXqmEJHZxavL0thew+feaiLVG8RZYvPkLFAQrFfswgltEw2J+0Mo3aWF4obGS+nSZ0QJA5NIxdXd0XxKyvik2cxTgkMIPqqStlCWa6wctsX9a8zafpv41nByhx/+q1y2MlzHkU3wt/P3cWT6d0YpotptGw4QsgrD85lmo5Gs6HjV3TMGY1oEVLnHfSiA4VSy9brOmtmJfemeT2pOdch0nHxfYk2XyIxquGZCnZUpb6gc6C4veXDuSSuLGn0qiQ+2kArN1s7pHpj3e6QgCvjw6/WUBfLxMYMmnGLxkSG35/7CRTTIxJtoL6WozjgSUGlaCEbKlpRQ60LekebyHKl5X2wVrisC/L63/laykqwYv4nstkk+uVDRFW1VdlCUW4vamu5E1vVL67Gc3vvUBF6R+O3Jh/vwijayDi6IjCBuFDoXk6qI6+djP3W1txfR6Hot4NfLuNXKijjl+j6fisvrbicD/t2IyUv95fvI6VEOm7wzLWBFXUAlK4LrntbZUACAm7JlZ3ANT9afZ5S7eVVu6XgmVubrNNCagEBAQGrn0CAAwICAtpEIMABAQEBbSIQ4ICAgIA2EQhwQEBAQJsIBDggICCgTQj5OvwnhRDzwDorSnYDg1LKjtt98zukT+B19EvQJzfnHdIvQZ/cnJv2y+sS4ICAgICAlSMwQQQEBAS0iUCAAwICAtpEIMABAQEBbSIQ4ICAgIA2EQhwQEBAQJsIBDggICCgTQQCHBAQENAmAgEOCAgIaBOBAAcEBAS0if8f9sEGrksMge0AAAAASUVORK5CYII=\n" }, "metadata": {} } ], "source": [ "plot_example(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "YQvC-rWf_v7f" }, "source": [ "## Build Neural Network with PyTorch\n", "Simple, fully connected neural network with one hidden layer. Input layer has 784 dimensions (28x28), hidden layer has 98 (= 784 / 8) and output layer 10 neurons, representing digits 0 - 9." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "5PG7R0W8_v7f" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "NUjWrGBP_v7g" }, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "gVCW8F3N_v7g" }, "outputs": [], "source": [ "mnist_dim = X.shape[1]\n", "hidden_dim = int(mnist_dim/8)\n", "output_dim = len(np.unique(mnist.target))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "ShYicHlv_v7h", "outputId": "95071ead-b292-4702-9093-4d6cc1f0f94a", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(784, 98, 10)" ] }, "metadata": {}, "execution_count": 16 } ], "source": [ "mnist_dim, hidden_dim, output_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "OeVnFhBS_v7i" }, "source": [ "A Neural network in PyTorch's framework." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "Xxtli0l__v7j" }, "outputs": [], "source": [ "class ClassifierModule(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim=mnist_dim,\n", " hidden_dim=hidden_dim,\n", " output_dim=output_dim,\n", " dropout=0.5,\n", " ):\n", " super(ClassifierModule, self).__init__()\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " self.hidden = nn.Linear(input_dim, hidden_dim)\n", " self.output = nn.Linear(hidden_dim, output_dim)\n", "\n", " def forward(self, X, **kwargs):\n", " X = F.relu(self.hidden(X))\n", " X = self.dropout(X)\n", " X = F.softmax(self.output(X), dim=-1)\n", " return X" ] }, { "cell_type": "markdown", "metadata": { "id": "LlEHSwjt_v7k" }, "source": [ "skorch allows to use PyTorch's networks in the SciKit-Learn setting:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "s0aDatqN_v7l" }, "outputs": [], "source": [ "from skorch import NeuralNetClassifier" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "dOrCbBjk_v7m" }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "\n", "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "8i_gnvPi_v7m", "outputId": "002bd91f-bc4f-4e24-af69-f2e2ef9a4771", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.8387\u001b[0m \u001b[32m0.8800\u001b[0m \u001b[35m0.4174\u001b[0m 3.8169\n", " 2 \u001b[36m0.4332\u001b[0m \u001b[32m0.9103\u001b[0m \u001b[35m0.3133\u001b[0m 0.8510\n", " 3 \u001b[36m0.3612\u001b[0m \u001b[32m0.9233\u001b[0m \u001b[35m0.2684\u001b[0m 0.8208\n", " 4 \u001b[36m0.3233\u001b[0m \u001b[32m0.9309\u001b[0m \u001b[35m0.2317\u001b[0m 0.8079\n", " 5 \u001b[36m0.2938\u001b[0m \u001b[32m0.9353\u001b[0m \u001b[35m0.2173\u001b[0m 0.8074\n", " 6 \u001b[36m0.2738\u001b[0m \u001b[32m0.9390\u001b[0m \u001b[35m0.2039\u001b[0m 0.8277\n", " 7 \u001b[36m0.2600\u001b[0m \u001b[32m0.9454\u001b[0m \u001b[35m0.1868\u001b[0m 0.8224\n", " 8 \u001b[36m0.2427\u001b[0m \u001b[32m0.9484\u001b[0m \u001b[35m0.1757\u001b[0m 0.8623\n", " 9 \u001b[36m0.2362\u001b[0m \u001b[32m0.9503\u001b[0m \u001b[35m0.1683\u001b[0m 0.8312\n", " 10 \u001b[36m0.2226\u001b[0m \u001b[32m0.9512\u001b[0m \u001b[35m0.1621\u001b[0m 0.8221\n", " 11 \u001b[36m0.2184\u001b[0m \u001b[32m0.9529\u001b[0m \u001b[35m0.1565\u001b[0m 0.8158\n", " 12 \u001b[36m0.2090\u001b[0m \u001b[32m0.9541\u001b[0m \u001b[35m0.1508\u001b[0m 0.7974\n", " 13 \u001b[36m0.2067\u001b[0m \u001b[32m0.9570\u001b[0m \u001b[35m0.1446\u001b[0m 0.8123\n", " 14 \u001b[36m0.1978\u001b[0m \u001b[32m0.9570\u001b[0m \u001b[35m0.1412\u001b[0m 0.8304\n", " 15 \u001b[36m0.1923\u001b[0m \u001b[32m0.9582\u001b[0m \u001b[35m0.1392\u001b[0m 0.8421\n", " 16 \u001b[36m0.1889\u001b[0m 0.9582 \u001b[35m0.1342\u001b[0m 0.8153\n", " 17 \u001b[36m0.1855\u001b[0m \u001b[32m0.9612\u001b[0m \u001b[35m0.1297\u001b[0m 0.8458\n", " 18 \u001b[36m0.1786\u001b[0m \u001b[32m0.9613\u001b[0m \u001b[35m0.1266\u001b[0m 0.8827\n", " 19 \u001b[36m0.1728\u001b[0m \u001b[32m0.9615\u001b[0m \u001b[35m0.1250\u001b[0m 0.8335\n", " 20 \u001b[36m0.1698\u001b[0m 0.9613 \u001b[35m0.1248\u001b[0m 0.8112\n" ] } ], "source": [ "net.fit(X_train, y_train);" ] }, { "cell_type": "markdown", "metadata": { "id": "5c3iyCKu_v7m" }, "source": [ "## Prediction" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "D7rdey0s_v7n" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "b9B8Zd6e_v7n" }, "outputs": [], "source": [ "y_pred = net.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "LCWachuM_v7o", "outputId": "c4785fc4-2ab1-4717-c024-c80f1aa34ef3", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.9631428571428572" ] }, "metadata": {}, "execution_count": 23 } ], "source": [ "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "0eRga6AV_v7o" }, "source": [ "An accuracy of about 96% for a network with only one hidden layer is not too bad.\n", "\n", "Let's take a look at some predictions that went wrong:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "XddXR1_a_v7p" }, "outputs": [], "source": [ "error_mask = y_pred != y_test" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "hXlNTrlt_v7q", "outputId": "11223953-1853-41ea-e43e-c384061351ca", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAABbCAYAAABNq1+WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO29d5Rd13Wn+Z2bXw6VI1AFFAJBEIwiRYlBlGgquGlJlmXZHtuj7tXTlt3tbi/b7ele7dVjWx6PZqaDuzXubvfIkiVbstzK8igyiqJIEDnnUDm9qpfjvfec+eMVAimABGQW3ivgfmthFYB69d6+p+79nX322XsfoZQiICAgIODGo7XagICAgIBblUCAAwICAlpEIMABAQEBLSIQ4ICAgIAWEQhwQEBAQIsIBDggICCgRQQCHBAQENAi2lKAhRBbhRDPCCHyQojTQogPtNqmdkAIsV4I8S0hRFYIMSeE+KQQwmi1Xa1CCFF6zR9fCPGfW21XqxFCPCeEqF02LidabVO7IIT4iBDimBCiLIQ4I4R4qJX2tJ0ArwjK14G/A9LA/wL8lRBiU0sNaw/+DFgA+oA7gUeAX2+pRS1EKRW98AfoBarA/2ixWe3CP71sfDa32ph2QAjxOPAJ4KNADHgYONtKm9pOgIEtQD/wH5RSvlLqGeBF4Jdba1ZbMAL8rVKqppSaA74DbGuxTe3Cz9KcnF5otSEBbcsfAH+olHpZKSWVUtNKqelWGtSOAnwlBHB7q41oA/4j8BEhRFgIMQC8h6YIB8CvAp9VQW39Bf5ECJERQrwohHi01ca0GiGEDtwLdK2ENadWQnihVtrVjgJ8gqYn87tCCFMI8VM0l9rh1prVFvyApsdbAKaA3cDXWmpRGyCEWEfzHvnLVtvSJvweMAoMAH8OfFMIsaG1JrWcHsAEPgQ8RDOEdxfwb1ppVNsJsFLKBd4PvA+YA34b+FuagnPLIoTQaHq7XwEiQCeQohnTutX5ZeCHSqlzrTakHVBK7VRKFZVSdaXUX9IM4b231Xa1mOrK1/+slJpVSmWAf0+Lx6XtBBhAKXVQKfWIUqpDKfUEzdn8lVbb1WLSwDDwyZUHawn4NMGDBfArBN7v66FohvFuWZRSWZpO3OUhqpaHq9pSgIUQdwghnJVY5+/Q3PX/TIvNaikrM/Y54GNCCEMIkaQZ9zzYWstaixDiQZpL7SD7ARBCJIUQT6w8P4YQ4pdo7vYHewVNh+WfCSG6hRAp4LdoZlu1jLYUYJpLylmaseB3Ao8rpeqtNakt+CDwbmAROA24NG+iW5lfBb6ilCq22pA2wQQ+TvMeyQD/DHi/UupkS61qD/4I2AWcBI4B+4A/bqVBItg0DggICGgN7eoBBwQEBNz0BAIcEBAQ0CICAQ4ICAhoEYEABwQEBLSIQIADAgICWsR1tTK0hK0cIqtlS1tQo0xD1a85af1WGBOAItmMUqrrWl4bjMmVuRXGJXh+rszV7pXrEmCHCPeLd755VrUhO9XT1/X6W2FMAJ5SXxq/1tcGY3JlboVxCZ6fK3O1eyUIQQQEBAS0iECAAwICAlrELXucTUBAwC2IEGi2DaaJsC0AVKWKcj2U58INrgwOBDggIOCWQe9Ik3nfJsr9AuuBZWzTw/vyJlIna5jHJvAzSzfUniAEERAQcGsgBCISpjAqqG6t8cntn+fTt32W4nqo9lgIx7nhJgUecEBAwE2PMC30ni7qI12E71rifQOnGDIqFKXeUrsCAW4nhADRXJQIbeXv2jWmVEoFSqJWvgI3PJ51wxECoa88QCvjpny/ef03+7VfK2Ll/hGXFrvi8nvqtffYhfvI92+qMRSWiUzGqHWaPNA3zhOJQ8SERg6BkAIhacn1BgLcYoRpIRwbMdRHeUOSRkyj2qFR7wBvUwXd8F/1eqWuLMjS1zBOhHGWIHW8jj1XRswu4C8t34jLuKFojoPW243XnWDm7TEaSUVjfR3lavR9zyAyXcM8Mo6fzbba1JYgbBu9vxflWMiog2/r1LpsfFtQ6dLwHfDCIC1Fo9fFijYu/qy7GMJa1uj7kYf1nV0tvIo3B2Hb6J0dNEZ7OPMRi+hAgfen9rDezPF8rYeXShtJnpLEDmeQufwNt+/mEuALs/0amrmFZaJFwtR6YyxvNainFXKoym2Ds3xuw1dJaJcObfUveLZXoCBr/FLfhzg+2YvwbRK6IJwvwU0owMKy8DvjlNZHcN9a5PbeWT4x/DWm/Sj/eOZjSMOh41wIbjUBXrn/NdvG64rjR0xqHSZuSKM0KPCiCm9dlUi0xrp4gbRd4df6nuVt9qX76lOFQb6XuY3Ti5vouQnO0BCGgUzFKQ3ZPPnAbh6LH+V+p4Cr4Gh1gAPZAaLTDeS5yWYWxA3m5hBgTaf8gXtZ3qLjLCnsvCJ5NI88cKzVll0VffNGKqMpCusMiqMg+2rsWHeKDrvCSCjDRmcORxhUZIN5v0FF6cz5UXz16n1TXUgiogHYvL93P7PpJN/p2srEUpyOp4bo3BWB+cwN391dDfR4HDXcT3kkzuS7BSJepyNcI1ON8lvnP8RSNYxVFLTBUV+riuY4iEQcYZpNLzcZobQ+Qj2mUVoncCMKMVDFsuvEw3lMoYgrgVICX2pIBWfmujjpaZzOddIdKfFg+iy3hybpNXL8T70v869Sm1p9mW8KWmea+ftTFMbgwdgptlkLOMJmxnP59OG3IsZDDOoNtE0jVNbHccMayV2zeOcnboh9N4UAC11n4R6NB95xmJ0T68nPhDGrMUIHWm3Z1akPJMjcYVLdXuUf3fEj7g6f552hymtepVFUNSb9KDk/zIl6H3VpvuYVik6zSFyr8t7ISXriIT6Q2Mukl+Sfz36U8HycSLUON4EAi0iYyroYS7cZfOyR7+IIj2/M3UGmFGH8bDd6SSNWBnFz6y8iFIJkHBmy8KIW5QGHzJ0Ct8flgzv2Mhpa5N2RYyQ0QUyzqCmP71Z6mWp08NzSJuZKMfyMjZnXyM3YLBsdZG8PUeqxeXfiIG+1fX4nfvXV1lpCxiLkN4E+UuIee5r1RhiAZaljHguTOC1RuqA2EGPxDpNGUhGdTEIgwK+DpmMMDyCTUaYfS1Ja73Pv3Sf5mY79nMx2s0i41Ra+IctbbLreNc29HRO8PXqCXr0MXEqDOea6/HX2AU4UejhwZghR1gnN6ohXh4RROlQHfESiwf/5li/zZCRLr+7jiCU23TvOsXQ/XT/oo1MTkM2vyZiw3tmBHOknsyXK8nuqpBI5TpZ7mSilmHhpECsn6JuSmGWJvVRFr7jIwk1wRNzKpqzR34vflaS4MUpuo44XVbhJiTIlwvEx7QpdiRIAP1oY4Vl3jH+XfRwlBUJXKE/DnDMxKgI7B3pd0Z+VaJ7P0laDWo+iK1RmU2iO/zT9Ln59sYf04ZvjEGUZtTBGS9zVP0VYNEN136/28Vx+K86iwsn6aJ7CdzRq26psG5pl6vwI3e42tPG5VV85rkkBFpqgMZimPOhwz4cP8enhF4BmjPRzkSKLdLbYwjemOCr5/pbP4wgdW5hcLr4Axxs9fPnYnYipEOue97AXK4gDJ1H1V59NqjkO9bdvozjs8Mq2UZ6M7CGlOaQ0+D9GvsLkUJLfLP7PRKfTzU9YgwJMOsnytiiZeyUvvv2TTPo2v3/2A5ya7GHdDz2cmSKcPI+s1QC4OXw3QGgI08DvS5PbFGHhQcU/fOgZhq0MdzpTRIRHWtNYlIpD9T5+VNzIVw7cjTlvsu55D6PUjGlqDRd9fAJVKiOrNZA+wrbRQg7FgW1UBxU9oQJbrFn+4OST9H3fIHEsd1OMoxcxeXDoBI8kjxPWdJZ9n68t3s3B+X6SixJruYYfMfFtjXdtOs7v936fh2//baxijHRp9VeOLRVgPR6H3i7c7hjZLSE0D0IZHyvXwNh9HEwT954xvFAz1Ujaguwmg0ZC0ejyMGNVfjN1uJWXcN3oHWlEPIYflegItJVamHm/ylkvyjdzd/GVI3ciFmySxwV2ThI+X0CUKvi+/wbv/mq6NA/dyCLjHrUOE3vBXo1LWjW0cBgtESe/vZOld9Xo78zzldJWXsiOMfnsMOk5RWhiCZEt4Hve1d9ICIz1w/jJaPOfSiFmM/jzCzfoSq4PLRxG2DbutnWUhh2yWwRyU5ktPRmieo2TtT6ezW1lvhpjKp+gXHLQZhyMkqBrQmEXJKGJPKK2kt3g+8hSGeV56Ik4IhJm+eEhSv0a1buqjPUu8tLMep4+vZnkHov46QJaJr+mBViPx2Goj6X1FttjU2yy5jHRWZQGL53YgDNuEZksYiwWEDKGb+ss1qLM+DZEXcq9Dsno6j8vLRVgkUpQ2tzB8laDkfedpVB3mDjRQ2QyzNDxCMJxmHm7Qz3dvBVkxOc/PPZZ/kG4cPE9dLHGivm6O6gOJTDiDTRxaZk35YX4Tv4OvvTKfdz2J7OoQglZLKKkwpfXJ7wX6NRDpHVJLF2m0p0iOuWwlhaWIhbFH+xi6XadL7/tv3LS7ea/jD/K+PFetv71NHIhg1+pvGHWi9B1qmNdFAcvxc879yhoUwEWkQgkY8zfH6Z2b5kPb93L73ft5bTrsac2xA+XN7LnyCihaYPuvS79CzXYv+dVvQyudMcIw0CkEri9CZZ/psLPbd7HWyJn6NBL/OqXf4Ph77uETk3jnRvndaazNYFIJchuT1HYAA+FT7LZlNjCYtJNktplkjpRRz85gV8qo+lDGI7BfCXGpNtBJF6jMmDhxW1Wu0yjpQLcWN/JxHsg1p/l8c5j5P0QL1t1JoaSnHM2o3SI3J1hKNzcnIpZNUbNDLr48ZLBk26ZSS/OiYVuQnMaVr7xY69pC1wPveIh/VdL4clGD0/NbMaZM1DlCqpavaZkeGFZZDfZFEclI/biq753pOFx3uugOBNjcMrHyFau+GC2K6onTebOKLWhBgnN5Uh1kLkfDtA5rlDFEqrx+s1ThG1T/akdFAd08psVdNWQdR1cDScXJ3oqjGo0UK/nPbeCRBSvK0alT/KW4Qkmqil+ffId7FsYIDeRxM7odEwpQss+oakSWr6Md4VGMsK20YYHkPEQhbEYjaig2iXwQwqvUeMb52/nC6V7kRWDngMKZ7qIKpZadNFvLjIVY2m7gNEySa2BKa7szSrfh3wJ07ZYzMY4WB2iN15kdp2gkYgQuuJPvXm0VICXbnN44af/L3r0ENqKbyY7jjS/3tf0erXL2lVoiCuKr68kL1Q28HxuE96ZKJ2nfKyFcluKjajWMYp1ZCOCVIoLLunu0gi5g52kzihkvohyr3ECCTnk7nJ527ZT3B06z4X2HhLJ98u38XxmE4njBvF9U6gWJJr/faisi5N9uMZ9IxN06QYvZ0YY+a+n8ReXrmlVoEUjTH/E5Z/e+T0eixxn1IDDDZMzbjd/PPXzxHcmkYUiqthGG3ZC4HfGKA2FiI3l+N8G/45fPforFJ/vofOgS/f3972q4lEqddVQgRYOU9jRRXFQp/Onp3ioY4JtoSkkGn/w3PtRB0OMPZWHQ4f+XiutdqTaH2HbQ6d5a/osPbqBcTVfVin8xUV0t4E/t5WXu0e4Oz1JZ0+RL3b81M0twEoIHCEwxaXBudLf3giJ4nOTDzB1oI/0MQhPVRD59pzJVaWKljcwMkm+Xl7PbfY0dzW74iFkMz7J6xRcXETTMfp68Ps7iHaW2R6bJqk1AIfTbp0ZP8bnz9xH6USKvnEPVSo3Pca1hgBtJa/M1H1kPIpereEXi1f0foVto4XDNO4cId9vsWVgnO3OJF2ahylsvl28g6dnNxOeV6haDdz2GxOt6mIVfebGE/xh1/tYONZF3ymf0EwZebWJWYhmzLynCxkPUxqJUUtqLG9XyGSDPs3nXLmDZ2fGKJQdEkcNYpM++nIBr91WAG8CShOkrCoJvQo0NaIk6yz5UfQ66HW/WXZ9AalAgVzxiPQblEu+JrMgXourfDJP9zP2p3tRfrOO3bsWEWsBfjYLuRyJ0wP86YnHeHL9Ie7qPHjd76OFHKpb+ygOWzwyuJcPxPfTo2u4yue75dt4OTeK+HaKTd+eQmXz+IXCG79pmyKVwEcRN2vMD/VjaxqiUvnx0IGmoyUTqJ40Z35BZ9umcf7l8Le533YBm5ry+KtDbyH5A4euvYX2TMlTCm2pQMhX9L6Y5MDEbYzsqWG8cBB5tU1YIRCGidaZJvuWXopDGv3vnuCe5BzvT+6loXT+3fgT7J8aIP3NMMNnqxinTyGzWbzr3NhdK0hT0Ofk6TVyaGjUlctZz+BMrRurpNBL9bYIPbVUgMMZye9Ov5sBJ8eIvYgumqLpCJdRawGpNM643ZjC4/HQLCm9md/rK8mEV2FZWjxf3sKJSg+RGXUxDelyjN4e/IFLaWn6cqlZ5dLKcmWliE57zB5K8WX/ToatJXYtDmOUBEbtDSYOIRCWhZZKsninTXmdz+2RKWJCYaIjkUzW0owXUlhFhcoXkK9JXVsr2MsNrFNhDlr9nBgw6LRL7HvIxlnqJrk+ifAVXkRHSNCrPkKCq0EjadDZn+W+1Diu0jnWqPOF3P0cyvXjHAkRn3DRs+W23WhS1SqarhGZCaF5VjOcdoXJRk/Eoa8LP2pT7Q5R7dDJbgMv5WIbHsuNMJ9bfJCleoRThwdxFjSi0zWMxQKqXG4LAVpNdCEvakpN+ZxxezhX7sAq+GiF688qWg1aKsCJvfMc/9Nt7OvSKIz5KL0piiLi8ZaN52n4OvtOr0OzfAYe/BRvW4lKePh8vXQ7u/Pr2f3UVlLHFR27F64Y863eMcTUY+bFBUXqWJz01Oy1x1hXidBzR9jwSojcO8f4+DuexF7Qic0p7GWvGd+7CkLX0VNJGiPdPPjhffxm9zMMGhAWzWhVRfkczA0wN97ByKKHv8bivpdjHDnHyEIHsz/Vyxc33M+OyCT//B8+y/56P5849QS6JrmvawJP6eyaH6ZUtanPh1GG4uNjT/NEeIK/K4/wrfIOnvmLB+h7bol1C6eQuXxbe37+chaWs+jTs0R1Hd+9gqdvmciNg0w/EqPaJ9l41yTbYxk+3LGTRS/Of596iN3Tw+ivxAnNK7b+YAa5kEHV6/hSwU0U770WliV8d3k7B6YH2HAu23onbIXWhiCKZaKTNcyKhRLGxfbwXshgZ3UjCIWVMXCjBmVpA814na8UZ6tdnMl3EJoXRCdriGIZAC0SQYTD0JXC7YhQWGfi2wqtITCqAt2V1xZjXWVkrY5wPSIzdWKnQlg5RWTew8zWUK9jn7BtvPU9FNfZbAwvMGiAI5q/RomkIn2mlpM4MwbGypisVVStjsgVic528Y0Td3Ckt48N6xZwlc5wPItUAqk0Cq5DvuTgVk3QFJiSiNb0+r+xsIODUwP0T3kwn0EWSy2ffN+QFWFQngeXeanCMBC2jejvoTqaprDepLTZxU7UiJp1cm6Izy2+jelKghMnBrCWdZLjklDGRWVzyPKl+0GPx8G2m8fy6FpzpVSutuRYnhtFXRpIXwP5+u1KTeFjCh9pNp835XqrNmG1VID9hUX0XJ6IrhM1LjNFE2BaiGiY0u095EcMFv040KxKqSuPl+ZGWD6VZvRgFf2Vo3gXvISRISrDMWbfZtB33yyFpQTMhbBygsRZSXSq9roe5g1D+ijpo79ylMFDDsqX4LrNr69zc2ipJOPvjFIdaXB36DxhYV38Xl42mPdNOBxj6Kky5rm5tl1mXwuyVoN6ndjzPtHT3SzdNcy/+sgHGIrluDMxxUIjxrOTY5SyYZJ7LDRXURgFr7M5gS36gjNfHWPs6WXE5Dh+LremxUVLJqC7g+nHO9nws6d4NDHNhxO72Vsf4r+df5jJqQ76v2vgZFxuOzGFqtVQ1RrK9/EblyYdYRjITcPUO0NUO3W8kKDjUBn9zCyqWLxiKO9WQUMR02ukjRJuTKB3dyFzeeQqZcq01gNWClWvX3W/UfdSSLMX3wL9Nck2tuEhHUmt0yI21I+fjuLFLArDFtUegTtYYzi2zNRiitC8RmhBEcq46Pkasg084Auoeh3/emK0ho6bUDiJOo64tIMvkZx1HY43+rByYCyVm7v8ax2lkKUy2uwSsa4I5051MhdLcqazg2rdojEexclrRGd9lIDCRg3N8TGFh0RgFhRidgl5DQUbbcdKLwg9lUDEYzSG0+TXOxTHfB7uOEVaL3HWS3O61sNyOQx1Dd8CN6ZjDHUiPHnxfaStgy5QmkCaGvkRk0Zc0EgqfAukESHSN4qddTHydfR8GVUooqq1tTl2r8FXgopn4nvaG15LWGvQoZfwIuB3JtAaDbgpBfgNEOEQ2TGd8ohHr3EplmkLg0d7T7HXrnFMDGLd08vmh87xWwPPEtaawvSjyhh7C8OYJ0MMf+oUuA1Uw20uJ9bwzaRMA7+7wdbuRZJaHWh6wBXp8mfz72bv7BCdp1zk2YmW9DddDVS9jr+0jPVSma1H46DrYOigqlBfbGa+1GqIvm6mH+9gfc8yHVpzuW1WVDPbYQ3GPDXbRjg2lQc2sLjDRLsvx3+847P06iUGDfjz3O382+c/gKjpGBWBbinm3+li2D4dyQqG1hTgkOlyX3qcbqtAv5klotWJiAaOcDGFj46ioGwq0uaLmbewf36A6v5+uvd6hCfKaMfOoDxvTW/aFZXJ+VwashbCu/q9IIRizJ7jfjtLfbRG5p44XWr1qibbUoCFaaF3pvEGOqj2SaI9JWJajQtiowvBiL1IKW4z1x8jHw3zRNcR3h2uU5I1itKj4tucynVh5cBfXHz9D1wDaI6DGBmiMpIkkSoyGM7hXNYazUcxWUpRzoTpK3vtH+e8XqSPrFSa3thr0XT0eBTlmBjxBkPRLDkZZsGPobtrd8NJODYiFqPUa1AZcXm0b5JHHRd9ZcM1oVfBks0+v0pDOpJkuoxjuaSc6sX86bDRoNMskdZLdOiliysnH4GvmhJg4WNpFW6LzgDw3HCEbNHGc2Ik5Ah6roQ/M9+sHGvj8RSGgRYO44UEYa05yehCUJE2+XwYKy/gdQQYwBQetjAwTB/fFihz9QqS21KA9f4epj44RGmd5F8/8TUeDJ1l1LxUx28Lkw/FzvHeyGn+SccPcNEY0iUQ5plqmheKm/nSnnvp2GnQc+wKD+xaZMso478v2NF7mn/d80MGjAI9+qVfX00pzp/tJnnIxFzMt2UV4Gqhx6NU3raJ/IjJR29/mvfGDvL/Zh5m39IAzvLa9dro76EynGD5oQZfePjP6TeqQBhfSXSh8cHoKe58dJxJt4NXyqN4srmLfa7cwb7DI2i1lV1tAbuNjaCBEpeqL1+FrkBX3DU2zpPdB3jP2w4Sf6jGV5fv5rlzY1h7hhj+W4Eqltozf3oFvbeHyrY+cmMa94fPsNEsoBHiZKOXxIsOiXMuqo1ala6qAAvTQpgGwjKbm2phBxV2Lh0ddBWq/TGKIz7hwRIPhs6y1frx/r4JLURCg76Vf59xSxx1Jd/J3cGu+WGcKZPYlIux3J4lydeFEPhhk3v6z/FE+ghjZpakpmGuZD+UZJ1530IvGNhZeakL1q2CrlNL6tQTMGgtEdNcjuR6mZlOM1Zdu2EYLxGi0m2QTC9zjw0Sm7ryqCuPspIUpUZF2hSlw0I9RtU3Kbk20/kEzpyBcR1bAFIHZcDxRA+D4XXcFp5hwMlye2SG4qDDztnNuP0pjEUDsvn29YJFM8YtTdX0fmm2MKhLE6uoMAvulUMpug5as+rSVQZ15eG5OnpdIVx/1eriVlWAtdFh6gMJikMWlT5B9fYqv7J95xv+XEyvcbszSYdefpXnezXqyuXJ3f8E7eUEyTM+nWeLdOenUbkCqlp9My6ldWg6WiRMPWoyHMqywVwgres4wkBDo6IafLE4xs7CKMkTkDqcR2XXbu7vT4KwLCq9GvUuyVSjg2/5EbLfGGDjvgrGycm1OQELwdL2MPlHavzs0AkkkkW/OdHuqW3kmeUt7J0YIrQzgp1TxCYaaA0fve7T6/po+elmutV1fB5C4H0/waHknfxo4F6qPQJ1d4H/fcfX0N6ieDE8RuJQL/1/mUFeqRKxDVClMuGJAs6GNOfdTsLaLJ3668un5jiIcAjpSNJ2hTkvwV7hYZ236TxQQptbWrXWnKsqwF5nlMJ6i9KQoDbg8ujG0/zbrqPX8Q7WVb+T8cvUlKKiBDlpURuPMXTYJXJiEe/s+TXdy/RyhGmgJRM0Yjppo0xMa2CiXxTfjO+zszDK/vkBossSbbm4ZivffiKEAMvEjYGMedSVQcWNE53xsU7P4pfWbi60NAW206DqmxxowKTbz+l6D/sKQ+ybGoRzYToP1TGXa2hnJlGNBrJWQ/GTN6XXFiKEwmHspV7Kg2GmhsI0lM5waJmJdRnmF3sRkTDC99tTgBsNtGIVo6ooyBA11awOvSorlaXYFtg+aatMQxks+VGMskBfKq5qNtHqCbAQnH9fiH/85PfoMoqk9RJjZgbehOOC8rLKr53/GQ7N9ONPhbGWNda/VMc+NLFq+XqtQh/oY+of9FPc6HN/+DSDBphCpyTr/I/SRnYVRtj3+e2kjzcIHZ3CX8ygGrdGCEIYzcnJHewg/cAcD3afwxQ+Z8pdWHkPfym7djNBlKJnZ4FCJsYLHffxVPo+wvOK2ISLXpesr7jopSWYz6AaLn658qYUGMlqDdFw0U7UiU049Fmj/J76edZtXODjY1/lX/ofonTPMKHZMuw73nahCFmtwew8dr6PjBsjZ4aBq+8DCV2HoT4qgzF2jE7ysc7nOeV2Mul2YBUUcmYOuYpNrFZRgDXcHpffTZ+5+F++cqgrF1f5uJfdLKbQiGo/3mbycvKyiqskRamY8cMcnBqAc2FSpyCy4OKcXcS7CbIdXoUQyFiI0jpJeKBEl14lvNKOs6YkJyq9HF7uJXXSxdl7Dv962ljeDOg6Ih6jnrLY0THDfdGz7CqNMl+Nodf8NT8W+lyWhOsTiTu4cYPw+QL+kRMXv78q0rdSIOS7DSgUiE71Ez0fZrk/xL1Wg9HEEmeTnVhFC10T7VBU+mqkj6z56A1FXRm46PhKoQnZ3Hx87faT0JBhCy+iMxxZZqNpM+m51KWJ1mDVi1JuaBbE/obHM3y13acAAA6SSURBVOWtHCoOcGihH6mao3F37xT/z+DThLUrhxymvBK/dvbnODXfhfNSlMicZN10HaOYR8uXoVZH5tdut68rIWwbPZ0ivzHOow8e5qHkSbq05njVlcu4F+Krh+/EPucwMrnU7CG8Vr29nxC9r4eJD/VTWu/x28mDRLQ6X9pzL+FzJuvm59Zm7PcyZGYJUShiWiaWaSJbEE4xFoukTlpMboze8M9+MzGFjxsWeGEDXV/tcy6undUVYFdjwb900xyqj/Di0kZOLnbROBe7eHz4biWoD3iEXxPz9ZWkpOpMeSGOjvdhn3MY+PYC/onTQDPO1W4T8JuFsCxUMkYtpfHe9EHuc2YIazYSiaskORlGn7WJTClEobzmvb2fBBWyKa336BzOscFcIidtnGmTxFmJKK399ENZq0GLqxlFw8XKe4jGG2+GtzOW8PAdgR/SEfplx5hpAmloSENczJu+kayeAEufjV9wec++37n0YRWwSpKeqsTMX3pApt+RpHyPJPWat9jf8Pi5F34DY8pm+EUPZ74As+15jtebjejvYebxLvJbPcasBdJaM+vBVT5zPhytDdC1X5Hcl0G2cV7mjWTJj5I6Lkm9MofM5lptzk1BbbSL6UdsOtZnqCuPimdh1BRafW2tL7ZYs3iP5pleF2PzkRSi3MyO0mJRFu8IUxyBsdD8DbdrddPQnt9H5/Nv9CKd6NhbcC+bfHwlqSuP440+Eq84JE81CL1yBj+bXfPLymtFxkMURyTxviJpzcNeqX6SSPLSZrqeJDpRvbgauCURAkSzfNRVGjk/TGS2jnf2fKstu2mopw0aIzU2pDKUlaTqmZdyY9uhqdU10qNXeWz4JM+KMWTUQZhGMwPCcaj0CvzB6sV2BxINV92YMEVrj6UfG2XpgR4W75VEtEvR8W9W4vzeF3+Z6CT0vpxFWy6u6XSinwQvYmIOlNncuYB1WeHKsvT4q6VH+MH0KAPlxk0bgrkWZMikazjLptQi//fsExzN9NCTr9/SY/JmoadSzZOFN+n84h0/4my5kyf3/yPKB9JsODYH+RJ+2+3AXZ20rvOL6ZfRhGL3pnuIWqNIS6eStkk+MM8vDu/mbnsGV1l8I3sXP5weIbG8+tfX0jPdve44SzsUqZEs9mXHy79SGmXDZxfo/G8vIQ8cw5ucuuVinL6t0ZcqMBJZwrxs67YodfYtDVCciyEa7ZeHeSORlsGm1CLrw0vsnxsgN55EVG+t+2S1ENEIXnecap/Pr6ReBqCyP036mEKen2z2V2nzpla+0vBVU1eiwuY+W/BQ7CSVLo1qb5jygENxUOfnhvbx68lzDBthJJKj2V5KU3Gswuo/Xy31gMsDDo+8/TD3J84SFhYZv8wz1X5+tDhK5LWnANwiGEODlHb0s3inwUf7DrPdmcIWBllZ45nKIE/nbqP4vV4GJnxYvLVjv8rQ2B6bZqM9zze121ttzk2BHo8jYlHmnxhi+aE6w33zfL14By+fGWF4p4czU2k25GlzrLzH/ze+jfneOA/1fx97pWx/wMiSvcujsEFH9dZJxCvcGz6LRDHhVZn04kyc6KFzj4Y9W1z1kGdLBbjSqfGJge/SqUcAnWUJz+W3MpVJstnLtNK0luF3J1jcYVDfUuXxyFH6DQ9TOOR8l2fzW3hxcoTh5/NopyZvubDMa5GmxlZnmvXmMobe/qKwFhCxKLIjzvKdkv/+9r9kd2WUl5ZHMSdswq+cRlUqqDb3fAGMikt+KsEh06PWJ4mv/H+XXmXbpinKrsXPD+xmzJ5ju1lBYjLtRzle7yN6Xqdjfw4WllbfzlX/hCugd6Shp5NaJ+iXLa8X/RA754ZRc86rjmK5FTD6enFHepi/J0LPo9Pc2zFBj+4SXpm5Z7wYT53YgnHeQSsuoJRqbiTIq28WXPBUhNZs7N3urQSvF73u81R+Gzsik2xOL3LQNfHSEfRYDFmu3FTXulpcaJiltm2g2hticYdBdUOdnp4l/ibzAM+c2kR0T4jeU17zIM81sjLVM0VSByIs1js5tCnFbVaWHj1EWtP4cO8uasriHuc8aa2BLSyKssEfnfsgp8/3MHzSQ5tfRlZWv49MSwRYJOIUNyWpd/po4nIBjrM8nSQyr6HcW6uoQPakyewIk7+7zjc2/zVpzbiY+QAw4aZxjoaIj0tEuYqSEmG8zq9PqebRTlI18x51Heqy/SqX/h5odY9di8N4SueuxAQAM6mNROIxVKOBWmOpUq1AWCYiHGZpW5TsFnjssX382cCL/MnSbXzh1D3Edobo+8yhZp+JtdRjJLNM164wiDhH6wPEtRppzSeuOfxS7EIqq8EFCZz3G5zbP0DvHogemcObnbshZrZEgP1EhOKggUhX0REcaVT5zPKDPD21ic6dOtFZF1W9CY7TuQ78iEW1SxCO13CEwBSv9my32TPEH55nKRdl8e5htMbrt/QECM8K7Jyi0iNopBTpw4rUnkXIFm6KJvXC9VnKR5gJJ3hn4ih6QvGpHVtxI0Mkfyhv2EPUtgiBFo02D/MMh8DQUdEw0jGoDEVoRDXKfRpuTOGO1OjuLDBeSvML5x5n1+ENpA7opI/XmpPZGvF8L6BqdfTFHMlTDv/pmSf4i6ECn7zjC4wZJTr1ENrKyruk6nw6v5Xd+fUkTgoSp0qoG1hV2xIBdlMOxRFJb2ceDY1dtUG+9PJ9xE4b9H71CH4uf8ulErkxk1qfx6ZknrBodju7nDssnR/c8bcAr9/d6cL7KZ//dfYRXpge5aMbd/LRxGHu/+HHMCsdRM6ZcDMIcMPDXY4xHU0wtm6BLdY8n7n/fua6E8SPJeEWF2BhmGjpJMq28NIRfMeg3GfRiAmW7/VI9Wb56Ogu3h45gVQaDXR+48AvMrF7HSM765hPvQKwar1wVxNZqyGnprELRTZlh1i+I8kPRrdgRY+Q0PyLm3JF6fMXp95K5UyCjbsLqD1HbmitQVudiCEUsIaSu99M7PkKiaMJTkW6yY15JLXmyR9XQkNDIln261QU7KkPsOxFcZWBJiSPR46zzrB4a/w0tuZyuzPZbHgUqVHpDmMvOa3NP3yTEHUXZ84gY8c5tambIXOJhwfOciTSy8zjA8Q33U/s2DJkcshCAbWWltCsxGctE9Hfg9sTxyjU0JYKzbMNa7VmaGnl9HC3L4layaX3HZ3SgIXvQC0tkCa4cYk0QcU8dNunvyNPwq7xcm6E3fl1HMv0UCiECB0O0XHSx5lZ/QyAG4Gs19EzeRJnTD7zvUf5VPwh9IiHWCk79l2N6EGHznmJninc8FPE20qAb2W0M5P0Fyt44T7mH7KwhHtVAYamhzvuhTjvdvJn5x9lIR9F+hqaLknuqDASXeDD0QU+FG16gb7S6I8XOD2UJrxgEbrqO68dVLlC4oxEcy323L2eWKzKH/U+i9ur+Bfhn+bgfD+Nr3aSOmqhnfev7/TpNkCLRhDRCLk7u1jarhGdDJE6GUIvNdCXSyjLREZtqv0RFu4yUEZTVBopyeP372c0lGGrM02HXmLMrBJeCWu5SvL18nqOVfv50pG70KYcendK+o5kYGkcfzmLvwYyHa4FVa/jTU6hTU2z8eWVsJ54tftxYbPaa8GmbSDAbYJqNBDFMpFZxR9Pvo/tiRk+EN9LQnPpN+yLIYm8rPFUZZDxRiffnt3GYjGCdzzePGyQZlLEvw+9k739p+kwy8T0Gqbw0ITiyLl+ek4qQrNr/JSQC1RrxCbq6K7NF3Y+wN/1buMPt32TUTPDXYkJQrrLD7Ztx43E6XF9tHKl2dt1rWRHaAJMg0ZUo9EhycYEhQ0OWsNBryZQOkhL4cYVoeE8+sopyN2hGqOhDFG9xuHqEBLBs4CrdKaqKfKuw77zQ6isRfS8TiijiExVIN88hr7dCyx+IpRqywbygQC3CRc6X6UPdHDmq2Ps3zCK+bDPVmeGLn3hYqXgjK/zieNPkJuN0/OCRu+ci3PkLH42h2bbYNtkxzfwnXUPUO+U+DEf4fgYpk/XcxbJv9m1pmr4Xw+/UEB78SBxyyT5Sjf1kU7+y795lF8YeIUPxg4Qjh/gT95VZ9fiMPlcF6mlNGRzVz5ZuR0xDJRtUusQpIazPDF4nH/R+dKrwkeSZjmreZlXV5E+J9w4ZxrdfH5mG8vFCPXZMEZZI3Ye7Lxi8+551NRsM6dXKpTv46+ViekmoiUCbC1XiZ+2mO1KIpGM1zuJnjWITbbnMSc3Ei1fJn4+jtbQ+VT47ViRBp9M5zFWvJulcpjq/jSJJYhNVTEzFWSxhKrXkb6P8Dyi03WEb1NfEngRE2maSANiU/Wbb3ylj2yAyOWx5h1O7h3i4/OddKWKmJpkcrIDPWuybqGBqtXWRBXXRep1RKlKbFKSOdzBFzL3sWt4HQAaCnnF442h7hlkKyFqVQsxEcIoCxJLYFQVkTkPq+BCNr/qzcYD3piWCLA8eIKekza+dSe1d/jsyq5j+G/G8RczayvXcBXwxieJzM4T1XX6/tpsxqsu61/aq8qoxgz4PqrRwJfq4pJaeR7K89BfPERc1xFCgKZdPIVaNdw1uaP9hkgfP5eHQomxP5xt5kevbEht9ZdRvkTV6vieu6aW134uD/kC8a8tkviW1dx0uyz3+2olOKZSRMk3PVvPax7OqRRIifIlKIm/liaim5jWhCCkj6xUiE37/Nr5n+HI0SG2Fo6vuV3qVUEpVL3eFMqfcKWsPA887+YU29dD+jfdmYAX74fg2bgpaWkMOPatQ1RfTrG1fg7/ZntwAgICAt6AlgqwrFTWzoZIQEBAwJvMzZCPHxAQELAmCQQ4ICAgoEWI6+ntKYRYBMZXz5y2YJ1SqutaX3yLjAlcx7gEY3JlbpFxCcbkylxxXK5LgAMCAgIC3jyCEERAQEBAiwgEOCAgIKBFBAIcEBAQ0CICAQ4ICAhoEYEABwQEBLSIQIADAgICWkQgwAEBAQEtIhDggICAgBYRCHBAQEBAi/j/AdWTsiy7hBF+AAAAAElFTkSuQmCC\n" }, "metadata": {} } ], "source": [ "plot_example(X_test[error_mask], y_pred[error_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "I2GsBaxH_v7r" }, "source": [ "# Convolutional Network\n", "PyTorch expects a 4 dimensional tensor as input for its 2D convolution layer. The dimensions represent:\n", "* Batch size\n", "* Number of channel\n", "* Height\n", "* Width\n", "\n", "As initial batch size the number of examples needs to be provided. MNIST data has only one channel. As stated above, each MNIST vector represents a 28x28 pixel image. Hence, the resulting shape for PyTorch tensor needs to be (x, 1, 28, 28). " ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "Bwsaz88X_v7r" }, "outputs": [], "source": [ "XCnn = X.reshape(-1, 1, 28, 28)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "Go0Yz8xl_v7s", "outputId": "e66d3f48-5f64-4a03-995e-b8b021efdf4f", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(70000, 1, 28, 28)" ] }, "metadata": {}, "execution_count": 27 } ], "source": [ "XCnn.shape" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "Rt8ETXGa_v7t" }, "outputs": [], "source": [ "XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "-jQm85il_v7t", "outputId": "bdfcf1d4-fd4b-4c5b-a887-edd7af625602", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((52500, 1, 28, 28), (52500,))" ] }, "metadata": {}, "execution_count": 29 } ], "source": [ "XCnn_train.shape, y_train.shape" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "YdQ-ISvb_v7u" }, "outputs": [], "source": [ "class Cnn(nn.Module):\n", " def __init__(self, dropout=0.5):\n", " super(Cnn, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n", " self.conv2_drop = nn.Dropout2d(p=dropout)\n", " self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height\n", " self.fc2 = nn.Linear(100, 10)\n", " self.fc1_drop = nn.Dropout(p=dropout)\n", "\n", " def forward(self, x):\n", " x = torch.relu(F.max_pool2d(self.conv1(x), 2))\n", " x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", " \n", " # flatten over channel, height and width = 1600\n", " x = x.view(-1, x.size(1) * x.size(2) * x.size(3))\n", " \n", " x = torch.relu(self.fc1_drop(self.fc1(x)))\n", " x = torch.softmax(self.fc2(x), dim=-1)\n", " return x" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "WJrWyFfb_v7u" }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "\n", "cnn = NeuralNetClassifier(\n", " Cnn,\n", " max_epochs=10,\n", " lr=0.002,\n", " optimizer=torch.optim.Adam,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "KRtHfgg8_v7u", "outputId": "071223dd-1d8a-4ad5-a748-7bac10d59ab3", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.4319\u001b[0m \u001b[32m0.9721\u001b[0m \u001b[35m0.0891\u001b[0m 5.8088\n", " 2 \u001b[36m0.1628\u001b[0m \u001b[32m0.9794\u001b[0m \u001b[35m0.0641\u001b[0m 2.1617\n", " 3 \u001b[36m0.1349\u001b[0m \u001b[32m0.9815\u001b[0m \u001b[35m0.0568\u001b[0m 1.8369\n", " 4 \u001b[36m0.1153\u001b[0m \u001b[32m0.9844\u001b[0m \u001b[35m0.0507\u001b[0m 1.4844\n", " 5 \u001b[36m0.1006\u001b[0m \u001b[32m0.9863\u001b[0m \u001b[35m0.0441\u001b[0m 1.4542\n", " 6 \u001b[36m0.0962\u001b[0m \u001b[32m0.9881\u001b[0m \u001b[35m0.0397\u001b[0m 1.4394\n", " 7 \u001b[36m0.0861\u001b[0m 0.9872 0.0423 1.4464\n", " 8 \u001b[36m0.0853\u001b[0m 0.9863 0.0410 1.4599\n", " 9 \u001b[36m0.0805\u001b[0m 0.9880 \u001b[35m0.0384\u001b[0m 1.4535\n", " 10 \u001b[36m0.0753\u001b[0m \u001b[32m0.9888\u001b[0m 0.0392 1.4857\n" ] } ], "source": [ "cnn.fit(XCnn_train, y_train);" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "p6V7wyzb_v7v" }, "outputs": [], "source": [ "y_pred_cnn = cnn.predict(XCnn_test)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "ZSyiZ3p6_v7v", "outputId": "124fdbe6-8747-4218-d1c1-a60908632eff", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.9883428571428572" ] }, "metadata": {}, "execution_count": 34 } ], "source": [ "accuracy_score(y_test, y_pred_cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "npZlem33_v7v" }, "source": [ "An accuracy of >98% should suffice for this example!\n", "\n", "Let's see how we fare on the examples that went wrong before:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "2oWL0xGC_v7w", "outputId": "449de735-c1a9-4301-d758-e07ac678d18f", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.7705426356589147" ] }, "metadata": {}, "execution_count": 35 } ], "source": [ "accuracy_score(y_test[error_mask], y_pred_cnn[error_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "8239U9fF_v7w" }, "source": [ "Over 70% of the previously misclassified images are now correctly identified." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "5vU1SXeV_v7x", "outputId": "332726a3-6c15-49c2-860d-b48eee5c85e2", "colab": { "base_uri": "https://localhost:8080/", "height": 108 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAABbCAYAAABNq1+WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO29d5Rd13Wn+Z2bXw6VI1AFFAJBEIwiRYlBgaaCTVmyLMvytD3qnl625F5ue3Wa7tVa3e1Wj0cz3TPtHrXcYWTZkiVZbiVLHkUGURRFgkTOQCFVTq/q5XjvPWf+eIVACiABmYX3CrjfWrWAqnr1at9T9/7OPvvsvY9QShEQEBAQcOPRWm1AQEBAwK1KIMABAQEBLSIQ4ICAgIAWEQhwQEBAQIsIBDggICCgRQQCHBAQENAiAgEOCAgIaBFtKcBCiNKrPnwhxP/TartaiRDCFkJ8VggxIYQoCiEOCCHe3Wq7Wo0QYqMQ4jtCiKwQYl4I8WkhhNFqu1qJEOIvhBBzQoiCEOKUEOJ/abVN7YAQIi2E+IYQorz6HH2k1Ta1pQArpaIXPoBeoAr8jxab1WoMYAp4BEgA/xL4KyHExhba1A58BlgE+oA7aY7Px1tqUev5I2CjUioOPAF8UghxT4ttagf+M9AAeoDfAP5ECLGjlQa1pQC/il+h+YA912pDWolSqqyU+tdKqfNKKamU+hvgHHCrP1gjwF8ppWpKqXnge0BLH6pWo5Q6qpSqX/h09WNTC01qOUKICE0t+YRSqqSU+gnwLeDvtNKu9SDAvwV8XgU1069ACNEDbAGOttqWFvMfgQ8LIcJCiAHg3TRF+JZGCPEZIUQFOAHMAd9psUmtZgvgKaVOXfa1g7R4sm5rARZCbKC5pPzzVtvSTgghTOCLwJ8rpU602p4W82OaD1EBmAb2AN9sqUVtgFLq40AMeAj4OlB/7Z+46YnSvEcuJ09zjFpGWwswzeXBT5RS51ptSLsghNCAL9CMZf2DFpvTUlbH4ns0BSYCdAIp4FOttKtdUEr5q0vtQeBjrbanxZSA+Ku+FgeKLbDlIu0uwL9J4P1eRAghgM/S3ET4FaWU22KTWk0aGAY+rZSqK6WWgc8B72mtWW2HwS0eAwZOAYYQYuyyr+2ixSG8thVgIcSDwABB9sPl/AmwHfglpVS11ca0GqVUhuZG5MeEEIYQIklzz+BQay1rHUKIbiHEh4UQUSGELoR4HPh14KlW29ZKlFJlmiulPxRCRIQQbwHeR3M12TLaVoBpPkhfV0q1dInQLqzGw3+bZqrV/GU50r/RYtNazQeAdwFLwGnABf6gpRa1FkUz3DANZIF/D/y+UupbLbWqPfg4EKKZVfVl4GNKqZZ6wCJILggICAhoDe3sAQcEBATc1AQCHBAQENAiAgEOCAgIaBGBAAcEBAS0iECAAwICAlrEdbXts4StHCJrZUtbUKNMQ9XFtb7+VhgTgCLZjFKq61peG4zJlbkVxiV4fq7M1e6V6xJghwj3i3e8cVa1IbvV9eWr3wpjAvCk+urEtb42GJMrcyuMS/D8XJmr3StBCCIgICCgRQQCHBAQENAibumjWwICAm4xhECzbTBNhG0BoCpVlOuhPBducGVwIMABAQG3DHpHmsx7t1DuF1gPrGCbHt7XtpA6VcM8PomfWb6h9gQhiICAgFsDIRCRMIVRQXV7jU/v/BKfu+3zFDdCtcdCOM4NNynwgAMCAm56hGmh93RRH+kifNcy7x0YZ8ioUJR6S+0KBLidEAJEc1EitNX/a9eYUikVKIla/Re44fGsG44QCH31AVodN+X7zeu/2a/9WhGr94+4tNgVl99Tr77HLtxHvn9TjaGwTGQyRq3T5IG+CR5PHCYmNHIIhBQISUuuNxDgFiNMC+HYiKE+ypuSNGIa1Q6Negd4Wyrohv+K1yt1ZUGWvoZxMoyzDKkTdez5MmJuEX955UZcxg1Fcxy03m687gSzb43RSCoaG+soV6PvBwaRmRrm0Qn8bLbVprYEYdvo/b0ox0JGHXxbp9Zl49uCSpeG74AXBmkpGr0uVrRx8WfdpRDWikbfTz2s773cwqt4YxC2jd7ZQWO0hzMftogOFPjl1F42mjmerfXwQmkzyXFJ7EgGmcvfcPtuLgG+MNuvo5lbWCZaJEytN8bKdoN6WiGHqtw2OMcXNn2DhBa6+Fr/gmd7BQqyxm/0fZATU70I3yahC8L5EtyEAiwsC78zTmljBPfNRW7vneNTw99kxo/y92c/hjQcOs6F4FYT4NX7X7NtvK44fsSk1mHihjRKgwIvqvA2VIlEa2yIF0jbFX6n7xneYl+6rz5bGOQHmds4vbSFnpvgbGlhGMhUnNKQzRMP7OHt8WPc7xRwFRyrDnAwO0B0poE8N9XMgrjB3BwCrOmU338vK9t0nGWFnVckj+WRB4+32rKrom/dTGU0RWGDQXEUZF+NXRvG6bArjIQybHbmcYRBRTZY8BtUlM68H8VXr9w31YUkIhqAzS/3HmAuneR7XduZXI7T8eQQnS9HYCFzw3d31wI9HkcN91MeiTP1LoGI1+kI18hUo/zB+Q+yXA1jFQXNQyFuXjTHQSTiCNNsernJCKWNEeoxjdIGgRtRiIEqll0nHs5jCkVcCZQS+FJDKjgz38UpT+N0rpPuSIkH02e5PTRFr5Hjf+p9kX+e2tLqy3xD0DrTLNyfojAGD8bG2WEt4gibWc/lc0fejJgIMag30LaMUNkYxw1rJF+ewzs/eUPsuykEWOg6i/doPPC2I+ye3Eh+NoxZjRE62GrLrk59IEHmDpPqzip/746fcnf4PO8IVV71Ko2iqjHlR8n5YU7W+6hL81WvUHSaReJalfdETtETD/H+xD6mvCT/cO6jhBfiRKp1uAkEWETCVDbEWL7N4GOPfB9HeHxr/g4ypQgTZ7vRSxqxMoibW38RoRAk48iQhRe1KA84ZO4UuD0uH9i1j9HQEu+KHCehCWKaRU15fL/Sy3Sjgx8tb2G+FMPP2Jh5jdyszYrRQfb2EKUem3clDvFm2+cfx6++2lpPyFiE/BbQR0rcY8+w0QgDsCJ1zONhEqclShfUBmIs3WHSSCqiU0kIBPg10HSM4QFkMsrM25OUNvrce/cp3tdxgFPZbpYIt9rC12Vlm03XO2e4t2OSt0ZP0quXgUtpMMddly9mH+BkoYeDZ4YQZZ3QnI54ZUgYpUN1wEckGvwfb/oaT0Sy9Oo+jlhmy70THE/30/XjPjo1Adn8uowJ650dyJF+MtuirLy7SiqR41S5l8lSiskXBrFygr5piVmW2MtV9IqLLNwERwmubsoa/b34XUmKm6PkNut4UYWblChTIhwf067QlSgB8NPFEZ5xx/gP2cdQUiB0hfI0zHkToyKwc6DXFf1Zieb5LG83qPUoukJltoTm+U8z7+TjSz2kj1xzP522RkYtjNESd/VPExbNUN0Pq338KL8dZ0nhZH00T+E7GrUdVXYMzTF9foRudwfaxPyarxzXpQALTdAYTFMedLjnQ4f53PBzQDNG+oVIkSU6W2zh61Mclfxw25dwhI4tTC4XX4ATjR6+dvxOxHSIDc962EsVxMFTqHr9Fa/THIf6W3dQHHZ4accoT0T2ktIcUhr87yNfZ2ooye8V/2eiM+nmb1iHAkw6ycqOKJl7Jc+/9dNM+TafOPt+xqd62PATD2e2CKfOI2s1AG4O3w0QGsI08PvS5LZEWHxQ8XcfepphK8OdzjQR4ZHWNJak4nC9j58WN/P1g3djLphseNbDKDVjmlrDRZ+YRJXKyGoNpI+wbbSQQ3FgB9VBRU+owDZrjn9z6gn6fmiQOJ67KcbRi5g8OHSSR5InCGs6K77PN5fu5tBCP8klibVSw4+Y+LbGO7ec4BO9P+Th2/8RVjFGurT2K8eWCrAej0NvF253jOy2EJoHoYyPlWtg7DkBpol7zxheqJlqJG1BdotBI6FodHmYsSq/lzrSyku4bvSONCIew49KdATaai3Mgl/lrBfl27m7+PrROxGLNskTAjsnCZ8vIEoVfN9/nXd/JV2ah25kkXGPWoeJvWivxSWtGVo4jJaIk9/ZyfI7a/R35vl6aTvPZceYemaY9LwiNLmMyBbwPe/qbyQExsZh/GS0+alSiLkM/sLiDbqS60MLhxG2jbtjA6Vhh+w2gdxSZltPhqhe41Stj2dy21moxpjOJyiXHLRZB6Mk6JpU2AVJaDKPqK1mN/g+slRGeR56Io6IhFl5eIhSv0b1ripjvUu8MLuRp05vJbnXIn66gJbJr2sB1uNxGOpjeaPFztg0W6wFTHSWpMELJzfhTFhEpooYSwWEjOHbOku1KLO+DVGXcq9DMrr2z0tLBVikEpS2drCy3WDkvWcp1B0mT/YQmQozdCKCcBxm3+pQTzdvBRnx+b/f/nl+KVy4+B66WGfFfN0dVIcSGPEGmri0zJv2Qnwvfwdffek+bvujOVShhCwWUVLhy+sT3gt06iHSuiSWLlPpThGddlhPC0sRi+IPdrF8u87X3vJfOOV28ycTjzJxopftX5xBLmbwK5XXzXoRuk51rIvi4KX4eedeBW0qwCISgWSMhfvD1O4t86Ht+/hE1z5Oux57a0P8ZGUze4+OEpox6N7n0r9YgwN7X9HL4Ep3jDAMRCqB25tg5X0VfnXrft4UOUOHXuK3vva7DP/QJTQ+g3dugteYztYFIpUguzNFYRM8FD7FVlNiC4spN0nqZZPUyTr6qUn8UhlNH8JwDBYqMabcDiLxGpUBCy9us9ZlGi0V4MbGTibfDbH+LI91Hifvh3jRqjM5lOScsxWlQ+TuDEPh5uZUzKoxambQxc+WDJ5yy0x5cU4udhOa17DyjZ95TVvgeugVD+m/UgpPNXp4cnYrzryBKldQ1eo1JcMLyyK7xaY4Khmxl17xvaMNj/NeB8XZGIPTPka2csUHs11RPWkyd0apDTVIaC5Hq4PM/2SAzgmFKpZQjdduniJsm+ov7KI4oJPfqqCrhqzr4Go4uTjR8TCq0UC9lvfcChJRvK4YlT7Jm4Ynmaym+PjU29i/OEBuMomd0emYVoRWfELTJbR8Ge8KjWSEbaMNDyDjIQpjMRpRQbVL4IcUXqPGt87fzpdL9yIrBj0HFc5MEVUsteii31hkKsbyTgGjZZJaA1Nc2ZtVvg/5EqZtsZSNcag6RG+8yNwGQSMRIXTFn3rjaKkAL9/m8Nwv/p/06CG0Vd9Mdhxt/ntf0+vVLmtXoSGuKL6+kjxX2cSzuS14Z6J0jvtYi+W2FBtRrWMU68hGBKkUF1zSPaURcoc6SZ1RyHwR5V7jBBJyyN3l8pYd49wdOs+F9h4SyQ/Lt/FsZguJEwbx/dOoFiSa/22obIiTfbjGfSOTdOkGL2ZGGPkvp/GXlq9pVaBFI8x82OUf3PkD3h45wagBRxomZ9xu/t30rxHfnUQWiqhiG23YCYHfGaM0FCI2luNfD/4Nv3XsNyk+20PnIZfuH+5/RcWjVOqqoQItHKawq4vioE7nL07zUMckO0LTSDT+zY9+GXUoxNiTeTh8+G+10mpHqv0Rdjx0mjenz9KjGxhX82WVwl9aQncb+PPbebF7hLvTU3T2FPlKxy/c3AKshMARAlNcGpwr/e/1kCi+MPUA0wf7SB+H8HQFkW/PmVxVqmh5AyOT5K/LG7nNnuGuZlc8hGzGJ3mNgouLaDpGXw9+fwfRzjI7YzMktQbgcNqtM+vH+NKZ+yidTNE34aFK5abHuN4QoK3mlZm6j4xH0as1/GLxit6vsG20cJjGnSPk+y22DUyw05miS/Mwhc13i3fw1NxWwgsKVauB235jolVdrKLP/ESCP+x6L4vHu+gb9wnNlpFXm5iFaMbMe7qQ8TClkRi1pMbKToVMNujTfM6VO3hmdoxC2SFxzCA25aOvFPDabQXwBqA0QcqqktCrQFMjSrLOsh9Fr4Ne95tl1xeQChTIVY9Iv0G55OsyC+LVuMon81Q/Y3+8D+U369i9axGxFuBns5DLkTg9wB+ffDtPbDzMXZ2Hrvt9tJBDdXsfxWGLRwb38f74AXp0DVf5fL98Gy/mRhHfTbHlu9OobB6/UHj9N21TpBL4KOJmjYWhfmxNQ1QqPxs60HS0ZALVk+bMr+vs2DLBPx3+LvfbLmBTUx5/cfhNJH/s0LWv0J4peUqhLRcI+Yre55McnLyNkb01jOcOIa+2CSsEwjDROtNk39RLcUij/12T3JOc55eT+2gonf8w8TgHpgdIfzvM8NkqxulxZDaLd50bu+sFaQr6nDy9Rg4NjbpyOesZnKl1Y5UUeqneFqGnlgpwOCP5JzPvYsDJMWIvoYumaDrCZdRaRCqNM243pvB4LDRHSm/m9/pKMulVWJEWz5a3cbLSQ2RWXUxDuhyjtwd/4FJamr5Sala5tLJcWSmiMx5zh1N8zb+TYWuZl5eGMUoCo/Y6E4cQCMtCSyVZutOmvMHn9sg0MaEw0ZFIpmppJgoprKJC5QvIV6WurRfslQbWeJhDVj8nBww67RL7H7JxlrtJbkwifIUX0RES9KqPkOBq0EgadPZnuS81gat0jjfqfDl3P4dz/ThHQ8QnXfRsuW03mlS1iqZrRGZDaJ7VDKddYbLRE3Ho68KP2lS7Q1Q7dLI7wEu52IbHSiPMF5YeZLkeYfzIIM6iRnSmhrFUQJXLbSFAa4ku5EVNqSmfM24P58odWAUfrXD9WUVrQUsFOLFvgRN/vIP9XRqFMR+lN0VRRDzetPk8DV9n/+kNaJbPwIOf5S2rUQkPn78u3c6e/Eb2PLmd1AlFx57FK8Z8q3cMMf128+KCInU8Tnp67tpjrGtE6EdH2fRSiNw7xvjk257AXtSJzSvsFa8Z37sKQtfRU0kaI908+KH9/F730wwaEBbNaFVF+RzKDTA/0cHIkoe/zuK+l2McPcfIYgdzv9DLVzbdz67IFP/w7z7DgXo/nxp/HF2T3Nc1iad0Xl4YplS1qS+EUYbik2NP8Xh4kr8pj/Cd8i6e/tMH6PvRMhsWx5G5fFt7fv5KFlay6DNzRHUd372Cp2+ZyM2DzDwSo9on2XzXFDtjGT7UsZslL85/n36IPTPD6C/FCS0otv94FrmYQdXr+FLBTRTvvRZWJHx/ZScHZwbYdC7beidsldaGIIplolM1zIqFEsbF9vBeyGB3dTMIhZUxcKMGZWkDzXidrxRnq12cyXcQWhBEp2qIYhkALRJBhMPQlcLtiFDYYOLbCq0hMKoC3ZXXFmNdY2StjnA9IrN1YuMhrJwisuBhZmuo17BP2Dbexh6KG2w2hxcZNMARzT+jRFKRPtMrSZxZA2N1TNYrqlZH5IpE57r41sk7ONrbx6YNi7hKZzieRSqBVBoF1yFfcnCrJmgKTElEa3r931rcxaHpAfqnPVjIIIullk++r8uqMCjPg8u8VGEYCNtG9PdQHU1T2GhS2upiJ2pEzTo5N8QXlt7CTCXByZMDWCs6yQlJKOOisjlk+dL9oMfjYNvNY3l0rblSKldbcizPjaIuDaSvgXztdqWm8DGFjzSbz5tyvTWbsFoqwP7iEnouT0TXiRqXmaIJMC1ENEzp9h7yIwZLfhxoVqXUlccL8yOsjKcZPVRFf+kY3gUvYWSIynCMubcY9N03R2E5AfMhrJwgcVYSna69pod5w5A+SvroLx1j8LCD8iW4bvPf17g5tFSSiXdEqY40uDt0nrCwLn4vLxss+CYciTH0ZBnz3HzbLrOvBVmrQb1O7Fmf6Olulu8a5p9/+P0MxXLcmZhmsRHjmakxStkwyb0WmqsojILX2ZzAlnzBmW+MMfbUCmJqAj+XW9fioiUT0N3BzGOdbPqVcR5NzPChxB721Yf4r+cfZmq6g/7vGzgZl9tOTqNqNVS1hvJ9/MalSUcYBnLLMPXOENVOHS8k6DhcRj8zhyoWrxjKu1XQUMT0GmmjhBsT6N1dyFweuUaZMq31gJVC1etX3W/UvRTS7MW3QH9Vso1teEhHUuu0iA3146ejeDGLwrBFtUfgDtYYjq0wvZQitKARWlSEMi56voZsAw/4Aqpex7+eGK2h4yYUTqKOIy7t4EskZ12HE40+rBwYy+XmLv96RylkqYw2t0ysK8K58U7mY0nOdHZQrVs0JqI4eY3onI8SUNisoTk+pvCQCMyCQswtI6+hYKPtWO0FoacSiHiMxnCa/EaH4pjPwx3jpPUSZ700p2s9rJTDUNfwLXBjOsZQJ8KTF99H2jroAqUJpKmRHzFpxAWNpMK3QBoRIn2j2FkXI19Hz5dRhSKqWlufY/cqfCWoeCa+p73utYS1Bh16CS8CfmcCrdGAm1KAXwcRDpEd0ymPePQal2KZtjB4tHecfXaN42IQ655etj50jj8YeIaw1hSmn1bG2FcYxjwVYviz4+A2UA23uZxYxzeTMg387gbbu5dIanWg6QFXpMtnFt7FvrkhOsdd5NnJlvQ3XQtUvY6/vIL1Qpntx+Kg62DooKpQX2pmvtRqiL5uZh7rYGPPCh1ac7ltVlQz22Edxjw120Y4NpUHNrG0y0S7L8d/vOPz9OolBg34b7nb+VfPvh9R0zEqAt1SLLzDxbB9OpIVDK0pwCHT5b70BN1WgX4zS0SrExENHOFiCh8dRUHZVKTNVzJv4sDCANUD/XTv8whPltGOn0F53rretCsqk/O5NGQthHf1e0EIxZg9z/12lvpojcw9cbrU2lVNtqUAC9NC70zjDXRQ7ZNEe0rEtBoXxEYXghF7iVLcZr4/Rj4a5vGuo7wrXKckaxSlR8W3Gc91YeXAX1p67V+4DtAcBzEyRGUkSSJVZDCcw7msNZqPYqqUopwJ01f22j/Oeb1IH1mpNL2xV6Pp6PEoyjEx4g2GollyMsyiH0N31++Gk3BsRCxGqdegMuLyaN8Ujzou+uqGa0KvgiWbfX6VhnQkyXQZx3JJOdWL+dNho0GnWSKtl+jQSxdXTj4CXzUlwMLH0ircFp0F4EfDEbJFG8+JkZAj6LkS/uxCs3KsjcdTGAZaOIwXEoS15iSjC0FF2uTzYay8gNcQYABTeNjCwDB9fFugzLUrSG5LAdb7e5j+wBClDZJ/8fg3eTB0llHzUh2/LUw+GDvHeyKn+e2OH+OiMaRLIMzT1TTPFbfy1b330rHboOf4FR7Y9ci2USY+IdjVe5p/0fMTBowCPfqlP19NKc6f7SZ52MRcyrdlFeBaocejVN6yhfyIyUdvf4r3xA7x/2YeZv/yAM7K+vXa6O+hMpxg5aEGX374v9FvVIEwvpLoQuMD0XHufHSCKbeDl8qjeLK5i32u3MH+IyNotdVdbQF7jM2ggRKXqi9fga5AV9w1NsET3Qd591sOEX+oxjdW7uZH58aw9g4x/FcCVSy1Z/70KnpvD5UdfeTGNO4Pn2GzWUAjxKlGL4nnHRLnXFQbtSpdUwEWpoUwDYRlNjfVwg4q7Fw6OugqVPtjFEd8woMlHgydZbv1s/19E1qIhAZ9q5+fcUsccyXfy93BywvDONMmsWkXY6U9S5KvCyHwwyb39J/j8fRRxswsSU3DXM1+KMk6C76FXjCws/JSF6xbBV2nltSpJ2DQWiamuRzN9TI7k2asun7DMF4iRKXbIJle4R4bJDZ15VFXHmUlKUqNirQpSofFeoyqb1JybWbyCZx5A+M6tgCkDsqAE4keBsMbuC08y4CT5fbILMVBh91zW3H7UxhLBmTz7esFi2aMW5qq6f3SbGFQlyZWUWEW3CuHUnQdtGbVpasM6srDc3X0ukK4/prVxa2pAGujw9QHEhSHLCp9gurtVX5z5+7X/bmYXuN2Z4oOvfwKz/dq1JXLE3t+G+3FBMkzPp1ni3TnZ1C5AqpafSMupXVoOlokTD1qMhzKsslcJK3rOMJAQ6OiGnylOMbuwijJk5A6kkdl12/u78+DsCwqvRr1Lsl0o4Pv+BGy3xpg8/4Kxqmp9TkBC8HyzjD5R2r8ytBJJJIlvznR7q1t5umVbeybHCK0O4KdU8QmG2gNH73u0+v6aPmZZrrVdfw+hMD7YYLDyTv56cC9VHsE6u4C/9uub6K9SfF8eIzE4V76/zyDvFIlYhugSmXCkwWcTWnOu52EtTk69deWT81xEOEQ0pGk7QrzXoJ9wsM6b9N5sIQ2v7xmrTnXVIC9ziiFjRalIUFtwOXRzaf5V13HruMdrKt+J+OXqSlFRQly0qI2EWPoiEvk5BLe2fPrupfp5QjTQEsmaMR00kaZmNbARL8ovhnfZ3dhlAMLA0RXJNpKcd1Wvv1cCAGWiRsDGfOoK4OKGyc662OdnsMvrd9caGkKbKdB1Tc52IApt5/T9R72F4bYPz0I58J0Hq5jrtTQzkyhGg1krYbi529Kry1GCIXD2Mu9lAfDTA+FaSid4dAKkxsyLCz1IiJhhO+3pwA3GmjFKkZVUZAhaqpZHXpVVitLsS2wfdJWmYYyWPajGGWBvlxc02yitRNgITj/3hB//4kf0GUUSeslxswMvAHHBeVlld85/z4Oz/bjT4exVjQ2vlDHPjy5Zvl6rUIf6GP6l/opbva5P3yaQQNMoVOSdf5HaTMvF0bY/6WdpE80CB2bxl/KoBq3RghCGM3JyR3sIP3APA92n8MUPmfKXVh5D385u34zQZSiZ3eBQibGcx338WT6PsILitiki16XbKy46KVlWMigGi5+ufKGFBjJag3RcNFO1olNOvRZo/wz9Wts2LzIJ8e+wT/1P0jpnmFCc2XYf6LtQhGyWoO5Bex8Hxk3Rs4MA1ffBxK6DkN9VAZj7Bqd4mOdzzLudjLldmAVFHJ2HrmGTazWUIA13B6Xf5I+c/FLvnKoKxdX+biX3Sym0IhqP9tm8nLysoqrJEWpmPXDHJoegHNhUuMQWXRxzi7h3QTZDq9ACGQsRGmDJDxQokuvEl5tx1lTkpOVXo6s9JI65eLsO4d/PW0sbwZ0HRGPUU9Z7OqY5b7oWV4ujbJQjaHX/HU/Fvp8loTrE4k7uHGD8PkC/tGTF7+/JtK3WiDkuw0oFIhO9xM9H2alP8S9VoPRxDJnk51YRQtdE+1QVPpKpI+s+egNRV0ZuOj4SqEJ2dx8fPX2k9CQYQsvojMcWWGzaTPludSlidZgzRInJY0AAA6cSURBVItSbmgWxIGGx9Pl7RwuDnB4sR+pmqNxd+80/3nwKcLalUMO016J3zn7q4wvdOG8ECUyL9kwU8co5tHyZajVkfn12+3rSgjbRk+nyG+O8+iDR3goeYourTledeUy4YX4xpE7sc85jEwtN3sIr1dv7+dE7+th8oP9lDZ6/KPkISJana/uvZfwOZMNC/PrM/Z7GTKzjCgUMS0TyzSRLQinGEtFUqcspjZHb/jvfiMxhY8bFnhhA11f63Murp21FWBXY9G/dNMcro/w/PJmTi110TgXu3h8+B4lqA94hF8V8/WVpKTqTHshjk30YZ9zGPjuIv7J00AzztVuE/AbhbAsVDJGLaXxnvQh7nNmCWs2EomrJDkZRp+ziUwrRKG87r29nwcVsilt9OgczrHJXCYnbZwZk8RZiSit//RDWatBi6sZRcPFynuIxutvhrczlvDwHYEf0hH6ZceYaQJpaEhDXMybvpGsnQBLn81fdnn3/n986ZdVwCpJeqoSM3/pAZl5W5LyPZLUq97iQMPjV5/7XYxpm+HnPZyFAsy15zlebzSiv4fZx7rIb/cYsxZJa82sB1f5zPtwrDZA1wFFcn8G2cZ5mTeSZT9K6oQk9dI8MptrtTk3BbXRLmYesenYmKGuPCqehVFTaPX1tb7YZs3hPZpnZkOMrUdTiHIzO0qLRVm6I0xxBMZCCzfcrrVNQ3t2P53Pvt6LdKJjb8K9bPLxlaSuPE40+ki85JAcbxB66Qx+Nrvul5XXioyHKI5I4n1F0pqHvVr9JJHkpc1MPUl0snpxNXBLIgSIZvmoqzRyfpjIXB3v7PlWW3bTUE8bNEZqbEplKCtJ1TMv5ca2Q1Ora6RHr/L24VM8I8aQUQdhGs0MCMeh0ivwB6sX2x1INFx1Y8IUrT2WfmyU5Qd6WLpXEtEuRce/XYnzz77yd4hOQe+LWbSV4rpOJ/p58CIm5kCZrZ2LWJcVrqxIj79YfoQfz4wyUG7ctCGYa0GGTLqGs2xJLfHv5x7nWKaHnnz9lh6TNwo9lWqeLLxF5yN3/JSz5U6eOPD3KB9Ms+n4PORL+G23A3d10rrOR9IvognFni33ELVGkZZOJW2TfGCBjwzv4W57FldZfCt7Fz+ZGSGxsvbX19Iz3b3uOMu7FKmRLPZlx8u/VBpl0+cX6fyvLyAPHsebmr7lYpy+rdGXKjASWca8bOu2KHX2Lw9QnI8hGu2Xh3kjkZbBltQSG8PLHJgfIDeRRFRvrftkrRDRCF53nGqfz2+mXgSgciBN+rhCnp9q9ldp86ZWvtLwVVNXosLmPlvwUOwUlS6Nam+Y8oBDcVDnV4f28/HkOYaNMBLJsWwvpek4VmHtn6+WesDlAYdH3nqE+xNnCQuLjF/m6Wo/P10aJfLqUwBuEYyhQUq7+lm60+CjfUfY6UxjC4OsrPF0ZZCncrdR/EEvA5M+LN3asV9laOyMzbDZXuDb2u2tNuemQI/HEbEoC48PsfJQneG+Bf66eAcvnhlheLeHM1tpNuRpc6y8x/83sYOF3jgP9f8Qe7Vsf8DIkr3Lo7BJR/XWScQr3Bs+i0Qx6VWZ8uJMnuyhc6+GPVdc85BnSwW40qnxqYHv06lHAJ0VCT/Kb2c6k2Srl2mlaS3D706wtMugvq3KY5Fj9BsepnDI+S7P5Lfx/NQIw8/m0canbrmwzKuRpsZ2Z4aN5gqG3v6isB4QsSiyI87KnZL//tY/Z09llBdWRjEnbcIvnUZVKqg293wBjIpLfjrBYdOj1ieJr369S6+yY8s0Zdfi1wb2MGbPs9OsIDGZ8aOcqPcRPa/TcSAHi8trb+ea/4YroHekoaeTWifoly2vl/wQu+eHUfPOK45iuRUw+npxR3pYuCdCz6Mz3NsxSY/uEl6duWe9GE+e3IZx3kErLqKUam4kyKtvFlzwVITWbOzd7q0Erxe97vNkfge7IlNsTS9xyDXx0hH0WAxZrtxU17pWXGiYpXZsotobYmmXQXVTnZ6eZf4y8wBPj28hujdE77jXPMhznaxM9UyR1MEIS/VODm9JcZuVpUcPkdY0PtT7MjVlcY9znrTWwBYWRdng3577AKfP9zB8ykNbWEFW1r6PTEsEWCTiFLckqXf6aOJyAY6zMpMksqCh3FurqED2pMnsCpO/u863tn6RtGZczHwAmHTTOMdCxCckolxFSYkwXuPPp1TzaCepmnmPug512X6VS38LtLrHy0vDeErnrsQkALOpzUTiMVSjgVpnqVKtQFgmIhxmeUeU7DZ4+9v385mB5/mj5dv48vg9xHaH6Puzw80+E+upx0hmha6XwyDiHKsPENdqpDWfuObwG7ELqawGFyRwwW9w7sAAvXshenQeb27+hpjZEgH2ExGKgwYiXUVHcLRR5c9WHuSp6S107taJzrmo6k1wnM514Ecsql2CcLyGIwSmeKVnu8OeJf7wAsu5KEt3D6M1XrulJ0B4TmDnFJUeQSOlSB9RpPYuQbZwUzSpF67Pcj7CbDjBOxLH0BOKz+7ajhsZIvkTecMeorZFCLRotHmYZzgEho6KhpGOQWUoQiOqUe7TcGMKd6RGd2eBiVKaXz/3GC8f2UTqoE76RK05ma0Tz/cCqlZHX8qRHHf4T08/zp8OFfj0HV9mzCjRqYfQVlfeJVXnc/nt7MlvJHFKkBgvoW5gVW1LBNhNORRHJL2deTQ0Xq4N8tUX7yN22qD3G0fxc/lbLpXIjZnU+jy2JPOERbPb2eXcYen8+I6/Anjt7k4X3k/5/K9zj/DczCgf3bybjyaOcP9PPoZZ6SByzoSbQYAbHu5KjJlogrENi2yzFviz++9nvjtB/HgSbnEBFoaJlk6ibAsvHcF3DMp9Fo2YYOVej1Rvlo+OvsxbIyeRSqOBzu8e/AiTezYwsruO+eRLAGvWC3ctkbUacnoGu1BkS3aIlTuS/Hh0G1b0KAnNv7gpV5Q+fzr+ZipnEmzeU0DtPXpDaw3a6kQMoYB1lNz9RmIvVEgcSzAe6SY35pHUmid/XAkNDYlkxa9TUbC3PsCKF8VVBpqQPBY5wQbD4s3x09iay+3OVLPhUaRGpTuMvey0Nv/wDULUXZx5g4wdZ3xLN0PmMg8PnOVopJfZxwaIb7mf2PEVyOSQhQJqPS2hWY3PWiaivwe3J45RqKEtF5pnG9ZqzdDS6unhbl8StZpL7zs6pQEL34FaWiBNcOMSaYKKeei2T39HnoRd48XcCHvyGzie6aFQCBE6EqLjlI8zu/YZADcCWa+jZ/Ikzpj82Q8e5bPxh9AjHmK17Nh3NaKHHDoXJHqmcMNPEW8rAb6V0c5M0V+s4IX7WHjIwhLuVQUYmh7uhBfivNvJZ84/ymI+ivQ1NF2S3FVhJLrIh6KLfDDa9AJ9pdEfL3B6KE140SJ01XdeP6hyhcQZieZa7L17I7FYlX/b+wxur+L3w7/IoYV+Gt/oJHXMQjvvX9/p022AFo0gohFyd3axvFMjOhUidSqEXmqgr5RQlomM2lT7IyzeZaCMpqg0UpLH7j/AaCjDdmeGDr3EmFklvBrWcpXkr8sbOV7t56tH70KbdujdLek7moHlCfyVLP46yHS4FlS9jjc1jTY9w+YXV8N64pXux4XNaq8Fm7aBALcJqtFAFMtE5hT/buq97EzM8v74PhKaS79hXwxJ5GWNJyuDTDQ6+e7cDpaKEbwT8eZhgzSTIv6v0DvY13+aDrNMTK9hCg9NKI6e66fnlCI0t85PCblAtUZsso7u2nx59wP8Te8O/nDHtxk1M9yVmCSku/x4x07cSJwe10crV5q9XddLdoQmwDRoRDUaHZJsTFDY5KA1HPRqAqWDtBRuXBEazqOvnoLcHaoxGsoQ1WscqQ4hETwDuEpnupoi7zrsPz+EylpEz+uEMorIdAXyzWPo273A4udCqbZsIB8IcJtwofNV+mAHZ74xxoFNo5gP+2x3ZunSFy9WCs76Op868Ti5uTg9z2n0zrs4R8/iZ3Notg22TXZiE9/b8AD1Tokf8xGOj2H6dP3IIvmXL6+rGv7Xwi8U0J4/RNwySb7UTX2kkz/5l4/y6wMv8YHYQcLxg/zRO+u8vDRMPtdFajkN2dyVT1ZuRwwDZZvUOgSp4SyPD57g9ztfeEX4SNIsZzUv8+oq0uekG+dMo5svze5gpRihPhfGKGvEzoOdV2zds4Canmvm9EqF8n389TIx3US0RICtlSrx0xZzXUkkkol6J9GzBrGp9jzm5Eai5cvEz8fRGjqfDb8VK9Lg0+k8xqp3s1wOUz2QJrEMsekqZqaCLJZQ9TrS9xGeR3SmjvBt6ssCL2IiTRNpQGy6fvONr/SRDRC5PNaCw6l9Q3xyoZOuVBFTk0xNdaBnTTYsNlC12rqo4rpIvY4oVYlNSTJHOvhy5j5eHt4AgIZCXvF4Y6h7BtlKiFrVQkyGMMqCxDIYVUVk3sMquJDNr3mz8YDXpyUCLA+dpOeUjW/dSe1tPi9nNzD8lxP4S5n1lWu4BngTU0TmFojqOn1fNJvxqsv6l/aqMqoxC76PajTwpbq4pFaeh/I89OcPE9d1hBCgaRdPoVYNd13uaL8u0sfP5aFQYuwP55r50asbUtv9FZQvUbU6vueuq+W1n8tDvkD8m0skvmM1N90uy/2+WgmOqRRR8k3P1vOah3MqBVKifAlK4q+niegmpjUhCOkjKxViMz6/c/59HD02xPbCiXW3S70mKIWq15tC+XOulJXngefdnGL7Wkj/pjsT8OL9EDwbNyUtjQHHvnOY6ospttfP4d9sD05AQEDA69BSAZaVyvrZEAkICAh4g7kZ8vEDAgIC1iWBAAcEBAS0CHE9vT2FEEvAxNqZ0xZsUEp1XeuLb5ExgesYl2BMrswtMi7BmFyZK47LdQlwQEBAQMAbRxCCCAgICGgRgQAHBAQEtIhAgAMCAgJaRCDAAQEBAS0iEOCAgICAFhEIcEBAQECLCAQ4ICAgoEUEAhwQEBDQIgIBDggICGgR/z9h+ak5K2QFrgAAAABJRU5ErkJggg==\n" }, "metadata": {} } ], "source": [ "plot_example(X_test[error_mask], y_pred_cnn[error_mask])" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "h-tIl3el_v7x" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13 (default, Mar 28 2022, 08:03:21) [MSC v.1916 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "bd97b8bffa4d3737e84826bc3d37be3046061822757ce35137ab82ad4c5a2016" } }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 0 }