{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.9.0\n", "\n", "torch 1.3.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# BatchNorm before and after Activation for Network-in-Network CIFAR-10 Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The CNN architecture is based on \n", "\n", "- Lin, Min, Qiang Chen, and Shuicheng Yan. \"[Network in network](https://arxiv.org/abs/1312.4400).\" arXiv preprint arXiv:1312.4400 (2013).\n", "\n", "This paper compares using BatchNorm before the activation function as suggested in\n", "\n", "- Ioffe, Sergey, and Christian Szegedy. \"[Batch normalization: Accelerating deep network training by reducing internal covariate shift.](https://arxiv.org/abs/1502.03167)\" arXiv preprint arXiv:1502.03167 (2015)\n", "\n", "and after the activation function as it is nowadays common practice." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import Subset\n", "\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model Settings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.0005\n", "BATCH_SIZE = 256\n", "NUM_EPOCHS = 100\n", "\n", "# Architecture\n", "NUM_CLASSES = 10\n", "\n", "# Other\n", "DEVICE = \"cuda:2\"\n", "GRAYSCALE = False" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n", "Image batch dimensions: torch.Size([256, 3, 32, 32])\n", "Image label dimensions: torch.Size([256])\n" ] } ], "source": [ "##########################\n", "### CIFAR-10 Dataset\n", "##########################\n", "\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "\n", "\n", "train_indices = torch.arange(0, 49000)\n", "valid_indices = torch.arange(49000, 50000)\n", "\n", "\n", "train_and_valid = datasets.CIFAR10(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "\n", "\n", "test_dataset = datasets.CIFAR10(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "#####################################################\n", "### Data Loaders\n", "#####################################################\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=8,\n", " shuffle=False)\n", "\n", "#####################################################\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n", "\n", "for images, labels in test_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n", " \n", "for images, labels in valid_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Without BatchNorm" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class NiN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(NiN, self).__init__()\n", " self.num_classes = num_classes\n", " self.classifier = nn.Sequential(\n", " nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n", "\n", " )\n", "\n", " def forward(self, x):\n", " x = self.classifier(x)\n", " logits = x.view(x.size(0), self.num_classes)\n", " probas = torch.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = NiN(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/192 | Cost: 2.3043\n", "Epoch: 001/100 | Batch 120/192 | Cost: 2.0653\n", "Epoch: 001/100 Train Acc.: 24.69% | Validation Acc.: 24.50%\n", "Time elapsed: 0.33 min\n", "Epoch: 002/100 | Batch 000/192 | Cost: 1.8584\n", "Epoch: 002/100 | Batch 120/192 | Cost: 1.7447\n", "Epoch: 002/100 Train Acc.: 36.51% | Validation Acc.: 36.90%\n", "Time elapsed: 0.65 min\n", "Epoch: 003/100 | Batch 000/192 | Cost: 1.6050\n", "Epoch: 003/100 | Batch 120/192 | Cost: 1.5591\n", "Epoch: 003/100 Train Acc.: 40.50% | Validation Acc.: 37.50%\n", "Time elapsed: 0.97 min\n", "Epoch: 004/100 | Batch 000/192 | Cost: 1.5428\n", "Epoch: 004/100 | Batch 120/192 | Cost: 1.4454\n", "Epoch: 004/100 Train Acc.: 46.12% | Validation Acc.: 45.80%\n", "Time elapsed: 1.30 min\n", "Epoch: 005/100 | Batch 000/192 | Cost: 1.4038\n", "Epoch: 005/100 | Batch 120/192 | Cost: 1.4141\n", "Epoch: 005/100 Train Acc.: 50.21% | Validation Acc.: 49.90%\n", "Time elapsed: 1.63 min\n", "Epoch: 006/100 | Batch 000/192 | Cost: 1.3475\n", "Epoch: 006/100 | Batch 120/192 | Cost: 1.2627\n", "Epoch: 006/100 Train Acc.: 52.66% | Validation Acc.: 54.40%\n", "Time elapsed: 1.97 min\n", "Epoch: 007/100 | Batch 000/192 | Cost: 1.3238\n", "Epoch: 007/100 | Batch 120/192 | Cost: 1.2220\n", "Epoch: 007/100 Train Acc.: 54.42% | Validation Acc.: 54.40%\n", "Time elapsed: 2.31 min\n", "Epoch: 008/100 | Batch 000/192 | Cost: 1.2009\n", "Epoch: 008/100 | Batch 120/192 | Cost: 1.2045\n", "Epoch: 008/100 Train Acc.: 55.81% | Validation Acc.: 55.30%\n", "Time elapsed: 2.65 min\n", "Epoch: 009/100 | Batch 000/192 | Cost: 1.2797\n", "Epoch: 009/100 | Batch 120/192 | Cost: 1.1397\n", "Epoch: 009/100 Train Acc.: 59.10% | Validation Acc.: 60.60%\n", "Time elapsed: 2.98 min\n", "Epoch: 010/100 | Batch 000/192 | Cost: 1.0562\n", "Epoch: 010/100 | Batch 120/192 | Cost: 1.1625\n", "Epoch: 010/100 Train Acc.: 59.79% | Validation Acc.: 60.90%\n", "Time elapsed: 3.32 min\n", "Epoch: 011/100 | Batch 000/192 | Cost: 1.0868\n", "Epoch: 011/100 | Batch 120/192 | Cost: 1.0636\n", "Epoch: 011/100 Train Acc.: 60.43% | Validation Acc.: 61.00%\n", "Time elapsed: 3.66 min\n", "Epoch: 012/100 | Batch 000/192 | Cost: 1.0049\n", "Epoch: 012/100 | Batch 120/192 | Cost: 1.2247\n", "Epoch: 012/100 Train Acc.: 62.14% | Validation Acc.: 62.70%\n", "Time elapsed: 4.00 min\n", "Epoch: 013/100 | Batch 000/192 | Cost: 0.9232\n", "Epoch: 013/100 | Batch 120/192 | Cost: 1.0345\n", "Epoch: 013/100 Train Acc.: 61.42% | Validation Acc.: 61.70%\n", "Time elapsed: 4.34 min\n", "Epoch: 014/100 | Batch 000/192 | Cost: 0.9256\n", "Epoch: 014/100 | Batch 120/192 | Cost: 1.1639\n", "Epoch: 014/100 Train Acc.: 63.82% | Validation Acc.: 65.80%\n", "Time elapsed: 4.68 min\n", "Epoch: 015/100 | Batch 000/192 | Cost: 0.9600\n", "Epoch: 015/100 | Batch 120/192 | Cost: 1.0263\n", "Epoch: 015/100 Train Acc.: 63.94% | Validation Acc.: 64.00%\n", "Time elapsed: 5.02 min\n", "Epoch: 016/100 | Batch 000/192 | Cost: 0.8859\n", "Epoch: 016/100 | Batch 120/192 | Cost: 1.0307\n", "Epoch: 016/100 Train Acc.: 65.79% | Validation Acc.: 66.40%\n", "Time elapsed: 5.36 min\n", "Epoch: 017/100 | Batch 000/192 | Cost: 1.0020\n", "Epoch: 017/100 | Batch 120/192 | Cost: 0.9755\n", "Epoch: 017/100 Train Acc.: 66.95% | Validation Acc.: 66.60%\n", "Time elapsed: 5.70 min\n", "Epoch: 018/100 | Batch 000/192 | Cost: 0.9551\n", "Epoch: 018/100 | Batch 120/192 | Cost: 0.8429\n", "Epoch: 018/100 Train Acc.: 67.56% | Validation Acc.: 66.30%\n", "Time elapsed: 6.04 min\n", "Epoch: 019/100 | Batch 000/192 | Cost: 1.0420\n", "Epoch: 019/100 | Batch 120/192 | Cost: 0.9771\n", "Epoch: 019/100 Train Acc.: 69.44% | Validation Acc.: 68.20%\n", "Time elapsed: 6.38 min\n", "Epoch: 020/100 | Batch 000/192 | Cost: 0.8471\n", "Epoch: 020/100 | Batch 120/192 | Cost: 0.8322\n", "Epoch: 020/100 Train Acc.: 69.99% | Validation Acc.: 70.20%\n", "Time elapsed: 6.72 min\n", "Epoch: 021/100 | Batch 000/192 | Cost: 0.8974\n", "Epoch: 021/100 | Batch 120/192 | Cost: 0.8585\n", "Epoch: 021/100 Train Acc.: 69.52% | Validation Acc.: 69.30%\n", "Time elapsed: 7.07 min\n", "Epoch: 022/100 | Batch 000/192 | Cost: 0.8691\n", "Epoch: 022/100 | Batch 120/192 | Cost: 0.6618\n", "Epoch: 022/100 Train Acc.: 68.26% | Validation Acc.: 65.90%\n", "Time elapsed: 7.41 min\n", "Epoch: 023/100 | Batch 000/192 | Cost: 0.9277\n", "Epoch: 023/100 | Batch 120/192 | Cost: 0.9011\n", "Epoch: 023/100 Train Acc.: 71.66% | Validation Acc.: 72.10%\n", "Time elapsed: 7.75 min\n", "Epoch: 024/100 | Batch 000/192 | Cost: 0.7764\n", "Epoch: 024/100 | Batch 120/192 | Cost: 0.7561\n", "Epoch: 024/100 Train Acc.: 71.70% | Validation Acc.: 68.80%\n", "Time elapsed: 8.09 min\n", "Epoch: 025/100 | Batch 000/192 | Cost: 0.8113\n", "Epoch: 025/100 | Batch 120/192 | Cost: 0.7186\n", "Epoch: 025/100 Train Acc.: 73.62% | Validation Acc.: 73.00%\n", "Time elapsed: 8.44 min\n", "Epoch: 026/100 | Batch 000/192 | Cost: 0.6515\n", "Epoch: 026/100 | Batch 120/192 | Cost: 0.6954\n", "Epoch: 026/100 Train Acc.: 72.22% | Validation Acc.: 70.20%\n", "Time elapsed: 8.78 min\n", "Epoch: 027/100 | Batch 000/192 | Cost: 0.7278\n", "Epoch: 027/100 | Batch 120/192 | Cost: 0.7117\n", "Epoch: 027/100 Train Acc.: 74.82% | Validation Acc.: 72.30%\n", "Time elapsed: 9.12 min\n", "Epoch: 028/100 | Batch 000/192 | Cost: 0.6732\n", "Epoch: 028/100 | Batch 120/192 | Cost: 0.6591\n", "Epoch: 028/100 Train Acc.: 74.93% | Validation Acc.: 72.60%\n", "Time elapsed: 9.46 min\n", "Epoch: 029/100 | Batch 000/192 | Cost: 0.7438\n", "Epoch: 029/100 | Batch 120/192 | Cost: 0.6429\n", "Epoch: 029/100 Train Acc.: 75.44% | Validation Acc.: 72.80%\n", "Time elapsed: 9.80 min\n", "Epoch: 030/100 | Batch 000/192 | Cost: 0.7306\n", "Epoch: 030/100 | Batch 120/192 | Cost: 0.6643\n", "Epoch: 030/100 Train Acc.: 76.34% | Validation Acc.: 74.40%\n", "Time elapsed: 10.15 min\n", "Epoch: 031/100 | Batch 000/192 | Cost: 0.5957\n", "Epoch: 031/100 | Batch 120/192 | Cost: 0.5574\n", "Epoch: 031/100 Train Acc.: 76.60% | Validation Acc.: 75.90%\n", "Time elapsed: 10.49 min\n", "Epoch: 032/100 | Batch 000/192 | Cost: 0.6414\n", "Epoch: 032/100 | Batch 120/192 | Cost: 0.6951\n", "Epoch: 032/100 Train Acc.: 77.15% | Validation Acc.: 76.10%\n", "Time elapsed: 10.83 min\n", "Epoch: 033/100 | Batch 000/192 | Cost: 0.6898\n", "Epoch: 033/100 | Batch 120/192 | Cost: 0.7784\n", "Epoch: 033/100 Train Acc.: 77.15% | Validation Acc.: 74.70%\n", "Time elapsed: 11.17 min\n", "Epoch: 034/100 | Batch 000/192 | Cost: 0.5633\n", "Epoch: 034/100 | Batch 120/192 | Cost: 0.6176\n", "Epoch: 034/100 Train Acc.: 77.53% | Validation Acc.: 74.30%\n", "Time elapsed: 11.52 min\n", "Epoch: 035/100 | Batch 000/192 | Cost: 0.6300\n", "Epoch: 035/100 | Batch 120/192 | Cost: 0.6720\n", "Epoch: 035/100 Train Acc.: 78.39% | Validation Acc.: 76.10%\n", "Time elapsed: 11.86 min\n", "Epoch: 036/100 | Batch 000/192 | Cost: 0.7154\n", "Epoch: 036/100 | Batch 120/192 | Cost: 0.6519\n", "Epoch: 036/100 Train Acc.: 78.49% | Validation Acc.: 75.40%\n", "Time elapsed: 12.20 min\n", "Epoch: 037/100 | Batch 000/192 | Cost: 0.6381\n", "Epoch: 037/100 | Batch 120/192 | Cost: 0.6618\n", "Epoch: 037/100 Train Acc.: 79.58% | Validation Acc.: 75.80%\n", "Time elapsed: 12.54 min\n", "Epoch: 038/100 | Batch 000/192 | Cost: 0.6078\n", "Epoch: 038/100 | Batch 120/192 | Cost: 0.5283\n", "Epoch: 038/100 Train Acc.: 79.17% | Validation Acc.: 76.00%\n", "Time elapsed: 12.88 min\n", "Epoch: 039/100 | Batch 000/192 | Cost: 0.5576\n", "Epoch: 039/100 | Batch 120/192 | Cost: 0.6219\n", "Epoch: 039/100 Train Acc.: 79.91% | Validation Acc.: 76.70%\n", "Time elapsed: 13.22 min\n", "Epoch: 040/100 | Batch 000/192 | Cost: 0.5660\n", "Epoch: 040/100 | Batch 120/192 | Cost: 0.5577\n", "Epoch: 040/100 Train Acc.: 80.49% | Validation Acc.: 76.50%\n", "Time elapsed: 13.56 min\n", "Epoch: 041/100 | Batch 000/192 | Cost: 0.5098\n", "Epoch: 041/100 | Batch 120/192 | Cost: 0.6621\n", "Epoch: 041/100 Train Acc.: 80.86% | Validation Acc.: 75.70%\n", "Time elapsed: 13.90 min\n", "Epoch: 042/100 | Batch 000/192 | Cost: 0.4589\n", "Epoch: 042/100 | Batch 120/192 | Cost: 0.5637\n", "Epoch: 042/100 Train Acc.: 81.11% | Validation Acc.: 77.00%\n", "Time elapsed: 14.24 min\n", "Epoch: 043/100 | Batch 000/192 | Cost: 0.4507\n", "Epoch: 043/100 | Batch 120/192 | Cost: 0.4865\n", "Epoch: 043/100 Train Acc.: 82.07% | Validation Acc.: 78.10%\n", "Time elapsed: 14.58 min\n", "Epoch: 044/100 | Batch 000/192 | Cost: 0.4427\n", "Epoch: 044/100 | Batch 120/192 | Cost: 0.5242\n", "Epoch: 044/100 Train Acc.: 82.61% | Validation Acc.: 79.10%\n", "Time elapsed: 14.92 min\n", "Epoch: 045/100 | Batch 000/192 | Cost: 0.4989\n", "Epoch: 045/100 | Batch 120/192 | Cost: 0.5811\n", "Epoch: 045/100 Train Acc.: 82.55% | Validation Acc.: 79.30%\n", "Time elapsed: 15.26 min\n", "Epoch: 046/100 | Batch 000/192 | Cost: 0.5303\n", "Epoch: 046/100 | Batch 120/192 | Cost: 0.4242\n", "Epoch: 046/100 Train Acc.: 81.80% | Validation Acc.: 76.80%\n", "Time elapsed: 15.60 min\n", "Epoch: 047/100 | Batch 000/192 | Cost: 0.4491\n", "Epoch: 047/100 | Batch 120/192 | Cost: 0.4902\n", "Epoch: 047/100 Train Acc.: 82.54% | Validation Acc.: 77.90%\n", "Time elapsed: 15.94 min\n", "Epoch: 048/100 | Batch 000/192 | Cost: 0.4913\n", "Epoch: 048/100 | Batch 120/192 | Cost: 0.6474\n", "Epoch: 048/100 Train Acc.: 83.31% | Validation Acc.: 79.20%\n", "Time elapsed: 16.28 min\n", "Epoch: 049/100 | Batch 000/192 | Cost: 0.4585\n", "Epoch: 049/100 | Batch 120/192 | Cost: 0.4845\n", "Epoch: 049/100 Train Acc.: 83.53% | Validation Acc.: 78.40%\n", "Time elapsed: 16.62 min\n", "Epoch: 050/100 | Batch 000/192 | Cost: 0.6038\n", "Epoch: 050/100 | Batch 120/192 | Cost: 0.5446\n", "Epoch: 050/100 Train Acc.: 83.86% | Validation Acc.: 80.50%\n", "Time elapsed: 16.96 min\n", "Epoch: 051/100 | Batch 000/192 | Cost: 0.3793\n", "Epoch: 051/100 | Batch 120/192 | Cost: 0.4499\n", "Epoch: 051/100 Train Acc.: 83.11% | Validation Acc.: 76.80%\n", "Time elapsed: 17.29 min\n", "Epoch: 052/100 | Batch 000/192 | Cost: 0.5527\n", "Epoch: 052/100 | Batch 120/192 | Cost: 0.4610\n", "Epoch: 052/100 Train Acc.: 84.63% | Validation Acc.: 79.30%\n", "Time elapsed: 17.63 min\n", "Epoch: 053/100 | Batch 000/192 | Cost: 0.5015\n", "Epoch: 053/100 | Batch 120/192 | Cost: 0.4079\n", "Epoch: 053/100 Train Acc.: 84.18% | Validation Acc.: 77.60%\n", "Time elapsed: 17.97 min\n", "Epoch: 054/100 | Batch 000/192 | Cost: 0.5012\n", "Epoch: 054/100 | Batch 120/192 | Cost: 0.4912\n", "Epoch: 054/100 Train Acc.: 84.41% | Validation Acc.: 77.20%\n", "Time elapsed: 18.30 min\n", "Epoch: 055/100 | Batch 000/192 | Cost: 0.4015\n", "Epoch: 055/100 | Batch 120/192 | Cost: 0.4919\n", "Epoch: 055/100 Train Acc.: 85.16% | Validation Acc.: 80.20%\n", "Time elapsed: 18.64 min\n", "Epoch: 056/100 | Batch 000/192 | Cost: 0.3976\n", "Epoch: 056/100 | Batch 120/192 | Cost: 0.4252\n", "Epoch: 056/100 Train Acc.: 85.28% | Validation Acc.: 80.30%\n", "Time elapsed: 18.97 min\n", "Epoch: 057/100 | Batch 000/192 | Cost: 0.3372\n", "Epoch: 057/100 | Batch 120/192 | Cost: 0.4634\n", "Epoch: 057/100 Train Acc.: 84.29% | Validation Acc.: 78.60%\n", "Time elapsed: 19.32 min\n", "Epoch: 058/100 | Batch 000/192 | Cost: 0.4438\n", "Epoch: 058/100 | Batch 120/192 | Cost: 0.3490\n", "Epoch: 058/100 Train Acc.: 85.93% | Validation Acc.: 77.50%\n", "Time elapsed: 19.66 min\n", "Epoch: 059/100 | Batch 000/192 | Cost: 0.4541\n", "Epoch: 059/100 | Batch 120/192 | Cost: 0.4415\n", "Epoch: 059/100 Train Acc.: 84.34% | Validation Acc.: 78.40%\n", "Time elapsed: 19.99 min\n", "Epoch: 060/100 | Batch 000/192 | Cost: 0.3766\n", "Epoch: 060/100 | Batch 120/192 | Cost: 0.4851\n", "Epoch: 060/100 Train Acc.: 86.02% | Validation Acc.: 80.00%\n", "Time elapsed: 20.33 min\n", "Epoch: 061/100 | Batch 000/192 | Cost: 0.4967\n", "Epoch: 061/100 | Batch 120/192 | Cost: 0.3708\n", "Epoch: 061/100 Train Acc.: 85.57% | Validation Acc.: 79.50%\n", "Time elapsed: 20.67 min\n", "Epoch: 062/100 | Batch 000/192 | Cost: 0.4197\n", "Epoch: 062/100 | Batch 120/192 | Cost: 0.3054\n", "Epoch: 062/100 Train Acc.: 86.23% | Validation Acc.: 78.40%\n", "Time elapsed: 21.01 min\n", "Epoch: 063/100 | Batch 000/192 | Cost: 0.4595\n", "Epoch: 063/100 | Batch 120/192 | Cost: 0.4200\n", "Epoch: 063/100 Train Acc.: 86.52% | Validation Acc.: 79.80%\n", "Time elapsed: 21.35 min\n", "Epoch: 064/100 | Batch 000/192 | Cost: 0.3806\n", "Epoch: 064/100 | Batch 120/192 | Cost: 0.3670\n", "Epoch: 064/100 Train Acc.: 86.81% | Validation Acc.: 80.20%\n", "Time elapsed: 21.69 min\n", "Epoch: 065/100 | Batch 000/192 | Cost: 0.3922\n", "Epoch: 065/100 | Batch 120/192 | Cost: 0.3698\n", "Epoch: 065/100 Train Acc.: 86.30% | Validation Acc.: 77.90%\n", "Time elapsed: 22.03 min\n", "Epoch: 066/100 | Batch 000/192 | Cost: 0.3608\n", "Epoch: 066/100 | Batch 120/192 | Cost: 0.4444\n", "Epoch: 066/100 Train Acc.: 88.01% | Validation Acc.: 80.10%\n", "Time elapsed: 22.37 min\n", "Epoch: 067/100 | Batch 000/192 | Cost: 0.3374\n", "Epoch: 067/100 | Batch 120/192 | Cost: 0.3158\n", "Epoch: 067/100 Train Acc.: 87.94% | Validation Acc.: 80.40%\n", "Time elapsed: 22.70 min\n", "Epoch: 068/100 | Batch 000/192 | Cost: 0.3959\n", "Epoch: 068/100 | Batch 120/192 | Cost: 0.2217\n", "Epoch: 068/100 Train Acc.: 87.74% | Validation Acc.: 79.70%\n", "Time elapsed: 23.04 min\n", "Epoch: 069/100 | Batch 000/192 | Cost: 0.3795\n", "Epoch: 069/100 | Batch 120/192 | Cost: 0.3398\n", "Epoch: 069/100 Train Acc.: 88.28% | Validation Acc.: 79.70%\n", "Time elapsed: 23.39 min\n", "Epoch: 070/100 | Batch 000/192 | Cost: 0.3098\n", "Epoch: 070/100 | Batch 120/192 | Cost: 0.3012\n", "Epoch: 070/100 Train Acc.: 87.96% | Validation Acc.: 80.80%\n", "Time elapsed: 23.73 min\n", "Epoch: 071/100 | Batch 000/192 | Cost: 0.3705\n", "Epoch: 071/100 | Batch 120/192 | Cost: 0.2943\n", "Epoch: 071/100 Train Acc.: 88.02% | Validation Acc.: 79.90%\n", "Time elapsed: 24.06 min\n", "Epoch: 072/100 | Batch 000/192 | Cost: 0.3353\n", "Epoch: 072/100 | Batch 120/192 | Cost: 0.3237\n", "Epoch: 072/100 Train Acc.: 88.34% | Validation Acc.: 80.60%\n", "Time elapsed: 24.40 min\n", "Epoch: 073/100 | Batch 000/192 | Cost: 0.3683\n", "Epoch: 073/100 | Batch 120/192 | Cost: 0.4178\n", "Epoch: 073/100 Train Acc.: 88.93% | Validation Acc.: 80.10%\n", "Time elapsed: 24.74 min\n", "Epoch: 074/100 | Batch 000/192 | Cost: 0.2282\n", "Epoch: 074/100 | Batch 120/192 | Cost: 0.1967\n", "Epoch: 074/100 Train Acc.: 88.58% | Validation Acc.: 81.40%\n", "Time elapsed: 25.08 min\n", "Epoch: 075/100 | Batch 000/192 | Cost: 0.2701\n", "Epoch: 075/100 | Batch 120/192 | Cost: 0.3722\n", "Epoch: 075/100 Train Acc.: 87.93% | Validation Acc.: 79.70%\n", "Time elapsed: 25.42 min\n", "Epoch: 076/100 | Batch 000/192 | Cost: 0.2850\n", "Epoch: 076/100 | Batch 120/192 | Cost: 0.2874\n", "Epoch: 076/100 Train Acc.: 88.92% | Validation Acc.: 81.10%\n", "Time elapsed: 25.75 min\n", "Epoch: 077/100 | Batch 000/192 | Cost: 0.2686\n", "Epoch: 077/100 | Batch 120/192 | Cost: 0.4312\n", "Epoch: 077/100 Train Acc.: 89.39% | Validation Acc.: 81.60%\n", "Time elapsed: 26.10 min\n", "Epoch: 078/100 | Batch 000/192 | Cost: 0.2282\n", "Epoch: 078/100 | Batch 120/192 | Cost: 0.3395\n", "Epoch: 078/100 Train Acc.: 88.67% | Validation Acc.: 78.90%\n", "Time elapsed: 26.43 min\n", "Epoch: 079/100 | Batch 000/192 | Cost: 0.3127\n", "Epoch: 079/100 | Batch 120/192 | Cost: 0.2906\n", "Epoch: 079/100 Train Acc.: 90.77% | Validation Acc.: 81.20%\n", "Time elapsed: 26.77 min\n", "Epoch: 080/100 | Batch 000/192 | Cost: 0.2468\n", "Epoch: 080/100 | Batch 120/192 | Cost: 0.3638\n", "Epoch: 080/100 Train Acc.: 89.99% | Validation Acc.: 80.40%\n", "Time elapsed: 27.11 min\n", "Epoch: 081/100 | Batch 000/192 | Cost: 0.2936\n", "Epoch: 081/100 | Batch 120/192 | Cost: 0.3772\n", "Epoch: 081/100 Train Acc.: 90.76% | Validation Acc.: 80.50%\n", "Time elapsed: 27.45 min\n", "Epoch: 082/100 | Batch 000/192 | Cost: 0.2584\n", "Epoch: 082/100 | Batch 120/192 | Cost: 0.2718\n", "Epoch: 082/100 Train Acc.: 91.01% | Validation Acc.: 81.20%\n", "Time elapsed: 27.79 min\n", "Epoch: 083/100 | Batch 000/192 | Cost: 0.1904\n", "Epoch: 083/100 | Batch 120/192 | Cost: 0.3090\n", "Epoch: 083/100 Train Acc.: 90.68% | Validation Acc.: 81.30%\n", "Time elapsed: 28.14 min\n", "Epoch: 084/100 | Batch 000/192 | Cost: 0.2506\n", "Epoch: 084/100 | Batch 120/192 | Cost: 0.2825\n", "Epoch: 084/100 Train Acc.: 90.43% | Validation Acc.: 80.40%\n", "Time elapsed: 28.47 min\n", "Epoch: 085/100 | Batch 000/192 | Cost: 0.2307\n", "Epoch: 085/100 | Batch 120/192 | Cost: 0.2441\n", "Epoch: 085/100 Train Acc.: 90.88% | Validation Acc.: 81.30%\n", "Time elapsed: 28.82 min\n", "Epoch: 086/100 | Batch 000/192 | Cost: 0.3149\n", "Epoch: 086/100 | Batch 120/192 | Cost: 0.3129\n", "Epoch: 086/100 Train Acc.: 90.13% | Validation Acc.: 82.40%\n", "Time elapsed: 29.16 min\n", "Epoch: 087/100 | Batch 000/192 | Cost: 0.3487\n", "Epoch: 087/100 | Batch 120/192 | Cost: 0.2559\n", "Epoch: 087/100 Train Acc.: 90.74% | Validation Acc.: 81.40%\n", "Time elapsed: 29.50 min\n", "Epoch: 088/100 | Batch 000/192 | Cost: 0.2412\n", "Epoch: 088/100 | Batch 120/192 | Cost: 0.1828\n", "Epoch: 088/100 Train Acc.: 91.08% | Validation Acc.: 80.20%\n", "Time elapsed: 29.84 min\n", "Epoch: 089/100 | Batch 000/192 | Cost: 0.2957\n", "Epoch: 089/100 | Batch 120/192 | Cost: 0.2939\n", "Epoch: 089/100 Train Acc.: 90.67% | Validation Acc.: 80.30%\n", "Time elapsed: 30.19 min\n", "Epoch: 090/100 | Batch 000/192 | Cost: 0.2298\n", "Epoch: 090/100 | Batch 120/192 | Cost: 0.2900\n", "Epoch: 090/100 Train Acc.: 91.63% | Validation Acc.: 79.00%\n", "Time elapsed: 30.53 min\n", "Epoch: 091/100 | Batch 000/192 | Cost: 0.2558\n", "Epoch: 091/100 | Batch 120/192 | Cost: 0.2915\n", "Epoch: 091/100 Train Acc.: 91.36% | Validation Acc.: 81.00%\n", "Time elapsed: 30.88 min\n", "Epoch: 092/100 | Batch 000/192 | Cost: 0.1510\n", "Epoch: 092/100 | Batch 120/192 | Cost: 0.1974\n", "Epoch: 092/100 Train Acc.: 91.84% | Validation Acc.: 82.20%\n", "Time elapsed: 31.22 min\n", "Epoch: 093/100 | Batch 000/192 | Cost: 0.2308\n", "Epoch: 093/100 | Batch 120/192 | Cost: 0.2247\n", "Epoch: 093/100 Train Acc.: 91.50% | Validation Acc.: 80.50%\n", "Time elapsed: 31.56 min\n", "Epoch: 094/100 | Batch 000/192 | Cost: 0.2712\n", "Epoch: 094/100 | Batch 120/192 | Cost: 0.3268\n", "Epoch: 094/100 Train Acc.: 91.74% | Validation Acc.: 81.30%\n", "Time elapsed: 31.91 min\n", "Epoch: 095/100 | Batch 000/192 | Cost: 0.2417\n", "Epoch: 095/100 | Batch 120/192 | Cost: 0.2162\n", "Epoch: 095/100 Train Acc.: 91.53% | Validation Acc.: 79.00%\n", "Time elapsed: 32.26 min\n", "Epoch: 096/100 | Batch 000/192 | Cost: 0.2523\n", "Epoch: 096/100 | Batch 120/192 | Cost: 0.2598\n", "Epoch: 096/100 Train Acc.: 91.56% | Validation Acc.: 81.00%\n", "Time elapsed: 32.60 min\n", "Epoch: 097/100 | Batch 000/192 | Cost: 0.2027\n", "Epoch: 097/100 | Batch 120/192 | Cost: 0.2432\n", "Epoch: 097/100 Train Acc.: 92.53% | Validation Acc.: 80.80%\n", "Time elapsed: 32.94 min\n", "Epoch: 098/100 | Batch 000/192 | Cost: 0.2115\n", "Epoch: 098/100 | Batch 120/192 | Cost: 0.2746\n", "Epoch: 098/100 Train Acc.: 92.30% | Validation Acc.: 81.10%\n", "Time elapsed: 33.28 min\n", "Epoch: 099/100 | Batch 000/192 | Cost: 0.1611\n", "Epoch: 099/100 | Batch 120/192 | Cost: 0.2142\n", "Epoch: 099/100 Train Acc.: 92.66% | Validation Acc.: 80.90%\n", "Time elapsed: 33.62 min\n", "Epoch: 100/100 | Batch 000/192 | Cost: 0.1935\n", "Epoch: 100/100 | Batch 120/192 | Cost: 0.2488\n", "Epoch: 100/100 Train Acc.: 92.68% | Validation Acc.: 80.20%\n", "Time elapsed: 33.97 min\n", "Total Training Time: 33.97 min\n" ] } ], "source": [ "def compute_accuracy(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " \n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " ### PREPARE MINIBATCH\n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 120:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " # no need to build the computation graph for backprop when computing accuracy\n", " with torch.set_grad_enabled(False):\n", " train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n", " valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n", " f' | Validation Acc.: {valid_acc:.2f}%')\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# BatchNorm before Activation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class NiN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(NiN, self).__init__()\n", " self.num_classes = num_classes\n", " self.classifier = nn.Sequential(\n", " nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(160),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(96),\n", " nn.ReLU(inplace=True),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.BatchNorm2d(192),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n", "\n", " )\n", "\n", " def forward(self, x):\n", " x = self.classifier(x)\n", " logits = x.view(x.size(0), self.num_classes)\n", " probas = torch.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = NiN(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/192 | Cost: 2.3003\n", "Epoch: 001/100 | Batch 120/192 | Cost: 1.1791\n", "Epoch: 001/100 Train Acc.: 61.28% | Validation Acc.: 61.40%\n", "Time elapsed: 0.37 min\n", "Epoch: 002/100 | Batch 000/192 | Cost: 1.2742\n", "Epoch: 002/100 | Batch 120/192 | Cost: 0.9198\n", "Epoch: 002/100 Train Acc.: 69.36% | Validation Acc.: 66.70%\n", "Time elapsed: 0.74 min\n", "Epoch: 003/100 | Batch 000/192 | Cost: 0.7803\n", "Epoch: 003/100 | Batch 120/192 | Cost: 0.8857\n", "Epoch: 003/100 Train Acc.: 74.03% | Validation Acc.: 71.70%\n", "Time elapsed: 1.11 min\n", "Epoch: 004/100 | Batch 000/192 | Cost: 0.7233\n", "Epoch: 004/100 | Batch 120/192 | Cost: 0.7254\n", "Epoch: 004/100 Train Acc.: 76.76% | Validation Acc.: 75.80%\n", "Time elapsed: 1.48 min\n", "Epoch: 005/100 | Batch 000/192 | Cost: 0.6941\n", "Epoch: 005/100 | Batch 120/192 | Cost: 0.7137\n", "Epoch: 005/100 Train Acc.: 79.56% | Validation Acc.: 77.70%\n", "Time elapsed: 1.84 min\n", "Epoch: 006/100 | Batch 000/192 | Cost: 0.7098\n", "Epoch: 006/100 | Batch 120/192 | Cost: 0.5519\n", "Epoch: 006/100 Train Acc.: 80.33% | Validation Acc.: 78.80%\n", "Time elapsed: 2.22 min\n", "Epoch: 007/100 | Batch 000/192 | Cost: 0.6615\n", "Epoch: 007/100 | Batch 120/192 | Cost: 0.5217\n", "Epoch: 007/100 Train Acc.: 81.49% | Validation Acc.: 79.40%\n", "Time elapsed: 2.58 min\n", "Epoch: 008/100 | Batch 000/192 | Cost: 0.5005\n", "Epoch: 008/100 | Batch 120/192 | Cost: 0.5437\n", "Epoch: 008/100 Train Acc.: 83.25% | Validation Acc.: 80.10%\n", "Time elapsed: 2.94 min\n", "Epoch: 009/100 | Batch 000/192 | Cost: 0.4481\n", "Epoch: 009/100 | Batch 120/192 | Cost: 0.5191\n", "Epoch: 009/100 Train Acc.: 83.73% | Validation Acc.: 80.50%\n", "Time elapsed: 3.32 min\n", "Epoch: 010/100 | Batch 000/192 | Cost: 0.5392\n", "Epoch: 010/100 | Batch 120/192 | Cost: 0.4766\n", "Epoch: 010/100 Train Acc.: 84.86% | Validation Acc.: 80.20%\n", "Time elapsed: 3.68 min\n", "Epoch: 011/100 | Batch 000/192 | Cost: 0.4486\n", "Epoch: 011/100 | Batch 120/192 | Cost: 0.5472\n", "Epoch: 011/100 Train Acc.: 86.29% | Validation Acc.: 82.30%\n", "Time elapsed: 4.05 min\n", "Epoch: 012/100 | Batch 000/192 | Cost: 0.4129\n", "Epoch: 012/100 | Batch 120/192 | Cost: 0.3839\n", "Epoch: 012/100 Train Acc.: 87.13% | Validation Acc.: 82.60%\n", "Time elapsed: 4.42 min\n", "Epoch: 013/100 | Batch 000/192 | Cost: 0.3117\n", "Epoch: 013/100 | Batch 120/192 | Cost: 0.3525\n", "Epoch: 013/100 Train Acc.: 87.16% | Validation Acc.: 83.50%\n", "Time elapsed: 4.78 min\n", "Epoch: 014/100 | Batch 000/192 | Cost: 0.3939\n", "Epoch: 014/100 | Batch 120/192 | Cost: 0.3900\n", "Epoch: 014/100 Train Acc.: 87.78% | Validation Acc.: 83.30%\n", "Time elapsed: 5.15 min\n", "Epoch: 015/100 | Batch 000/192 | Cost: 0.4223\n", "Epoch: 015/100 | Batch 120/192 | Cost: 0.3745\n", "Epoch: 015/100 Train Acc.: 88.49% | Validation Acc.: 82.40%\n", "Time elapsed: 5.52 min\n", "Epoch: 016/100 | Batch 000/192 | Cost: 0.3464\n", "Epoch: 016/100 | Batch 120/192 | Cost: 0.3434\n", "Epoch: 016/100 Train Acc.: 88.83% | Validation Acc.: 83.10%\n", "Time elapsed: 5.88 min\n", "Epoch: 017/100 | Batch 000/192 | Cost: 0.2876\n", "Epoch: 017/100 | Batch 120/192 | Cost: 0.2826\n", "Epoch: 017/100 Train Acc.: 89.34% | Validation Acc.: 82.40%\n", "Time elapsed: 6.25 min\n", "Epoch: 018/100 | Batch 000/192 | Cost: 0.3779\n", "Epoch: 018/100 | Batch 120/192 | Cost: 0.2662\n", "Epoch: 018/100 Train Acc.: 90.05% | Validation Acc.: 84.50%\n", "Time elapsed: 6.62 min\n", "Epoch: 019/100 | Batch 000/192 | Cost: 0.3824\n", "Epoch: 019/100 | Batch 120/192 | Cost: 0.2750\n", "Epoch: 019/100 Train Acc.: 90.35% | Validation Acc.: 82.60%\n", "Time elapsed: 6.99 min\n", "Epoch: 020/100 | Batch 000/192 | Cost: 0.2361\n", "Epoch: 020/100 | Batch 120/192 | Cost: 0.2459\n", "Epoch: 020/100 Train Acc.: 91.09% | Validation Acc.: 83.90%\n", "Time elapsed: 7.35 min\n", "Epoch: 021/100 | Batch 000/192 | Cost: 0.2592\n", "Epoch: 021/100 | Batch 120/192 | Cost: 0.2218\n", "Epoch: 021/100 Train Acc.: 91.12% | Validation Acc.: 82.40%\n", "Time elapsed: 7.72 min\n", "Epoch: 022/100 | Batch 000/192 | Cost: 0.2464\n", "Epoch: 022/100 | Batch 120/192 | Cost: 0.2699\n", "Epoch: 022/100 Train Acc.: 91.39% | Validation Acc.: 84.20%\n", "Time elapsed: 8.14 min\n", "Epoch: 023/100 | Batch 000/192 | Cost: 0.1852\n", "Epoch: 023/100 | Batch 120/192 | Cost: 0.2371\n", "Epoch: 023/100 Train Acc.: 91.83% | Validation Acc.: 85.00%\n", "Time elapsed: 8.92 min\n", "Epoch: 024/100 | Batch 000/192 | Cost: 0.2384\n", "Epoch: 024/100 | Batch 120/192 | Cost: 0.2285\n", "Epoch: 024/100 Train Acc.: 91.90% | Validation Acc.: 83.90%\n", "Time elapsed: 9.48 min\n", "Epoch: 025/100 | Batch 000/192 | Cost: 0.1705\n", "Epoch: 025/100 | Batch 120/192 | Cost: 0.2497\n", "Epoch: 025/100 Train Acc.: 92.28% | Validation Acc.: 85.40%\n", "Time elapsed: 9.99 min\n", "Epoch: 026/100 | Batch 000/192 | Cost: 0.2336\n", "Epoch: 026/100 | Batch 120/192 | Cost: 0.2631\n", "Epoch: 026/100 Train Acc.: 93.21% | Validation Acc.: 85.80%\n", "Time elapsed: 10.36 min\n", "Epoch: 027/100 | Batch 000/192 | Cost: 0.1927\n", "Epoch: 027/100 | Batch 120/192 | Cost: 0.1936\n", "Epoch: 027/100 Train Acc.: 93.37% | Validation Acc.: 84.90%\n", "Time elapsed: 10.85 min\n", "Epoch: 028/100 | Batch 000/192 | Cost: 0.1647\n", "Epoch: 028/100 | Batch 120/192 | Cost: 0.1183\n", "Epoch: 028/100 Train Acc.: 93.40% | Validation Acc.: 84.50%\n", "Time elapsed: 11.21 min\n", "Epoch: 029/100 | Batch 000/192 | Cost: 0.1562\n", "Epoch: 029/100 | Batch 120/192 | Cost: 0.1956\n", "Epoch: 029/100 Train Acc.: 93.58% | Validation Acc.: 84.40%\n", "Time elapsed: 11.58 min\n", "Epoch: 030/100 | Batch 000/192 | Cost: 0.1309\n", "Epoch: 030/100 | Batch 120/192 | Cost: 0.2334\n", "Epoch: 030/100 Train Acc.: 93.98% | Validation Acc.: 86.40%\n", "Time elapsed: 11.95 min\n", "Epoch: 031/100 | Batch 000/192 | Cost: 0.1280\n", "Epoch: 031/100 | Batch 120/192 | Cost: 0.1637\n", "Epoch: 031/100 Train Acc.: 94.12% | Validation Acc.: 85.80%\n", "Time elapsed: 12.31 min\n", "Epoch: 032/100 | Batch 000/192 | Cost: 0.1509\n", "Epoch: 032/100 | Batch 120/192 | Cost: 0.2148\n", "Epoch: 032/100 Train Acc.: 93.88% | Validation Acc.: 84.20%\n", "Time elapsed: 12.68 min\n", "Epoch: 033/100 | Batch 000/192 | Cost: 0.1845\n", "Epoch: 033/100 | Batch 120/192 | Cost: 0.1109\n", "Epoch: 033/100 Train Acc.: 94.30% | Validation Acc.: 83.90%\n", "Time elapsed: 13.05 min\n", "Epoch: 034/100 | Batch 000/192 | Cost: 0.1668\n", "Epoch: 034/100 | Batch 120/192 | Cost: 0.1756\n", "Epoch: 034/100 Train Acc.: 94.75% | Validation Acc.: 83.80%\n", "Time elapsed: 13.42 min\n", "Epoch: 035/100 | Batch 000/192 | Cost: 0.1348\n", "Epoch: 035/100 | Batch 120/192 | Cost: 0.1297\n", "Epoch: 035/100 Train Acc.: 94.46% | Validation Acc.: 85.10%\n", "Time elapsed: 13.79 min\n", "Epoch: 036/100 | Batch 000/192 | Cost: 0.1827\n", "Epoch: 036/100 | Batch 120/192 | Cost: 0.2066\n", "Epoch: 036/100 Train Acc.: 95.07% | Validation Acc.: 83.90%\n", "Time elapsed: 14.15 min\n", "Epoch: 037/100 | Batch 000/192 | Cost: 0.1531\n", "Epoch: 037/100 | Batch 120/192 | Cost: 0.1473\n", "Epoch: 037/100 Train Acc.: 95.05% | Validation Acc.: 85.80%\n", "Time elapsed: 14.52 min\n", "Epoch: 038/100 | Batch 000/192 | Cost: 0.0932\n", "Epoch: 038/100 | Batch 120/192 | Cost: 0.1691\n", "Epoch: 038/100 Train Acc.: 95.08% | Validation Acc.: 84.40%\n", "Time elapsed: 14.89 min\n", "Epoch: 039/100 | Batch 000/192 | Cost: 0.1114\n", "Epoch: 039/100 | Batch 120/192 | Cost: 0.2167\n", "Epoch: 039/100 Train Acc.: 95.74% | Validation Acc.: 86.20%\n", "Time elapsed: 15.25 min\n", "Epoch: 040/100 | Batch 000/192 | Cost: 0.1177\n", "Epoch: 040/100 | Batch 120/192 | Cost: 0.1948\n", "Epoch: 040/100 Train Acc.: 95.52% | Validation Acc.: 85.60%\n", "Time elapsed: 15.62 min\n", "Epoch: 041/100 | Batch 000/192 | Cost: 0.1345\n", "Epoch: 041/100 | Batch 120/192 | Cost: 0.1891\n", "Epoch: 041/100 Train Acc.: 96.02% | Validation Acc.: 85.60%\n", "Time elapsed: 15.99 min\n", "Epoch: 042/100 | Batch 000/192 | Cost: 0.1558\n", "Epoch: 042/100 | Batch 120/192 | Cost: 0.1588\n", "Epoch: 042/100 Train Acc.: 95.50% | Validation Acc.: 85.60%\n", "Time elapsed: 16.35 min\n", "Epoch: 043/100 | Batch 000/192 | Cost: 0.0832\n", "Epoch: 043/100 | Batch 120/192 | Cost: 0.1639\n", "Epoch: 043/100 Train Acc.: 96.13% | Validation Acc.: 83.70%\n", "Time elapsed: 16.72 min\n", "Epoch: 044/100 | Batch 000/192 | Cost: 0.0768\n", "Epoch: 044/100 | Batch 120/192 | Cost: 0.1157\n", "Epoch: 044/100 Train Acc.: 95.69% | Validation Acc.: 85.80%\n", "Time elapsed: 17.09 min\n", "Epoch: 045/100 | Batch 000/192 | Cost: 0.1428\n", "Epoch: 045/100 | Batch 120/192 | Cost: 0.1093\n", "Epoch: 045/100 Train Acc.: 95.85% | Validation Acc.: 84.70%\n", "Time elapsed: 17.45 min\n", "Epoch: 046/100 | Batch 000/192 | Cost: 0.1009\n", "Epoch: 046/100 | Batch 120/192 | Cost: 0.1148\n", "Epoch: 046/100 Train Acc.: 96.11% | Validation Acc.: 82.90%\n", "Time elapsed: 17.82 min\n", "Epoch: 047/100 | Batch 000/192 | Cost: 0.1023\n", "Epoch: 047/100 | Batch 120/192 | Cost: 0.1426\n", "Epoch: 047/100 Train Acc.: 96.00% | Validation Acc.: 84.50%\n", "Time elapsed: 18.19 min\n", "Epoch: 048/100 | Batch 000/192 | Cost: 0.1000\n", "Epoch: 048/100 | Batch 120/192 | Cost: 0.1366\n", "Epoch: 048/100 Train Acc.: 96.49% | Validation Acc.: 85.20%\n", "Time elapsed: 18.56 min\n", "Epoch: 049/100 | Batch 000/192 | Cost: 0.0983\n", "Epoch: 049/100 | Batch 120/192 | Cost: 0.1003\n", "Epoch: 049/100 Train Acc.: 96.57% | Validation Acc.: 85.10%\n", "Time elapsed: 18.93 min\n", "Epoch: 050/100 | Batch 000/192 | Cost: 0.0748\n", "Epoch: 050/100 | Batch 120/192 | Cost: 0.1001\n", "Epoch: 050/100 Train Acc.: 96.27% | Validation Acc.: 85.30%\n", "Time elapsed: 19.29 min\n", "Epoch: 051/100 | Batch 000/192 | Cost: 0.1418\n", "Epoch: 051/100 | Batch 120/192 | Cost: 0.0902\n", "Epoch: 051/100 Train Acc.: 96.55% | Validation Acc.: 85.70%\n", "Time elapsed: 19.66 min\n", "Epoch: 052/100 | Batch 000/192 | Cost: 0.0924\n", "Epoch: 052/100 | Batch 120/192 | Cost: 0.1003\n", "Epoch: 052/100 Train Acc.: 96.74% | Validation Acc.: 86.00%\n", "Time elapsed: 20.03 min\n", "Epoch: 053/100 | Batch 000/192 | Cost: 0.1101\n", "Epoch: 053/100 | Batch 120/192 | Cost: 0.1555\n", "Epoch: 053/100 Train Acc.: 96.44% | Validation Acc.: 84.90%\n", "Time elapsed: 20.39 min\n", "Epoch: 054/100 | Batch 000/192 | Cost: 0.0853\n", "Epoch: 054/100 | Batch 120/192 | Cost: 0.0984\n", "Epoch: 054/100 Train Acc.: 96.78% | Validation Acc.: 85.10%\n", "Time elapsed: 20.76 min\n", "Epoch: 055/100 | Batch 000/192 | Cost: 0.0503\n", "Epoch: 055/100 | Batch 120/192 | Cost: 0.0870\n", "Epoch: 055/100 Train Acc.: 96.91% | Validation Acc.: 84.80%\n", "Time elapsed: 21.13 min\n", "Epoch: 056/100 | Batch 000/192 | Cost: 0.0659\n", "Epoch: 056/100 | Batch 120/192 | Cost: 0.0849\n", "Epoch: 056/100 Train Acc.: 96.95% | Validation Acc.: 86.60%\n", "Time elapsed: 21.50 min\n", "Epoch: 057/100 | Batch 000/192 | Cost: 0.1177\n", "Epoch: 057/100 | Batch 120/192 | Cost: 0.1281\n", "Epoch: 057/100 Train Acc.: 97.02% | Validation Acc.: 86.70%\n", "Time elapsed: 21.87 min\n", "Epoch: 058/100 | Batch 000/192 | Cost: 0.0996\n", "Epoch: 058/100 | Batch 120/192 | Cost: 0.1410\n", "Epoch: 058/100 Train Acc.: 96.55% | Validation Acc.: 85.00%\n", "Time elapsed: 22.24 min\n", "Epoch: 059/100 | Batch 000/192 | Cost: 0.0621\n", "Epoch: 059/100 | Batch 120/192 | Cost: 0.0648\n", "Epoch: 059/100 Train Acc.: 97.04% | Validation Acc.: 85.50%\n", "Time elapsed: 22.61 min\n", "Epoch: 060/100 | Batch 000/192 | Cost: 0.0626\n", "Epoch: 060/100 | Batch 120/192 | Cost: 0.0791\n", "Epoch: 060/100 Train Acc.: 96.42% | Validation Acc.: 84.30%\n", "Time elapsed: 22.98 min\n", "Epoch: 061/100 | Batch 000/192 | Cost: 0.1322\n", "Epoch: 061/100 | Batch 120/192 | Cost: 0.0991\n", "Epoch: 061/100 Train Acc.: 97.13% | Validation Acc.: 85.80%\n", "Time elapsed: 23.35 min\n", "Epoch: 062/100 | Batch 000/192 | Cost: 0.0598\n", "Epoch: 062/100 | Batch 120/192 | Cost: 0.1386\n", "Epoch: 062/100 Train Acc.: 97.04% | Validation Acc.: 84.30%\n", "Time elapsed: 23.71 min\n", "Epoch: 063/100 | Batch 000/192 | Cost: 0.0402\n", "Epoch: 063/100 | Batch 120/192 | Cost: 0.1163\n", "Epoch: 063/100 Train Acc.: 97.16% | Validation Acc.: 84.80%\n", "Time elapsed: 24.19 min\n", "Epoch: 064/100 | Batch 000/192 | Cost: 0.0672\n", "Epoch: 064/100 | Batch 120/192 | Cost: 0.0687\n", "Epoch: 064/100 Train Acc.: 97.28% | Validation Acc.: 85.20%\n", "Time elapsed: 24.70 min\n", "Epoch: 065/100 | Batch 000/192 | Cost: 0.0783\n", "Epoch: 065/100 | Batch 120/192 | Cost: 0.1035\n", "Epoch: 065/100 Train Acc.: 97.17% | Validation Acc.: 85.70%\n", "Time elapsed: 25.46 min\n", "Epoch: 066/100 | Batch 000/192 | Cost: 0.0331\n", "Epoch: 066/100 | Batch 120/192 | Cost: 0.0829\n", "Epoch: 066/100 Train Acc.: 97.63% | Validation Acc.: 86.80%\n", "Time elapsed: 26.24 min\n", "Epoch: 067/100 | Batch 000/192 | Cost: 0.0836\n", "Epoch: 067/100 | Batch 120/192 | Cost: 0.0810\n", "Epoch: 067/100 Train Acc.: 97.38% | Validation Acc.: 84.20%\n", "Time elapsed: 27.03 min\n", "Epoch: 068/100 | Batch 000/192 | Cost: 0.0746\n", "Epoch: 068/100 | Batch 120/192 | Cost: 0.1084\n", "Epoch: 068/100 Train Acc.: 97.64% | Validation Acc.: 85.60%\n", "Time elapsed: 27.79 min\n", "Epoch: 069/100 | Batch 000/192 | Cost: 0.0548\n", "Epoch: 069/100 | Batch 120/192 | Cost: 0.0487\n", "Epoch: 069/100 Train Acc.: 97.65% | Validation Acc.: 86.00%\n", "Time elapsed: 28.57 min\n", "Epoch: 070/100 | Batch 000/192 | Cost: 0.0811\n", "Epoch: 070/100 | Batch 120/192 | Cost: 0.0865\n", "Epoch: 070/100 Train Acc.: 97.45% | Validation Acc.: 86.60%\n", "Time elapsed: 29.34 min\n", "Epoch: 071/100 | Batch 000/192 | Cost: 0.0757\n", "Epoch: 071/100 | Batch 120/192 | Cost: 0.1505\n", "Epoch: 071/100 Train Acc.: 97.52% | Validation Acc.: 84.50%\n", "Time elapsed: 30.12 min\n", "Epoch: 072/100 | Batch 000/192 | Cost: 0.1299\n", "Epoch: 072/100 | Batch 120/192 | Cost: 0.0503\n", "Epoch: 072/100 Train Acc.: 97.53% | Validation Acc.: 86.70%\n", "Time elapsed: 30.91 min\n", "Epoch: 073/100 | Batch 000/192 | Cost: 0.0463\n", "Epoch: 073/100 | Batch 120/192 | Cost: 0.0583\n", "Epoch: 073/100 Train Acc.: 97.75% | Validation Acc.: 85.00%\n", "Time elapsed: 31.67 min\n", "Epoch: 074/100 | Batch 000/192 | Cost: 0.0454\n", "Epoch: 074/100 | Batch 120/192 | Cost: 0.0507\n", "Epoch: 074/100 Train Acc.: 97.64% | Validation Acc.: 86.50%\n", "Time elapsed: 32.45 min\n", "Epoch: 075/100 | Batch 000/192 | Cost: 0.0686\n", "Epoch: 075/100 | Batch 120/192 | Cost: 0.0734\n", "Epoch: 075/100 Train Acc.: 97.79% | Validation Acc.: 86.60%\n", "Time elapsed: 33.22 min\n", "Epoch: 076/100 | Batch 000/192 | Cost: 0.1011\n", "Epoch: 076/100 | Batch 120/192 | Cost: 0.0856\n", "Epoch: 076/100 Train Acc.: 97.77% | Validation Acc.: 85.90%\n", "Time elapsed: 34.00 min\n", "Epoch: 077/100 | Batch 000/192 | Cost: 0.0494\n", "Epoch: 077/100 | Batch 120/192 | Cost: 0.0623\n", "Epoch: 077/100 Train Acc.: 97.74% | Validation Acc.: 86.90%\n", "Time elapsed: 34.78 min\n", "Epoch: 078/100 | Batch 000/192 | Cost: 0.0519\n", "Epoch: 078/100 | Batch 120/192 | Cost: 0.0740\n", "Epoch: 078/100 Train Acc.: 97.52% | Validation Acc.: 86.30%\n", "Time elapsed: 35.55 min\n", "Epoch: 079/100 | Batch 000/192 | Cost: 0.0502\n", "Epoch: 079/100 | Batch 120/192 | Cost: 0.0762\n", "Epoch: 079/100 Train Acc.: 97.44% | Validation Acc.: 86.00%\n", "Time elapsed: 36.33 min\n", "Epoch: 080/100 | Batch 000/192 | Cost: 0.0973\n", "Epoch: 080/100 | Batch 120/192 | Cost: 0.0414\n", "Epoch: 080/100 Train Acc.: 98.03% | Validation Acc.: 86.70%\n", "Time elapsed: 37.10 min\n", "Epoch: 081/100 | Batch 000/192 | Cost: 0.0882\n", "Epoch: 081/100 | Batch 120/192 | Cost: 0.1327\n", "Epoch: 081/100 Train Acc.: 97.92% | Validation Acc.: 86.20%\n", "Time elapsed: 37.88 min\n", "Epoch: 082/100 | Batch 000/192 | Cost: 0.0425\n", "Epoch: 082/100 | Batch 120/192 | Cost: 0.0632\n", "Epoch: 082/100 Train Acc.: 97.72% | Validation Acc.: 85.00%\n", "Time elapsed: 38.66 min\n", "Epoch: 083/100 | Batch 000/192 | Cost: 0.0676\n", "Epoch: 083/100 | Batch 120/192 | Cost: 0.0444\n", "Epoch: 083/100 Train Acc.: 98.06% | Validation Acc.: 87.10%\n", "Time elapsed: 39.43 min\n", "Epoch: 084/100 | Batch 000/192 | Cost: 0.0565\n", "Epoch: 084/100 | Batch 120/192 | Cost: 0.0478\n", "Epoch: 084/100 Train Acc.: 97.96% | Validation Acc.: 86.80%\n", "Time elapsed: 40.22 min\n", "Epoch: 085/100 | Batch 000/192 | Cost: 0.1038\n", "Epoch: 085/100 | Batch 120/192 | Cost: 0.0502\n", "Epoch: 085/100 Train Acc.: 98.02% | Validation Acc.: 87.20%\n", "Time elapsed: 41.00 min\n", "Epoch: 086/100 | Batch 000/192 | Cost: 0.1114\n", "Epoch: 086/100 | Batch 120/192 | Cost: 0.0419\n", "Epoch: 086/100 Train Acc.: 97.93% | Validation Acc.: 86.10%\n", "Time elapsed: 41.77 min\n", "Epoch: 087/100 | Batch 000/192 | Cost: 0.0485\n", "Epoch: 087/100 | Batch 120/192 | Cost: 0.0526\n", "Epoch: 087/100 Train Acc.: 97.99% | Validation Acc.: 87.00%\n", "Time elapsed: 42.56 min\n", "Epoch: 088/100 | Batch 000/192 | Cost: 0.0429\n", "Epoch: 088/100 | Batch 120/192 | Cost: 0.0542\n", "Epoch: 088/100 Train Acc.: 97.95% | Validation Acc.: 87.10%\n", "Time elapsed: 43.34 min\n", "Epoch: 089/100 | Batch 000/192 | Cost: 0.0533\n", "Epoch: 089/100 | Batch 120/192 | Cost: 0.0241\n", "Epoch: 089/100 Train Acc.: 98.05% | Validation Acc.: 86.60%\n", "Time elapsed: 44.13 min\n", "Epoch: 090/100 | Batch 000/192 | Cost: 0.0738\n", "Epoch: 090/100 | Batch 120/192 | Cost: 0.0324\n", "Epoch: 090/100 Train Acc.: 97.87% | Validation Acc.: 86.10%\n", "Time elapsed: 44.91 min\n", "Epoch: 091/100 | Batch 000/192 | Cost: 0.0778\n", "Epoch: 091/100 | Batch 120/192 | Cost: 0.0754\n", "Epoch: 091/100 Train Acc.: 98.22% | Validation Acc.: 86.40%\n", "Time elapsed: 45.68 min\n", "Epoch: 092/100 | Batch 000/192 | Cost: 0.0695\n", "Epoch: 092/100 | Batch 120/192 | Cost: 0.0946\n", "Epoch: 092/100 Train Acc.: 97.94% | Validation Acc.: 86.40%\n", "Time elapsed: 46.47 min\n", "Epoch: 093/100 | Batch 000/192 | Cost: 0.0322\n", "Epoch: 093/100 | Batch 120/192 | Cost: 0.0522\n", "Epoch: 093/100 Train Acc.: 98.28% | Validation Acc.: 86.40%\n", "Time elapsed: 47.26 min\n", "Epoch: 094/100 | Batch 000/192 | Cost: 0.0442\n", "Epoch: 094/100 | Batch 120/192 | Cost: 0.0545\n", "Epoch: 094/100 Train Acc.: 98.22% | Validation Acc.: 86.70%\n", "Time elapsed: 48.04 min\n", "Epoch: 095/100 | Batch 000/192 | Cost: 0.0355\n", "Epoch: 095/100 | Batch 120/192 | Cost: 0.0459\n", "Epoch: 095/100 Train Acc.: 98.13% | Validation Acc.: 87.40%\n", "Time elapsed: 48.84 min\n", "Epoch: 096/100 | Batch 000/192 | Cost: 0.0448\n", "Epoch: 096/100 | Batch 120/192 | Cost: 0.0468\n", "Epoch: 096/100 Train Acc.: 98.19% | Validation Acc.: 85.90%\n", "Time elapsed: 49.60 min\n", "Epoch: 097/100 | Batch 000/192 | Cost: 0.0175\n", "Epoch: 097/100 | Batch 120/192 | Cost: 0.0409\n", "Epoch: 097/100 Train Acc.: 98.17% | Validation Acc.: 87.10%\n", "Time elapsed: 50.39 min\n", "Epoch: 098/100 | Batch 000/192 | Cost: 0.0374\n", "Epoch: 098/100 | Batch 120/192 | Cost: 0.0465\n", "Epoch: 098/100 Train Acc.: 98.27% | Validation Acc.: 86.20%\n", "Time elapsed: 51.15 min\n", "Epoch: 099/100 | Batch 000/192 | Cost: 0.0628\n", "Epoch: 099/100 | Batch 120/192 | Cost: 0.0555\n", "Epoch: 099/100 Train Acc.: 98.00% | Validation Acc.: 85.90%\n", "Time elapsed: 51.94 min\n", "Epoch: 100/100 | Batch 000/192 | Cost: 0.0570\n", "Epoch: 100/100 | Batch 120/192 | Cost: 0.0494\n", "Epoch: 100/100 Train Acc.: 98.31% | Validation Acc.: 85.80%\n", "Time elapsed: 52.72 min\n", "Total Training Time: 52.72 min\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " \n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " ### PREPARE MINIBATCH\n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 120:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " # no need to build the computation graph for backprop when computing accuracy\n", " with torch.set_grad_enabled(False):\n", " train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n", " valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n", " f' | Validation Acc.: {valid_acc:.2f}%')\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# BatchNorm after Activation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class NiN(nn.Module):\n", " def __init__(self, num_classes):\n", " super(NiN, self).__init__()\n", " self.num_classes = num_classes\n", " self.classifier = nn.Sequential(\n", " nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(160),\n", " nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(96),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n", " nn.Dropout(0.5),\n", "\n", " nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n", " nn.ReLU(inplace=True),\n", " nn.BatchNorm2d(192),\n", " nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),\n", " nn.ReLU(inplace=True),\n", " nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n", "\n", " )\n", "\n", " def forward(self, x):\n", " x = self.classifier(x)\n", " logits = x.view(x.size(0), self.num_classes)\n", " probas = torch.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "model = NiN(NUM_CLASSES)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/192 | Cost: 2.3059\n", "Epoch: 001/100 | Batch 120/192 | Cost: 1.0759\n", "Epoch: 001/100 Train Acc.: 64.08% | Validation Acc.: 64.80%\n", "Time elapsed: 0.77 min\n", "Epoch: 002/100 | Batch 000/192 | Cost: 1.1736\n", "Epoch: 002/100 | Batch 120/192 | Cost: 0.8403\n", "Epoch: 002/100 Train Acc.: 72.13% | Validation Acc.: 69.60%\n", "Time elapsed: 1.55 min\n", "Epoch: 003/100 | Batch 000/192 | Cost: 0.7607\n", "Epoch: 003/100 | Batch 120/192 | Cost: 0.7570\n", "Epoch: 003/100 Train Acc.: 76.62% | Validation Acc.: 73.90%\n", "Time elapsed: 2.32 min\n", "Epoch: 004/100 | Batch 000/192 | Cost: 0.6554\n", "Epoch: 004/100 | Batch 120/192 | Cost: 0.6539\n", "Epoch: 004/100 Train Acc.: 78.93% | Validation Acc.: 76.70%\n", "Time elapsed: 3.10 min\n", "Epoch: 005/100 | Batch 000/192 | Cost: 0.5906\n", "Epoch: 005/100 | Batch 120/192 | Cost: 0.7284\n", "Epoch: 005/100 Train Acc.: 81.88% | Validation Acc.: 79.70%\n", "Time elapsed: 3.87 min\n", "Epoch: 006/100 | Batch 000/192 | Cost: 0.5847\n", "Epoch: 006/100 | Batch 120/192 | Cost: 0.5115\n", "Epoch: 006/100 Train Acc.: 83.57% | Validation Acc.: 79.90%\n", "Time elapsed: 4.65 min\n", "Epoch: 007/100 | Batch 000/192 | Cost: 0.5185\n", "Epoch: 007/100 | Batch 120/192 | Cost: 0.4879\n", "Epoch: 007/100 Train Acc.: 84.50% | Validation Acc.: 80.30%\n", "Time elapsed: 5.42 min\n", "Epoch: 008/100 | Batch 000/192 | Cost: 0.4134\n", "Epoch: 008/100 | Batch 120/192 | Cost: 0.4843\n", "Epoch: 008/100 Train Acc.: 85.61% | Validation Acc.: 80.80%\n", "Time elapsed: 6.19 min\n", "Epoch: 009/100 | Batch 000/192 | Cost: 0.3521\n", "Epoch: 009/100 | Batch 120/192 | Cost: 0.5180\n", "Epoch: 009/100 Train Acc.: 87.21% | Validation Acc.: 80.00%\n", "Time elapsed: 6.96 min\n", "Epoch: 010/100 | Batch 000/192 | Cost: 0.4342\n", "Epoch: 010/100 | Batch 120/192 | Cost: 0.4116\n", "Epoch: 010/100 Train Acc.: 87.58% | Validation Acc.: 80.20%\n", "Time elapsed: 7.74 min\n", "Epoch: 011/100 | Batch 000/192 | Cost: 0.4375\n", "Epoch: 011/100 | Batch 120/192 | Cost: 0.4573\n", "Epoch: 011/100 Train Acc.: 88.85% | Validation Acc.: 82.40%\n", "Time elapsed: 8.50 min\n", "Epoch: 012/100 | Batch 000/192 | Cost: 0.3115\n", "Epoch: 012/100 | Batch 120/192 | Cost: 0.3661\n", "Epoch: 012/100 Train Acc.: 89.30% | Validation Acc.: 81.80%\n", "Time elapsed: 9.27 min\n", "Epoch: 013/100 | Batch 000/192 | Cost: 0.2318\n", "Epoch: 013/100 | Batch 120/192 | Cost: 0.2555\n", "Epoch: 013/100 Train Acc.: 89.73% | Validation Acc.: 81.90%\n", "Time elapsed: 10.05 min\n", "Epoch: 014/100 | Batch 000/192 | Cost: 0.3029\n", "Epoch: 014/100 | Batch 120/192 | Cost: 0.3206\n", "Epoch: 014/100 Train Acc.: 90.71% | Validation Acc.: 84.40%\n", "Time elapsed: 10.81 min\n", "Epoch: 015/100 | Batch 000/192 | Cost: 0.3103\n", "Epoch: 015/100 | Batch 120/192 | Cost: 0.3303\n", "Epoch: 015/100 Train Acc.: 91.45% | Validation Acc.: 81.90%\n", "Time elapsed: 11.59 min\n", "Epoch: 016/100 | Batch 000/192 | Cost: 0.3105\n", "Epoch: 016/100 | Batch 120/192 | Cost: 0.2497\n", "Epoch: 016/100 Train Acc.: 91.92% | Validation Acc.: 82.60%\n", "Time elapsed: 12.36 min\n", "Epoch: 017/100 | Batch 000/192 | Cost: 0.1741\n", "Epoch: 017/100 | Batch 120/192 | Cost: 0.2539\n", "Epoch: 017/100 Train Acc.: 92.74% | Validation Acc.: 83.10%\n", "Time elapsed: 13.13 min\n", "Epoch: 018/100 | Batch 000/192 | Cost: 0.2569\n", "Epoch: 018/100 | Batch 120/192 | Cost: 0.2318\n", "Epoch: 018/100 Train Acc.: 93.14% | Validation Acc.: 83.60%\n", "Time elapsed: 13.91 min\n", "Epoch: 019/100 | Batch 000/192 | Cost: 0.2926\n", "Epoch: 019/100 | Batch 120/192 | Cost: 0.1889\n", "Epoch: 019/100 Train Acc.: 92.98% | Validation Acc.: 84.50%\n", "Time elapsed: 14.67 min\n", "Epoch: 020/100 | Batch 000/192 | Cost: 0.1761\n", "Epoch: 020/100 | Batch 120/192 | Cost: 0.1828\n", "Epoch: 020/100 Train Acc.: 93.68% | Validation Acc.: 83.10%\n", "Time elapsed: 15.44 min\n", "Epoch: 021/100 | Batch 000/192 | Cost: 0.1238\n", "Epoch: 021/100 | Batch 120/192 | Cost: 0.1776\n", "Epoch: 021/100 Train Acc.: 93.89% | Validation Acc.: 84.90%\n", "Time elapsed: 16.20 min\n", "Epoch: 022/100 | Batch 000/192 | Cost: 0.2031\n", "Epoch: 022/100 | Batch 120/192 | Cost: 0.1599\n", "Epoch: 022/100 Train Acc.: 93.94% | Validation Acc.: 84.50%\n", "Time elapsed: 16.97 min\n", "Epoch: 023/100 | Batch 000/192 | Cost: 0.1342\n", "Epoch: 023/100 | Batch 120/192 | Cost: 0.1964\n", "Epoch: 023/100 Train Acc.: 94.68% | Validation Acc.: 84.70%\n", "Time elapsed: 17.74 min\n", "Epoch: 024/100 | Batch 000/192 | Cost: 0.1671\n", "Epoch: 024/100 | Batch 120/192 | Cost: 0.1648\n", "Epoch: 024/100 Train Acc.: 94.51% | Validation Acc.: 85.20%\n", "Time elapsed: 18.50 min\n", "Epoch: 025/100 | Batch 000/192 | Cost: 0.1436\n", "Epoch: 025/100 | Batch 120/192 | Cost: 0.1684\n", "Epoch: 025/100 Train Acc.: 94.94% | Validation Acc.: 84.90%\n", "Time elapsed: 19.27 min\n", "Epoch: 026/100 | Batch 000/192 | Cost: 0.1587\n", "Epoch: 026/100 | Batch 120/192 | Cost: 0.1912\n", "Epoch: 026/100 Train Acc.: 95.05% | Validation Acc.: 83.30%\n", "Time elapsed: 20.04 min\n", "Epoch: 027/100 | Batch 000/192 | Cost: 0.1599\n", "Epoch: 027/100 | Batch 120/192 | Cost: 0.1704\n", "Epoch: 027/100 Train Acc.: 95.52% | Validation Acc.: 83.70%\n", "Time elapsed: 20.81 min\n", "Epoch: 028/100 | Batch 000/192 | Cost: 0.1275\n", "Epoch: 028/100 | Batch 120/192 | Cost: 0.1232\n", "Epoch: 028/100 Train Acc.: 95.63% | Validation Acc.: 84.70%\n", "Time elapsed: 21.60 min\n", "Epoch: 029/100 | Batch 000/192 | Cost: 0.1452\n", "Epoch: 029/100 | Batch 120/192 | Cost: 0.1621\n", "Epoch: 029/100 Train Acc.: 95.83% | Validation Acc.: 84.00%\n", "Time elapsed: 22.36 min\n", "Epoch: 030/100 | Batch 000/192 | Cost: 0.0822\n", "Epoch: 030/100 | Batch 120/192 | Cost: 0.1508\n", "Epoch: 030/100 Train Acc.: 95.72% | Validation Acc.: 85.00%\n", "Time elapsed: 23.14 min\n", "Epoch: 031/100 | Batch 000/192 | Cost: 0.1148\n", "Epoch: 031/100 | Batch 120/192 | Cost: 0.0952\n", "Epoch: 031/100 Train Acc.: 95.70% | Validation Acc.: 84.50%\n", "Time elapsed: 23.92 min\n", "Epoch: 032/100 | Batch 000/192 | Cost: 0.1098\n", "Epoch: 032/100 | Batch 120/192 | Cost: 0.1265\n", "Epoch: 032/100 Train Acc.: 95.58% | Validation Acc.: 84.50%\n", "Time elapsed: 24.69 min\n", "Epoch: 033/100 | Batch 000/192 | Cost: 0.0968\n", "Epoch: 033/100 | Batch 120/192 | Cost: 0.1536\n", "Epoch: 033/100 Train Acc.: 96.37% | Validation Acc.: 85.80%\n", "Time elapsed: 25.47 min\n", "Epoch: 034/100 | Batch 000/192 | Cost: 0.1380\n", "Epoch: 034/100 | Batch 120/192 | Cost: 0.1361\n", "Epoch: 034/100 Train Acc.: 96.40% | Validation Acc.: 84.50%\n", "Time elapsed: 26.24 min\n", "Epoch: 035/100 | Batch 000/192 | Cost: 0.1400\n", "Epoch: 035/100 | Batch 120/192 | Cost: 0.1103\n", "Epoch: 035/100 Train Acc.: 96.68% | Validation Acc.: 86.40%\n", "Time elapsed: 27.02 min\n", "Epoch: 036/100 | Batch 000/192 | Cost: 0.0920\n", "Epoch: 036/100 | Batch 120/192 | Cost: 0.1332\n", "Epoch: 036/100 Train Acc.: 96.61% | Validation Acc.: 84.30%\n", "Time elapsed: 27.78 min\n", "Epoch: 037/100 | Batch 000/192 | Cost: 0.0682\n", "Epoch: 037/100 | Batch 120/192 | Cost: 0.1231\n", "Epoch: 037/100 Train Acc.: 96.53% | Validation Acc.: 84.80%\n", "Time elapsed: 28.55 min\n", "Epoch: 038/100 | Batch 000/192 | Cost: 0.1042\n", "Epoch: 038/100 | Batch 120/192 | Cost: 0.1283\n", "Epoch: 038/100 Train Acc.: 96.54% | Validation Acc.: 84.70%\n", "Time elapsed: 29.33 min\n", "Epoch: 039/100 | Batch 000/192 | Cost: 0.1099\n", "Epoch: 039/100 | Batch 120/192 | Cost: 0.0976\n", "Epoch: 039/100 Train Acc.: 97.05% | Validation Acc.: 84.80%\n", "Time elapsed: 30.09 min\n", "Epoch: 040/100 | Batch 000/192 | Cost: 0.0670\n", "Epoch: 040/100 | Batch 120/192 | Cost: 0.1400\n", "Epoch: 040/100 Train Acc.: 96.85% | Validation Acc.: 84.90%\n", "Time elapsed: 30.87 min\n", "Epoch: 041/100 | Batch 000/192 | Cost: 0.1038\n", "Epoch: 041/100 | Batch 120/192 | Cost: 0.1502\n", "Epoch: 041/100 Train Acc.: 97.14% | Validation Acc.: 83.80%\n", "Time elapsed: 31.64 min\n", "Epoch: 042/100 | Batch 000/192 | Cost: 0.0742\n", "Epoch: 042/100 | Batch 120/192 | Cost: 0.1515\n", "Epoch: 042/100 Train Acc.: 97.21% | Validation Acc.: 86.00%\n", "Time elapsed: 32.41 min\n", "Epoch: 043/100 | Batch 000/192 | Cost: 0.1119\n", "Epoch: 043/100 | Batch 120/192 | Cost: 0.1353\n", "Epoch: 043/100 Train Acc.: 97.21% | Validation Acc.: 84.70%\n", "Time elapsed: 33.19 min\n", "Epoch: 044/100 | Batch 000/192 | Cost: 0.0806\n", "Epoch: 044/100 | Batch 120/192 | Cost: 0.0663\n", "Epoch: 044/100 Train Acc.: 97.22% | Validation Acc.: 85.50%\n", "Time elapsed: 33.96 min\n", "Epoch: 045/100 | Batch 000/192 | Cost: 0.0712\n", "Epoch: 045/100 | Batch 120/192 | Cost: 0.0965\n", "Epoch: 045/100 Train Acc.: 97.40% | Validation Acc.: 85.40%\n", "Time elapsed: 34.73 min\n", "Epoch: 046/100 | Batch 000/192 | Cost: 0.0878\n", "Epoch: 046/100 | Batch 120/192 | Cost: 0.0740\n", "Epoch: 046/100 Train Acc.: 97.51% | Validation Acc.: 84.40%\n", "Time elapsed: 35.51 min\n", "Epoch: 047/100 | Batch 000/192 | Cost: 0.1174\n", "Epoch: 047/100 | Batch 120/192 | Cost: 0.0488\n", "Epoch: 047/100 Train Acc.: 97.63% | Validation Acc.: 84.30%\n", "Time elapsed: 36.28 min\n", "Epoch: 048/100 | Batch 000/192 | Cost: 0.0605\n", "Epoch: 048/100 | Batch 120/192 | Cost: 0.1052\n", "Epoch: 048/100 Train Acc.: 97.45% | Validation Acc.: 84.70%\n", "Time elapsed: 37.06 min\n", "Epoch: 049/100 | Batch 000/192 | Cost: 0.0446\n", "Epoch: 049/100 | Batch 120/192 | Cost: 0.0897\n", "Epoch: 049/100 Train Acc.: 97.74% | Validation Acc.: 85.30%\n", "Time elapsed: 37.82 min\n", "Epoch: 050/100 | Batch 000/192 | Cost: 0.0623\n", "Epoch: 050/100 | Batch 120/192 | Cost: 0.0904\n", "Epoch: 050/100 Train Acc.: 97.39% | Validation Acc.: 83.80%\n", "Time elapsed: 38.60 min\n", "Epoch: 051/100 | Batch 000/192 | Cost: 0.0641\n", "Epoch: 051/100 | Batch 120/192 | Cost: 0.0890\n", "Epoch: 051/100 Train Acc.: 97.44% | Validation Acc.: 85.60%\n", "Time elapsed: 39.38 min\n", "Epoch: 052/100 | Batch 000/192 | Cost: 0.0482\n", "Epoch: 052/100 | Batch 120/192 | Cost: 0.0669\n", "Epoch: 052/100 Train Acc.: 97.49% | Validation Acc.: 85.40%\n", "Time elapsed: 40.14 min\n", "Epoch: 053/100 | Batch 000/192 | Cost: 0.0710\n", "Epoch: 053/100 | Batch 120/192 | Cost: 0.1376\n", "Epoch: 053/100 Train Acc.: 97.81% | Validation Acc.: 85.70%\n", "Time elapsed: 40.91 min\n", "Epoch: 054/100 | Batch 000/192 | Cost: 0.0518\n", "Epoch: 054/100 | Batch 120/192 | Cost: 0.0818\n", "Epoch: 054/100 Train Acc.: 97.23% | Validation Acc.: 83.10%\n", "Time elapsed: 41.68 min\n", "Epoch: 055/100 | Batch 000/192 | Cost: 0.0913\n", "Epoch: 055/100 | Batch 120/192 | Cost: 0.1024\n", "Epoch: 055/100 Train Acc.: 97.34% | Validation Acc.: 84.50%\n", "Time elapsed: 42.45 min\n", "Epoch: 056/100 | Batch 000/192 | Cost: 0.0641\n", "Epoch: 056/100 | Batch 120/192 | Cost: 0.1011\n", "Epoch: 056/100 Train Acc.: 97.61% | Validation Acc.: 84.50%\n", "Time elapsed: 43.23 min\n", "Epoch: 057/100 | Batch 000/192 | Cost: 0.0562\n", "Epoch: 057/100 | Batch 120/192 | Cost: 0.0859\n", "Epoch: 057/100 Train Acc.: 98.03% | Validation Acc.: 84.30%\n", "Time elapsed: 44.00 min\n", "Epoch: 058/100 | Batch 000/192 | Cost: 0.0774\n", "Epoch: 058/100 | Batch 120/192 | Cost: 0.0956\n", "Epoch: 058/100 Train Acc.: 97.79% | Validation Acc.: 84.80%\n", "Time elapsed: 44.77 min\n", "Epoch: 059/100 | Batch 000/192 | Cost: 0.0640\n", "Epoch: 059/100 | Batch 120/192 | Cost: 0.0551\n", "Epoch: 059/100 Train Acc.: 97.86% | Validation Acc.: 84.80%\n", "Time elapsed: 45.54 min\n", "Epoch: 060/100 | Batch 000/192 | Cost: 0.0810\n", "Epoch: 060/100 | Batch 120/192 | Cost: 0.0322\n", "Epoch: 060/100 Train Acc.: 97.87% | Validation Acc.: 84.50%\n", "Time elapsed: 46.32 min\n", "Epoch: 061/100 | Batch 000/192 | Cost: 0.0813\n", "Epoch: 061/100 | Batch 120/192 | Cost: 0.0924\n", "Epoch: 061/100 Train Acc.: 97.86% | Validation Acc.: 84.30%\n", "Time elapsed: 47.10 min\n", "Epoch: 062/100 | Batch 000/192 | Cost: 0.0727\n", "Epoch: 062/100 | Batch 120/192 | Cost: 0.0776\n", "Epoch: 062/100 Train Acc.: 97.73% | Validation Acc.: 84.60%\n", "Time elapsed: 47.86 min\n", "Epoch: 063/100 | Batch 000/192 | Cost: 0.0436\n", "Epoch: 063/100 | Batch 120/192 | Cost: 0.0313\n", "Epoch: 063/100 Train Acc.: 98.00% | Validation Acc.: 86.40%\n", "Time elapsed: 48.63 min\n", "Epoch: 064/100 | Batch 000/192 | Cost: 0.0491\n", "Epoch: 064/100 | Batch 120/192 | Cost: 0.0530\n", "Epoch: 064/100 Train Acc.: 98.26% | Validation Acc.: 85.40%\n", "Time elapsed: 49.40 min\n", "Epoch: 065/100 | Batch 000/192 | Cost: 0.0721\n", "Epoch: 065/100 | Batch 120/192 | Cost: 0.0621\n", "Epoch: 065/100 Train Acc.: 97.99% | Validation Acc.: 85.20%\n", "Time elapsed: 50.17 min\n", "Epoch: 066/100 | Batch 000/192 | Cost: 0.0697\n", "Epoch: 066/100 | Batch 120/192 | Cost: 0.0426\n", "Epoch: 066/100 Train Acc.: 98.02% | Validation Acc.: 84.80%\n", "Time elapsed: 50.96 min\n", "Epoch: 067/100 | Batch 000/192 | Cost: 0.0613\n", "Epoch: 067/100 | Batch 120/192 | Cost: 0.0714\n", "Epoch: 067/100 Train Acc.: 97.90% | Validation Acc.: 84.00%\n", "Time elapsed: 51.72 min\n", "Epoch: 068/100 | Batch 000/192 | Cost: 0.0676\n", "Epoch: 068/100 | Batch 120/192 | Cost: 0.0286\n", "Epoch: 068/100 Train Acc.: 98.15% | Validation Acc.: 84.10%\n", "Time elapsed: 52.49 min\n", "Epoch: 069/100 | Batch 000/192 | Cost: 0.0482\n", "Epoch: 069/100 | Batch 120/192 | Cost: 0.0609\n", "Epoch: 069/100 Train Acc.: 97.92% | Validation Acc.: 83.80%\n", "Time elapsed: 53.25 min\n", "Epoch: 070/100 | Batch 000/192 | Cost: 0.0462\n", "Epoch: 070/100 | Batch 120/192 | Cost: 0.0434\n", "Epoch: 070/100 Train Acc.: 98.15% | Validation Acc.: 84.90%\n", "Time elapsed: 54.02 min\n", "Epoch: 071/100 | Batch 000/192 | Cost: 0.0306\n", "Epoch: 071/100 | Batch 120/192 | Cost: 0.1153\n", "Epoch: 071/100 Train Acc.: 98.24% | Validation Acc.: 86.20%\n", "Time elapsed: 54.80 min\n", "Epoch: 072/100 | Batch 000/192 | Cost: 0.0465\n", "Epoch: 072/100 | Batch 120/192 | Cost: 0.0603\n", "Epoch: 072/100 Train Acc.: 98.17% | Validation Acc.: 85.70%\n", "Time elapsed: 55.57 min\n", "Epoch: 073/100 | Batch 000/192 | Cost: 0.0943\n", "Epoch: 073/100 | Batch 120/192 | Cost: 0.0509\n", "Epoch: 073/100 Train Acc.: 98.30% | Validation Acc.: 84.70%\n", "Time elapsed: 56.35 min\n", "Epoch: 074/100 | Batch 000/192 | Cost: 0.0651\n", "Epoch: 074/100 | Batch 120/192 | Cost: 0.0559\n", "Epoch: 074/100 Train Acc.: 98.24% | Validation Acc.: 86.00%\n", "Time elapsed: 57.12 min\n", "Epoch: 075/100 | Batch 000/192 | Cost: 0.0400\n", "Epoch: 075/100 | Batch 120/192 | Cost: 0.0258\n", "Epoch: 075/100 Train Acc.: 98.37% | Validation Acc.: 85.30%\n", "Time elapsed: 57.90 min\n", "Epoch: 076/100 | Batch 000/192 | Cost: 0.0398\n", "Epoch: 076/100 | Batch 120/192 | Cost: 0.0495\n", "Epoch: 076/100 Train Acc.: 98.30% | Validation Acc.: 86.00%\n", "Time elapsed: 58.68 min\n", "Epoch: 077/100 | Batch 000/192 | Cost: 0.0373\n", "Epoch: 077/100 | Batch 120/192 | Cost: 0.0597\n", "Epoch: 077/100 Train Acc.: 98.31% | Validation Acc.: 84.90%\n", "Time elapsed: 59.44 min\n", "Epoch: 078/100 | Batch 000/192 | Cost: 0.0468\n", "Epoch: 078/100 | Batch 120/192 | Cost: 0.0494\n", "Epoch: 078/100 Train Acc.: 98.31% | Validation Acc.: 85.60%\n", "Time elapsed: 60.22 min\n", "Epoch: 079/100 | Batch 000/192 | Cost: 0.0481\n", "Epoch: 079/100 | Batch 120/192 | Cost: 0.0493\n", "Epoch: 079/100 Train Acc.: 98.44% | Validation Acc.: 85.10%\n", "Time elapsed: 60.99 min\n", "Epoch: 080/100 | Batch 000/192 | Cost: 0.0282\n", "Epoch: 080/100 | Batch 120/192 | Cost: 0.0537\n", "Epoch: 080/100 Train Acc.: 98.48% | Validation Acc.: 86.80%\n", "Time elapsed: 61.75 min\n", "Epoch: 081/100 | Batch 000/192 | Cost: 0.0496\n", "Epoch: 081/100 | Batch 120/192 | Cost: 0.0403\n", "Epoch: 081/100 Train Acc.: 98.14% | Validation Acc.: 86.40%\n", "Time elapsed: 62.52 min\n", "Epoch: 082/100 | Batch 000/192 | Cost: 0.1032\n", "Epoch: 082/100 | Batch 120/192 | Cost: 0.0374\n", "Epoch: 082/100 Train Acc.: 98.17% | Validation Acc.: 86.00%\n", "Time elapsed: 63.28 min\n", "Epoch: 083/100 | Batch 000/192 | Cost: 0.0847\n", "Epoch: 083/100 | Batch 120/192 | Cost: 0.0557\n", "Epoch: 083/100 Train Acc.: 98.59% | Validation Acc.: 86.30%\n", "Time elapsed: 64.04 min\n", "Epoch: 084/100 | Batch 000/192 | Cost: 0.0786\n", "Epoch: 084/100 | Batch 120/192 | Cost: 0.0694\n", "Epoch: 084/100 Train Acc.: 98.49% | Validation Acc.: 83.90%\n", "Time elapsed: 64.81 min\n", "Epoch: 085/100 | Batch 000/192 | Cost: 0.0483\n", "Epoch: 085/100 | Batch 120/192 | Cost: 0.0588\n", "Epoch: 085/100 Train Acc.: 98.14% | Validation Acc.: 85.20%\n", "Time elapsed: 65.58 min\n", "Epoch: 086/100 | Batch 000/192 | Cost: 0.0279\n", "Epoch: 086/100 | Batch 120/192 | Cost: 0.0710\n", "Epoch: 086/100 Train Acc.: 98.48% | Validation Acc.: 86.60%\n", "Time elapsed: 66.35 min\n", "Epoch: 087/100 | Batch 000/192 | Cost: 0.0264\n", "Epoch: 087/100 | Batch 120/192 | Cost: 0.0266\n", "Epoch: 087/100 Train Acc.: 98.54% | Validation Acc.: 85.20%\n", "Time elapsed: 67.11 min\n", "Epoch: 088/100 | Batch 000/192 | Cost: 0.0273\n", "Epoch: 088/100 | Batch 120/192 | Cost: 0.0402\n", "Epoch: 088/100 Train Acc.: 98.46% | Validation Acc.: 85.50%\n", "Time elapsed: 67.89 min\n", "Epoch: 089/100 | Batch 000/192 | Cost: 0.0601\n", "Epoch: 089/100 | Batch 120/192 | Cost: 0.0424\n", "Epoch: 089/100 Train Acc.: 98.42% | Validation Acc.: 86.00%\n", "Time elapsed: 68.67 min\n", "Epoch: 090/100 | Batch 000/192 | Cost: 0.0147\n", "Epoch: 090/100 | Batch 120/192 | Cost: 0.0417\n", "Epoch: 090/100 Train Acc.: 98.60% | Validation Acc.: 85.50%\n", "Time elapsed: 69.44 min\n", "Epoch: 091/100 | Batch 000/192 | Cost: 0.0825\n", "Epoch: 091/100 | Batch 120/192 | Cost: 0.1113\n", "Epoch: 091/100 Train Acc.: 98.61% | Validation Acc.: 86.00%\n", "Time elapsed: 70.22 min\n", "Epoch: 092/100 | Batch 000/192 | Cost: 0.0482\n", "Epoch: 092/100 | Batch 120/192 | Cost: 0.0664\n", "Epoch: 092/100 Train Acc.: 98.61% | Validation Acc.: 85.30%\n", "Time elapsed: 70.98 min\n", "Epoch: 093/100 | Batch 000/192 | Cost: 0.0298\n", "Epoch: 093/100 | Batch 120/192 | Cost: 0.0673\n", "Epoch: 093/100 Train Acc.: 98.62% | Validation Acc.: 86.60%\n", "Time elapsed: 71.74 min\n", "Epoch: 094/100 | Batch 000/192 | Cost: 0.0173\n", "Epoch: 094/100 | Batch 120/192 | Cost: 0.0699\n", "Epoch: 094/100 Train Acc.: 98.55% | Validation Acc.: 85.40%\n", "Time elapsed: 72.51 min\n", "Epoch: 095/100 | Batch 000/192 | Cost: 0.0298\n", "Epoch: 095/100 | Batch 120/192 | Cost: 0.0382\n", "Epoch: 095/100 Train Acc.: 98.61% | Validation Acc.: 87.00%\n", "Time elapsed: 73.27 min\n", "Epoch: 096/100 | Batch 000/192 | Cost: 0.0715\n", "Epoch: 096/100 | Batch 120/192 | Cost: 0.0298\n", "Epoch: 096/100 Train Acc.: 98.83% | Validation Acc.: 85.50%\n", "Time elapsed: 74.05 min\n", "Epoch: 097/100 | Batch 000/192 | Cost: 0.0645\n", "Epoch: 097/100 | Batch 120/192 | Cost: 0.0374\n", "Epoch: 097/100 Train Acc.: 98.66% | Validation Acc.: 86.30%\n", "Time elapsed: 74.81 min\n", "Epoch: 098/100 | Batch 000/192 | Cost: 0.0257\n", "Epoch: 098/100 | Batch 120/192 | Cost: 0.0492\n", "Epoch: 098/100 Train Acc.: 98.62% | Validation Acc.: 87.20%\n", "Time elapsed: 75.59 min\n", "Epoch: 099/100 | Batch 000/192 | Cost: 0.0785\n", "Epoch: 099/100 | Batch 120/192 | Cost: 0.0587\n", "Epoch: 099/100 Train Acc.: 98.55% | Validation Acc.: 85.80%\n", "Time elapsed: 76.37 min\n", "Epoch: 100/100 | Batch 000/192 | Cost: 0.0470\n", "Epoch: 100/100 | Batch 120/192 | Cost: 0.0452\n", "Epoch: 100/100 Train Acc.: 98.75% | Validation Acc.: 86.20%\n", "Time elapsed: 77.12 min\n", "Total Training Time: 77.12 min\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " \n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " ### PREPARE MINIBATCH\n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 120:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " # no need to build the computation graph for backprop when computing accuracy\n", " with torch.set_grad_enabled(False):\n", " train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n", " valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n", " f' | Validation Acc.: {valid_acc:.2f}%')\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.17.4\n", "torchvision 0.4.1a0+d94043a\n", "matplotlib 3.1.0\n", "torch 1.3.0\n", "PIL.Image 6.2.1\n", "pandas 0.24.2\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }