{ "cells": [ { "cell_type": "markdown", "id": "a73b2bf8-3926-4861-b274-329a055ce241", "metadata": {}, "source": [ "# Convolutional Neural Networks (CNNs) with PyTorch\n", "\n", "**Authors:** Jeffrey Huang and Alex Michels\n", "\n", "In this notebook, we will use PyTorch CNNs to recognize text from images. We use CNNs in this use case because the individual values of pixels don't tell us very much, but convolutions can help us extract features." ] }, { "cell_type": "code", "execution_count": 1, "id": "67b15a5d-2602-4912-b674-3b5cf1473c95", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from pathlib import Path\n", "import time\n", "import torch\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "from torch.utils.data import DataLoader\n", "from torch import optim\n", "from torch.autograd import Variable\n", "import torch.nn as nn" ] }, { "cell_type": "markdown", "id": "4d437ff8", "metadata": {}, "source": [ "## Data Wrangling\n", "\n", "First, we need to download the built in PyTorch MNIST dataset:" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bb898ec-3b9c-41f3-9369-8919f80d71a4", "metadata": {}, "outputs": [], "source": [ "train_data = datasets.MNIST(\n", " root = 'data',\n", " train = True,\n", " transform = ToTensor(),\n", " download = True,\n", ")\n", "test_data = datasets.MNIST(\n", " root = 'data', \n", " train = False,\n", " transform = ToTensor()\n", ")" ] }, { "cell_type": "markdown", "id": "98009d0e", "metadata": {}, "source": [ "Next, we will examine the data. The data is of arrays based on the color of the pixel. We can also plot our data to see the images they make:" ] }, { "cell_type": "code", "execution_count": 3, "id": "5bd8ea61-60d1-45ed-8a39-43b9245b7172", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3294, 0.7255,\n", " 0.6235, 0.5922, 0.2353, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8706, 0.9961,\n", " 0.9961, 0.9961, 0.9961, 0.9451, 0.7765, 0.7765, 0.7765, 0.7765,\n", " 0.7765, 0.7765, 0.7765, 0.7765, 0.6667, 0.2039, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2627, 0.4471,\n", " 0.2824, 0.4471, 0.6392, 0.8902, 0.9961, 0.8824, 0.9961, 0.9961,\n", " 0.9961, 0.9804, 0.8980, 0.9961, 0.9961, 0.5490, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0667, 0.2588, 0.0549, 0.2627, 0.2627,\n", " 0.2627, 0.2314, 0.0824, 0.9255, 0.9961, 0.4157, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.3255, 0.9922, 0.8196, 0.0706, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0863, 0.9137, 1.0000, 0.3255, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.5059, 0.9961, 0.9333, 0.1725, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.2314, 0.9765, 0.9961, 0.2431, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.5216, 0.9961, 0.7333, 0.0196, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353,\n", " 0.8039, 0.9725, 0.2275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4941,\n", " 0.9961, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.9843,\n", " 0.9412, 0.2235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0745, 0.8667, 0.9961,\n", " 0.6510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7961, 0.9961, 0.8588,\n", " 0.1373, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.1490, 0.9961, 0.9961, 0.3020,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.1216, 0.8784, 0.9961, 0.4510, 0.0039,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.5216, 0.9961, 0.9961, 0.2039, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.2392, 0.9490, 0.9961, 0.9961, 0.2039, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.4745, 0.9961, 0.9961, 0.8588, 0.1569, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.4745, 0.9961, 0.8118, 0.0706, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000]]]),\n", " 7)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data[0]" ] }, { "cell_type": "code", "execution_count": 4, "id": "57cb6a65-df9c-4835-8dab-aa2cc30e5244", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "figure = plt.figure(figsize=(10, 8))\n", "cols, rows = 5, 5\n", "for i in range(1, cols * rows + 1):\n", " sample_idx = torch.randint(len(train_data), size=(1,)).item()\n", " img, label = train_data[sample_idx]\n", " figure.add_subplot(rows, cols, i)\n", " plt.title(label)\n", " plt.axis(\"off\")\n", " plt.imshow(img.squeeze(), cmap=\"gray\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "832bcf0e", "metadata": {}, "source": [ "DataLoaders are necessary for letting the model process large datasets." ] }, { "cell_type": "code", "execution_count": 5, "id": "d99724e6-e386-45c7-af80-e3d21e2a097e", "metadata": {}, "outputs": [], "source": [ "loaders = {\n", " 'train' : torch.utils.data.DataLoader(train_data,\n", " batch_size=100,\n", " shuffle=True,\n", " num_workers=1),\n", " 'test' : torch.utils.data.DataLoader(test_data,\n", " batch_size=100,\n", " shuffle=True,\n", " num_workers=1)\n", "}" ] }, { "cell_type": "markdown", "id": "e470faa8", "metadata": {}, "source": [ "## Creating the Model\n", "\n", "Next, we define the model as well as its forward pass function. This is a fairly simple CNN with 2 convolutional layers." ] }, { "cell_type": "code", "execution_count": 6, "id": "0d1ea6d6-e052-48dd-93a6-b521b38dadbd", "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self):\n", " super(CNN, self).__init__()\n", " self.conv1 = nn.Sequential(\n", " nn.Conv2d(\n", " in_channels=1, # (int) -> number of channels in input image =1 because output is grayscale image\n", " out_channels=16, # (int) -> number of channels produced by the convolution\n", " kernel_size=5, # (int, tuple) -> size of convolving kernel\n", " stride=1, # (int, tuple, optional) -> stride of convolution, default is 1\n", " padding=2, # (int, tuple, optional) -> zero-padding added to both sides of input, default is 0\n", " ),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2),\n", " )\n", " self.conv2 = nn.Sequential(\n", " nn.Conv2d(16, 32, 5, 1, 2),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " )\n", " self.out = nn.Linear(32 * 7 * 7, 10)\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.conv2(x)\n", " x = x.view(x.size(0), -1)\n", " output = self.out(x)\n", " return output, x " ] }, { "cell_type": "code", "execution_count": 7, "id": "bf629382-9dc6-417d-ac6e-9bfc8d929e3f", "metadata": {}, "outputs": [], "source": [ "cnn = CNN()" ] }, { "cell_type": "markdown", "id": "703a643c", "metadata": {}, "source": [ "Next we define the loss function and optimizer. Cross Entropy Loss is a loss function for classification problems, while Adam is a popular and powerful optimizer that extends the already effective stochastic gradient descent algorithm." ] }, { "cell_type": "code", "execution_count": 8, "id": "8242ffcc-d539-4bfb-bf5c-d6a48734a9f0", "metadata": {}, "outputs": [], "source": [ "loss_func = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(cnn.parameters(), lr=0.01) " ] }, { "cell_type": "markdown", "id": "df753edf", "metadata": {}, "source": [ "It is worthwhile to check and make sure what device we are running on. " ] }, { "cell_type": "code", "execution_count": 9, "id": "3c7e42c3-a6ff-4bf2-8952-799307d39269", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cpu\n", "\n" ] } ], "source": [ "# train on the GPU or on the CPU, if a GPU is not available\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "\n", "print('Using device:', device)\n", "print()\n", "\n", "#Additional Info when using cuda\n", "if device.type == 'cuda':\n", " print(torch.cuda.get_device_name(0))\n", " print('Memory Usage:')\n", " print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n", " print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')" ] }, { "cell_type": "markdown", "id": "b29c3200", "metadata": {}, "source": [ "## Training the Model\n", "\n", "Next is the training step. \n", "\n", "**As this is a convolutional neural network being trained on a large image dataset, expect training time to take much longer than for a simple linear regression model. It is also worthwhile to time the training loop.**" ] }, { "cell_type": "code", "execution_count": 10, "id": "61975a57-f7ae-4a47-90ad-27f7b03b89c1", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [1/5], Step[1/600], Loss: 2.3010\n", "Epoch [1/5], Step[2/600], Loss: 2.7873\n", "Epoch [1/5], Step[3/600], Loss: 2.2077\n", "Epoch [1/5], Step[8/600], Loss: 1.4460\n", "Epoch [1/5], Step[9/600], Loss: 1.1443\n", "Epoch [1/5], Step[10/600], Loss: 0.8667\n", "Epoch [1/5], Step[11/600], Loss: 0.9214\n", "Epoch [1/5], Step[16/600], Loss: 1.0815\n", "Epoch [1/5], Step[17/600], Loss: 0.5077\n", "Epoch [1/5], Step[18/600], Loss: 1.0005\n", "Epoch [1/5], Step[19/600], Loss: 0.6165\n", "Epoch [1/5], Step[24/600], Loss: 0.4259\n", "Epoch [1/5], Step[25/600], Loss: 0.3769\n", "Epoch [1/5], Step[26/600], Loss: 0.3689\n", "Epoch [1/5], Step[27/600], Loss: 0.3474\n", "Epoch [1/5], Step[128/600], Loss: 0.0309\n", "Epoch [1/5], Step[129/600], Loss: 0.0608\n", "Epoch [1/5], Step[130/600], Loss: 0.0554\n", "Epoch [1/5], Step[131/600], Loss: 0.2005\n", "Epoch [1/5], Step[136/600], Loss: 0.1554\n", "Epoch [1/5], Step[137/600], Loss: 0.2146\n", "Epoch [1/5], Step[138/600], Loss: 0.0350\n", "Epoch [1/5], Step[139/600], Loss: 0.0702\n", "Epoch [1/5], Step[144/600], Loss: 0.1806\n", "Epoch [1/5], Step[145/600], Loss: 0.0505\n", "Epoch [1/5], Step[146/600], Loss: 0.0413\n", "Epoch [1/5], Step[147/600], Loss: 0.0866\n", "Epoch [1/5], Step[152/600], Loss: 0.1066\n", "Epoch [1/5], Step[153/600], Loss: 0.0252\n", "Epoch [1/5], Step[154/600], Loss: 0.1037\n", "Epoch [1/5], Step[155/600], Loss: 0.0856\n", "Epoch [1/5], Step[256/600], Loss: 0.0678\n", "Epoch [1/5], Step[257/600], Loss: 0.0756\n", "Epoch [1/5], Step[258/600], Loss: 0.1582\n", "Epoch [1/5], Step[259/600], Loss: 0.1289\n", "Epoch [1/5], Step[264/600], Loss: 0.0499\n", "Epoch [1/5], Step[265/600], Loss: 0.0876\n", "Epoch [1/5], Step[266/600], Loss: 0.0833\n", "Epoch [1/5], Step[267/600], Loss: 0.0905\n", "Epoch [1/5], Step[272/600], Loss: 0.1186\n", "Epoch [1/5], Step[273/600], Loss: 0.0602\n", "Epoch [1/5], Step[274/600], Loss: 0.0861\n", "Epoch [1/5], Step[275/600], Loss: 0.1177\n", "Epoch [1/5], Step[280/600], Loss: 0.0497\n", "Epoch [1/5], Step[281/600], Loss: 0.0421\n", "Epoch [1/5], Step[282/600], Loss: 0.1780\n", "Epoch [1/5], Step[283/600], Loss: 0.0327\n", "Epoch [1/5], Step[384/600], Loss: 0.0285\n", "Epoch [1/5], Step[385/600], Loss: 0.0828\n", "Epoch [1/5], Step[386/600], Loss: 0.0458\n", "Epoch [1/5], Step[387/600], Loss: 0.0500\n", "Epoch [1/5], Step[392/600], Loss: 0.2066\n", "Epoch [1/5], Step[393/600], Loss: 0.1402\n", "Epoch [1/5], Step[394/600], Loss: 0.0493\n", "Epoch [1/5], Step[395/600], Loss: 0.0344\n", "Epoch [1/5], Step[400/600], Loss: 0.0806\n", "Epoch [1/5], Step[401/600], Loss: 0.0828\n", "Epoch [1/5], Step[402/600], Loss: 0.1046\n", "Epoch [1/5], Step[403/600], Loss: 0.0535\n", "Epoch [1/5], Step[408/600], Loss: 0.0878\n", "Epoch [1/5], Step[409/600], Loss: 0.0360\n", "Epoch [1/5], Step[410/600], Loss: 0.0552\n", "Epoch [1/5], Step[411/600], Loss: 0.0530\n", "Epoch [1/5], Step[512/600], Loss: 0.0371\n", "Epoch [1/5], Step[513/600], Loss: 0.0273\n", "Epoch [1/5], Step[514/600], Loss: 0.0427\n", "Epoch [1/5], Step[515/600], Loss: 0.0510\n", "Epoch [1/5], Step[520/600], Loss: 0.0423\n", "Epoch [1/5], Step[521/600], Loss: 0.0098\n", "Epoch [1/5], Step[522/600], Loss: 0.0454\n", "Epoch [1/5], Step[523/600], Loss: 0.0378\n", "Epoch [1/5], Step[528/600], Loss: 0.0534\n", "Epoch [1/5], Step[529/600], Loss: 0.1308\n", "Epoch [1/5], Step[530/600], Loss: 0.0389\n", "Epoch [1/5], Step[531/600], Loss: 0.0375\n", "Epoch [1/5], Step[536/600], Loss: 0.0539\n", "Epoch [1/5], Step[537/600], Loss: 0.0870\n", "Epoch [1/5], Step[538/600], Loss: 0.0444\n", "Epoch [1/5], Step[539/600], Loss: 0.0144\n", "Epoch [2/5], Step[1/600], Loss: 0.1345\n", "Epoch [2/5], Step[2/600], Loss: 0.1259\n", "Epoch [2/5], Step[3/600], Loss: 0.0329\n", "Epoch [2/5], Step[8/600], Loss: 0.0653\n", "Epoch [2/5], Step[9/600], Loss: 0.0308\n", "Epoch [2/5], Step[10/600], Loss: 0.0378\n", "Epoch [2/5], Step[11/600], Loss: 0.0493\n", "Epoch [2/5], Step[16/600], Loss: 0.1019\n", "Epoch [2/5], Step[17/600], Loss: 0.0834\n", "Epoch [2/5], Step[18/600], Loss: 0.0103\n", "Epoch [2/5], Step[19/600], Loss: 0.0963\n", "Epoch [2/5], Step[24/600], Loss: 0.0314\n", "Epoch [2/5], Step[25/600], Loss: 0.0304\n", "Epoch [2/5], Step[26/600], Loss: 0.0737\n", "Epoch [2/5], Step[27/600], Loss: 0.0337\n", "Epoch [2/5], Step[128/600], Loss: 0.0157\n", "Epoch [2/5], Step[129/600], Loss: 0.0345\n", "Epoch [2/5], Step[130/600], Loss: 0.0390\n", "Epoch [2/5], Step[131/600], Loss: 0.0328\n", "Epoch [2/5], Step[136/600], Loss: 0.0401\n", "Epoch [2/5], Step[137/600], Loss: 0.0749\n", "Epoch [2/5], Step[138/600], Loss: 0.0254\n", "Epoch [2/5], Step[139/600], Loss: 0.0194\n", "Epoch [2/5], Step[144/600], Loss: 0.0575\n", "Epoch [2/5], Step[145/600], Loss: 0.1244\n", "Epoch [2/5], Step[146/600], Loss: 0.0719\n", "Epoch [2/5], Step[147/600], Loss: 0.2173\n", "Epoch [2/5], Step[152/600], Loss: 0.0419\n", "Epoch [2/5], Step[153/600], Loss: 0.0076\n", "Epoch [2/5], Step[154/600], Loss: 0.0256\n", "Epoch [2/5], Step[155/600], Loss: 0.0540\n", "Epoch [2/5], Step[256/600], Loss: 0.0578\n", "Epoch [2/5], Step[257/600], Loss: 0.0705\n", "Epoch [2/5], Step[258/600], Loss: 0.0657\n", "Epoch [2/5], Step[259/600], Loss: 0.0462\n", "Epoch [2/5], Step[264/600], Loss: 0.0519\n", "Epoch [2/5], Step[265/600], Loss: 0.0296\n", "Epoch [2/5], Step[266/600], Loss: 0.0405\n", "Epoch [2/5], Step[267/600], Loss: 0.0738\n", "Epoch [2/5], Step[272/600], Loss: 0.0106\n", "Epoch [2/5], Step[273/600], Loss: 0.0147\n", "Epoch [2/5], Step[274/600], Loss: 0.0693\n", "Epoch [2/5], Step[275/600], Loss: 0.1410\n", "Epoch [2/5], Step[280/600], Loss: 0.0135\n", "Epoch [2/5], Step[281/600], Loss: 0.0027\n", "Epoch [2/5], Step[282/600], Loss: 0.0105\n", "Epoch [2/5], Step[283/600], Loss: 0.0218\n", "Epoch [2/5], Step[384/600], Loss: 0.1432\n", "Epoch [2/5], Step[385/600], Loss: 0.0150\n", "Epoch [2/5], Step[386/600], Loss: 0.1172\n", "Epoch [2/5], Step[387/600], Loss: 0.0214\n", "Epoch [2/5], Step[392/600], Loss: 0.0736\n", "Epoch [2/5], Step[393/600], Loss: 0.0790\n", "Epoch [2/5], Step[394/600], Loss: 0.0813\n", "Epoch [2/5], Step[395/600], Loss: 0.0265\n", "Epoch [2/5], Step[400/600], Loss: 0.0591\n", "Epoch [2/5], Step[401/600], Loss: 0.0764\n", "Epoch [2/5], Step[402/600], Loss: 0.1025\n", "Epoch [2/5], Step[403/600], Loss: 0.0495\n", "Epoch [2/5], Step[408/600], Loss: 0.0133\n", "Epoch [2/5], Step[409/600], Loss: 0.0040\n", "Epoch [2/5], Step[410/600], Loss: 0.0600\n", "Epoch [2/5], Step[411/600], Loss: 0.0564\n", "Epoch [2/5], Step[512/600], Loss: 0.0248\n", "Epoch [2/5], Step[513/600], Loss: 0.1097\n", "Epoch [2/5], Step[514/600], Loss: 0.0120\n", "Epoch [2/5], Step[515/600], Loss: 0.0348\n", "Epoch [2/5], Step[520/600], Loss: 0.0252\n", "Epoch [2/5], Step[521/600], Loss: 0.0141\n", "Epoch [2/5], Step[522/600], Loss: 0.0725\n", "Epoch [2/5], Step[523/600], Loss: 0.0303\n", "Epoch [2/5], Step[528/600], Loss: 0.0250\n", "Epoch [2/5], Step[529/600], Loss: 0.0133\n", "Epoch [2/5], Step[530/600], Loss: 0.0673\n", "Epoch [2/5], Step[531/600], Loss: 0.1008\n", "Epoch [2/5], Step[536/600], Loss: 0.0387\n", "Epoch [2/5], Step[537/600], Loss: 0.0968\n", "Epoch [2/5], Step[538/600], Loss: 0.0580\n", "Epoch [2/5], Step[539/600], Loss: 0.0297\n", "Epoch [3/5], Step[1/600], Loss: 0.0949\n", "Epoch [3/5], Step[2/600], Loss: 0.0547\n", "Epoch [3/5], Step[3/600], Loss: 0.0498\n", "Epoch [3/5], Step[8/600], Loss: 0.0582\n", "Epoch [3/5], Step[9/600], Loss: 0.0140\n", "Epoch [3/5], Step[10/600], Loss: 0.0189\n", "Epoch [3/5], Step[11/600], Loss: 0.0566\n", "Epoch [3/5], Step[16/600], Loss: 0.0660\n", "Epoch [3/5], Step[17/600], Loss: 0.0192\n", "Epoch [3/5], Step[18/600], Loss: 0.0157\n", "Epoch [3/5], Step[19/600], Loss: 0.0537\n", "Epoch [3/5], Step[24/600], Loss: 0.0737\n", "Epoch [3/5], Step[25/600], Loss: 0.0106\n", "Epoch [3/5], Step[26/600], Loss: 0.0207\n", "Epoch [3/5], Step[27/600], Loss: 0.0169\n", "Epoch [3/5], Step[128/600], Loss: 0.0203\n", "Epoch [3/5], Step[129/600], Loss: 0.0098\n", "Epoch [3/5], Step[130/600], Loss: 0.0613\n", "Epoch [3/5], Step[131/600], Loss: 0.0507\n", "Epoch [3/5], Step[136/600], Loss: 0.1746\n", "Epoch [3/5], Step[137/600], Loss: 0.0173\n", "Epoch [3/5], Step[138/600], Loss: 0.0164\n", "Epoch [3/5], Step[139/600], Loss: 0.0727\n", "Epoch [3/5], Step[144/600], Loss: 0.0210\n", "Epoch [3/5], Step[145/600], Loss: 0.0289\n", "Epoch [3/5], Step[146/600], Loss: 0.0423\n", "Epoch [3/5], Step[147/600], Loss: 0.1073\n", "Epoch [3/5], Step[152/600], Loss: 0.0219\n", "Epoch [3/5], Step[153/600], Loss: 0.0165\n", "Epoch [3/5], Step[154/600], Loss: 0.0172\n", "Epoch [3/5], Step[155/600], Loss: 0.0520\n", "Epoch [3/5], Step[256/600], Loss: 0.0422\n", "Epoch [3/5], Step[257/600], Loss: 0.0447\n", "Epoch [3/5], Step[258/600], Loss: 0.0892\n", "Epoch [3/5], Step[259/600], Loss: 0.0150\n", "Epoch [3/5], Step[264/600], Loss: 0.0835\n", "Epoch [3/5], Step[265/600], Loss: 0.0137\n", "Epoch [3/5], Step[266/600], Loss: 0.1238\n", "Epoch [3/5], Step[267/600], Loss: 0.0161\n", "Epoch [3/5], Step[272/600], Loss: 0.0355\n", "Epoch [3/5], Step[273/600], Loss: 0.1550\n", "Epoch [3/5], Step[274/600], Loss: 0.0348\n", "Epoch [3/5], Step[275/600], Loss: 0.0518\n", "Epoch [3/5], Step[280/600], Loss: 0.0117\n", "Epoch [3/5], Step[281/600], Loss: 0.0162\n", "Epoch [3/5], Step[282/600], Loss: 0.0951\n", "Epoch [3/5], Step[283/600], Loss: 0.0658\n", "Epoch [3/5], Step[384/600], Loss: 0.1450\n", "Epoch [3/5], Step[385/600], Loss: 0.1073\n", "Epoch [3/5], Step[386/600], Loss: 0.0034\n", "Epoch [3/5], Step[387/600], Loss: 0.1261\n", "Epoch [3/5], Step[392/600], Loss: 0.0082\n", "Epoch [3/5], Step[393/600], Loss: 0.0144\n", "Epoch [3/5], Step[394/600], Loss: 0.0348\n", "Epoch [3/5], Step[395/600], Loss: 0.0034\n", "Epoch [3/5], Step[400/600], Loss: 0.0244\n", "Epoch [3/5], Step[401/600], Loss: 0.0274\n", "Epoch [3/5], Step[402/600], Loss: 0.0612\n", "Epoch [3/5], Step[403/600], Loss: 0.0084\n", "Epoch [3/5], Step[408/600], Loss: 0.0263\n", "Epoch [3/5], Step[409/600], Loss: 0.1690\n", "Epoch [3/5], Step[410/600], Loss: 0.0497\n", "Epoch [3/5], Step[411/600], Loss: 0.0342\n", "Epoch [3/5], Step[512/600], Loss: 0.0184\n", "Epoch [3/5], Step[513/600], Loss: 0.0869\n", "Epoch [3/5], Step[514/600], Loss: 0.0501\n", "Epoch [3/5], Step[515/600], Loss: 0.0334\n", "Epoch [3/5], Step[520/600], Loss: 0.0065\n", "Epoch [3/5], Step[521/600], Loss: 0.0680\n", "Epoch [3/5], Step[522/600], Loss: 0.0803\n", "Epoch [3/5], Step[523/600], Loss: 0.1726\n", "Epoch [3/5], Step[528/600], Loss: 0.0060\n", "Epoch [3/5], Step[529/600], Loss: 0.0189\n", "Epoch [3/5], Step[530/600], Loss: 0.1532\n", "Epoch [3/5], Step[531/600], Loss: 0.0268\n", "Epoch [3/5], Step[536/600], Loss: 0.0517\n", "Epoch [3/5], Step[537/600], Loss: 0.0222\n", "Epoch [3/5], Step[538/600], Loss: 0.0674\n", "Epoch [3/5], Step[539/600], Loss: 0.0449\n", "Epoch [4/5], Step[1/600], Loss: 0.0334\n", "Epoch [4/5], Step[2/600], Loss: 0.0218\n", "Epoch [4/5], Step[3/600], Loss: 0.0115\n", "Epoch [4/5], Step[8/600], Loss: 0.0506\n", "Epoch [4/5], Step[9/600], Loss: 0.0566\n", "Epoch [4/5], Step[10/600], Loss: 0.0143\n", "Epoch [4/5], Step[11/600], Loss: 0.0344\n", "Epoch [4/5], Step[16/600], Loss: 0.0552\n", "Epoch [4/5], Step[17/600], Loss: 0.0691\n", "Epoch [4/5], Step[18/600], Loss: 0.0226\n", "Epoch [4/5], Step[19/600], Loss: 0.0551\n", "Epoch [4/5], Step[24/600], Loss: 0.0192\n", "Epoch [4/5], Step[25/600], Loss: 0.0141\n", "Epoch [4/5], Step[26/600], Loss: 0.0177\n", "Epoch [4/5], Step[27/600], Loss: 0.0807\n", "Epoch [4/5], Step[128/600], Loss: 0.0812\n", "Epoch [4/5], Step[129/600], Loss: 0.0597\n", "Epoch [4/5], Step[130/600], Loss: 0.0027\n", "Epoch [4/5], Step[131/600], Loss: 0.0721\n", "Epoch [4/5], Step[136/600], Loss: 0.0038\n", "Epoch [4/5], Step[137/600], Loss: 0.0641\n", "Epoch [4/5], Step[138/600], Loss: 0.0209\n", "Epoch [4/5], Step[139/600], Loss: 0.0039\n", "Epoch [4/5], Step[144/600], Loss: 0.0399\n", "Epoch [4/5], Step[145/600], Loss: 0.0050\n", "Epoch [4/5], Step[146/600], Loss: 0.0550\n", "Epoch [4/5], Step[147/600], Loss: 0.0069\n", "Epoch [4/5], Step[152/600], Loss: 0.0056\n", "Epoch [4/5], Step[153/600], Loss: 0.0557\n", "Epoch [4/5], Step[154/600], Loss: 0.1290\n", "Epoch [4/5], Step[155/600], Loss: 0.0248\n", "Epoch [4/5], Step[256/600], Loss: 0.0065\n", "Epoch [4/5], Step[257/600], Loss: 0.0109\n", "Epoch [4/5], Step[258/600], Loss: 0.0144\n", "Epoch [4/5], Step[259/600], Loss: 0.0839\n", "Epoch [4/5], Step[264/600], Loss: 0.0247\n", "Epoch [4/5], Step[265/600], Loss: 0.0234\n", "Epoch [4/5], Step[266/600], Loss: 0.0582\n", "Epoch [4/5], Step[267/600], Loss: 0.0041\n", "Epoch [4/5], Step[272/600], Loss: 0.1262\n", "Epoch [4/5], Step[273/600], Loss: 0.0061\n", "Epoch [4/5], Step[274/600], Loss: 0.0117\n", "Epoch [4/5], Step[275/600], Loss: 0.0508\n", "Epoch [4/5], Step[280/600], Loss: 0.0058\n", "Epoch [4/5], Step[281/600], Loss: 0.1394\n", "Epoch [4/5], Step[282/600], Loss: 0.1032\n", "Epoch [4/5], Step[283/600], Loss: 0.0093\n", "Epoch [4/5], Step[384/600], Loss: 0.0099\n", "Epoch [4/5], Step[385/600], Loss: 0.0244\n", "Epoch [4/5], Step[386/600], Loss: 0.0634\n", "Epoch [4/5], Step[387/600], Loss: 0.0861\n", "Epoch [4/5], Step[392/600], Loss: 0.0415\n", "Epoch [4/5], Step[393/600], Loss: 0.0244\n", "Epoch [4/5], Step[394/600], Loss: 0.0073\n", "Epoch [4/5], Step[395/600], Loss: 0.0727\n", "Epoch [4/5], Step[400/600], Loss: 0.0344\n", "Epoch [4/5], Step[401/600], Loss: 0.0526\n", "Epoch [4/5], Step[402/600], Loss: 0.0469\n", "Epoch [4/5], Step[403/600], Loss: 0.0225\n", "Epoch [4/5], Step[408/600], Loss: 0.0051\n", "Epoch [4/5], Step[409/600], Loss: 0.0374\n", "Epoch [4/5], Step[410/600], Loss: 0.1101\n", "Epoch [4/5], Step[411/600], Loss: 0.0204\n", "Epoch [4/5], Step[512/600], Loss: 0.0102\n", "Epoch [4/5], Step[513/600], Loss: 0.0353\n", "Epoch [4/5], Step[514/600], Loss: 0.0593\n", "Epoch [4/5], Step[515/600], Loss: 0.0230\n", "Epoch [4/5], Step[520/600], Loss: 0.1317\n", "Epoch [4/5], Step[521/600], Loss: 0.0152\n", "Epoch [4/5], Step[522/600], Loss: 0.1528\n", "Epoch [4/5], Step[523/600], Loss: 0.0796\n", "Epoch [4/5], Step[528/600], Loss: 0.0619\n", "Epoch [4/5], Step[529/600], Loss: 0.0280\n", "Epoch [4/5], Step[530/600], Loss: 0.0453\n", "Epoch [4/5], Step[531/600], Loss: 0.0276\n", "Epoch [4/5], Step[536/600], Loss: 0.0162\n", "Epoch [4/5], Step[537/600], Loss: 0.0576\n", "Epoch [4/5], Step[538/600], Loss: 0.0131\n", "Epoch [4/5], Step[539/600], Loss: 0.0536\n", "Epoch [5/5], Step[1/600], Loss: 0.0300\n", "Epoch [5/5], Step[2/600], Loss: 0.0062\n", "Epoch [5/5], Step[3/600], Loss: 0.0746\n", "Epoch [5/5], Step[8/600], Loss: 0.0050\n", "Epoch [5/5], Step[9/600], Loss: 0.0073\n", "Epoch [5/5], Step[10/600], Loss: 0.0374\n", "Epoch [5/5], Step[11/600], Loss: 0.0194\n", "Epoch [5/5], Step[16/600], Loss: 0.0032\n", "Epoch [5/5], Step[17/600], Loss: 0.0114\n", "Epoch [5/5], Step[18/600], Loss: 0.0017\n", "Epoch [5/5], Step[19/600], Loss: 0.0296\n", "Epoch [5/5], Step[24/600], Loss: 0.0107\n", "Epoch [5/5], Step[25/600], Loss: 0.1005\n", "Epoch [5/5], Step[26/600], Loss: 0.0535\n", "Epoch [5/5], Step[27/600], Loss: 0.0150\n", "Epoch [5/5], Step[128/600], Loss: 0.0950\n", "Epoch [5/5], Step[129/600], Loss: 0.0121\n", "Epoch [5/5], Step[130/600], Loss: 0.0068\n", "Epoch [5/5], Step[131/600], Loss: 0.0046\n", "Epoch [5/5], Step[136/600], Loss: 0.0561\n", "Epoch [5/5], Step[137/600], Loss: 0.0932\n", "Epoch [5/5], Step[138/600], Loss: 0.0313\n", "Epoch [5/5], Step[139/600], Loss: 0.0569\n", "Epoch [5/5], Step[144/600], Loss: 0.0050\n", "Epoch [5/5], Step[145/600], Loss: 0.0543\n", "Epoch [5/5], Step[146/600], Loss: 0.1123\n", "Epoch [5/5], Step[147/600], Loss: 0.0038\n", "Epoch [5/5], Step[152/600], Loss: 0.0188\n", "Epoch [5/5], Step[153/600], Loss: 0.0004\n", "Epoch [5/5], Step[154/600], Loss: 0.0089\n", "Epoch [5/5], Step[155/600], Loss: 0.0101\n", "Epoch [5/5], Step[256/600], Loss: 0.0143\n", "Epoch [5/5], Step[257/600], Loss: 0.0368\n", "Epoch [5/5], Step[258/600], Loss: 0.1298\n", "Epoch [5/5], Step[259/600], Loss: 0.0296\n", "Epoch [5/5], Step[264/600], Loss: 0.0167\n", "Epoch [5/5], Step[265/600], Loss: 0.0628\n", "Epoch [5/5], Step[266/600], Loss: 0.0583\n", "Epoch [5/5], Step[267/600], Loss: 0.0081\n", "Epoch [5/5], Step[272/600], Loss: 0.1128\n", "Epoch [5/5], Step[273/600], Loss: 0.0334\n", "Epoch [5/5], Step[274/600], Loss: 0.0126\n", "Epoch [5/5], Step[275/600], Loss: 0.0020\n", "Epoch [5/5], Step[280/600], Loss: 0.0065\n", "Epoch [5/5], Step[281/600], Loss: 0.0387\n", "Epoch [5/5], Step[282/600], Loss: 0.0177\n", "Epoch [5/5], Step[283/600], Loss: 0.0286\n", "Epoch [5/5], Step[384/600], Loss: 0.0409\n", "Epoch [5/5], Step[385/600], Loss: 0.0167\n", "Epoch [5/5], Step[386/600], Loss: 0.0175\n", "Epoch [5/5], Step[387/600], Loss: 0.0345\n", "Epoch [5/5], Step[392/600], Loss: 0.0106\n", "Epoch [5/5], Step[393/600], Loss: 0.0108\n", "Epoch [5/5], Step[394/600], Loss: 0.0212\n", "Epoch [5/5], Step[395/600], Loss: 0.0150\n", "Epoch [5/5], Step[400/600], Loss: 0.0364\n", "Epoch [5/5], Step[401/600], Loss: 0.0001\n", "Epoch [5/5], Step[402/600], Loss: 0.0165\n", "Epoch [5/5], Step[403/600], Loss: 0.0735\n", "Epoch [5/5], Step[408/600], Loss: 0.0078\n", "Epoch [5/5], Step[409/600], Loss: 0.0587\n", "Epoch [5/5], Step[410/600], Loss: 0.0046\n", "Epoch [5/5], Step[411/600], Loss: 0.0513\n", "Epoch [5/5], Step[512/600], Loss: 0.0328\n", "Epoch [5/5], Step[513/600], Loss: 0.0730\n", "Epoch [5/5], Step[514/600], Loss: 0.0896\n", "Epoch [5/5], Step[515/600], Loss: 0.0244\n", "Epoch [5/5], Step[520/600], Loss: 0.1046\n", "Epoch [5/5], Step[521/600], Loss: 0.0129\n", "Epoch [5/5], Step[522/600], Loss: 0.0958\n", "Epoch [5/5], Step[523/600], Loss: 0.1078\n", "Epoch [5/5], Step[528/600], Loss: 0.0319\n", "Epoch [5/5], Step[529/600], Loss: 0.0606\n", "Epoch [5/5], Step[530/600], Loss: 0.0835\n", "Epoch [5/5], Step[531/600], Loss: 0.0390\n", "Epoch [5/5], Step[536/600], Loss: 0.0489\n", "Epoch [5/5], Step[537/600], Loss: 0.0168\n", "Epoch [5/5], Step[538/600], Loss: 0.0122\n", "Epoch [5/5], Step[539/600], Loss: 0.0445\n", "Training took 690.5935745239258 seconds\n" ] } ], "source": [ "start_time = time.time()\n", "\n", "num_epochs = 5\n", "def train(num_epochs, cnn, loaders):\n", " cnn.train()\n", " total_step = len(loaders['train'])\n", " for epoch in range(num_epochs):\n", " for i, (images, labels) in enumerate(loaders['train']):\n", " # gives batch data, normalizes x when iterating train_loader\n", " b_x = Variable(images)\n", " b_y = Variable(labels)\n", " \n", " output = cnn(b_x)[0]\n", " loss = loss_func(output, b_y)\n", "\n", " # clear gradients for this trainign step\n", " optimizer.zero_grad()\n", " # backpropogation, computing gradients\n", " loss.backward()\n", " # apply gradients\n", " optimizer.step()\n", " if (i+1) & 100 == 0:\n", " print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}'\n", " .format(epoch + 1, num_epochs, i+1, total_step, loss.item()))\n", " pass\n", " pass\n", " pass\n", "train(num_epochs, cnn, loaders)\n", "\n", "curr_time = time.time()\n", "print(f\"Training took {curr_time - start_time} seconds\")" ] }, { "cell_type": "markdown", "id": "ccb6166e", "metadata": {}, "source": [ "## Testing the Model\n", "\n", "Let's define a testing function for model evaluation:" ] }, { "cell_type": "code", "execution_count": 11, "id": "05c7199d-5500-4673-b8c0-488514d84edf", "metadata": {}, "outputs": [], "source": [ "def test():\n", " cnn.eval() # set to evaluation mode\n", " with torch.no_grad(): # don't update gradients during testing\n", " correct = 0\n", " total = 0\n", " for images, labels in loaders['test']: # for each item in the test set, test it\n", " test_output, last_layer = cnn(images)\n", " pred_y = torch.max(test_output, 1)[1].data.squeeze()\n", " accuracy = (pred_y == labels).sum().item() / float(labels.size(0))\n", " print('Test accuracy of model on 10000 test images: %.2f' % accuracy)" ] }, { "cell_type": "code", "execution_count": 12, "id": "ecf1dbc9-82e7-4ffc-b6e0-a4a4e452b727", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy of model on 10000 test images: 1.00\n" ] } ], "source": [ "test() # this takes a minute" ] }, { "cell_type": "markdown", "id": "658188de-b88a-4877-9fed-dcf1a705a6d3", "metadata": {}, "source": [ "Let's grab a few samples from our test set, plot them, and then manually check our accuracy:" ] }, { "cell_type": "code", "execution_count": 13, "id": "a1f7e3f7-6d92-4a70-8b45-359b8e06edcb", "metadata": {}, "outputs": [], "source": [ "sample = next(iter(loaders['test']))\n", "imgs, lbls = sample\n", "actual_number = lbls[:10].numpy()" ] }, { "cell_type": "markdown", "id": "b78692b8-1f68-401e-bde2-f5bca7ea0d2d", "metadata": {}, "source": [ "Plotting the samples:" ] }, { "cell_type": "code", "execution_count": 14, "id": "ccfeb183-5de2-4082-ac2d-1526ca66d219", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "figure = plt.figure(figsize=(10, 4))\n", "cols, rows = 5, 2\n", "for i in range(1, cols * rows + 1):\n", " img, label = imgs[i], lbls[i]\n", " figure.add_subplot(rows, cols, i)\n", " plt.title(label)\n", " plt.axis(\"off\")\n", " plt.imshow(img.squeeze(), cmap=\"gray\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "02e45c02-e6cf-4f63-8fd2-64a633b5a8e5", "metadata": {}, "source": [ "Checking our prediction abilities on our 10 samples:" ] }, { "cell_type": "code", "execution_count": 15, "id": "fb2b657e-c7c4-40e8-be58-d32b87baa0f0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction number: [0 9 2 9 8 8 7 3 8 1]\n", "Actual number: [0 9 2 4 8 8 7 3 8 1]\n" ] } ], "source": [ "test_output, last_layer = cnn(imgs[:10])\n", "pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n", "print(f'Prediction number: {pred_y}')\n", "print(f'Actual number: {actual_number}')" ] }, { "cell_type": "markdown", "id": "a703075e", "metadata": {}, "source": [ "## Saving The Trained Model\n", "\n", "Next, we can save the trained model's state using the lines below:" ] }, { "cell_type": "code", "execution_count": 16, "id": "48653cab-eb2b-4a17-85c2-57de0b4732b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving model to: models/01_pytorch_mnist_cnn.pth\n" ] } ], "source": [ "# 1. Create models directory \n", "MODEL_PATH = Path(\"models\")\n", "MODEL_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", "# 2. Create model save path \n", "MODEL_NAME = \"01_pytorch_mnist_cnn.pth\"\n", "MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n", "\n", "# 3. Save the model state dict \n", "print(f\"Saving model to: {MODEL_SAVE_PATH}\")\n", "torch.save(obj=cnn.state_dict(), # only saving the state_dict() only saves the models learned parameters\n", " f=MODEL_SAVE_PATH)\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "a3b27366-1acb-4e2a-a4e3-03667a66ffab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Instantiate a new instance of our model (this will be instantiated with random weights)\n", "loaded_model_0 = CNN()\n", "\n", "# Load the state_dict of our saved model (this will update the new instance of our model with trained weights)\n", "loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))" ] }, { "cell_type": "code", "execution_count": 18, "id": "ad626a94-4f29-4e3a-a774-5ea528848fc9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy of model on 10000 test images: 0.99\n" ] } ], "source": [ "loaded_model_0.eval()\n", "with torch.no_grad():\n", " correct = 0\n", " total = 0\n", " for images, labels in loaders['test']:\n", " test_output, last_layer = loaded_model_0(images)\n", " pred_y = torch.max(test_output, 1)[1].data.squeeze()\n", " accuracy = (pred_y == labels).sum().item() / float(labels.size(0))\n", " pass\n", "print('Test accuracy of model on 10000 test images: %.2f' % accuracy)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3-0.9.4", "language": "python", "name": "python3-0.9.4" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }