{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.1.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# Model Zoo -- AlexNet CIFAR-10 Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Network Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "References\n", " \n", "- [1] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. \"[Imagenet classification with deep convolutional neural networks.](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)\" In Advances in Neural Information Processing Systems, pp. 1097-1105. 2012.\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import Subset\n", "\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model Settings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.0001\n", "BATCH_SIZE = 256\n", "NUM_EPOCHS = 40\n", "\n", "# Architecture\n", "NUM_CLASSES = 10\n", "\n", "# Other\n", "DEVICE = \"cuda:0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "train_indices = torch.arange(0, 48000)\n", "valid_indices = torch.arange(48000, 50000)\n", "\n", "\n", "train_transform = transforms.Compose([transforms.Resize((70, 70)),\n", " transforms.RandomCrop((64, 64)),\n", " transforms.ToTensor()])\n", "\n", "test_transform = transforms.Compose([transforms.Resize((70, 70)),\n", " transforms.CenterCrop((64, 64)),\n", " transforms.ToTensor()])\n", "\n", "train_and_valid = datasets.CIFAR10(root='data', \n", " train=True, \n", " transform=train_transform,\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "test_dataset = datasets.CIFAR10(root='data', \n", " train=False, \n", " transform=test_transform,\n", " download=False)\n", "\n", "\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Set:\n", "\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n", "\n", "Validation Set:\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n", "\n", "Testing Set:\n", "Image batch dimensions: torch.Size([256, 3, 64, 64])\n", "Image label dimensions: torch.Size([256])\n" ] } ], "source": [ "# Checking the dataset\n", "print('Training Set:\\n')\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " break\n", " \n", "# Checking the dataset\n", "print('\\nValidation Set:')\n", "for images, labels in valid_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " break\n", "\n", "# Checking the dataset\n", "print('\\nTesting Set:')\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.size())\n", " print('Image label dimensions:', labels.size())\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class AlexNet(nn.Module):\n", "\n", " def __init__(self, num_classes):\n", " super(AlexNet, self).__init__()\n", " self.features = nn.Sequential(\n", " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " )\n", " self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n", " self.classifier = nn.Sequential(\n", " nn.Dropout(0.5),\n", " nn.Linear(256 * 6 * 6, 4096),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(0.5),\n", " nn.Linear(4096, 4096),\n", " nn.ReLU(inplace=True),\n", " nn.Linear(4096, num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = self.avgpool(x)\n", " x = x.view(x.size(0), 256 * 6 * 6)\n", " logits = self.classifier(x)\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = AlexNet(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/040 | Batch 000/188 | Cost: 2.3029\n", "Epoch: 001/040 | Batch 150/188 | Cost: 1.7090\n", "Epoch: 001/040\n", "Train ACC: 31.92 | Validation ACC: 31.05\n", "Time elapsed: 0.22 min\n", "Epoch: 002/040 | Batch 000/188 | Cost: 1.7312\n", "Epoch: 002/040 | Batch 150/188 | Cost: 1.6115\n", "Epoch: 002/040\n", "Train ACC: 43.78 | Validation ACC: 44.35\n", "Time elapsed: 0.43 min\n", "Epoch: 003/040 | Batch 000/188 | Cost: 1.5096\n", "Epoch: 003/040 | Batch 150/188 | Cost: 1.4324\n", "Epoch: 003/040\n", "Train ACC: 53.03 | Validation ACC: 52.30\n", "Time elapsed: 0.64 min\n", "Epoch: 004/040 | Batch 000/188 | Cost: 1.3731\n", "Epoch: 004/040 | Batch 150/188 | Cost: 1.2505\n", "Epoch: 004/040\n", "Train ACC: 56.87 | Validation ACC: 57.30\n", "Time elapsed: 0.85 min\n", "Epoch: 005/040 | Batch 000/188 | Cost: 1.0734\n", "Epoch: 005/040 | Batch 150/188 | Cost: 1.1652\n", "Epoch: 005/040\n", "Train ACC: 60.97 | Validation ACC: 60.30\n", "Time elapsed: 1.07 min\n", "Epoch: 006/040 | Batch 000/188 | Cost: 1.0730\n", "Epoch: 006/040 | Batch 150/188 | Cost: 1.1333\n", "Epoch: 006/040\n", "Train ACC: 62.87 | Validation ACC: 60.90\n", "Time elapsed: 1.28 min\n", "Epoch: 007/040 | Batch 000/188 | Cost: 1.0317\n", "Epoch: 007/040 | Batch 150/188 | Cost: 1.0182\n", "Epoch: 007/040\n", "Train ACC: 67.03 | Validation ACC: 64.35\n", "Time elapsed: 1.50 min\n", "Epoch: 008/040 | Batch 000/188 | Cost: 1.0245\n", "Epoch: 008/040 | Batch 150/188 | Cost: 0.9324\n", "Epoch: 008/040\n", "Train ACC: 64.87 | Validation ACC: 64.55\n", "Time elapsed: 1.71 min\n", "Epoch: 009/040 | Batch 000/188 | Cost: 1.0012\n", "Epoch: 009/040 | Batch 150/188 | Cost: 0.8525\n", "Epoch: 009/040\n", "Train ACC: 70.84 | Validation ACC: 67.30\n", "Time elapsed: 1.93 min\n", "Epoch: 010/040 | Batch 000/188 | Cost: 0.7442\n", "Epoch: 010/040 | Batch 150/188 | Cost: 0.7908\n", "Epoch: 010/040\n", "Train ACC: 70.95 | Validation ACC: 67.10\n", "Time elapsed: 2.14 min\n", "Epoch: 011/040 | Batch 000/188 | Cost: 0.8389\n", "Epoch: 011/040 | Batch 150/188 | Cost: 0.8383\n", "Epoch: 011/040\n", "Train ACC: 74.18 | Validation ACC: 69.95\n", "Time elapsed: 2.36 min\n", "Epoch: 012/040 | Batch 000/188 | Cost: 0.7037\n", "Epoch: 012/040 | Batch 150/188 | Cost: 0.9285\n", "Epoch: 012/040\n", "Train ACC: 74.23 | Validation ACC: 66.70\n", "Time elapsed: 2.57 min\n", "Epoch: 013/040 | Batch 000/188 | Cost: 0.7205\n", "Epoch: 013/040 | Batch 150/188 | Cost: 0.7099\n", "Epoch: 013/040\n", "Train ACC: 76.88 | Validation ACC: 70.00\n", "Time elapsed: 2.78 min\n", "Epoch: 014/040 | Batch 000/188 | Cost: 0.6575\n", "Epoch: 014/040 | Batch 150/188 | Cost: 0.6311\n", "Epoch: 014/040\n", "Train ACC: 76.69 | Validation ACC: 70.00\n", "Time elapsed: 3.00 min\n", "Epoch: 015/040 | Batch 000/188 | Cost: 0.6724\n", "Epoch: 015/040 | Batch 150/188 | Cost: 0.8899\n", "Epoch: 015/040\n", "Train ACC: 80.01 | Validation ACC: 71.80\n", "Time elapsed: 3.22 min\n", "Epoch: 016/040 | Batch 000/188 | Cost: 0.6895\n", "Epoch: 016/040 | Batch 150/188 | Cost: 0.5913\n", "Epoch: 016/040\n", "Train ACC: 79.64 | Validation ACC: 70.75\n", "Time elapsed: 3.43 min\n", "Epoch: 017/040 | Batch 000/188 | Cost: 0.6096\n", "Epoch: 017/040 | Batch 150/188 | Cost: 0.5401\n", "Epoch: 017/040\n", "Train ACC: 82.48 | Validation ACC: 72.20\n", "Time elapsed: 3.65 min\n", "Epoch: 018/040 | Batch 000/188 | Cost: 0.5421\n", "Epoch: 018/040 | Batch 150/188 | Cost: 0.4187\n", "Epoch: 018/040\n", "Train ACC: 84.17 | Validation ACC: 73.60\n", "Time elapsed: 3.86 min\n", "Epoch: 019/040 | Batch 000/188 | Cost: 0.4490\n", "Epoch: 019/040 | Batch 150/188 | Cost: 0.4658\n", "Epoch: 019/040\n", "Train ACC: 84.06 | Validation ACC: 72.65\n", "Time elapsed: 4.08 min\n", "Epoch: 020/040 | Batch 000/188 | Cost: 0.4837\n", "Epoch: 020/040 | Batch 150/188 | Cost: 0.4519\n", "Epoch: 020/040\n", "Train ACC: 86.10 | Validation ACC: 72.90\n", "Time elapsed: 4.29 min\n", "Epoch: 021/040 | Batch 000/188 | Cost: 0.4615\n", "Epoch: 021/040 | Batch 150/188 | Cost: 0.5283\n", "Epoch: 021/040\n", "Train ACC: 85.61 | Validation ACC: 72.10\n", "Time elapsed: 4.51 min\n", "Epoch: 022/040 | Batch 000/188 | Cost: 0.4693\n", "Epoch: 022/040 | Batch 150/188 | Cost: 0.4589\n", "Epoch: 022/040\n", "Train ACC: 88.70 | Validation ACC: 73.90\n", "Time elapsed: 4.72 min\n", "Epoch: 023/040 | Batch 000/188 | Cost: 0.2818\n", "Epoch: 023/040 | Batch 150/188 | Cost: 0.4123\n", "Epoch: 023/040\n", "Train ACC: 89.58 | Validation ACC: 73.45\n", "Time elapsed: 4.94 min\n", "Epoch: 024/040 | Batch 000/188 | Cost: 0.3030\n", "Epoch: 024/040 | Batch 150/188 | Cost: 0.3685\n", "Epoch: 024/040\n", "Train ACC: 90.44 | Validation ACC: 73.60\n", "Time elapsed: 5.15 min\n", "Epoch: 025/040 | Batch 000/188 | Cost: 0.2399\n", "Epoch: 025/040 | Batch 150/188 | Cost: 0.3384\n", "Epoch: 025/040\n", "Train ACC: 90.85 | Validation ACC: 73.35\n", "Time elapsed: 5.37 min\n", "Epoch: 026/040 | Batch 000/188 | Cost: 0.2333\n", "Epoch: 026/040 | Batch 150/188 | Cost: 0.2852\n", "Epoch: 026/040\n", "Train ACC: 92.25 | Validation ACC: 72.20\n", "Time elapsed: 5.58 min\n", "Epoch: 027/040 | Batch 000/188 | Cost: 0.2728\n", "Epoch: 027/040 | Batch 150/188 | Cost: 0.3350\n", "Epoch: 027/040\n", "Train ACC: 91.94 | Validation ACC: 73.85\n", "Time elapsed: 5.80 min\n", "Epoch: 028/040 | Batch 000/188 | Cost: 0.2277\n", "Epoch: 028/040 | Batch 150/188 | Cost: 0.2987\n", "Epoch: 028/040\n", "Train ACC: 92.84 | Validation ACC: 72.85\n", "Time elapsed: 6.02 min\n", "Epoch: 029/040 | Batch 000/188 | Cost: 0.2115\n", "Epoch: 029/040 | Batch 150/188 | Cost: 0.2038\n", "Epoch: 029/040\n", "Train ACC: 92.28 | Validation ACC: 72.50\n", "Time elapsed: 6.23 min\n", "Epoch: 030/040 | Batch 000/188 | Cost: 0.1841\n", "Epoch: 030/040 | Batch 150/188 | Cost: 0.2074\n", "Epoch: 030/040\n", "Train ACC: 94.63 | Validation ACC: 74.00\n", "Time elapsed: 6.44 min\n", "Epoch: 031/040 | Batch 000/188 | Cost: 0.1490\n", "Epoch: 031/040 | Batch 150/188 | Cost: 0.2191\n", "Epoch: 031/040\n", "Train ACC: 94.67 | Validation ACC: 72.60\n", "Time elapsed: 6.66 min\n", "Epoch: 032/040 | Batch 000/188 | Cost: 0.1719\n", "Epoch: 032/040 | Batch 150/188 | Cost: 0.1990\n", "Epoch: 032/040\n", "Train ACC: 93.93 | Validation ACC: 71.60\n", "Time elapsed: 6.87 min\n", "Epoch: 033/040 | Batch 000/188 | Cost: 0.1839\n", "Epoch: 033/040 | Batch 150/188 | Cost: 0.1939\n", "Epoch: 033/040\n", "Train ACC: 95.61 | Validation ACC: 73.75\n", "Time elapsed: 7.09 min\n", "Epoch: 034/040 | Batch 000/188 | Cost: 0.0995\n", "Epoch: 034/040 | Batch 150/188 | Cost: 0.1726\n", "Epoch: 034/040\n", "Train ACC: 95.35 | Validation ACC: 73.85\n", "Time elapsed: 7.30 min\n", "Epoch: 035/040 | Batch 000/188 | Cost: 0.1451\n", "Epoch: 035/040 | Batch 150/188 | Cost: 0.1414\n", "Epoch: 035/040\n", "Train ACC: 95.90 | Validation ACC: 73.55\n", "Time elapsed: 7.52 min\n", "Epoch: 036/040 | Batch 000/188 | Cost: 0.0551\n", "Epoch: 036/040 | Batch 150/188 | Cost: 0.1009\n", "Epoch: 036/040\n", "Train ACC: 95.40 | Validation ACC: 72.45\n", "Time elapsed: 7.73 min\n", "Epoch: 037/040 | Batch 000/188 | Cost: 0.1616\n", "Epoch: 037/040 | Batch 150/188 | Cost: 0.1102\n", "Epoch: 037/040\n", "Train ACC: 96.08 | Validation ACC: 72.55\n", "Time elapsed: 7.95 min\n", "Epoch: 038/040 | Batch 000/188 | Cost: 0.1409\n", "Epoch: 038/040 | Batch 150/188 | Cost: 0.1090\n", "Epoch: 038/040\n", "Train ACC: 96.69 | Validation ACC: 73.85\n", "Time elapsed: 8.16 min\n", "Epoch: 039/040 | Batch 000/188 | Cost: 0.1088\n", "Epoch: 039/040 | Batch 150/188 | Cost: 0.1309\n", "Epoch: 039/040\n", "Train ACC: 96.11 | Validation ACC: 72.15\n", "Time elapsed: 8.38 min\n", "Epoch: 040/040 | Batch 000/188 | Cost: 0.1418\n", "Epoch: 040/040 | Batch 150/188 | Cost: 0.1745\n", "Epoch: 040/040\n", "Train ACC: 97.45 | Validation ACC: 74.00\n", "Time elapsed: 8.59 min\n", "Total Training Time: 8.59 min\n" ] } ], "source": [ "def compute_acc(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " model.eval()\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " assert predicted_labels.size() == targets.size()\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "\n", "cost_list = []\n", "train_acc_list, valid_acc_list = [], []\n", "\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " #################################################\n", " ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n", " ################################################\n", " cost_list.append(cost.item())\n", " if not batch_idx % 150:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " \n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " \n", " train_acc = compute_acc(model, train_loader, device=DEVICE)\n", " valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n", " \n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n", " f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n", " \n", " train_acc_list.append(train_acc)\n", " valid_acc_list.append(valid_acc)\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(cost_list, label='Minibatch cost')\n", "plt.plot(np.convolve(cost_list, \n", " np.ones(200,)/200, mode='valid'), \n", " label='Running average')\n", "\n", "plt.ylabel('Cross Entropy')\n", "plt.xlabel('Iteration')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n", "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n", "\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation ACC: 74.55%\n", "Test ACC: 73.68%\n" ] } ], "source": [ "with torch.set_grad_enabled(False):\n", " test_acc = compute_acc(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " valid_acc = compute_acc(model=model,\n", " data_loader=valid_loader,\n", " device=DEVICE)\n", " \n", "\n", "print(f'Validation ACC: {valid_acc:.2f}%')\n", "print(f'Test ACC: {test_acc:.2f}%')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "matplotlib 3.1.0\n", "pandas 0.24.2\n", "torch 1.1.0\n", "numpy 1.16.4\n", "PIL.Image 6.0.0\n", "torchvision 0.3.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }