{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.9.0\n", "\n", "torch 1.3.1\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# LeNet-5 QuickDraw Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook implements the classic LeNet-5 convolutional network [1] and applies it to MNIST digit classification. The basic architecture is shown in the figure below:\n", "\n", "![](../images/lenet/lenet-5_1.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "LeNet-5 is commonly regarded as the pioneer of convolutional neural networks, consisting of a very simple architecture (by modern standards). In total, LeNet-5 consists of only 7 layers. 3 out of these 7 layers are convolutional layers (C1, C3, C5), which are connected by two average pooling layers (S2 & S4). The penultimate layer is a fully connexted layer (F6), which is followed by the final output layer. The additional details are summarized below:\n", "\n", "- All convolutional layers use 5x5 kernels with stride 1.\n", "- The two average pooling (subsampling) layers are 2x2 pixels wide with stride 1.\n", "- Throughrout the network, tanh sigmoid activation functions are used. (**In this notebook, we replace these with ReLU activations**)\n", "- The output layer uses 10 custom Euclidean Radial Basis Function neurons for the output layer. (**In this notebook, we replace these with softmax activations**)\n", "- The input size is 32x32; here, we rescale the MNIST images from 28x28 to 32x32 to match this input dimension. Alternatively, we would have to change the \n", "achieve error rate below 1% on the MNIST data set, which was very close to the state of the art at the time (produced by a boosted ensemble of three LeNet-4 networks).\n", "\n", "\n", "### References\n", "\n", "- [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, november 1998." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import Dataset\n", "\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model Settings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.001\n", "BATCH_SIZE = 128\n", "NUM_EPOCHS = 10\n", "\n", "# Architecture\n", "NUM_FEATURES = 28*28\n", "NUM_CLASSES = 10\n", "\n", "# Other\n", "DEVICE = \"cuda:1\"\n", "GRAYSCALE = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is based on Google's Quickdraw dataset (https://quickdraw.withgoogle.com). In particular we will be working with an arbitrary subset of 10 categories in png format:\n", "\n", " label_dict = {\n", " \"lollipop\": 0,\n", " \"binoculars\": 1,\n", " \"mouse\": 2,\n", " \"basket\": 3,\n", " \"penguin\": 4,\n", " \"washing machine\": 5,\n", " \"canoe\": 6,\n", " \"eyeglasses\": 7,\n", " \"beach\": 8,\n", " \"screwdriver\": 9,\n", " }\n", " \n", "(The class labels 0-9 can be ignored in this notebook). \n", "\n", "For more details on obtaining and preparing the dataset, please see the\n", "\n", "- [custom-data-loader-quickdraw.ipynb](custom-data-loader-quickdraw.ipynb)\n", "\n", "notebook." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQnUlEQVR4nO3da4xUdZoG8OcRGkFaAixt21HcRiG6BNxmKFB0NeAoCjHBMY4MH4D1skwMJDNmYryFDNEPELMzkyGuGmYhNGYWmeigfDDsEKPRMeHSEARadHWRnQFaugkXucj93Q99mDTa53+KOqfq1PT7/JJOd9dTp+pN6cOprn+dOjQziEjPd1neA4hIZajsIk6o7CJOqOwiTqjsIk70ruSdDRkyxBobGyt5lyKu7N69GwcOHGB3Waqyk7wPwG8B9ALwn2a2KHT9xsZGtLS0pLlLEQkoFAqxWclP40n2AvAfAKYAGAlgBsmRpd6eiJRXmr/ZxwP40sx2mdlpAG8AmJbNWCKStTRlvwbAX7v8vie67CIk55BsIdnS0dGR4u5EJI00Ze/uRYDvvffWzJaYWcHMCnV1dSnuTkTSSFP2PQCGdvn9WgD70o0jIuWSpuybAIwgOYxkHwA/AbAmm7FEJGslL72Z2VmS8wD8NzqX3paZWWtmk/Ugn3zySTDfsGFDMD9y5Egwv+yy+H+zR44ML5CElmoAQH969Ryp1tnN7F0A72Y0i4iUkd4uK+KEyi7ihMou4oTKLuKEyi7ihMou4kRFj2fvqZLWye++++5gfuzYsSzHyVTS5w8krdOH8qRtx44dG8wHDhwYzOVi2rOLOKGyizihsos4obKLOKGyizihsos4oaW3Iu3YsSM2mzJlSnDba6+9NpivXbs2mNfX1wfz0Mk5t27dGtw26dN+k/JNmzYF87feeis2S3tS0REjRgTz0NLeuHHjSt4WABoaGoL58OHDg3ketGcXcUJlF3FCZRdxQmUXcUJlF3FCZRdxQmUXcULr7JFdu3YF88mTJ8dmAwYMCG67bt26YJ60Dp/GhAkTUuVphQ7f3bJlS3DbtO8BCOUrV64MbpvWG2+8EcynT59e1vvvjvbsIk6o7CJOqOwiTqjsIk6o7CJOqOwiTqjsIk64WWdva2sL5vfcc08wP3/+fGyW5zp6tautrY3N7rzzzuC2Sfk333wTzBcsWBCbvfzyy8Fte/cOV+Pbb78N5v369QvmeUhVdpK7ARwFcA7AWTMLH/EvIrnJYs8+ycwOZHA7IlJG+ptdxIm0ZTcAfyK5meSc7q5Acg7JFpItHR0dKe9OREqVtuy3m9kPAEwBMJfk915RMbMlZlYws0JdXV3KuxORUqUqu5nti763A1gNYHwWQ4lI9kouO8n+JK+88DOAyQDiP29ZRHKV5tX4egCrSV64nf8ys/AHoJfRoUOHgvm9994bzA8ePBjM33///dgs6fPLpXtJnxvf3NwczJ999tlgfuBA/CLRE088Edz20UcfDeZjxowJ5sePHw/meSi57Ga2C8A/ZziLiJSRlt5EnFDZRZxQ2UWcUNlFnFDZRZzoMYe4Tp06NZiHTrkMhA/FBIBJkybFZtHyY6yBAwcG87RCs9fU1KS67T59+gTzpqamYD527NjYbOnSpcFt169fH8wnTpwYzBcvXhybjR49OrjtiRMngnmS/fv3p9q+HLRnF3FCZRdxQmUXcUJlF3FCZRdxQmUXcUJlF3Gix6yzf/3118E86VNyHnnkkZLvO2lN9tSpUyXfNgCcPXs2mB89ejTV7YckHTq8fPnyYP7aa6/FZkn/TVatWhXMH3744WCexhVXXBHM+/fvH8zb29uzHCcT2rOLOKGyizihsos4obKLOKGyizihsos4obKLONFj1tmTPs456fS+ixYtynIcN2bNmhXMX3/99dhs3rx5wW3LuY6eVtJ7BKrxVGfas4s4obKLOKGyizihsos4obKLOKGyizihsos40WPW2ZMkHZ+cp5MnTwbzN998M5jv2bMnNkv6bPVbb701mCdZsWJFMP/ggw9is71796a67zzV19cH87/L49lJLiPZTnJHl8sGk1xH8ovo+6DyjikiaRXzNH45gPu+c9kzAN4zsxEA3ot+F5Eqllh2M/sQwMHvXDwNQHP0czOABzKeS0QyVuoLdPVm1gYA0fer4q5Icg7JFpIt1fh+YREvyv5qvJktMbOCmRWSDh4QkfIptez7STYAQPS9+l56FJGLlFr2NQBmRz/PBvBONuOISLkkrrOTXAlgIoAhJPcA+CWARQD+QPIxAH8B8ONyDvn37qWXXgrmCxcuDOaHDx/OcpyLNDc3B/Ok49WTDBs2LDb76quvUt12npLW2avx/OyJZTezGTHRDzOeRUTKSG+XFXFCZRdxQmUXcUJlF3FCZRdxws0hruW0evXqYP7000+nuv2ZM2cG81dffTU2e+ihh4Lbzp07N5gnbZ906HBo6e3jjz8OblvNhgwZEsy3bdtWoUmKpz27iBMqu4gTKruIEyq7iBMqu4gTKruIEyq7iBNaZ8/AunXrgvnVV18dzJM+Uvmyy0r/N/n5558P5nfccUcwX7t2bTB/8MEHg/n1118fm61cuTK47ZkzZ4J5TU1NMC+nQ4cOBfPa2toKTVI87dlFnFDZRZxQ2UWcUNlFnFDZRZxQ2UWcUNlFnNA6ewaOHz8ezAcMGBDM06yjJ7ntttuCedJ7AN5+++1gnrTOPnr06Njs9OnTwW1bW1uDeVNTUzAvp40bNwbzyZMnV2iS4mnPLuKEyi7ihMou4oTKLuKEyi7ihMou4oTKLuKE1tkz0NjYGMxXrVoVzM+dOxfMe/Xqdakj/U3SGn7oc90BoKOjo+T7BoBCoVDyti0tLcG8nOvsbW1twTzpMwhuueWWLMfJROKeneQyku0kd3S5bAHJvSS3Rl9TyzumiKRVzNP45QDu6+by35hZU/T1brZjiUjWEstuZh8COFiBWUSkjNK8QDeP5Lboaf6guCuRnEOyhWRL2r//RKR0pZb9VQA3AGgC0AbgV3FXNLMlZlYws0JdXV2JdyciaZVUdjPbb2bnzOw8gN8BGJ/tWCKStZLKTrKhy68/ArAj7roiUh0S19lJrgQwEcAQknsA/BLARJJNAAzAbgA/LeOMVW/cuHHB/NSpU8H8ySefDOaLFy++5JmK9dlnnwXz2bNnp7r9oUOHxmZJx9Jv3rw5mD/++OMlzVSM9evXp9p+/Pjqe7KbWHYzm9HNxUvLMIuIlJHeLivihMou4oTKLuKEyi7ihMou4oQOcc3A/fffH8znz58fzF988cVgnrQMNH369NjsyiuvDG6bdOrhsWPHBvM0km476RDXckr6qOi+ffsG81GjRmU5Tia0ZxdxQmUXcUJlF3FCZRdxQmUXcUJlF3FCZRdxQuvsFfDCCy8E8+uuuy6YL1++PJg/9dRTsZmZBbe96aabgnloDT+tpI+ZXrhwYTBPOnT48ssvv+SZLkhaZx8zZkwwr6mpKfm+y0V7dhEnVHYRJ1R2ESdUdhEnVHYRJ1R2ESdUdhEntM5eBZI+EjkpD51eeN++fcFtk07ZXM714qTj2U+fPh3Mt2/fHszTnC76xhtvDOblPF10uWjPLuKEyi7ihMou4oTKLuKEyi7ihMou4oTKLuKE1tl7gIaGhpKyvN18882ptv/000+DeZp19ldeeaXkbatV4p6d5FCS75PcSbKV5M+iyweTXEfyi+j7oPKPKyKlKuZp/FkAvzCzfwJwK4C5JEcCeAbAe2Y2AsB70e8iUqUSy25mbWa2Jfr5KICdAK4BMA1Ac3S1ZgAPlGtIEUnvkl6gI9kIYAyADQDqzawN6PwHAcBVMdvMIdlCsqWjoyPdtCJSsqLLTrIWwFsAfm5m3xS7nZktMbOCmRXq6upKmVFEMlBU2UnWoLPovzezP0YX7yfZEOUNANrLM6KIZKGYV+MJYCmAnWb26y7RGgCzo59nA3gn+/GkJ+vdu3fwK4mZBb/kYsWss98OYCaA7SS3Rpc9B2ARgD+QfAzAXwD8uDwjikgWEstuZn8GwJj4h9mOIyLlorfLijihsos4obKLOKGyizihsos40WMOce3bt28wb2/Xe36qzdmzZ1Nt36tXr4wm8UF7dhEnVHYRJ1R2ESdUdhEnVHYRJ1R2ESdUdhEnesw6+/Dhw4P5Rx99FMwPHTpU8n0PGqQP1i3FmTNnUm3fp0+fjCbxQXt2ESdUdhEnVHYRJ1R2ESdUdhEnVHYRJ1R2ESd6zDr7qFGjgvnhw4eD+eDBg7Mc55IkHYuflIekfQ9A0n3fcMMNwXzChAmxWdrTSdfU1KTa3hvt2UWcUNlFnFDZRZxQ2UWcUNlFnFDZRZxQ2UWcSFxnJzkUwAoAVwM4D2CJmf2W5AIA/wagI7rqc2b2brkGTTJjxoxgnnS+75MnTwbzU6dOxWYnTpwIbpsk6T0ASecaL+dsR44cCeaff/55MJ8/f35sdv78+ZJmukDr7JemmDfVnAXwCzPbQvJKAJtJrouy35jZv5dvPBHJSjHnZ28D0Bb9fJTkTgDXlHswEcnWJf3NTrIRwBgAG6KL5pHcRnIZyW7fl0lyDskWki0dHR3dXUVEKqDospOsBfAWgJ+b2TcAXgVwA4AmdO75f9Xddma2xMwKZlaoq6vLYGQRKUVRZSdZg86i/97M/ggAZrbfzM6Z2XkAvwMwvnxjikhaiWUnSQBLAew0s193ubzrIUs/ArAj+/FEJCvFvBp/O4CZALaT3Bpd9hyAGSSbABiA3QB+WpYJi9SvX79gPmvWrApNIl0dO3YsNtu4cWNw29bW1mA+adKkkmbyqphX4/8MgN1Eua2pi8il0zvoRJxQ2UWcUNlFnFDZRZxQ2UWcUNlFnOgxHyUt1am2tjY2u+uuu4LbJuVyabRnF3FCZRdxQmUXcUJlF3FCZRdxQmUXcUJlF3GCSR9TnOmdkR0A/q/LRUMAHKjYAJemWmer1rkAzVaqLGf7RzPr9vPfKlr279052WJmhdwGCKjW2ap1LkCzlapSs+lpvIgTKruIE3mXfUnO9x9SrbNV61yAZitVRWbL9W92EamcvPfsIlIhKruIE7mUneR9JD8n+SXJZ/KYIQ7J3SS3k9xKsiXnWZaRbCe5o8tlg0muI/lF9L3bc+zlNNsCknujx24ryak5zTaU5Pskd5JsJfmz6PJcH7vAXBV53Cr+NzvJXgD+B8A9APYA2ARghpl9WtFBYpDcDaBgZrm/AYPknQCOAVhhZqOiy14CcNDMFkX/UA4ys6erZLYFAI7lfRrv6GxFDV1PMw7gAQD/ihwfu8BcD6MCj1see/bxAL40s11mdhrAGwCm5TBH1TOzDwEc/M7F0wA0Rz83o/N/loqLma0qmFmbmW2Jfj4K4MJpxnN97AJzVUQeZb8GwF+7/L4H1XW+dwPwJ5KbSc7Je5hu1JtZG9D5Pw+Aq3Ke57sST+NdSd85zXjVPHalnP48rTzK3t2ppKpp/e92M/sBgCkA5kZPV6U4RZ3Gu1K6Oc14VSj19Odp5VH2PQCGdvn9WgD7cpijW2a2L/reDmA1qu9U1PsvnEE3+t6e8zx/U02n8e7uNOOogscuz9Of51H2TQBGkBxGsg+AnwBYk8Mc30Oyf/TCCUj2BzAZ1Xcq6jUAZkc/zwbwTo6zXKRaTuMdd5px5PzY5X76czOr+BeAqeh8Rf5/ATyfxwwxc10P4JPoqzXv2QCsROfTujPofEb0GIB/APAegC+i74OraLbXAWwHsA2dxWrIabZ/QeefhtsAbI2+pub92AXmqsjjprfLijihd9CJOKGyizihsos4obKLOKGyizihsos4obKLOPH/R9r9mQpsSDgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df = pd.read_csv('quickdraw_png_set1_train.csv', index_col=0)\n", "df.head()\n", "\n", "main_dir = 'quickdraw-png_set1/'\n", "\n", "img = Image.open(os.path.join(main_dir, df.index[99]))\n", "img = np.asarray(img, dtype=np.uint8)\n", "print(img.shape)\n", "plt.imshow(np.array(img), cmap='binary')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a Custom Data Loader" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class QuickdrawDataset(Dataset):\n", " \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n", "\n", " def __init__(self, txt_path, img_dir, transform=None):\n", " \n", " df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n", " self.img_dir = img_dir\n", " self.txt_path = txt_path\n", " self.img_names = df.index.values\n", " self.y = df['Label'].values\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_names[index]))\n", " \n", " if self.transform is not None:\n", " img = self.transform(img)\n", " \n", " label = self.y[index]\n", " return img, label\n", "\n", " def __len__(self):\n", " return self.y.shape[0]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Note that transforms.ToTensor()\n", "# already divides pixels by 255. internally\n", "\n", "\n", "BATCH_SIZE = 128\n", "\n", "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n", " img_dir='quickdraw-png_set1/',\n", " transform=custom_transform)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=4) \n", "\n", "\n", "valid_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_valid.csv',\n", " img_dir='quickdraw-png_set1/',\n", " transform=custom_transform)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=4) \n", "\n", "\n", "\n", "test_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n", " img_dir='quickdraw-png_set1/',\n", " transform=custom_transform)\n", "\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=4) " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Batch index: 0 | Batch size: 128\n", "Epoch: 2 | Batch index: 0 | Batch size: 128\n" ] } ], "source": [ "device = torch.device(DEVICE if torch.cuda.is_available() else \"cpu\")\n", "torch.manual_seed(0)\n", "\n", "num_epochs = 2\n", "for epoch in range(num_epochs):\n", "\n", " for batch_idx, (x, y) in enumerate(train_loader):\n", " \n", " print('Epoch:', epoch+1, end='')\n", " print(' | Batch index:', batch_idx, end='')\n", " print(' | Batch size:', y.size()[0])\n", " \n", " x = x.to(device)\n", " y = y.to(device)\n", " break" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class LeNet5(nn.Module):\n", "\n", " def __init__(self, num_classes, grayscale=False):\n", " super(LeNet5, self).__init__()\n", " \n", " self.grayscale = grayscale\n", " self.num_classes = num_classes\n", "\n", " if self.grayscale:\n", " in_channels = 1\n", " else:\n", " in_channels = 3\n", "\n", " self.features = nn.Sequential(\n", " \n", " nn.Conv2d(in_channels, 6, kernel_size=5),\n", " nn.Tanh(),\n", " nn.MaxPool2d(kernel_size=2),\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.Tanh(),\n", " nn.MaxPool2d(kernel_size=2)\n", " )\n", "\n", " self.classifier = nn.Sequential(\n", " nn.Linear(16*4*4, 120),\n", " nn.Tanh(),\n", " nn.Linear(120, 84),\n", " nn.Tanh(),\n", " nn.Linear(84, num_classes),\n", " )\n", "\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = torch.flatten(x, 1)\n", " logits = self.classifier(x)\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = LeNet5(NUM_CLASSES, GRAYSCALE)\n", "model = model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\nmodel.features[0].register_forward_hook(print_sizes)\\nmodel.features[1].register_forward_hook(print_sizes)\\nmodel.features[2].register_forward_hook(print_sizes)\\nmodel.features[3].register_forward_hook(print_sizes)\\n\\nmodel.classifier[0].register_forward_hook(print_sizes)\\nmodel.classifier[1].register_forward_hook(print_sizes)\\nmodel.classifier[2].register_forward_hook(print_sizes)\\n'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def print_sizes(self, input, output):\n", "\n", " print('Inside ' + self.__class__.__name__ + ' forward')\n", " print('input size:', input[0].size())\n", " print('output size:', output.data.size())\n", "\n", " \n", "## Debugging\n", "\n", "\"\"\"\n", "model.features[0].register_forward_hook(print_sizes)\n", "model.features[1].register_forward_hook(print_sizes)\n", "model.features[2].register_forward_hook(print_sizes)\n", "model.features[3].register_forward_hook(print_sizes)\n", "\n", "model.classifier[0].register_forward_hook(print_sizes)\n", "model.classifier[1].register_forward_hook(print_sizes)\n", "model.classifier[2].register_forward_hook(print_sizes)\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 0000/8290 | Cost: 2.3096\n", "Epoch: 001/010 | Batch 0500/8290 | Cost: 0.6812\n", "Epoch: 001/010 | Batch 1000/8290 | Cost: 0.4123\n", "Epoch: 001/010 | Batch 1500/8290 | Cost: 0.2931\n", "Epoch: 001/010 | Batch 2000/8290 | Cost: 0.3878\n", "Epoch: 001/010 | Batch 2500/8290 | Cost: 0.2494\n", "Epoch: 001/010 | Batch 3000/8290 | Cost: 0.3749\n", "Epoch: 001/010 | Batch 3500/8290 | Cost: 0.3393\n", "Epoch: 001/010 | Batch 4000/8290 | Cost: 0.4072\n", "Epoch: 001/010 | Batch 4500/8290 | Cost: 0.2639\n", "Epoch: 001/010 | Batch 5000/8290 | Cost: 0.4709\n", "Epoch: 001/010 | Batch 5500/8290 | Cost: 0.3594\n", "Epoch: 001/010 | Batch 6000/8290 | Cost: 0.4542\n", "Epoch: 001/010 | Batch 6500/8290 | Cost: 0.2887\n", "Epoch: 001/010 | Batch 7000/8290 | Cost: 0.3441\n", "Epoch: 001/010 | Batch 7500/8290 | Cost: 0.2771\n", "Epoch: 001/010 | Batch 8000/8290 | Cost: 0.4163\n", "Epoch: 001/010 | Train: 90.619% | Validation: 90.339%\n", "Time elapsed: 3.36 min\n", "Epoch: 002/010 | Batch 0000/8290 | Cost: 0.3799\n", "Epoch: 002/010 | Batch 0500/8290 | Cost: 0.2720\n", "Epoch: 002/010 | Batch 1000/8290 | Cost: 0.3350\n", "Epoch: 002/010 | Batch 1500/8290 | Cost: 0.3859\n", "Epoch: 002/010 | Batch 2000/8290 | Cost: 0.2861\n", "Epoch: 002/010 | Batch 2500/8290 | Cost: 0.4202\n", "Epoch: 002/010 | Batch 3000/8290 | Cost: 0.3077\n", "Epoch: 002/010 | Batch 3500/8290 | Cost: 0.3045\n", "Epoch: 002/010 | Batch 4000/8290 | Cost: 0.1604\n", "Epoch: 002/010 | Batch 4500/8290 | Cost: 0.2022\n", "Epoch: 002/010 | Batch 5000/8290 | Cost: 0.2315\n", "Epoch: 002/010 | Batch 5500/8290 | Cost: 0.2880\n", "Epoch: 002/010 | Batch 6000/8290 | Cost: 0.2055\n", "Epoch: 002/010 | Batch 6500/8290 | Cost: 0.5247\n", "Epoch: 002/010 | Batch 7000/8290 | Cost: 0.4131\n", "Epoch: 002/010 | Batch 7500/8290 | Cost: 0.2302\n", "Epoch: 002/010 | Batch 8000/8290 | Cost: 0.2234\n", "Epoch: 002/010 | Train: 91.823% | Validation: 91.486%\n", "Time elapsed: 5.58 min\n", "Epoch: 003/010 | Batch 0000/8290 | Cost: 0.3333\n", "Epoch: 003/010 | Batch 0500/8290 | Cost: 0.3250\n", "Epoch: 003/010 | Batch 1000/8290 | Cost: 0.2323\n", "Epoch: 003/010 | Batch 1500/8290 | Cost: 0.2834\n", "Epoch: 003/010 | Batch 2000/8290 | Cost: 0.3315\n", "Epoch: 003/010 | Batch 2500/8290 | Cost: 0.3029\n", "Epoch: 003/010 | Batch 3000/8290 | Cost: 0.2193\n", "Epoch: 003/010 | Batch 3500/8290 | Cost: 0.1904\n", "Epoch: 003/010 | Batch 4000/8290 | Cost: 0.2865\n", "Epoch: 003/010 | Batch 4500/8290 | Cost: 0.2746\n", "Epoch: 003/010 | Batch 5000/8290 | Cost: 0.3442\n", "Epoch: 003/010 | Batch 5500/8290 | Cost: 0.2003\n", "Epoch: 003/010 | Batch 6000/8290 | Cost: 0.3828\n", "Epoch: 003/010 | Batch 6500/8290 | Cost: 0.2139\n", "Epoch: 003/010 | Batch 7000/8290 | Cost: 0.2914\n", "Epoch: 003/010 | Batch 7500/8290 | Cost: 0.2799\n", "Epoch: 003/010 | Batch 8000/8290 | Cost: 0.2144\n", "Epoch: 003/010 | Train: 92.152% | Validation: 91.699%\n", "Time elapsed: 7.79 min\n", "Epoch: 004/010 | Batch 0000/8290 | Cost: 0.1746\n", "Epoch: 004/010 | Batch 0500/8290 | Cost: 0.3684\n", "Epoch: 004/010 | Batch 1000/8290 | Cost: 0.3992\n", "Epoch: 004/010 | Batch 1500/8290 | Cost: 0.3352\n", "Epoch: 004/010 | Batch 2000/8290 | Cost: 0.2877\n", "Epoch: 004/010 | Batch 2500/8290 | Cost: 0.2366\n", "Epoch: 004/010 | Batch 3000/8290 | Cost: 0.3215\n", "Epoch: 004/010 | Batch 3500/8290 | Cost: 0.1784\n", "Epoch: 004/010 | Batch 4000/8290 | Cost: 0.3136\n", "Epoch: 004/010 | Batch 4500/8290 | Cost: 0.3379\n", "Epoch: 004/010 | Batch 5000/8290 | Cost: 0.3069\n", "Epoch: 004/010 | Batch 5500/8290 | Cost: 0.1735\n", "Epoch: 004/010 | Batch 6000/8290 | Cost: 0.1910\n", "Epoch: 004/010 | Batch 6500/8290 | Cost: 0.3131\n", "Epoch: 004/010 | Batch 7000/8290 | Cost: 0.2566\n", "Epoch: 004/010 | Batch 7500/8290 | Cost: 0.2888\n", "Epoch: 004/010 | Batch 8000/8290 | Cost: 0.3298\n", "Epoch: 004/010 | Train: 92.251% | Validation: 91.693%\n", "Time elapsed: 10.01 min\n", "Epoch: 005/010 | Batch 0000/8290 | Cost: 0.2621\n", "Epoch: 005/010 | Batch 0500/8290 | Cost: 0.1341\n", "Epoch: 005/010 | Batch 1000/8290 | Cost: 0.2740\n", "Epoch: 005/010 | Batch 1500/8290 | Cost: 0.2190\n", "Epoch: 005/010 | Batch 2000/8290 | Cost: 0.2355\n", "Epoch: 005/010 | Batch 2500/8290 | Cost: 0.2771\n", "Epoch: 005/010 | Batch 3000/8290 | Cost: 0.3470\n", "Epoch: 005/010 | Batch 3500/8290 | Cost: 0.1613\n", "Epoch: 005/010 | Batch 4000/8290 | Cost: 0.3326\n", "Epoch: 005/010 | Batch 4500/8290 | Cost: 0.2114\n", "Epoch: 005/010 | Batch 5000/8290 | Cost: 0.3249\n", "Epoch: 005/010 | Batch 5500/8290 | Cost: 0.2614\n", "Epoch: 005/010 | Batch 6000/8290 | Cost: 0.2974\n", "Epoch: 005/010 | Batch 6500/8290 | Cost: 0.2653\n", "Epoch: 005/010 | Batch 7000/8290 | Cost: 0.1659\n", "Epoch: 005/010 | Batch 7500/8290 | Cost: 0.3587\n", "Epoch: 005/010 | Batch 8000/8290 | Cost: 0.1271\n", "Epoch: 005/010 | Train: 92.575% | Validation: 91.995%\n", "Time elapsed: 12.21 min\n", "Epoch: 006/010 | Batch 0000/8290 | Cost: 0.1457\n", "Epoch: 006/010 | Batch 0500/8290 | Cost: 0.2908\n", "Epoch: 006/010 | Batch 1000/8290 | Cost: 0.3151\n", "Epoch: 006/010 | Batch 1500/8290 | Cost: 0.3322\n", "Epoch: 006/010 | Batch 2000/8290 | Cost: 0.2056\n", "Epoch: 006/010 | Batch 2500/8290 | Cost: 0.2625\n", "Epoch: 006/010 | Batch 3000/8290 | Cost: 0.2600\n", "Epoch: 006/010 | Batch 3500/8290 | Cost: 0.3253\n", "Epoch: 006/010 | Batch 4000/8290 | Cost: 0.1884\n", "Epoch: 006/010 | Batch 4500/8290 | Cost: 0.2553\n", "Epoch: 006/010 | Batch 5000/8290 | Cost: 0.3106\n", "Epoch: 006/010 | Batch 5500/8290 | Cost: 0.1887\n", "Epoch: 006/010 | Batch 6000/8290 | Cost: 0.2765\n", "Epoch: 006/010 | Batch 6500/8290 | Cost: 0.1896\n", "Epoch: 006/010 | Batch 7000/8290 | Cost: 0.2351\n", "Epoch: 006/010 | Batch 7500/8290 | Cost: 0.1942\n", "Epoch: 006/010 | Batch 8000/8290 | Cost: 0.2452\n", "Epoch: 006/010 | Train: 92.768% | Validation: 92.084%\n", "Time elapsed: 14.44 min\n", "Epoch: 007/010 | Batch 0000/8290 | Cost: 0.2731\n", "Epoch: 007/010 | Batch 0500/8290 | Cost: 0.1256\n", "Epoch: 007/010 | Batch 1000/8290 | Cost: 0.2282\n", "Epoch: 007/010 | Batch 1500/8290 | Cost: 0.2288\n", "Epoch: 007/010 | Batch 2000/8290 | Cost: 0.1315\n", "Epoch: 007/010 | Batch 2500/8290 | Cost: 0.2518\n", "Epoch: 007/010 | Batch 3000/8290 | Cost: 0.3285\n", "Epoch: 007/010 | Batch 3500/8290 | Cost: 0.2102\n", "Epoch: 007/010 | Batch 4000/8290 | Cost: 0.1955\n", "Epoch: 007/010 | Batch 4500/8290 | Cost: 0.1690\n", "Epoch: 007/010 | Batch 5000/8290 | Cost: 0.1595\n", "Epoch: 007/010 | Batch 5500/8290 | Cost: 0.2186\n", "Epoch: 007/010 | Batch 6000/8290 | Cost: 0.2465\n", "Epoch: 007/010 | Batch 6500/8290 | Cost: 0.2922\n", "Epoch: 007/010 | Batch 7000/8290 | Cost: 0.2836\n", "Epoch: 007/010 | Batch 7500/8290 | Cost: 0.1863\n", "Epoch: 007/010 | Batch 8000/8290 | Cost: 0.1654\n", "Epoch: 007/010 | Train: 92.966% | Validation: 92.307%\n", "Time elapsed: 16.70 min\n", "Epoch: 008/010 | Batch 0000/8290 | Cost: 0.2479\n", "Epoch: 008/010 | Batch 0500/8290 | Cost: 0.2505\n", "Epoch: 008/010 | Batch 1000/8290 | Cost: 0.3280\n", "Epoch: 008/010 | Batch 1500/8290 | Cost: 0.3119\n", "Epoch: 008/010 | Batch 2000/8290 | Cost: 0.3892\n", "Epoch: 008/010 | Batch 2500/8290 | Cost: 0.3371\n", "Epoch: 008/010 | Batch 3000/8290 | Cost: 0.3909\n", "Epoch: 008/010 | Batch 3500/8290 | Cost: 0.2831\n", "Epoch: 008/010 | Batch 4000/8290 | Cost: 0.2730\n", "Epoch: 008/010 | Batch 4500/8290 | Cost: 0.1258\n", "Epoch: 008/010 | Batch 5000/8290 | Cost: 0.2155\n", "Epoch: 008/010 | Batch 5500/8290 | Cost: 0.2419\n", "Epoch: 008/010 | Batch 6000/8290 | Cost: 0.2309\n", "Epoch: 008/010 | Batch 6500/8290 | Cost: 0.2843\n", "Epoch: 008/010 | Batch 7000/8290 | Cost: 0.2820\n", "Epoch: 008/010 | Batch 7500/8290 | Cost: 0.1245\n", "Epoch: 008/010 | Batch 8000/8290 | Cost: 0.3503\n", "Epoch: 008/010 | Train: 92.978% | Validation: 92.270%\n", "Time elapsed: 18.94 min\n", "Epoch: 009/010 | Batch 0000/8290 | Cost: 0.2116\n", "Epoch: 009/010 | Batch 0500/8290 | Cost: 0.3477\n", "Epoch: 009/010 | Batch 1000/8290 | Cost: 0.1537\n", "Epoch: 009/010 | Batch 1500/8290 | Cost: 0.2932\n", "Epoch: 009/010 | Batch 2000/8290 | Cost: 0.2075\n", "Epoch: 009/010 | Batch 2500/8290 | Cost: 0.2520\n", "Epoch: 009/010 | Batch 3000/8290 | Cost: 0.1347\n", "Epoch: 009/010 | Batch 3500/8290 | Cost: 0.1800\n", "Epoch: 009/010 | Batch 4000/8290 | Cost: 0.2365\n", "Epoch: 009/010 | Batch 4500/8290 | Cost: 0.2445\n", "Epoch: 009/010 | Batch 5000/8290 | Cost: 0.1622\n", "Epoch: 009/010 | Batch 5500/8290 | Cost: 0.1989\n", "Epoch: 009/010 | Batch 6000/8290 | Cost: 0.1404\n", "Epoch: 009/010 | Batch 6500/8290 | Cost: 0.1281\n", "Epoch: 009/010 | Batch 7000/8290 | Cost: 0.3659\n", "Epoch: 009/010 | Batch 7500/8290 | Cost: 0.2559\n", "Epoch: 009/010 | Batch 8000/8290 | Cost: 0.2351\n", "Epoch: 009/010 | Train: 93.070% | Validation: 92.308%\n", "Time elapsed: 21.15 min\n", "Epoch: 010/010 | Batch 0000/8290 | Cost: 0.1964\n", "Epoch: 010/010 | Batch 0500/8290 | Cost: 0.1686\n", "Epoch: 010/010 | Batch 1000/8290 | Cost: 0.2819\n", "Epoch: 010/010 | Batch 1500/8290 | Cost: 0.1610\n", "Epoch: 010/010 | Batch 2000/8290 | Cost: 0.1473\n", "Epoch: 010/010 | Batch 2500/8290 | Cost: 0.2996\n", "Epoch: 010/010 | Batch 3000/8290 | Cost: 0.2584\n", "Epoch: 010/010 | Batch 3500/8290 | Cost: 0.3147\n", "Epoch: 010/010 | Batch 4000/8290 | Cost: 0.1333\n", "Epoch: 010/010 | Batch 4500/8290 | Cost: 0.2588\n", "Epoch: 010/010 | Batch 5000/8290 | Cost: 0.1896\n", "Epoch: 010/010 | Batch 5500/8290 | Cost: 0.3248\n", "Epoch: 010/010 | Batch 6000/8290 | Cost: 0.3710\n", "Epoch: 010/010 | Batch 6500/8290 | Cost: 0.3223\n", "Epoch: 010/010 | Batch 7000/8290 | Cost: 0.1774\n", "Epoch: 010/010 | Batch 7500/8290 | Cost: 0.3240\n", "Epoch: 010/010 | Batch 8000/8290 | Cost: 0.2755\n", "Epoch: 010/010 | Train: 93.128% | Validation: 92.364%\n", "Time elapsed: 23.37 min\n", "Total Training Time: 23.37 min\n" ] } ], "source": [ "def compute_accuracy(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 500:\n", " print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n", " %(epoch+1, NUM_EPOCHS, batch_idx, \n", " len(train_loader), cost))\n", "\n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " print('Epoch: %03d/%03d | Train: %.3f%% | Validation: %.3f%%' % (\n", " epoch+1, NUM_EPOCHS, \n", " compute_accuracy(model, train_loader, device=DEVICE),\n", " compute_accuracy(model, valid_loader, device=DEVICE) ))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "paaeEQHQj5xC" }, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 6514, "status": "ok", "timestamp": 1524976895054, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "gzQMWKq5j5xE", "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 93.13%\n" ] } ], "source": [ "with torch.set_grad_enabled(False): # save memory during inference\n", " print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQIklEQVR4nO3df4xV5Z3H8c+XoYAiFdgZWQQsFSSRYBb1xh+wUUyzRtBEG+2m/mhmjS6NStJK1TUspkb/IWZbrcQY8UfEDdo0qQZjiNQQE+M/xiuyDiOsuDrbopOZISQUEBxgvvvHHN0pznnO9Z577rnj834lkztzvveZ8+Uynzl37nPPeczdBeC7b1zZDQBoDsIORIKwA5Eg7EAkCDsQifHN3Fl7e7vPnTu3mbsEotLT06N9+/bZaLVcYTezqyT9TlKbpGfcfV3o/nPnzlW1Ws2zSwABlUoltVb303gza5P0hKTlkhZKutHMFtb7/QAUK8/f7BdJ+tjdP3H3QUm/l3RtY9oC0Gh5wj5L0l9GfL032fY3zGylmVXNrDowMJBjdwDyyBP20V4E+MZ7b919g7tX3L3S0dGRY3cA8sgT9r2S5oz4erakz/O1A6AoecL+rqRzzOyHZjZB0k8lvdqYtgA0Wt1Tb+5+3MxWSdqq4am359y9u2GdAWioXPPs7r5F0pYG9QKgQLxdFogEYQciQdiBSBB2IBKEHYgEYQci0dTz2YFWMTg4GKwfPnw4WJ8yZUqwPn5860WLIzsQCcIORIKwA5Eg7EAkCDsQCcIORKL15geAxKFDh4L1N998M1jftGlTau3ll18Ojj127FiwvmjRomA96yrKEydODNaLwJEdiARhByJB2IFIEHYgEoQdiARhByJB2IFItNQ8e9a86u7du1Nrs2fPDo7t7g5f5fqLL74I1tva2lJr06ZNC441G3UF3a99//vfD9bHjQv/Tp48eXJqLdS3JJ1yyinBetb4rPno119/PbUWmgeXpK1btwbrQ0NDwfoll1ySWnv00UeDY7P+3XfccUew3tXVFayHVlstCkd2IBKEHYgEYQciQdiBSBB2IBKEHYgEYQci0VLz7Js3bw7Wb7nlliZ1gkYJvUfgyiuvDI59/vnng/Wrr746WJ8+fXqwHpI1h3/vvfcG61nn2pcxz54r7GbWI+mgpBOSjrt78/8FAGrSiCP7Fe6+rwHfB0CB+JsdiETesLukP5nZe2a2crQ7mNlKM6uaWXVgYCDn7gDUK2/Yl7r7BZKWS7rLzC47+Q7uvsHdK+5e6ejoyLk7APXKFXZ3/zy57Zf0iqSLGtEUgMarO+xmNtnMpnz1uaQrJe1sVGMAGivPq/EzJL2SnKs9XtKL7p5+8nIN5s2bV/fYF198MVhfvnx5sD5p0qRg/cSJE6m1rOV/s875/vLLL4P1rO8fGn/06NFc+846zz/rXPvQfHLWssdlyvp3XX755cH6+++/38h2GqLusLv7J5L+oYG9ACgQU29AJAg7EAnCDkSCsAORIOxAJFrqFNezzjqr7rELFiwI1qdOnVr3984SupQzvpuyLv+dNWVZBo7sQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EoqXm2dvb24P10NLHe/bsCY698MIL6+oJGM3EiROD9f379zepk9pxZAciQdiBSBB2IBKEHYgEYQciQdiBSBB2IBItNc8+YcKEYH3mzJmptd27dze6HSBV1s8q57MDKA1hByJB2IFIEHYgEoQdiARhByJB2IFItNQ8e5bzzz8/tZZ1Pvt3WVdXV2rtmmuuCY7NOu962bJlwfpll10WrN9+++2ptWnTpgXHtrKsJb6PHDnSpE5ql3lkN7PnzKzfzHaO2DbdzN4wsz3J7dj9XwMiUcvT+OclXXXStvslbXP3cyRtS74G0MIyw+7ub0k6+bnetZI2Jp9vlHRdg/sC0GD1vkA3w917JSm5PSPtjma20syqZlYdGBioc3cA8ir81Xh33+DuFXevdHR0FL07ACnqDXufmc2UpOS2v3EtAShCvWF/VVJn8nmnpM2NaQdAUTLn2c3sJUnLJLWb2V5Jv5a0TtIfzOw2SX+W9JMim/xKaA32rVu3NqOFUjz22GPB+urVq1NrS5YsCY49cOBAsP7aa68F61u2bAnW16xZk1pbtWpVcOy6deuC9axrtxdpLJ7Pnhl2d78xpfSjBvcCoEC8XRaIBGEHIkHYgUgQdiAShB2IxJg6xTU09bZ+/frg2KGhoWB93Ljyfu9t27YtWL/77ruD9Yceeii1Fpr6kqS2trZgffv27cH6/Pnzg/XHH388tfbAAw8Ex3Z2dgbrixcvDtaLlHWK69GjR5vUSe04sgORIOxAJAg7EAnCDkSCsAORIOxAJAg7EIkxNc++cOHC1Nrx48eDY7NO5Szyssb79u0L1q+//vpg/eabbw7W165dm1ozs+DYLBdccEGu8XfeeWdqLWuevRUvx/yVrNNrW7F3juxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkRiTM2zT58+ve6xBw8eDNaLnGe/7777gvXTTjstWN+wYUOwnncuvUhZj/tYNXv27GC9r68vWD9x4kRqLesaA/XiyA5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCTG1Dz7+PH1txua1yxa1rzpvHnzgvVTTz21ke001SOPPJJaa29vD47Ney59kc4+++xgPWudgtD7D6ZOnVpXT1kyj+xm9pyZ9ZvZzhHbHjSzz8xsR/KxopDuADRMLU/jn5d01SjbH3X3xcnHlsa2BaDRMsPu7m9J2t+EXgAUKM8LdKvM7IPkaX7qG8vNbKWZVc2sOjAwkGN3APKoN+xPSponabGkXkm/Sbuju29w94q7Vzo6OurcHYC86gq7u/e5+wl3H5L0tKSLGtsWgEarK+xmNnPElz+WtDPtvgBaQ+bEtZm9JGmZpHYz2yvp15KWmdliSS6pR9LPC+zxa3nm2bOuK1+kOXPmBOubNm0K1l944YVg/aabbkqt5XnMatHd3R2sP/XUU6m19evXB8dmXZu9TGeeeWau8f39/am1oubZM38S3P3GUTY/W0AvAArE22WBSBB2IBKEHYgEYQciQdiBSERzimuZU2+rV68O1nt7e4P1zs7OuutZl5nOc3luSdq/P3zaxKWXXppau/XWW3Ptu0xnnHFGrvGfffZZam3BggW5vncajuxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkQimnn2wcHBBnby7WQtyfzkk08G6/fcc0+w/vbbb6fWDh8+HBx75MiRYD3LrFmzgvUbbrghtVb06bdFyvo/nTRpUrAeOjX4iiuuqKunLBzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IxJia6ByrSzbnlbWkc1YdjTduXPg4uWTJkmD9nXfeSa2tWrWqrp6ycGQHIkHYgUgQdiAShB2IBGEHIkHYgUgQdiASY2qeva2tre6xx44da2AnQNjSpUuD9aeffrpJnfy/zCO7mc0xszfNbJeZdZvZL5Lt083sDTPbk9xOK75dAPWq5Wn8cUm/cvdzJV0i6S4zWyjpfknb3P0cSduSrwG0qMywu3uvu29PPj8oaZekWZKulbQxudtGSdcV1SSA/L7VC3RmNlfS+ZLekTTD3Xul4V8IkkZd/MrMVppZ1cyqAwMD+boFULeaw25mp0n6o6Rfuvtfax3n7hvcveLulY6Ojnp6BNAANYXdzL6n4aBvcveXk819ZjYzqc+U1F9MiwAaIXPqzYbX/H1W0i53/+2I0quSOiWtS243F9LhCHlOcWXqDc2UNfX28MMPp9YOHDgQHHv66afX1VMt6Vkq6WeSusxsR7JtjYZD/gczu03SnyX9pK4OADRFZtjd/W1JllL+UWPbAVAU3i4LRIKwA5Eg7EAkCDsQCcIORCKaU1yfeeaZYH3y5MnB+sSJE1NrWcvzzpgxI9e+Mfacd955dY/dvXt3sH7xxRfX9X05sgORIOxAJAg7EAnCDkSCsAORIOxAJAg7EIkxNc8+ZcqU1NratWuDY5944olgfePGjcF6kbLOT160aFGwvnDhwtTaggULgmPLvHrQhAkTgvV6z9tuBUePHq17bE9PT7DOPDuAIMIORIKwA5Eg7EAkCDsQCcIORIKwA5Ewd2/aziqViler1abtb6Tjx48H6319fcH60NBQau3QoUPBsZ9++mmw/uGHHwbrH330UbDe3d2dWuvq6gqOPXjwYLCOYgwvxzC6rP/v+fPnp9YqlYqq1eqo35wjOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkahlffY5kl6Q9PeShiRtcPffmdmDkv5V0kBy1zXuvqWoRvPKWtt91qxZhe373HPPDdZXrFhR2L7LNjg4mFoLvXchZlnrENSrlotXHJf0K3ffbmZTJL1nZm8ktUfd/T8K6QxAQ9WyPnuvpN7k84NmtktScYdBAIX4Vn+zm9lcSedLeifZtMrMPjCz58xsWsqYlWZWNbPqwMDAaHcB0AQ1h93MTpP0R0m/dPe/SnpS0jxJizV85P/NaOPcfYO7V9y9Uub1zoDY1RR2M/uehoO+yd1fliR373P3E+4+JOlpSRcV1yaAvDLDbsOn5zwraZe7/3bE9pkj7vZjSTsb3x6ARqnl1filkn4mqcvMdiTb1ki60cwWS3JJPZJ+XkiHGNOyLheN5qnl1fi3JY12fmzLzqkD+CbeQQdEgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EgrADkWjqks1mNiDpf0dsape0r2kNfDut2lur9iXRW70a2dsP3H3U6781Nezf2LlZ1d0rpTUQ0Kq9tWpfEr3Vq1m98TQeiARhByJRdtg3lLz/kFbtrVX7kuitXk3prdS/2QE0T9lHdgBNQtiBSJQSdjO7ysz+28w+NrP7y+ghjZn1mFmXme0ws2rJvTxnZv1mtnPEtulm9oaZ7UluR11jr6TeHjSzz5LHboeZlbIWtZnNMbM3zWyXmXWb2S+S7aU+doG+mvK4Nf1vdjNrk/SRpH+StFfSu5JudPcPm9pICjPrkVRx99LfgGFml0k6JOkFd1+UbHtE0n53X5f8opzm7v/WIr09KOlQ2ct4J6sVzRy5zLik6yT9i0p87AJ9/bOa8LiVcWS/SNLH7v6Juw9K+r2ka0voo+W5+1uS9p+0+VpJG5PPN2r4h6XpUnprCe7e6+7bk88PSvpqmfFSH7tAX01RRthnSfrLiK/3qrXWe3dJfzKz98xsZdnNjGKGu/dKwz88ks4ouZ+TZS7j3UwnLTPeMo9dPcuf51VG2EdbSqqV5v+WuvsFkpZLuit5uora1LSMd7OMssx4S6h3+fO8ygj7XklzRnw9W9LnJfQxKnf/PLntl/SKWm8p6r6vVtBNbvtL7udrrbSM92jLjKsFHrsylz8vI+zvSjrHzH5oZhMk/VTSqyX08Q1mNjl54URmNlnSlWq9pahfldSZfN4paXOJvfyNVlnGO22ZcZX82JW+/Lm7N/1D0goNvyL/P5L+vYweUvo6W9J/JR/dZfcm6SUNP607puFnRLdJ+jtJ2yTtSW6nt1Bv/ympS9IHGg7WzJJ6+0cN/2n4gaQdyceKsh+7QF9Nedx4uywQCd5BB0SCsAORIOxAJAg7EAnCDkSCsAORIOxAJP4PaZq41BtvbUoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for batch_idx, (features, targets) in enumerate(test_loader):\n", "\n", " features = features\n", " targets = targets\n", " break\n", " \n", " \n", "nhwc_img = np.transpose(features[5], axes=(1, 2, 0))\n", "nhw_img = np.squeeze(nhwc_img.numpy(), axis=2)\n", "plt.imshow(nhw_img, cmap='Greys');" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability Washing Machine 99.83%\n" ] } ], "source": [ "model.eval()\n", "logits, probas = model(features.to(device)[0, None])\n", "print('Probability Washing Machine %.2f%%' % (probas[0][4]*100))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch 1.3.1\n", "numpy 1.17.4\n", "PIL.Image 6.2.1\n", "torchvision 0.4.2\n", "matplotlib 3.1.0\n", "pandas 0.24.2\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }