{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.9.0\n", "\n", "torch 1.7.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# AlexNet CIFAR-10 Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Network Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "References\n", " \n", "- [1] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. \"[Imagenet classification with deep convolutional neural networks.](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)\" In Advances in Neural Information Processing Systems, pp. 1097-1105. 2012.\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "import random\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import Subset\n", "\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model Settings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Setting a random seed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I recommend using a function like the following one prior to using dataset loaders and initializing a model if you want to ensure the data is shuffled in the same manner if you rerun this notebook and the model gets the same initial random weights:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def set_all_seeds(seed):\n", " os.environ[\"PL_GLOBAL_SEED\"] = str(seed)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Setting cuDNN and PyTorch algorithmic behavior to deterministic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to the `set_all_seeds` function above, I recommend setting the behavior of PyTorch and cuDNN to deterministic (this is particulary relevant when using GPUs). We can also define a function for that:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def set_deterministic():\n", " if torch.cuda.is_available():\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", " torch.set_deterministic(True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.0001\n", "BATCH_SIZE = 256\n", "NUM_EPOCHS = 40\n", "\n", "# Architecture\n", "NUM_CLASSES = 10\n", "\n", "# Other\n", "DEVICE = \"cuda:0\"\n", "\n", "set_all_seeds(RANDOM_SEED)\n", "\n", "# Deterministic behavior not yet supported by AdaptiveAvgPool2d\n", "#set_deterministic()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import utility functions" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.insert(0, \"..\") # to include ../helper_evaluate.py etc.\n", "\n", "from helper_evaluate import compute_accuracy\n", "from helper_data import get_dataloaders_cifar10\n", "from helper_train import train_classifier_simple_v1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "### Set random seed ###\n", "set_all_seeds(RANDOM_SEED)\n", "\n", "##########################\n", "### Dataset\n", "##########################\n", "\n", "train_transforms = transforms.Compose([transforms.Resize((70, 70)),\n", " transforms.RandomCrop((64, 64)),\n", " transforms.ToTensor()])\n", "\n", "test_transforms = transforms.Compose([transforms.Resize((70, 70)),\n", " transforms.CenterCrop((64, 64)),\n", " transforms.ToTensor()])\n", "\n", "\n", "train_loader, valid_loader, test_loader = get_dataloaders_cifar10(\n", " batch_size=BATCH_SIZE, \n", " num_workers=2, \n", " train_transforms=train_transforms,\n", " test_transforms=test_transforms,\n", " validation_fraction=0.1)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Set:\n", "\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n", "tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3])\n", "\n", "Validation Set:\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n", "tensor([7, 1, 4, 1, 0, 2, 2, 5, 9, 6])\n", "\n", "Testing Set:\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n", "tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3])\n" ] } ], "source": [ "# Checking the dataset\n", "print('Training Set:\\n')\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " print(labels[:10])\n", " break\n", " \n", "# Checking the dataset\n", "print('\\nValidation Set:')\n", "for images, labels in valid_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " print(labels[:10])\n", " break\n", "\n", "# Checking the dataset\n", "print('\\nTesting Set:')\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " print(labels[:10])\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class AlexNet(nn.Module):\n", "\n", " def __init__(self, num_classes):\n", " super(AlexNet, self).__init__()\n", " self.features = nn.Sequential(\n", " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " )\n", " self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n", " self.classifier = nn.Sequential(\n", " nn.Dropout(0.5),\n", " nn.Linear(256 * 6 * 6, 4096),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(0.5),\n", " nn.Linear(4096, 4096),\n", " nn.ReLU(inplace=True),\n", " nn.Linear(4096, num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = self.avgpool(x)\n", " x = x.view(x.size(0), 256 * 6 * 6)\n", " logits = self.classifier(x)\n", " probas = F.softmax(logits, dim=1)\n", " return logits" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = AlexNet(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/040 | Batch 0000/0176 | Loss: 2.3030\n", "Epoch: 001/040 | Batch 0050/0176 | Loss: 1.9805\n", "Epoch: 001/040 | Batch 0100/0176 | Loss: 1.8457\n", "Epoch: 001/040 | Batch 0150/0176 | Loss: 1.7629\n", "Epoch: 001/040 | Train Acc.: 33.087% | Loss: 1.697\n", "Epoch: 001/040 | Validation Acc.: 34.160% | Loss: 1.671\n", "Time elapsed: 0.44 min\n", "Epoch: 002/040 | Batch 0000/0176 | Loss: 1.7060\n", "Epoch: 002/040 | Batch 0050/0176 | Loss: 1.5886\n", "Epoch: 002/040 | Batch 0100/0176 | Loss: 1.6655\n", "Epoch: 002/040 | Batch 0150/0176 | Loss: 1.4965\n", "Epoch: 002/040 | Train Acc.: 41.529% | Loss: 1.521\n", "Epoch: 002/040 | Validation Acc.: 41.880% | Loss: 1.511\n", "Time elapsed: 0.88 min\n", "Epoch: 003/040 | Batch 0000/0176 | Loss: 1.4985\n", "Epoch: 003/040 | Batch 0050/0176 | Loss: 1.4119\n", "Epoch: 003/040 | Batch 0100/0176 | Loss: 1.3107\n", "Epoch: 003/040 | Batch 0150/0176 | Loss: 1.2836\n", "Epoch: 003/040 | Train Acc.: 47.098% | Loss: 1.406\n", "Epoch: 003/040 | Validation Acc.: 47.520% | Loss: 1.400\n", "Time elapsed: 1.32 min\n", "Epoch: 004/040 | Batch 0000/0176 | Loss: 1.4325\n", "Epoch: 004/040 | Batch 0050/0176 | Loss: 1.2214\n", "Epoch: 004/040 | Batch 0100/0176 | Loss: 1.1575\n", "Epoch: 004/040 | Batch 0150/0176 | Loss: 1.2118\n", "Epoch: 004/040 | Train Acc.: 55.071% | Loss: 1.224\n", "Epoch: 004/040 | Validation Acc.: 54.980% | Loss: 1.229\n", "Time elapsed: 1.77 min\n", "Epoch: 005/040 | Batch 0000/0176 | Loss: 1.1903\n", "Epoch: 005/040 | Batch 0050/0176 | Loss: 1.1410\n", "Epoch: 005/040 | Batch 0100/0176 | Loss: 1.0611\n", "Epoch: 005/040 | Batch 0150/0176 | Loss: 1.0817\n", "Epoch: 005/040 | Train Acc.: 57.938% | Loss: 1.144\n", "Epoch: 005/040 | Validation Acc.: 57.640% | Loss: 1.147\n", "Time elapsed: 2.21 min\n", "Epoch: 006/040 | Batch 0000/0176 | Loss: 1.1344\n", "Epoch: 006/040 | Batch 0050/0176 | Loss: 1.0903\n", "Epoch: 006/040 | Batch 0100/0176 | Loss: 0.9820\n", "Epoch: 006/040 | Batch 0150/0176 | Loss: 1.0716\n", "Epoch: 006/040 | Train Acc.: 60.222% | Loss: 1.104\n", "Epoch: 006/040 | Validation Acc.: 58.700% | Loss: 1.120\n", "Time elapsed: 2.65 min\n", "Epoch: 007/040 | Batch 0000/0176 | Loss: 1.0856\n", "Epoch: 007/040 | Batch 0050/0176 | Loss: 0.9864\n", "Epoch: 007/040 | Batch 0100/0176 | Loss: 0.9559\n", "Epoch: 007/040 | Batch 0150/0176 | Loss: 0.9611\n", "Epoch: 007/040 | Train Acc.: 63.953% | Loss: 1.008\n", "Epoch: 007/040 | Validation Acc.: 61.700% | Loss: 1.055\n", "Time elapsed: 3.09 min\n", "Epoch: 008/040 | Batch 0000/0176 | Loss: 0.9608\n", "Epoch: 008/040 | Batch 0050/0176 | Loss: 0.9866\n", "Epoch: 008/040 | Batch 0100/0176 | Loss: 0.9418\n", "Epoch: 008/040 | Batch 0150/0176 | Loss: 0.9364\n", "Epoch: 008/040 | Train Acc.: 65.887% | Loss: 0.946\n", "Epoch: 008/040 | Validation Acc.: 63.220% | Loss: 1.006\n", "Time elapsed: 3.53 min\n", "Epoch: 009/040 | Batch 0000/0176 | Loss: 0.9465\n", "Epoch: 009/040 | Batch 0050/0176 | Loss: 0.9415\n", "Epoch: 009/040 | Batch 0100/0176 | Loss: 0.9320\n", "Epoch: 009/040 | Batch 0150/0176 | Loss: 0.8411\n", "Epoch: 009/040 | Train Acc.: 67.782% | Loss: 0.908\n", "Epoch: 009/040 | Validation Acc.: 64.740% | Loss: 0.995\n", "Time elapsed: 3.97 min\n", "Epoch: 010/040 | Batch 0000/0176 | Loss: 0.8873\n", "Epoch: 010/040 | Batch 0050/0176 | Loss: 0.9821\n", "Epoch: 010/040 | Batch 0100/0176 | Loss: 0.8360\n", "Epoch: 010/040 | Batch 0150/0176 | Loss: 0.7949\n", "Epoch: 010/040 | Train Acc.: 71.416% | Loss: 0.806\n", "Epoch: 010/040 | Validation Acc.: 67.360% | Loss: 0.912\n", "Time elapsed: 4.41 min\n", "Epoch: 011/040 | Batch 0000/0176 | Loss: 0.7469\n", "Epoch: 011/040 | Batch 0050/0176 | Loss: 0.8446\n", "Epoch: 011/040 | Batch 0100/0176 | Loss: 0.7797\n", "Epoch: 011/040 | Batch 0150/0176 | Loss: 0.7813\n", "Epoch: 011/040 | Train Acc.: 72.213% | Loss: 0.779\n", "Epoch: 011/040 | Validation Acc.: 67.640% | Loss: 0.913\n", "Time elapsed: 4.85 min\n", "Epoch: 012/040 | Batch 0000/0176 | Loss: 0.7076\n", "Epoch: 012/040 | Batch 0050/0176 | Loss: 0.8669\n", "Epoch: 012/040 | Batch 0100/0176 | Loss: 0.7664\n", "Epoch: 012/040 | Batch 0150/0176 | Loss: 0.7205\n", "Epoch: 012/040 | Train Acc.: 73.053% | Loss: 0.762\n", "Epoch: 012/040 | Validation Acc.: 67.440% | Loss: 0.918\n", "Time elapsed: 5.29 min\n", "Epoch: 013/040 | Batch 0000/0176 | Loss: 0.6941\n", "Epoch: 013/040 | Batch 0050/0176 | Loss: 0.7947\n", "Epoch: 013/040 | Batch 0100/0176 | Loss: 0.7407\n", "Epoch: 013/040 | Batch 0150/0176 | Loss: 0.6998\n", "Epoch: 013/040 | Train Acc.: 75.658% | Loss: 0.686\n", "Epoch: 013/040 | Validation Acc.: 69.320% | Loss: 0.883\n", "Time elapsed: 5.74 min\n", "Epoch: 014/040 | Batch 0000/0176 | Loss: 0.6040\n", "Epoch: 014/040 | Batch 0050/0176 | Loss: 0.7107\n", "Epoch: 014/040 | Batch 0100/0176 | Loss: 0.6712\n", "Epoch: 014/040 | Batch 0150/0176 | Loss: 0.6610\n", "Epoch: 014/040 | Train Acc.: 74.813% | Loss: 0.725\n", "Epoch: 014/040 | Validation Acc.: 68.140% | Loss: 0.937\n", "Time elapsed: 6.18 min\n", "Epoch: 015/040 | Batch 0000/0176 | Loss: 0.7104\n", "Epoch: 015/040 | Batch 0050/0176 | Loss: 0.6965\n", "Epoch: 015/040 | Batch 0100/0176 | Loss: 0.6800\n", "Epoch: 015/040 | Batch 0150/0176 | Loss: 0.6326\n", "Epoch: 015/040 | Train Acc.: 76.784% | Loss: 0.660\n", "Epoch: 015/040 | Validation Acc.: 69.560% | Loss: 0.926\n", "Time elapsed: 6.62 min\n", "Epoch: 016/040 | Batch 0000/0176 | Loss: 0.5889\n", "Epoch: 016/040 | Batch 0050/0176 | Loss: 0.6394\n", "Epoch: 016/040 | Batch 0100/0176 | Loss: 0.5344\n", "Epoch: 016/040 | Batch 0150/0176 | Loss: 0.5194\n", "Epoch: 016/040 | Train Acc.: 78.216% | Loss: 0.620\n", "Epoch: 016/040 | Validation Acc.: 69.060% | Loss: 0.925\n", "Time elapsed: 7.06 min\n", "Epoch: 017/040 | Batch 0000/0176 | Loss: 0.4660\n", "Epoch: 017/040 | Batch 0050/0176 | Loss: 0.6028\n", "Epoch: 017/040 | Batch 0100/0176 | Loss: 0.5386\n", "Epoch: 017/040 | Batch 0150/0176 | Loss: 0.5476\n", "Epoch: 017/040 | Train Acc.: 79.882% | Loss: 0.563\n", "Epoch: 017/040 | Validation Acc.: 70.940% | Loss: 0.900\n", "Time elapsed: 7.50 min\n", "Epoch: 018/040 | Batch 0000/0176 | Loss: 0.4768\n", "Epoch: 018/040 | Batch 0050/0176 | Loss: 0.6519\n", "Epoch: 018/040 | Batch 0100/0176 | Loss: 0.5177\n", "Epoch: 018/040 | Batch 0150/0176 | Loss: 0.5190\n", "Epoch: 018/040 | Train Acc.: 79.673% | Loss: 0.575\n", "Epoch: 018/040 | Validation Acc.: 69.540% | Loss: 0.961\n", "Time elapsed: 7.94 min\n", "Epoch: 019/040 | Batch 0000/0176 | Loss: 0.4714\n", "Epoch: 019/040 | Batch 0050/0176 | Loss: 0.5929\n", "Epoch: 019/040 | Batch 0100/0176 | Loss: 0.5053\n", "Epoch: 019/040 | Batch 0150/0176 | Loss: 0.5581\n", "Epoch: 019/040 | Train Acc.: 81.029% | Loss: 0.535\n", "Epoch: 019/040 | Validation Acc.: 70.120% | Loss: 0.971\n", "Time elapsed: 8.38 min\n", "Epoch: 020/040 | Batch 0000/0176 | Loss: 0.3960\n", "Epoch: 020/040 | Batch 0050/0176 | Loss: 0.5640\n", "Epoch: 020/040 | Batch 0100/0176 | Loss: 0.5285\n", "Epoch: 020/040 | Batch 0150/0176 | Loss: 0.4373\n", "Epoch: 020/040 | Train Acc.: 83.053% | Loss: 0.497\n", "Epoch: 020/040 | Validation Acc.: 71.060% | Loss: 0.936\n", "Time elapsed: 8.82 min\n", "Epoch: 021/040 | Batch 0000/0176 | Loss: 0.3562\n", "Epoch: 021/040 | Batch 0050/0176 | Loss: 0.4917\n", "Epoch: 021/040 | Batch 0100/0176 | Loss: 0.5012\n", "Epoch: 021/040 | Batch 0150/0176 | Loss: 0.5080\n", "Epoch: 021/040 | Train Acc.: 82.573% | Loss: 0.489\n", "Epoch: 021/040 | Validation Acc.: 70.460% | Loss: 0.943\n", "Time elapsed: 9.26 min\n", "Epoch: 022/040 | Batch 0000/0176 | Loss: 0.4068\n", "Epoch: 022/040 | Batch 0050/0176 | Loss: 0.4859\n", "Epoch: 022/040 | Batch 0100/0176 | Loss: 0.4582\n", "Epoch: 022/040 | Batch 0150/0176 | Loss: 0.4259\n", "Epoch: 022/040 | Train Acc.: 83.702% | Loss: 0.468\n", "Epoch: 022/040 | Validation Acc.: 71.920% | Loss: 0.940\n", "Time elapsed: 9.70 min\n", "Epoch: 023/040 | Batch 0000/0176 | Loss: 0.4236\n", "Epoch: 023/040 | Batch 0050/0176 | Loss: 0.4781\n", "Epoch: 023/040 | Batch 0100/0176 | Loss: 0.4485\n", "Epoch: 023/040 | Batch 0150/0176 | Loss: 0.3734\n", "Epoch: 023/040 | Train Acc.: 83.813% | Loss: 0.461\n", "Epoch: 023/040 | Validation Acc.: 70.860% | Loss: 0.958\n", "Time elapsed: 10.14 min\n", "Epoch: 024/040 | Batch 0000/0176 | Loss: 0.3325\n", "Epoch: 024/040 | Batch 0050/0176 | Loss: 0.3642\n", "Epoch: 024/040 | Batch 0100/0176 | Loss: 0.3749\n", "Epoch: 024/040 | Batch 0150/0176 | Loss: 0.3309\n", "Epoch: 024/040 | Train Acc.: 85.018% | Loss: 0.425\n", "Epoch: 024/040 | Validation Acc.: 71.560% | Loss: 0.965\n", "Time elapsed: 10.58 min\n", "Epoch: 025/040 | Batch 0000/0176 | Loss: 0.3156\n", "Epoch: 025/040 | Batch 0050/0176 | Loss: 0.3344\n", "Epoch: 025/040 | Batch 0100/0176 | Loss: 0.3945\n", "Epoch: 025/040 | Batch 0150/0176 | Loss: 0.3581\n", "Epoch: 025/040 | Train Acc.: 84.291% | Loss: 0.450\n", "Epoch: 025/040 | Validation Acc.: 69.320% | Loss: 1.077\n", "Time elapsed: 11.02 min\n", "Epoch: 026/040 | Batch 0000/0176 | Loss: 0.3812\n", "Epoch: 026/040 | Batch 0050/0176 | Loss: 0.3231\n", "Epoch: 026/040 | Batch 0100/0176 | Loss: 0.4320\n", "Epoch: 026/040 | Batch 0150/0176 | Loss: 0.3551\n", "Epoch: 026/040 | Train Acc.: 83.344% | Loss: 0.467\n", "Epoch: 026/040 | Validation Acc.: 69.180% | Loss: 1.108\n", "Time elapsed: 11.45 min\n", "Epoch: 027/040 | Batch 0000/0176 | Loss: 0.3591\n", "Epoch: 027/040 | Batch 0050/0176 | Loss: 0.3114\n", "Epoch: 027/040 | Batch 0100/0176 | Loss: 0.2481\n", "Epoch: 027/040 | Batch 0150/0176 | Loss: 0.3428\n", "Epoch: 027/040 | Train Acc.: 84.931% | Loss: 0.433\n", "Epoch: 027/040 | Validation Acc.: 69.600% | Loss: 1.126\n", "Time elapsed: 11.90 min\n", "Epoch: 028/040 | Batch 0000/0176 | Loss: 0.3464\n", "Epoch: 028/040 | Batch 0050/0176 | Loss: 0.2866\n", "Epoch: 028/040 | Batch 0100/0176 | Loss: 0.2776\n", "Epoch: 028/040 | Batch 0150/0176 | Loss: 0.3536\n", "Epoch: 028/040 | Train Acc.: 86.289% | Loss: 0.399\n", "Epoch: 028/040 | Validation Acc.: 70.460% | Loss: 1.114\n", "Time elapsed: 12.34 min\n", "Epoch: 029/040 | Batch 0000/0176 | Loss: 0.3019\n", "Epoch: 029/040 | Batch 0050/0176 | Loss: 0.3655\n", "Epoch: 029/040 | Batch 0100/0176 | Loss: 0.2486\n", "Epoch: 029/040 | Batch 0150/0176 | Loss: 0.4025\n", "Epoch: 029/040 | Train Acc.: 89.231% | Loss: 0.308\n", "Epoch: 029/040 | Validation Acc.: 71.820% | Loss: 1.050\n", "Time elapsed: 12.78 min\n", "Epoch: 030/040 | Batch 0000/0176 | Loss: 0.2827\n", "Epoch: 030/040 | Batch 0050/0176 | Loss: 0.3269\n", "Epoch: 030/040 | Batch 0100/0176 | Loss: 0.2699\n", "Epoch: 030/040 | Batch 0150/0176 | Loss: 0.3385\n", "Epoch: 030/040 | Train Acc.: 89.320% | Loss: 0.308\n", "Epoch: 030/040 | Validation Acc.: 71.940% | Loss: 1.063\n", "Time elapsed: 13.22 min\n", "Epoch: 031/040 | Batch 0000/0176 | Loss: 0.2108\n", "Epoch: 031/040 | Batch 0050/0176 | Loss: 0.2322\n", "Epoch: 031/040 | Batch 0100/0176 | Loss: 0.2662\n", "Epoch: 031/040 | Batch 0150/0176 | Loss: 0.1860\n", "Epoch: 031/040 | Train Acc.: 88.807% | Loss: 0.316\n", "Epoch: 031/040 | Validation Acc.: 71.720% | Loss: 1.073\n", "Time elapsed: 13.66 min\n", "Epoch: 032/040 | Batch 0000/0176 | Loss: 0.2714\n", "Epoch: 032/040 | Batch 0050/0176 | Loss: 0.3018\n", "Epoch: 032/040 | Batch 0100/0176 | Loss: 0.2323\n", "Epoch: 032/040 | Batch 0150/0176 | Loss: 0.2668\n", "Epoch: 032/040 | Train Acc.: 91.173% | Loss: 0.249\n", "Epoch: 032/040 | Validation Acc.: 71.980% | Loss: 1.025\n", "Time elapsed: 14.10 min\n", "Epoch: 033/040 | Batch 0000/0176 | Loss: 0.1913\n", "Epoch: 033/040 | Batch 0050/0176 | Loss: 0.3458\n", "Epoch: 033/040 | Batch 0100/0176 | Loss: 0.1746\n", "Epoch: 033/040 | Batch 0150/0176 | Loss: 0.2072\n", "Epoch: 033/040 | Train Acc.: 91.171% | Loss: 0.245\n", "Epoch: 033/040 | Validation Acc.: 72.680% | Loss: 1.060\n", "Time elapsed: 14.54 min\n", "Epoch: 034/040 | Batch 0000/0176 | Loss: 0.2103\n", "Epoch: 034/040 | Batch 0050/0176 | Loss: 0.2894\n", "Epoch: 034/040 | Batch 0100/0176 | Loss: 0.2032\n", "Epoch: 034/040 | Batch 0150/0176 | Loss: 0.2118\n", "Epoch: 034/040 | Train Acc.: 91.756% | Loss: 0.239\n", "Epoch: 034/040 | Validation Acc.: 73.200% | Loss: 1.038\n", "Time elapsed: 14.98 min\n", "Epoch: 035/040 | Batch 0000/0176 | Loss: 0.2010\n", "Epoch: 035/040 | Batch 0050/0176 | Loss: 0.2106\n", "Epoch: 035/040 | Batch 0100/0176 | Loss: 0.1802\n", "Epoch: 035/040 | Batch 0150/0176 | Loss: 0.2216\n", "Epoch: 035/040 | Train Acc.: 91.060% | Loss: 0.264\n", "Epoch: 035/040 | Validation Acc.: 71.280% | Loss: 1.146\n", "Time elapsed: 15.42 min\n", "Epoch: 036/040 | Batch 0000/0176 | Loss: 0.1478\n", "Epoch: 036/040 | Batch 0050/0176 | Loss: 0.1735\n", "Epoch: 036/040 | Batch 0100/0176 | Loss: 0.1186\n", "Epoch: 036/040 | Batch 0150/0176 | Loss: 0.1835\n", "Epoch: 036/040 | Train Acc.: 91.376% | Loss: 0.254\n", "Epoch: 036/040 | Validation Acc.: 72.160% | Loss: 1.181\n", "Time elapsed: 15.86 min\n", "Epoch: 037/040 | Batch 0000/0176 | Loss: 0.1154\n", "Epoch: 037/040 | Batch 0050/0176 | Loss: 0.1817\n", "Epoch: 037/040 | Batch 0100/0176 | Loss: 0.1166\n", "Epoch: 037/040 | Batch 0150/0176 | Loss: 0.1973\n", "Epoch: 037/040 | Train Acc.: 91.278% | Loss: 0.259\n", "Epoch: 037/040 | Validation Acc.: 71.480% | Loss: 1.241\n", "Time elapsed: 16.31 min\n", "Epoch: 038/040 | Batch 0000/0176 | Loss: 0.1402\n", "Epoch: 038/040 | Batch 0050/0176 | Loss: 0.1672\n", "Epoch: 038/040 | Batch 0100/0176 | Loss: 0.1366\n", "Epoch: 038/040 | Batch 0150/0176 | Loss: 0.1037\n", "Epoch: 038/040 | Train Acc.: 91.447% | Loss: 0.256\n", "Epoch: 038/040 | Validation Acc.: 71.260% | Loss: 1.256\n", "Time elapsed: 16.75 min\n", "Epoch: 039/040 | Batch 0000/0176 | Loss: 0.1505\n", "Epoch: 039/040 | Batch 0050/0176 | Loss: 0.1537\n", "Epoch: 039/040 | Batch 0100/0176 | Loss: 0.1592\n", "Epoch: 039/040 | Batch 0150/0176 | Loss: 0.1947\n", "Epoch: 039/040 | Train Acc.: 92.571% | Loss: 0.223\n", "Epoch: 039/040 | Validation Acc.: 72.040% | Loss: 1.266\n", "Time elapsed: 17.19 min\n", "Epoch: 040/040 | Batch 0000/0176 | Loss: 0.1707\n", "Epoch: 040/040 | Batch 0050/0176 | Loss: 0.1500\n", "Epoch: 040/040 | Batch 0100/0176 | Loss: 0.1422\n", "Epoch: 040/040 | Batch 0150/0176 | Loss: 0.1099\n", "Epoch: 040/040 | Train Acc.: 95.191% | Loss: 0.139\n", "Epoch: 040/040 | Validation Acc.: 72.740% | Loss: 1.201\n", "Time elapsed: 17.63 min\n", "Total Training Time: 17.63 min\n" ] } ], "source": [ "log_dict = train_classifier_simple_v1(num_epochs=NUM_EPOCHS, model=model, \n", " optimizer=optimizer, device=DEVICE, \n", " train_loader=train_loader, valid_loader=valid_loader, \n", " logging_interval=50)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'cost_list' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train_loss_per_batch'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Minibatch cost'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m plt.plot(np.convolve(cost_list, \n\u001b[0m\u001b[1;32m 3\u001b[0m np.ones(200,)/200, mode='valid'), \n\u001b[1;32m 4\u001b[0m label='Running average')\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'cost_list' is not defined" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(log_dict['train_loss_per_batch'], label='Minibatch cost')\n", "plt.plot(np.convolve(cost_list, \n", " np.ones(200,)/200, mode='valid'), \n", " label='Running average')\n", "\n", "plt.ylabel('Cross Entropy')\n", "plt.xlabel('Iteration')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['train_acc_per_batch'], label='Training')\n", "plt.plot(np.arange(1, NUM_EPOCHS+1), log_dict['valid_acc_per_batch'], label='Validation')\n", "\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.set_grad_enabled(False):\n", " \n", " train_acc = compute_accuracy(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " test_acc = compute_accuracy(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " valid_acc = compute_accuracy(model=model,\n", " data_loader=valid_loader,\n", " device=DEVICE)\n", " \n", "\n", "print(f'Train ACC: {valid_acc:.2f}%')\n", "print(f'Validation ACC: {valid_acc:.2f}%')\n", "print(f'Test ACC: {test_acc:.2f}%')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }