{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1.post2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MEu9MiOxj5wk"
   },
   "source": [
    "- Runs on CPU (not recommended here) or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Model Zoo -- Convolutional Neural Network (VGG19 Architecture)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of the VGG-19 architecture on Cifar10.  \n",
    "\n",
    "\n",
    "Reference for VGG-19:\n",
    "    \n",
    "- Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.\n",
    "\n",
    "\n",
    "The following table (taken from Simonyan & Zisserman referenced above) summarizes the VGG19 architecture:\n",
    "\n",
    "![](../images/vgg19/vgg19-arch-table.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PvgJ_0i7j5wt"
   },
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda:0\n",
      "Files already downloaded and verified\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('Device:', DEVICE)\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.001\n",
    "num_epochs = 20\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.CIFAR10(root='data', \n",
    "                                 train=True, \n",
    "                                 transform=transforms.ToTensor(),\n",
    "                                 download=True)\n",
    "\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class VGG16(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(VGG16, self).__init__()\n",
    "        \n",
    "        # calculate same padding:\n",
    "        # (w - k + 2*p)/s + 1 = o\n",
    "        # => p = (s(o-1) - w + k)/2\n",
    "        \n",
    "        self.block_1 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=3,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          # (1(32-1)- 32 + 3)/2 = 1\n",
    "                          padding=1), \n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=64,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_2 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=64,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=128,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_3 = nn.Sequential(        \n",
    "                nn.Conv2d(in_channels=128,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "          \n",
    "        self.block_4 = nn.Sequential(   \n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),   \n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_5 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),            \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),            \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),   \n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))             \n",
    "        )\n",
    "        \n",
    "        self.classifier = nn.Sequential(\n",
    "                nn.Linear(512, 4096),\n",
    "                nn.ReLU(True),\n",
    "                nn.Linear(4096, 4096),\n",
    "                nn.ReLU(True),\n",
    "                nn.Linear(4096, num_classes)\n",
    "        )\n",
    "            \n",
    "        \n",
    "        for m in self.modules():\n",
    "            if isinstance(m, torch.nn.Conv2d):\n",
    "                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
    "                #m.weight.data.normal_(0, np.sqrt(2. / n))\n",
    "                m.weight.detach().normal_(0, 0.05)\n",
    "                if m.bias is not None:\n",
    "                    m.bias.detach().zero_()\n",
    "            elif isinstance(m, torch.nn.Linear):\n",
    "                m.weight.detach().normal_(0, 0.05)\n",
    "                m.bias.detach().detach().zero_()\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "\n",
    "        x = self.block_1(x)\n",
    "        x = self.block_2(x)\n",
    "        x = self.block_3(x)\n",
    "        x = self.block_4(x)\n",
    "        x = self.block_5(x)\n",
    "        logits = self.classifier(x.view(-1, 512))\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = VGG16(num_features=num_features,\n",
    "              num_classes=num_classes)\n",
    "\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/020 | Batch 0000/0391 | Cost: 1061.4152\n",
      "Epoch: 001/020 | Batch 0050/0391 | Cost: 2.3018\n",
      "Epoch: 001/020 | Batch 0100/0391 | Cost: 2.0600\n",
      "Epoch: 001/020 | Batch 0150/0391 | Cost: 1.9973\n",
      "Epoch: 001/020 | Batch 0200/0391 | Cost: 1.8176\n",
      "Epoch: 001/020 | Batch 0250/0391 | Cost: 1.8368\n",
      "Epoch: 001/020 | Batch 0300/0391 | Cost: 1.7213\n",
      "Epoch: 001/020 | Batch 0350/0391 | Cost: 1.7154\n",
      "Epoch: 001/020 | Train: 35.478% | Loss: 1.685\n",
      "Time elapsed: 1.02 min\n",
      "Epoch: 002/020 | Batch 0000/0391 | Cost: 1.7648\n",
      "Epoch: 002/020 | Batch 0050/0391 | Cost: 1.7050\n",
      "Epoch: 002/020 | Batch 0100/0391 | Cost: 1.5464\n",
      "Epoch: 002/020 | Batch 0150/0391 | Cost: 1.6054\n",
      "Epoch: 002/020 | Batch 0200/0391 | Cost: 1.4430\n",
      "Epoch: 002/020 | Batch 0250/0391 | Cost: 1.4253\n",
      "Epoch: 002/020 | Batch 0300/0391 | Cost: 1.5701\n",
      "Epoch: 002/020 | Batch 0350/0391 | Cost: 1.4163\n",
      "Epoch: 002/020 | Train: 44.042% | Loss: 1.531\n",
      "Time elapsed: 2.07 min\n",
      "Epoch: 003/020 | Batch 0000/0391 | Cost: 1.5172\n",
      "Epoch: 003/020 | Batch 0050/0391 | Cost: 1.1992\n",
      "Epoch: 003/020 | Batch 0100/0391 | Cost: 1.2846\n",
      "Epoch: 003/020 | Batch 0150/0391 | Cost: 1.4088\n",
      "Epoch: 003/020 | Batch 0200/0391 | Cost: 1.4853\n",
      "Epoch: 003/020 | Batch 0250/0391 | Cost: 1.3923\n",
      "Epoch: 003/020 | Batch 0300/0391 | Cost: 1.3268\n",
      "Epoch: 003/020 | Batch 0350/0391 | Cost: 1.3162\n",
      "Epoch: 003/020 | Train: 55.596% | Loss: 1.223\n",
      "Time elapsed: 3.10 min\n",
      "Epoch: 004/020 | Batch 0000/0391 | Cost: 1.2210\n",
      "Epoch: 004/020 | Batch 0050/0391 | Cost: 1.2594\n",
      "Epoch: 004/020 | Batch 0100/0391 | Cost: 1.2881\n",
      "Epoch: 004/020 | Batch 0150/0391 | Cost: 1.0182\n",
      "Epoch: 004/020 | Batch 0200/0391 | Cost: 1.1256\n",
      "Epoch: 004/020 | Batch 0250/0391 | Cost: 1.1048\n",
      "Epoch: 004/020 | Batch 0300/0391 | Cost: 1.1812\n",
      "Epoch: 004/020 | Batch 0350/0391 | Cost: 1.1685\n",
      "Epoch: 004/020 | Train: 57.594% | Loss: 1.178\n",
      "Time elapsed: 4.13 min\n",
      "Epoch: 005/020 | Batch 0000/0391 | Cost: 1.1298\n",
      "Epoch: 005/020 | Batch 0050/0391 | Cost: 0.9705\n",
      "Epoch: 005/020 | Batch 0100/0391 | Cost: 0.9255\n",
      "Epoch: 005/020 | Batch 0150/0391 | Cost: 1.3610\n",
      "Epoch: 005/020 | Batch 0200/0391 | Cost: 0.9720\n",
      "Epoch: 005/020 | Batch 0250/0391 | Cost: 1.0088\n",
      "Epoch: 005/020 | Batch 0300/0391 | Cost: 0.9998\n",
      "Epoch: 005/020 | Batch 0350/0391 | Cost: 1.1961\n",
      "Epoch: 005/020 | Train: 63.570% | Loss: 1.003\n",
      "Time elapsed: 5.17 min\n",
      "Epoch: 006/020 | Batch 0000/0391 | Cost: 0.8837\n",
      "Epoch: 006/020 | Batch 0050/0391 | Cost: 0.9184\n",
      "Epoch: 006/020 | Batch 0100/0391 | Cost: 0.8568\n",
      "Epoch: 006/020 | Batch 0150/0391 | Cost: 1.0788\n",
      "Epoch: 006/020 | Batch 0200/0391 | Cost: 1.0365\n",
      "Epoch: 006/020 | Batch 0250/0391 | Cost: 0.8714\n",
      "Epoch: 006/020 | Batch 0300/0391 | Cost: 1.0370\n",
      "Epoch: 006/020 | Batch 0350/0391 | Cost: 1.0536\n",
      "Epoch: 006/020 | Train: 68.390% | Loss: 0.880\n",
      "Time elapsed: 6.20 min\n",
      "Epoch: 007/020 | Batch 0000/0391 | Cost: 1.0297\n",
      "Epoch: 007/020 | Batch 0050/0391 | Cost: 0.8801\n",
      "Epoch: 007/020 | Batch 0100/0391 | Cost: 0.9652\n",
      "Epoch: 007/020 | Batch 0150/0391 | Cost: 1.1417\n",
      "Epoch: 007/020 | Batch 0200/0391 | Cost: 0.8851\n",
      "Epoch: 007/020 | Batch 0250/0391 | Cost: 0.9499\n",
      "Epoch: 007/020 | Batch 0300/0391 | Cost: 0.9416\n",
      "Epoch: 007/020 | Batch 0350/0391 | Cost: 0.9220\n",
      "Epoch: 007/020 | Train: 68.740% | Loss: 0.872\n",
      "Time elapsed: 7.24 min\n",
      "Epoch: 008/020 | Batch 0000/0391 | Cost: 1.0054\n",
      "Epoch: 008/020 | Batch 0050/0391 | Cost: 0.8184\n",
      "Epoch: 008/020 | Batch 0100/0391 | Cost: 0.8955\n",
      "Epoch: 008/020 | Batch 0150/0391 | Cost: 0.9319\n",
      "Epoch: 008/020 | Batch 0200/0391 | Cost: 1.0566\n",
      "Epoch: 008/020 | Batch 0250/0391 | Cost: 1.0591\n",
      "Epoch: 008/020 | Batch 0300/0391 | Cost: 0.7914\n",
      "Epoch: 008/020 | Batch 0350/0391 | Cost: 0.9090\n",
      "Epoch: 008/020 | Train: 72.846% | Loss: 0.770\n",
      "Time elapsed: 8.27 min\n",
      "Epoch: 009/020 | Batch 0000/0391 | Cost: 0.6672\n",
      "Epoch: 009/020 | Batch 0050/0391 | Cost: 0.7192\n",
      "Epoch: 009/020 | Batch 0100/0391 | Cost: 0.8586\n",
      "Epoch: 009/020 | Batch 0150/0391 | Cost: 0.7310\n",
      "Epoch: 009/020 | Batch 0200/0391 | Cost: 0.8406\n",
      "Epoch: 009/020 | Batch 0250/0391 | Cost: 0.7620\n",
      "Epoch: 009/020 | Batch 0300/0391 | Cost: 0.6692\n",
      "Epoch: 009/020 | Batch 0350/0391 | Cost: 0.6407\n",
      "Epoch: 009/020 | Train: 73.702% | Loss: 0.748\n",
      "Time elapsed: 9.30 min\n",
      "Epoch: 010/020 | Batch 0000/0391 | Cost: 0.6539\n",
      "Epoch: 010/020 | Batch 0050/0391 | Cost: 1.0382\n",
      "Epoch: 010/020 | Batch 0100/0391 | Cost: 0.5921\n",
      "Epoch: 010/020 | Batch 0150/0391 | Cost: 0.4933\n",
      "Epoch: 010/020 | Batch 0200/0391 | Cost: 0.7485\n",
      "Epoch: 010/020 | Batch 0250/0391 | Cost: 0.6779\n",
      "Epoch: 010/020 | Batch 0300/0391 | Cost: 0.6787\n",
      "Epoch: 010/020 | Batch 0350/0391 | Cost: 0.6977\n",
      "Epoch: 010/020 | Train: 75.708% | Loss: 0.703\n",
      "Time elapsed: 10.34 min\n",
      "Epoch: 011/020 | Batch 0000/0391 | Cost: 0.6866\n",
      "Epoch: 011/020 | Batch 0050/0391 | Cost: 0.7203\n",
      "Epoch: 011/020 | Batch 0100/0391 | Cost: 0.5730\n",
      "Epoch: 011/020 | Batch 0150/0391 | Cost: 0.5762\n",
      "Epoch: 011/020 | Batch 0200/0391 | Cost: 0.6571\n",
      "Epoch: 011/020 | Batch 0250/0391 | Cost: 0.7582\n",
      "Epoch: 011/020 | Batch 0300/0391 | Cost: 0.7366\n",
      "Epoch: 011/020 | Batch 0350/0391 | Cost: 0.6810\n",
      "Epoch: 011/020 | Train: 79.044% | Loss: 0.606\n",
      "Time elapsed: 11.37 min\n",
      "Epoch: 012/020 | Batch 0000/0391 | Cost: 0.5665\n",
      "Epoch: 012/020 | Batch 0050/0391 | Cost: 0.7081\n",
      "Epoch: 012/020 | Batch 0100/0391 | Cost: 0.6823\n",
      "Epoch: 012/020 | Batch 0150/0391 | Cost: 0.8297\n",
      "Epoch: 012/020 | Batch 0200/0391 | Cost: 0.6470\n",
      "Epoch: 012/020 | Batch 0250/0391 | Cost: 0.7293\n",
      "Epoch: 012/020 | Batch 0300/0391 | Cost: 0.9127\n",
      "Epoch: 012/020 | Batch 0350/0391 | Cost: 0.8419\n",
      "Epoch: 012/020 | Train: 79.474% | Loss: 0.585\n",
      "Time elapsed: 12.40 min\n",
      "Epoch: 013/020 | Batch 0000/0391 | Cost: 0.4087\n",
      "Epoch: 013/020 | Batch 0050/0391 | Cost: 0.4224\n",
      "Epoch: 013/020 | Batch 0100/0391 | Cost: 0.4336\n",
      "Epoch: 013/020 | Batch 0150/0391 | Cost: 0.6586\n",
      "Epoch: 013/020 | Batch 0200/0391 | Cost: 0.7107\n",
      "Epoch: 013/020 | Batch 0250/0391 | Cost: 0.7359\n",
      "Epoch: 013/020 | Batch 0300/0391 | Cost: 0.4860\n",
      "Epoch: 013/020 | Batch 0350/0391 | Cost: 0.7271\n",
      "Epoch: 013/020 | Train: 80.746% | Loss: 0.549\n",
      "Time elapsed: 13.44 min\n",
      "Epoch: 014/020 | Batch 0000/0391 | Cost: 0.5500\n",
      "Epoch: 014/020 | Batch 0050/0391 | Cost: 0.5108\n",
      "Epoch: 014/020 | Batch 0100/0391 | Cost: 0.5186\n",
      "Epoch: 014/020 | Batch 0150/0391 | Cost: 0.4737\n",
      "Epoch: 014/020 | Batch 0200/0391 | Cost: 0.7015\n",
      "Epoch: 014/020 | Batch 0250/0391 | Cost: 0.6069\n",
      "Epoch: 014/020 | Batch 0300/0391 | Cost: 0.7080\n",
      "Epoch: 014/020 | Batch 0350/0391 | Cost: 0.6460\n",
      "Epoch: 014/020 | Train: 81.596% | Loss: 0.553\n",
      "Time elapsed: 14.47 min\n",
      "Epoch: 015/020 | Batch 0000/0391 | Cost: 0.5398\n",
      "Epoch: 015/020 | Batch 0050/0391 | Cost: 0.5269\n",
      "Epoch: 015/020 | Batch 0100/0391 | Cost: 0.5048\n",
      "Epoch: 015/020 | Batch 0150/0391 | Cost: 0.5873\n",
      "Epoch: 015/020 | Batch 0200/0391 | Cost: 0.5320\n",
      "Epoch: 015/020 | Batch 0250/0391 | Cost: 0.4743\n",
      "Epoch: 015/020 | Batch 0300/0391 | Cost: 0.6124\n",
      "Epoch: 015/020 | Batch 0350/0391 | Cost: 0.7204\n",
      "Epoch: 015/020 | Train: 85.276% | Loss: 0.439\n",
      "Time elapsed: 15.51 min\n",
      "Epoch: 016/020 | Batch 0000/0391 | Cost: 0.4387\n",
      "Epoch: 016/020 | Batch 0050/0391 | Cost: 0.3777\n",
      "Epoch: 016/020 | Batch 0100/0391 | Cost: 0.3430\n",
      "Epoch: 016/020 | Batch 0150/0391 | Cost: 0.5901\n",
      "Epoch: 016/020 | Batch 0200/0391 | Cost: 0.6303\n",
      "Epoch: 016/020 | Batch 0250/0391 | Cost: 0.4983\n",
      "Epoch: 016/020 | Batch 0300/0391 | Cost: 0.6507\n",
      "Epoch: 016/020 | Batch 0350/0391 | Cost: 0.4663\n",
      "Epoch: 016/020 | Train: 86.440% | Loss: 0.406\n",
      "Time elapsed: 16.55 min\n",
      "Epoch: 017/020 | Batch 0000/0391 | Cost: 0.4675\n",
      "Epoch: 017/020 | Batch 0050/0391 | Cost: 0.6440\n",
      "Epoch: 017/020 | Batch 0100/0391 | Cost: 0.3536\n",
      "Epoch: 017/020 | Batch 0150/0391 | Cost: 0.5421\n",
      "Epoch: 017/020 | Batch 0200/0391 | Cost: 0.4504\n",
      "Epoch: 017/020 | Batch 0250/0391 | Cost: 0.4169\n",
      "Epoch: 017/020 | Batch 0300/0391 | Cost: 0.4617\n",
      "Epoch: 017/020 | Batch 0350/0391 | Cost: 0.4092\n",
      "Epoch: 017/020 | Train: 84.636% | Loss: 0.459\n",
      "Time elapsed: 17.59 min\n",
      "Epoch: 018/020 | Batch 0000/0391 | Cost: 0.4267\n",
      "Epoch: 018/020 | Batch 0050/0391 | Cost: 0.6478\n",
      "Epoch: 018/020 | Batch 0100/0391 | Cost: 0.5806\n",
      "Epoch: 018/020 | Batch 0150/0391 | Cost: 0.5453\n",
      "Epoch: 018/020 | Batch 0200/0391 | Cost: 0.4984\n",
      "Epoch: 018/020 | Batch 0250/0391 | Cost: 0.2517\n",
      "Epoch: 018/020 | Batch 0300/0391 | Cost: 0.5219\n",
      "Epoch: 018/020 | Batch 0350/0391 | Cost: 0.5217\n",
      "Epoch: 018/020 | Train: 86.094% | Loss: 0.413\n",
      "Time elapsed: 18.63 min\n",
      "Epoch: 019/020 | Batch 0000/0391 | Cost: 0.3849\n",
      "Epoch: 019/020 | Batch 0050/0391 | Cost: 0.2890\n",
      "Epoch: 019/020 | Batch 0100/0391 | Cost: 0.5058\n",
      "Epoch: 019/020 | Batch 0150/0391 | Cost: 0.5718\n",
      "Epoch: 019/020 | Batch 0200/0391 | Cost: 0.4053\n",
      "Epoch: 019/020 | Batch 0250/0391 | Cost: 0.5241\n",
      "Epoch: 019/020 | Batch 0300/0391 | Cost: 0.7110\n",
      "Epoch: 019/020 | Batch 0350/0391 | Cost: 0.4572\n",
      "Epoch: 019/020 | Train: 87.586% | Loss: 0.365\n",
      "Time elapsed: 19.67 min\n",
      "Epoch: 020/020 | Batch 0000/0391 | Cost: 0.3576\n",
      "Epoch: 020/020 | Batch 0050/0391 | Cost: 0.3466\n",
      "Epoch: 020/020 | Batch 0100/0391 | Cost: 0.3427\n",
      "Epoch: 020/020 | Batch 0150/0391 | Cost: 0.3117\n",
      "Epoch: 020/020 | Batch 0200/0391 | Cost: 0.4912\n",
      "Epoch: 020/020 | Batch 0250/0391 | Cost: 0.4481\n",
      "Epoch: 020/020 | Batch 0300/0391 | Cost: 0.6303\n",
      "Epoch: 020/020 | Batch 0350/0391 | Cost: 0.4274\n",
      "Epoch: 020/020 | Train: 88.024% | Loss: 0.361\n",
      "Time elapsed: 20.71 min\n",
      "Total Training Time: 20.71 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "\n",
    "\n",
    "def compute_epoch_loss(model, data_loader):\n",
    "    model.eval()\n",
    "    curr_loss, num_examples = 0., 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.to(DEVICE)\n",
    "            targets = targets.to(DEVICE)\n",
    "            logits, probas = model(features)\n",
    "            loss = F.cross_entropy(logits, targets, reduction='sum')\n",
    "            num_examples += targets.size(0)\n",
    "            curr_loss += loss\n",
    "\n",
    "        curr_loss = curr_loss / num_examples\n",
    "        return curr_loss\n",
    "    \n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader),\n",
    "              compute_epoch_loss(model, train_loader)))\n",
    "\n",
    "\n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 74.56%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.1.post2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}