{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.9.0\n", "\n", "torch 1.3.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# Network in Network CIFAR-10 Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "based on \n", "\n", "- Lin, Min, Qiang Chen, and Shuicheng Yan. \"Network in network.\" arXiv preprint arXiv:1312.4400 (2013)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import Subset\n", "\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cpu')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.tensor([1]).device" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model Settings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.0005\n", "BATCH_SIZE = 256\n", "NUM_EPOCHS = 100\n", "\n", "# Architecture\n", "NUM_CLASSES = 10\n", "\n", "# Other\n", "DEVICE = \"cuda:3\"\n", "GRAYSCALE = False" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n" ] } ], "source": [ "##########################\n", "### CIFAR-10 Dataset\n", "##########################\n", "\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "\n", "\n", "train_indices = torch.arange(0, 49000)\n", "valid_indices = torch.arange(49000, 50000)\n", "\n", "\n", "train_and_valid = datasets.CIFAR10(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "\n", "\n", "test_dataset = datasets.CIFAR10(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "#####################################################\n", "### Data Loaders\n", "#####################################################\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=False)\n", "\n", "#####################################################\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n", "\n", "for images, labels in test_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n", " \n", "for images, labels in valid_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class NiN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(NiN, self).__init__()\n", " self.num_classes = num_classes\n", " self.classifier = nn.Sequential(\n", " nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n", "\n", " )\n", "\n", " def forward(self, x):\n", " x = self.classifier(x)\n", " logits = x.view(x.size(0), self.num_classes)\n", " probas = torch.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training without Pinned Memory" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = NiN(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/192 | Cost: 2.3043\n", "Epoch: 001/100 | Batch 120/192 | Cost: 2.0653\n", "Epoch: 001/100 Train Acc.: 24.69% | Validation Acc.: 24.50%\n", "Time elapsed: 0.32 min\n", "Epoch: 002/100 | Batch 000/192 | Cost: 1.8584\n", "Epoch: 002/100 | Batch 120/192 | Cost: 1.7447\n", "Epoch: 002/100 Train Acc.: 36.51% | Validation Acc.: 36.90%\n", "Time elapsed: 0.64 min\n", "Epoch: 003/100 | Batch 000/192 | Cost: 1.6050\n", "Epoch: 003/100 | Batch 120/192 | Cost: 1.5591\n", "Epoch: 003/100 Train Acc.: 40.50% | Validation Acc.: 37.50%\n", "Time elapsed: 0.97 min\n", "Epoch: 004/100 | Batch 000/192 | Cost: 1.5428\n", "Epoch: 004/100 | Batch 120/192 | Cost: 1.4454\n", "Epoch: 004/100 Train Acc.: 46.12% | Validation Acc.: 45.80%\n", "Time elapsed: 1.30 min\n", "Epoch: 005/100 | Batch 000/192 | Cost: 1.4038\n", "Epoch: 005/100 | Batch 120/192 | Cost: 1.4141\n", "Epoch: 005/100 Train Acc.: 50.21% | Validation Acc.: 49.90%\n", "Time elapsed: 1.63 min\n", "Epoch: 006/100 | Batch 000/192 | Cost: 1.3475\n", "Epoch: 006/100 | Batch 120/192 | Cost: 1.2627\n", "Epoch: 006/100 Train Acc.: 52.66% | Validation Acc.: 54.40%\n", "Time elapsed: 1.96 min\n", "Epoch: 007/100 | Batch 000/192 | Cost: 1.3238\n", "Epoch: 007/100 | Batch 120/192 | Cost: 1.2220\n", "Epoch: 007/100 Train Acc.: 54.42% | Validation Acc.: 54.40%\n", "Time elapsed: 2.29 min\n", "Epoch: 008/100 | Batch 000/192 | Cost: 1.2009\n", "Epoch: 008/100 | Batch 120/192 | Cost: 1.2045\n", "Epoch: 008/100 Train Acc.: 55.81% | Validation Acc.: 55.30%\n", "Time elapsed: 2.63 min\n", "Epoch: 009/100 | Batch 000/192 | Cost: 1.2797\n", "Epoch: 009/100 | Batch 120/192 | Cost: 1.1397\n", "Epoch: 009/100 Train Acc.: 59.10% | Validation Acc.: 60.60%\n", "Time elapsed: 2.97 min\n", "Epoch: 010/100 | Batch 000/192 | Cost: 1.0562\n", "Epoch: 010/100 | Batch 120/192 | Cost: 1.1625\n", "Epoch: 010/100 Train Acc.: 59.79% | Validation Acc.: 60.90%\n", "Time elapsed: 3.32 min\n", "Epoch: 011/100 | Batch 000/192 | Cost: 1.0868\n", "Epoch: 011/100 | Batch 120/192 | Cost: 1.0636\n", "Epoch: 011/100 Train Acc.: 60.43% | Validation Acc.: 61.00%\n", "Time elapsed: 3.66 min\n", "Epoch: 012/100 | Batch 000/192 | Cost: 1.0049\n", "Epoch: 012/100 | Batch 120/192 | Cost: 1.2247\n", "Epoch: 012/100 Train Acc.: 62.14% | Validation Acc.: 62.70%\n", "Time elapsed: 4.00 min\n", "Epoch: 013/100 | Batch 000/192 | Cost: 0.9232\n", "Epoch: 013/100 | Batch 120/192 | Cost: 1.0345\n", "Epoch: 013/100 Train Acc.: 61.42% | Validation Acc.: 61.70%\n", "Time elapsed: 4.35 min\n", "Epoch: 014/100 | Batch 000/192 | Cost: 0.9256\n", "Epoch: 014/100 | Batch 120/192 | Cost: 1.1639\n", "Epoch: 014/100 Train Acc.: 63.82% | Validation Acc.: 65.80%\n", "Time elapsed: 4.69 min\n", "Epoch: 015/100 | Batch 000/192 | Cost: 0.9600\n", "Epoch: 015/100 | Batch 120/192 | Cost: 1.0263\n", "Epoch: 015/100 Train Acc.: 63.94% | Validation Acc.: 64.00%\n", "Time elapsed: 5.04 min\n", "Epoch: 016/100 | Batch 000/192 | Cost: 0.8859\n", "Epoch: 016/100 | Batch 120/192 | Cost: 1.0307\n", "Epoch: 016/100 Train Acc.: 65.79% | Validation Acc.: 66.40%\n", "Time elapsed: 5.39 min\n", "Epoch: 017/100 | Batch 000/192 | Cost: 1.0020\n", "Epoch: 017/100 | Batch 120/192 | Cost: 0.9755\n", "Epoch: 017/100 Train Acc.: 66.95% | Validation Acc.: 66.60%\n", "Time elapsed: 5.73 min\n", "Epoch: 018/100 | Batch 000/192 | Cost: 0.9551\n", "Epoch: 018/100 | Batch 120/192 | Cost: 0.8429\n", "Epoch: 018/100 Train Acc.: 67.56% | Validation Acc.: 66.30%\n", "Time elapsed: 6.08 min\n", "Epoch: 019/100 | Batch 000/192 | Cost: 1.0420\n", "Epoch: 019/100 | Batch 120/192 | Cost: 0.9771\n", "Epoch: 019/100 Train Acc.: 69.44% | Validation Acc.: 68.20%\n", "Time elapsed: 6.43 min\n", "Epoch: 020/100 | Batch 000/192 | Cost: 0.8471\n", "Epoch: 020/100 | Batch 120/192 | Cost: 0.8322\n", "Epoch: 020/100 Train Acc.: 69.99% | Validation Acc.: 70.20%\n", "Time elapsed: 6.77 min\n", "Epoch: 021/100 | Batch 000/192 | Cost: 0.8974\n", "Epoch: 021/100 | Batch 120/192 | Cost: 0.8585\n", "Epoch: 021/100 Train Acc.: 69.52% | Validation Acc.: 69.30%\n", "Time elapsed: 7.12 min\n", "Epoch: 022/100 | Batch 000/192 | Cost: 0.8691\n", "Epoch: 022/100 | Batch 120/192 | Cost: 0.6618\n", "Epoch: 022/100 Train Acc.: 68.26% | Validation Acc.: 65.90%\n", "Time elapsed: 7.47 min\n", "Epoch: 023/100 | Batch 000/192 | Cost: 0.9277\n", "Epoch: 023/100 | Batch 120/192 | Cost: 0.9011\n", "Epoch: 023/100 Train Acc.: 71.66% | Validation Acc.: 72.10%\n", "Time elapsed: 7.81 min\n", "Epoch: 024/100 | Batch 000/192 | Cost: 0.7764\n", "Epoch: 024/100 | Batch 120/192 | Cost: 0.7561\n", "Epoch: 024/100 Train Acc.: 71.70% | Validation Acc.: 68.80%\n", "Time elapsed: 8.16 min\n", "Epoch: 025/100 | Batch 000/192 | Cost: 0.8113\n", "Epoch: 025/100 | Batch 120/192 | Cost: 0.7186\n", "Epoch: 025/100 Train Acc.: 73.62% | Validation Acc.: 73.00%\n", "Time elapsed: 8.50 min\n", "Epoch: 026/100 | Batch 000/192 | Cost: 0.6515\n", "Epoch: 026/100 | Batch 120/192 | Cost: 0.6954\n", "Epoch: 026/100 Train Acc.: 72.22% | Validation Acc.: 70.20%\n", "Time elapsed: 8.85 min\n", "Epoch: 027/100 | Batch 000/192 | Cost: 0.7278\n", "Epoch: 027/100 | Batch 120/192 | Cost: 0.7117\n", "Epoch: 027/100 Train Acc.: 74.82% | Validation Acc.: 72.30%\n", "Time elapsed: 9.19 min\n", "Epoch: 028/100 | Batch 000/192 | Cost: 0.6732\n", "Epoch: 028/100 | Batch 120/192 | Cost: 0.6591\n", "Epoch: 028/100 Train Acc.: 74.93% | Validation Acc.: 72.60%\n", "Time elapsed: 9.54 min\n", "Epoch: 029/100 | Batch 000/192 | Cost: 0.7438\n", "Epoch: 029/100 | Batch 120/192 | Cost: 0.6429\n", "Epoch: 029/100 Train Acc.: 75.44% | Validation Acc.: 72.80%\n", "Time elapsed: 9.88 min\n", "Epoch: 030/100 | Batch 000/192 | Cost: 0.7306\n", "Epoch: 030/100 | Batch 120/192 | Cost: 0.6643\n", "Epoch: 030/100 Train Acc.: 76.34% | Validation Acc.: 74.40%\n", "Time elapsed: 10.22 min\n", "Epoch: 031/100 | Batch 000/192 | Cost: 0.5957\n", "Epoch: 031/100 | Batch 120/192 | Cost: 0.5574\n", "Epoch: 031/100 Train Acc.: 76.60% | Validation Acc.: 75.90%\n", "Time elapsed: 10.57 min\n", "Epoch: 032/100 | Batch 000/192 | Cost: 0.6414\n", "Epoch: 032/100 | Batch 120/192 | Cost: 0.6951\n", "Epoch: 032/100 Train Acc.: 77.15% | Validation Acc.: 76.10%\n", "Time elapsed: 10.91 min\n", "Epoch: 033/100 | Batch 000/192 | Cost: 0.6898\n", "Epoch: 033/100 | Batch 120/192 | Cost: 0.7784\n", "Epoch: 033/100 Train Acc.: 77.15% | Validation Acc.: 74.70%\n", "Time elapsed: 11.26 min\n", "Epoch: 034/100 | Batch 000/192 | Cost: 0.5633\n", "Epoch: 034/100 | Batch 120/192 | Cost: 0.6176\n", "Epoch: 034/100 Train Acc.: 77.53% | Validation Acc.: 74.30%\n", "Time elapsed: 11.60 min\n", "Epoch: 035/100 | Batch 000/192 | Cost: 0.6300\n", "Epoch: 035/100 | Batch 120/192 | Cost: 0.6720\n", "Epoch: 035/100 Train Acc.: 78.39% | Validation Acc.: 76.10%\n", "Time elapsed: 11.94 min\n", "Epoch: 036/100 | Batch 000/192 | Cost: 0.7154\n", "Epoch: 036/100 | Batch 120/192 | Cost: 0.6519\n", "Epoch: 036/100 Train Acc.: 78.49% | Validation Acc.: 75.40%\n", "Time elapsed: 12.29 min\n", "Epoch: 037/100 | Batch 000/192 | Cost: 0.6381\n", "Epoch: 037/100 | Batch 120/192 | Cost: 0.6618\n", "Epoch: 037/100 Train Acc.: 79.58% | Validation Acc.: 75.80%\n", "Time elapsed: 12.63 min\n", "Epoch: 038/100 | Batch 000/192 | Cost: 0.6078\n", "Epoch: 038/100 | Batch 120/192 | Cost: 0.5283\n", "Epoch: 038/100 Train Acc.: 79.17% | Validation Acc.: 76.00%\n", "Time elapsed: 12.97 min\n", "Epoch: 039/100 | Batch 000/192 | Cost: 0.5576\n", "Epoch: 039/100 | Batch 120/192 | Cost: 0.6219\n", "Epoch: 039/100 Train Acc.: 79.91% | Validation Acc.: 76.70%\n", "Time elapsed: 13.32 min\n", "Epoch: 040/100 | Batch 000/192 | Cost: 0.5660\n", "Epoch: 040/100 | Batch 120/192 | Cost: 0.5577\n", "Epoch: 040/100 Train Acc.: 80.49% | Validation Acc.: 76.50%\n", "Time elapsed: 13.66 min\n", "Epoch: 041/100 | Batch 000/192 | Cost: 0.5098\n", "Epoch: 041/100 | Batch 120/192 | Cost: 0.6621\n", "Epoch: 041/100 Train Acc.: 80.86% | Validation Acc.: 75.70%\n", "Time elapsed: 14.00 min\n", "Epoch: 042/100 | Batch 000/192 | Cost: 0.4589\n", "Epoch: 042/100 | Batch 120/192 | Cost: 0.5637\n", "Epoch: 042/100 Train Acc.: 81.11% | Validation Acc.: 77.00%\n", "Time elapsed: 14.34 min\n", "Epoch: 043/100 | Batch 000/192 | Cost: 0.4507\n", "Epoch: 043/100 | Batch 120/192 | Cost: 0.4865\n", "Epoch: 043/100 Train Acc.: 82.07% | Validation Acc.: 78.10%\n", "Time elapsed: 14.68 min\n", "Epoch: 044/100 | Batch 000/192 | Cost: 0.4427\n", "Epoch: 044/100 | Batch 120/192 | Cost: 0.5242\n", "Epoch: 044/100 Train Acc.: 82.61% | Validation Acc.: 79.10%\n", "Time elapsed: 15.02 min\n", "Epoch: 045/100 | Batch 000/192 | Cost: 0.4989\n", "Epoch: 045/100 | Batch 120/192 | Cost: 0.5811\n", "Epoch: 045/100 Train Acc.: 82.55% | Validation Acc.: 79.30%\n", "Time elapsed: 15.36 min\n", "Epoch: 046/100 | Batch 000/192 | Cost: 0.5303\n", "Epoch: 046/100 | Batch 120/192 | Cost: 0.4242\n", "Epoch: 046/100 Train Acc.: 81.80% | Validation Acc.: 76.80%\n", "Time elapsed: 15.71 min\n", "Epoch: 047/100 | Batch 000/192 | Cost: 0.4491\n", "Epoch: 047/100 | Batch 120/192 | Cost: 0.4902\n", "Epoch: 047/100 Train Acc.: 82.54% | Validation Acc.: 77.90%\n", "Time elapsed: 16.05 min\n", "Epoch: 048/100 | Batch 000/192 | Cost: 0.4913\n", "Epoch: 048/100 | Batch 120/192 | Cost: 0.6474\n", "Epoch: 048/100 Train Acc.: 83.31% | Validation Acc.: 79.20%\n", "Time elapsed: 16.39 min\n", "Epoch: 049/100 | Batch 000/192 | Cost: 0.4585\n", "Epoch: 049/100 | Batch 120/192 | Cost: 0.4845\n", "Epoch: 049/100 Train Acc.: 83.53% | Validation Acc.: 78.40%\n", "Time elapsed: 16.73 min\n", "Epoch: 050/100 | Batch 000/192 | Cost: 0.6038\n", "Epoch: 050/100 | Batch 120/192 | Cost: 0.5446\n", "Epoch: 050/100 Train Acc.: 83.86% | Validation Acc.: 80.50%\n", "Time elapsed: 17.08 min\n", "Epoch: 051/100 | Batch 000/192 | Cost: 0.3793\n", "Epoch: 051/100 | Batch 120/192 | Cost: 0.4499\n", "Epoch: 051/100 Train Acc.: 83.11% | Validation Acc.: 76.80%\n", "Time elapsed: 17.42 min\n", "Epoch: 052/100 | Batch 000/192 | Cost: 0.5527\n", "Epoch: 052/100 | Batch 120/192 | Cost: 0.4610\n", "Epoch: 052/100 Train Acc.: 84.63% | Validation Acc.: 79.30%\n", "Time elapsed: 17.76 min\n", "Epoch: 053/100 | Batch 000/192 | Cost: 0.5015\n", "Epoch: 053/100 | Batch 120/192 | Cost: 0.4079\n", "Epoch: 053/100 Train Acc.: 84.18% | Validation Acc.: 77.60%\n", "Time elapsed: 18.11 min\n", "Epoch: 054/100 | Batch 000/192 | Cost: 0.5012\n", "Epoch: 054/100 | Batch 120/192 | Cost: 0.4912\n", "Epoch: 054/100 Train Acc.: 84.41% | Validation Acc.: 77.20%\n", "Time elapsed: 18.45 min\n", "Epoch: 055/100 | Batch 000/192 | Cost: 0.4015\n", "Epoch: 055/100 | Batch 120/192 | Cost: 0.4919\n", "Epoch: 055/100 Train Acc.: 85.16% | Validation Acc.: 80.20%\n", "Time elapsed: 18.79 min\n", "Epoch: 056/100 | Batch 000/192 | Cost: 0.3976\n", "Epoch: 056/100 | Batch 120/192 | Cost: 0.4252\n", "Epoch: 056/100 Train Acc.: 85.28% | Validation Acc.: 80.30%\n", "Time elapsed: 19.14 min\n", "Epoch: 057/100 | Batch 000/192 | Cost: 0.3372\n", "Epoch: 057/100 | Batch 120/192 | Cost: 0.4634\n", "Epoch: 057/100 Train Acc.: 84.29% | Validation Acc.: 78.60%\n", "Time elapsed: 19.48 min\n", "Epoch: 058/100 | Batch 000/192 | Cost: 0.4438\n", "Epoch: 058/100 | Batch 120/192 | Cost: 0.3490\n", "Epoch: 058/100 Train Acc.: 85.93% | Validation Acc.: 77.50%\n", "Time elapsed: 19.82 min\n", "Epoch: 059/100 | Batch 000/192 | Cost: 0.4541\n", "Epoch: 059/100 | Batch 120/192 | Cost: 0.4415\n", "Epoch: 059/100 Train Acc.: 84.34% | Validation Acc.: 78.40%\n", "Time elapsed: 20.16 min\n", "Epoch: 060/100 | Batch 000/192 | Cost: 0.3766\n", "Epoch: 060/100 | Batch 120/192 | Cost: 0.4851\n", "Epoch: 060/100 Train Acc.: 86.02% | Validation Acc.: 80.00%\n", "Time elapsed: 20.51 min\n", "Epoch: 061/100 | Batch 000/192 | Cost: 0.4967\n", "Epoch: 061/100 | Batch 120/192 | Cost: 0.3708\n", "Epoch: 061/100 Train Acc.: 85.57% | Validation Acc.: 79.50%\n", "Time elapsed: 20.85 min\n", "Epoch: 062/100 | Batch 000/192 | Cost: 0.4197\n", "Epoch: 062/100 | Batch 120/192 | Cost: 0.3054\n", "Epoch: 062/100 Train Acc.: 86.23% | Validation Acc.: 78.40%\n", "Time elapsed: 21.19 min\n", "Epoch: 063/100 | Batch 000/192 | Cost: 0.4595\n", "Epoch: 063/100 | Batch 120/192 | Cost: 0.4200\n", "Epoch: 063/100 Train Acc.: 86.52% | Validation Acc.: 79.80%\n", "Time elapsed: 21.54 min\n", "Epoch: 064/100 | Batch 000/192 | Cost: 0.3806\n", "Epoch: 064/100 | Batch 120/192 | Cost: 0.3670\n", "Epoch: 064/100 Train Acc.: 86.81% | Validation Acc.: 80.20%\n", "Time elapsed: 21.88 min\n", "Epoch: 065/100 | Batch 000/192 | Cost: 0.3922\n", "Epoch: 065/100 | Batch 120/192 | Cost: 0.3698\n", "Epoch: 065/100 Train Acc.: 86.30% | Validation Acc.: 77.90%\n", "Time elapsed: 22.22 min\n", "Epoch: 066/100 | Batch 000/192 | Cost: 0.3608\n", "Epoch: 066/100 | Batch 120/192 | Cost: 0.4444\n", "Epoch: 066/100 Train Acc.: 88.01% | Validation Acc.: 80.10%\n", "Time elapsed: 22.56 min\n", "Epoch: 067/100 | Batch 000/192 | Cost: 0.3374\n", "Epoch: 067/100 | Batch 120/192 | Cost: 0.3158\n", "Epoch: 067/100 Train Acc.: 87.94% | Validation Acc.: 80.40%\n", "Time elapsed: 22.91 min\n", "Epoch: 068/100 | Batch 000/192 | Cost: 0.3959\n", "Epoch: 068/100 | Batch 120/192 | Cost: 0.2217\n", "Epoch: 068/100 Train Acc.: 87.74% | Validation Acc.: 79.70%\n", "Time elapsed: 23.25 min\n", "Epoch: 069/100 | Batch 000/192 | Cost: 0.3795\n", "Epoch: 069/100 | Batch 120/192 | Cost: 0.3398\n", "Epoch: 069/100 Train Acc.: 88.28% | Validation Acc.: 79.70%\n", "Time elapsed: 23.59 min\n", "Epoch: 070/100 | Batch 000/192 | Cost: 0.3098\n", "Epoch: 070/100 | Batch 120/192 | Cost: 0.3012\n", "Epoch: 070/100 Train Acc.: 87.96% | Validation Acc.: 80.80%\n", "Time elapsed: 23.93 min\n", "Epoch: 071/100 | Batch 000/192 | Cost: 0.3705\n", "Epoch: 071/100 | Batch 120/192 | Cost: 0.2943\n", "Epoch: 071/100 Train Acc.: 88.02% | Validation Acc.: 79.90%\n", "Time elapsed: 24.27 min\n", "Epoch: 072/100 | Batch 000/192 | Cost: 0.3353\n", "Epoch: 072/100 | Batch 120/192 | Cost: 0.3237\n", "Epoch: 072/100 Train Acc.: 88.34% | Validation Acc.: 80.60%\n", "Time elapsed: 24.62 min\n", "Epoch: 073/100 | Batch 000/192 | Cost: 0.3683\n", "Epoch: 073/100 | Batch 120/192 | Cost: 0.4178\n", "Epoch: 073/100 Train Acc.: 88.93% | Validation Acc.: 80.10%\n", "Time elapsed: 24.96 min\n", "Epoch: 074/100 | Batch 000/192 | Cost: 0.2282\n", "Epoch: 074/100 | Batch 120/192 | Cost: 0.1967\n", "Epoch: 074/100 Train Acc.: 88.58% | Validation Acc.: 81.40%\n", "Time elapsed: 25.30 min\n", "Epoch: 075/100 | Batch 000/192 | Cost: 0.2701\n", "Epoch: 075/100 | Batch 120/192 | Cost: 0.3722\n", "Epoch: 075/100 Train Acc.: 87.93% | Validation Acc.: 79.70%\n", "Time elapsed: 25.64 min\n", "Epoch: 076/100 | Batch 000/192 | Cost: 0.2850\n", "Epoch: 076/100 | Batch 120/192 | Cost: 0.2874\n", "Epoch: 076/100 Train Acc.: 88.92% | Validation Acc.: 81.10%\n", "Time elapsed: 25.98 min\n", "Epoch: 077/100 | Batch 000/192 | Cost: 0.2686\n", "Epoch: 077/100 | Batch 120/192 | Cost: 0.4312\n", "Epoch: 077/100 Train Acc.: 89.39% | Validation Acc.: 81.60%\n", "Time elapsed: 26.33 min\n", "Epoch: 078/100 | Batch 000/192 | Cost: 0.2282\n", "Epoch: 078/100 | Batch 120/192 | Cost: 0.3395\n", "Epoch: 078/100 Train Acc.: 88.67% | Validation Acc.: 78.90%\n", "Time elapsed: 26.67 min\n", "Epoch: 079/100 | Batch 000/192 | Cost: 0.3127\n", "Epoch: 079/100 | Batch 120/192 | Cost: 0.2906\n", "Epoch: 079/100 Train Acc.: 90.77% | Validation Acc.: 81.20%\n", "Time elapsed: 27.01 min\n", "Epoch: 080/100 | Batch 000/192 | Cost: 0.2468\n", "Epoch: 080/100 | Batch 120/192 | Cost: 0.3638\n", "Epoch: 080/100 Train Acc.: 89.99% | Validation Acc.: 80.40%\n", "Time elapsed: 27.35 min\n", "Epoch: 081/100 | Batch 000/192 | Cost: 0.2936\n", "Epoch: 081/100 | Batch 120/192 | Cost: 0.3772\n", "Epoch: 081/100 Train Acc.: 90.76% | Validation Acc.: 80.50%\n", "Time elapsed: 27.69 min\n", "Epoch: 082/100 | Batch 000/192 | Cost: 0.2584\n", "Epoch: 082/100 | Batch 120/192 | Cost: 0.2718\n", "Epoch: 082/100 Train Acc.: 91.01% | Validation Acc.: 81.20%\n", "Time elapsed: 28.04 min\n", "Epoch: 083/100 | Batch 000/192 | Cost: 0.1904\n", "Epoch: 083/100 | Batch 120/192 | Cost: 0.3090\n", "Epoch: 083/100 Train Acc.: 90.68% | Validation Acc.: 81.30%\n", "Time elapsed: 28.38 min\n", "Epoch: 084/100 | Batch 000/192 | Cost: 0.2506\n", "Epoch: 084/100 | Batch 120/192 | Cost: 0.2825\n", "Epoch: 084/100 Train Acc.: 90.43% | Validation Acc.: 80.40%\n", "Time elapsed: 28.73 min\n", "Epoch: 085/100 | Batch 000/192 | Cost: 0.2307\n", "Epoch: 085/100 | Batch 120/192 | Cost: 0.2441\n", "Epoch: 085/100 Train Acc.: 90.88% | Validation Acc.: 81.30%\n", "Time elapsed: 29.07 min\n", "Epoch: 086/100 | Batch 000/192 | Cost: 0.3149\n", "Epoch: 086/100 | Batch 120/192 | Cost: 0.3129\n", "Epoch: 086/100 Train Acc.: 90.13% | Validation Acc.: 82.40%\n", "Time elapsed: 29.41 min\n", "Epoch: 087/100 | Batch 000/192 | Cost: 0.3487\n", "Epoch: 087/100 | Batch 120/192 | Cost: 0.2559\n", "Epoch: 087/100 Train Acc.: 90.74% | Validation Acc.: 81.40%\n", "Time elapsed: 29.75 min\n", "Epoch: 088/100 | Batch 000/192 | Cost: 0.2412\n", "Epoch: 088/100 | Batch 120/192 | Cost: 0.1828\n", "Epoch: 088/100 Train Acc.: 91.08% | Validation Acc.: 80.20%\n", "Time elapsed: 30.10 min\n", "Epoch: 089/100 | Batch 000/192 | Cost: 0.2957\n", "Epoch: 089/100 | Batch 120/192 | Cost: 0.2939\n", "Epoch: 089/100 Train Acc.: 90.67% | Validation Acc.: 80.30%\n", "Time elapsed: 30.44 min\n", "Epoch: 090/100 | Batch 000/192 | Cost: 0.2298\n", "Epoch: 090/100 | Batch 120/192 | Cost: 0.2900\n", "Epoch: 090/100 Train Acc.: 91.63% | Validation Acc.: 79.00%\n", "Time elapsed: 30.78 min\n", "Epoch: 091/100 | Batch 000/192 | Cost: 0.2558\n", "Epoch: 091/100 | Batch 120/192 | Cost: 0.2915\n", "Epoch: 091/100 Train Acc.: 91.36% | Validation Acc.: 81.00%\n", "Time elapsed: 31.12 min\n", "Epoch: 092/100 | Batch 000/192 | Cost: 0.1510\n", "Epoch: 092/100 | Batch 120/192 | Cost: 0.1974\n", "Epoch: 092/100 Train Acc.: 91.84% | Validation Acc.: 82.20%\n", "Time elapsed: 31.47 min\n", "Epoch: 093/100 | Batch 000/192 | Cost: 0.2308\n", "Epoch: 093/100 | Batch 120/192 | Cost: 0.2247\n", "Epoch: 093/100 Train Acc.: 91.50% | Validation Acc.: 80.50%\n", "Time elapsed: 31.81 min\n", "Epoch: 094/100 | Batch 000/192 | Cost: 0.2712\n", "Epoch: 094/100 | Batch 120/192 | Cost: 0.3268\n", "Epoch: 094/100 Train Acc.: 91.74% | Validation Acc.: 81.30%\n", "Time elapsed: 32.15 min\n", "Epoch: 095/100 | Batch 000/192 | Cost: 0.2417\n", "Epoch: 095/100 | Batch 120/192 | Cost: 0.2162\n", "Epoch: 095/100 Train Acc.: 91.53% | Validation Acc.: 79.00%\n", "Time elapsed: 32.49 min\n", "Epoch: 096/100 | Batch 000/192 | Cost: 0.2523\n", "Epoch: 096/100 | Batch 120/192 | Cost: 0.2598\n", "Epoch: 096/100 Train Acc.: 91.56% | Validation Acc.: 81.00%\n", "Time elapsed: 32.84 min\n", "Epoch: 097/100 | Batch 000/192 | Cost: 0.2027\n", "Epoch: 097/100 | Batch 120/192 | Cost: 0.2432\n", "Epoch: 097/100 Train Acc.: 92.53% | Validation Acc.: 80.80%\n", "Time elapsed: 33.18 min\n", "Epoch: 098/100 | Batch 000/192 | Cost: 0.2115\n", "Epoch: 098/100 | Batch 120/192 | Cost: 0.2746\n", "Epoch: 098/100 Train Acc.: 92.30% | Validation Acc.: 81.10%\n", "Time elapsed: 33.52 min\n", "Epoch: 099/100 | Batch 000/192 | Cost: 0.1611\n", "Epoch: 099/100 | Batch 120/192 | Cost: 0.2142\n", "Epoch: 099/100 Train Acc.: 92.66% | Validation Acc.: 80.90%\n", "Time elapsed: 33.86 min\n", "Epoch: 100/100 | Batch 000/192 | Cost: 0.1935\n", "Epoch: 100/100 | Batch 120/192 | Cost: 0.2488\n", "Epoch: 100/100 Train Acc.: 92.68% | Validation Acc.: 80.20%\n", "Time elapsed: 34.20 min\n", "Total Training Time: 34.20 min\n" ] } ], "source": [ "def compute_accuracy(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " \n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " ### PREPARE MINIBATCH\n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 120:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " # no need to build the computation graph for backprop when computing accuracy\n", " with torch.set_grad_enabled(False):\n", " train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n", " valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n", " f' | Validation Acc.: {valid_acc:.2f}%')\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PIL.Image 6.2.1\n", "torch 1.3.0\n", "numpy 1.17.4\n", "matplotlib 3.1.0\n", "pandas 0.24.2\n", "torchvision 0.4.1a0+d94043a\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }