{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Model-Zoo----Convolutional-Neural-Network-(VGG16)\" data-toc-modified-id=\"Model-Zoo----Convolutional-Neural-Network-(VGG16)-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Model Zoo -- Convolutional Neural Network (VGG16)</a></span><ul class=\"toc-item\"><li><span><a href=\"#Imports\" data-toc-modified-id=\"Imports-1.1\"><span class=\"toc-item-num\">1.1&nbsp;&nbsp;</span>Imports</a></span></li><li><span><a href=\"#Settings-and-Dataset\" data-toc-modified-id=\"Settings-and-Dataset-1.2\"><span class=\"toc-item-num\">1.2&nbsp;&nbsp;</span>Settings and Dataset</a></span></li><li><span><a href=\"#Model\" data-toc-modified-id=\"Model-1.3\"><span class=\"toc-item-num\">1.3&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href=\"#Training\" data-toc-modified-id=\"Training-1.4\"><span class=\"toc-item-num\">1.4&nbsp;&nbsp;</span>Training</a></span></li><li><span><a href=\"#Evaluation\" data-toc-modified-id=\"Evaluation-1.5\"><span class=\"toc-item-num\">1.5&nbsp;&nbsp;</span>Evaluation</a></span></li></ul></li></ul></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.6.1\n",
      "\n",
      "torch 1.3.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MEu9MiOxj5wk"
   },
   "source": [
    "- Runs on CPU (not recommended here) or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Model Zoo -- Convolutional Neural Network (VGG16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PvgJ_0i7j5wt"
   },
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda:3\n",
      "Files already downloaded and verified\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('Device:', DEVICE)\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.001\n",
    "num_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.CIFAR10(root='data', \n",
    "                                 train=True, \n",
    "                                 transform=transforms.ToTensor(),\n",
    "                                 download=True)\n",
    "\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class VGG16(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(VGG16, self).__init__()\n",
    "        \n",
    "        # calculate same padding:\n",
    "        # (w - k + 2*p)/s + 1 = o\n",
    "        # => p = (s(o-1) - w + k)/2\n",
    "        \n",
    "        self.block_1 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=3,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          # (1(32-1)- 32 + 3)/2 = 1\n",
    "                          padding=1), \n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=64,\n",
    "                          out_channels=64,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_2 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=64,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=128,\n",
    "                          out_channels=128,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_3 = nn.Sequential(        \n",
    "                nn.Conv2d(in_channels=128,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=256,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),\n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "          \n",
    "        self.block_4 = nn.Sequential(   \n",
    "                nn.Conv2d(in_channels=256,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),        \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),            \n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))\n",
    "        )\n",
    "        \n",
    "        self.block_5 = nn.Sequential(\n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),            \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),            \n",
    "                nn.Conv2d(in_channels=512,\n",
    "                          out_channels=512,\n",
    "                          kernel_size=(3, 3),\n",
    "                          stride=(1, 1),\n",
    "                          padding=1),\n",
    "                nn.ReLU(),    \n",
    "                nn.MaxPool2d(kernel_size=(2, 2),\n",
    "                             stride=(2, 2))             \n",
    "        )\n",
    "            \n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(512, 4096),\n",
    "            nn.ReLU(True),\n",
    "            #nn.Dropout(p=0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(True),\n",
    "            #nn.Dropout(p=0.5),\n",
    "            nn.Linear(4096, num_classes),\n",
    "        )\n",
    "            \n",
    "        for m in self.modules():\n",
    "            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):\n",
    "                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')\n",
    "                if m.bias is not None:\n",
    "                    m.bias.detach().zero_()\n",
    "                    \n",
    "        #self.avgpool = nn.AdaptiveAvgPool2d((7, 7))\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "\n",
    "        x = self.block_1(x)\n",
    "        x = self.block_2(x)\n",
    "        x = self.block_3(x)\n",
    "        x = self.block_4(x)\n",
    "        x = self.block_5(x)\n",
    "        #x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        logits = self.classifier(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = VGG16(num_features=num_features,\n",
    "              num_classes=num_classes)\n",
    "\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 0000/0391 | Cost: 2.4443\n",
      "Epoch: 001/010 | Batch 0050/0391 | Cost: 2.2800\n",
      "Epoch: 001/010 | Batch 0100/0391 | Cost: 2.1115\n",
      "Epoch: 001/010 | Batch 0150/0391 | Cost: 1.9596\n",
      "Epoch: 001/010 | Batch 0200/0391 | Cost: 2.0510\n",
      "Epoch: 001/010 | Batch 0250/0391 | Cost: 1.8116\n",
      "Epoch: 001/010 | Batch 0300/0391 | Cost: 1.8209\n",
      "Epoch: 001/010 | Batch 0350/0391 | Cost: 1.7161\n",
      "Epoch: 001/010 | Train: 36.232% |  Loss: 1.600\n",
      "Time elapsed: 0.59 min\n",
      "Epoch: 002/010 | Batch 0000/0391 | Cost: 1.6484\n",
      "Epoch: 002/010 | Batch 0050/0391 | Cost: 1.8022\n",
      "Epoch: 002/010 | Batch 0100/0391 | Cost: 1.6488\n",
      "Epoch: 002/010 | Batch 0150/0391 | Cost: 1.5652\n",
      "Epoch: 002/010 | Batch 0200/0391 | Cost: 1.3576\n",
      "Epoch: 002/010 | Batch 0250/0391 | Cost: 1.4434\n",
      "Epoch: 002/010 | Batch 0300/0391 | Cost: 1.3092\n",
      "Epoch: 002/010 | Batch 0350/0391 | Cost: 1.1899\n",
      "Epoch: 002/010 | Train: 55.148% |  Loss: 1.229\n",
      "Time elapsed: 1.17 min\n",
      "Epoch: 003/010 | Batch 0000/0391 | Cost: 1.1967\n",
      "Epoch: 003/010 | Batch 0050/0391 | Cost: 1.1714\n",
      "Epoch: 003/010 | Batch 0100/0391 | Cost: 1.3094\n",
      "Epoch: 003/010 | Batch 0150/0391 | Cost: 1.2088\n",
      "Epoch: 003/010 | Batch 0200/0391 | Cost: 1.2108\n",
      "Epoch: 003/010 | Batch 0250/0391 | Cost: 0.9571\n",
      "Epoch: 003/010 | Batch 0300/0391 | Cost: 1.0708\n",
      "Epoch: 003/010 | Batch 0350/0391 | Cost: 0.9150\n",
      "Epoch: 003/010 | Train: 67.218% |  Loss: 0.938\n",
      "Time elapsed: 1.77 min\n",
      "Epoch: 004/010 | Batch 0000/0391 | Cost: 0.8263\n",
      "Epoch: 004/010 | Batch 0050/0391 | Cost: 0.9029\n",
      "Epoch: 004/010 | Batch 0100/0391 | Cost: 1.1502\n",
      "Epoch: 004/010 | Batch 0150/0391 | Cost: 0.9787\n",
      "Epoch: 004/010 | Batch 0200/0391 | Cost: 0.8300\n",
      "Epoch: 004/010 | Batch 0250/0391 | Cost: 1.0111\n",
      "Epoch: 004/010 | Batch 0300/0391 | Cost: 0.9322\n",
      "Epoch: 004/010 | Batch 0350/0391 | Cost: 0.9228\n",
      "Epoch: 004/010 | Train: 70.676% |  Loss: 0.852\n",
      "Time elapsed: 2.36 min\n",
      "Epoch: 005/010 | Batch 0000/0391 | Cost: 0.8467\n",
      "Epoch: 005/010 | Batch 0050/0391 | Cost: 0.6927\n",
      "Epoch: 005/010 | Batch 0100/0391 | Cost: 0.8755\n",
      "Epoch: 005/010 | Batch 0150/0391 | Cost: 0.8654\n",
      "Epoch: 005/010 | Batch 0200/0391 | Cost: 0.8372\n",
      "Epoch: 005/010 | Batch 0250/0391 | Cost: 0.6731\n",
      "Epoch: 005/010 | Batch 0300/0391 | Cost: 0.6184\n",
      "Epoch: 005/010 | Batch 0350/0391 | Cost: 0.6990\n",
      "Epoch: 005/010 | Train: 73.508% |  Loss: 0.764\n",
      "Time elapsed: 2.96 min\n",
      "Epoch: 006/010 | Batch 0000/0391 | Cost: 0.7755\n",
      "Epoch: 006/010 | Batch 0050/0391 | Cost: 0.5677\n",
      "Epoch: 006/010 | Batch 0100/0391 | Cost: 0.8109\n",
      "Epoch: 006/010 | Batch 0150/0391 | Cost: 0.5656\n",
      "Epoch: 006/010 | Batch 0200/0391 | Cost: 0.5641\n",
      "Epoch: 006/010 | Batch 0250/0391 | Cost: 0.8524\n",
      "Epoch: 006/010 | Batch 0300/0391 | Cost: 0.8494\n",
      "Epoch: 006/010 | Batch 0350/0391 | Cost: 0.6887\n",
      "Epoch: 006/010 | Train: 81.112% |  Loss: 0.559\n",
      "Time elapsed: 3.56 min\n",
      "Epoch: 007/010 | Batch 0000/0391 | Cost: 0.5640\n",
      "Epoch: 007/010 | Batch 0050/0391 | Cost: 0.5809\n",
      "Epoch: 007/010 | Batch 0100/0391 | Cost: 0.7346\n",
      "Epoch: 007/010 | Batch 0150/0391 | Cost: 0.6881\n",
      "Epoch: 007/010 | Batch 0200/0391 | Cost: 0.4644\n",
      "Epoch: 007/010 | Batch 0250/0391 | Cost: 0.5216\n",
      "Epoch: 007/010 | Batch 0300/0391 | Cost: 0.6380\n",
      "Epoch: 007/010 | Batch 0350/0391 | Cost: 0.5521\n",
      "Epoch: 007/010 | Train: 84.438% |  Loss: 0.460\n",
      "Time elapsed: 4.15 min\n",
      "Epoch: 008/010 | Batch 0000/0391 | Cost: 0.4843\n",
      "Epoch: 008/010 | Batch 0050/0391 | Cost: 0.5226\n",
      "Epoch: 008/010 | Batch 0100/0391 | Cost: 0.3682\n",
      "Epoch: 008/010 | Batch 0150/0391 | Cost: 0.4941\n",
      "Epoch: 008/010 | Batch 0200/0391 | Cost: 0.5376\n",
      "Epoch: 008/010 | Batch 0250/0391 | Cost: 0.4707\n",
      "Epoch: 008/010 | Batch 0300/0391 | Cost: 0.5882\n",
      "Epoch: 008/010 | Batch 0350/0391 | Cost: 0.5355\n",
      "Epoch: 008/010 | Train: 85.186% |  Loss: 0.442\n",
      "Time elapsed: 4.76 min\n",
      "Epoch: 009/010 | Batch 0000/0391 | Cost: 0.4174\n",
      "Epoch: 009/010 | Batch 0050/0391 | Cost: 0.4007\n",
      "Epoch: 009/010 | Batch 0100/0391 | Cost: 0.4133\n",
      "Epoch: 009/010 | Batch 0150/0391 | Cost: 0.6204\n",
      "Epoch: 009/010 | Batch 0200/0391 | Cost: 0.3827\n",
      "Epoch: 009/010 | Batch 0250/0391 | Cost: 0.4509\n",
      "Epoch: 009/010 | Batch 0300/0391 | Cost: 0.4461\n",
      "Epoch: 009/010 | Batch 0350/0391 | Cost: 0.4692\n",
      "Epoch: 009/010 | Train: 89.038% |  Loss: 0.328\n",
      "Time elapsed: 5.36 min\n",
      "Epoch: 010/010 | Batch 0000/0391 | Cost: 0.2964\n",
      "Epoch: 010/010 | Batch 0050/0391 | Cost: 0.2739\n",
      "Epoch: 010/010 | Batch 0100/0391 | Cost: 0.3803\n",
      "Epoch: 010/010 | Batch 0150/0391 | Cost: 0.3981\n",
      "Epoch: 010/010 | Batch 0200/0391 | Cost: 0.3877\n",
      "Epoch: 010/010 | Batch 0250/0391 | Cost: 0.4213\n",
      "Epoch: 010/010 | Batch 0300/0391 | Cost: 0.5088\n",
      "Epoch: 010/010 | Batch 0350/0391 | Cost: 0.4598\n",
      "Epoch: 010/010 | Train: 89.668% |  Loss: 0.321\n",
      "Time elapsed: 5.96 min\n",
      "Total Training Time: 5.96 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "\n",
    "\n",
    "def compute_epoch_loss(model, data_loader):\n",
    "    model.eval()\n",
    "    curr_loss, num_examples = 0., 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.to(DEVICE)\n",
    "            targets = targets.to(DEVICE)\n",
    "            logits, probas = model(features)\n",
    "            loss = F.cross_entropy(logits, targets, reduction='sum')\n",
    "            num_examples += targets.size(0)\n",
    "            curr_loss += loss\n",
    "\n",
    "        curr_loss = curr_loss / num_examples\n",
    "        return curr_loss\n",
    "    \n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d | Train: %.3f%% |  Loss: %.3f' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader),\n",
    "              compute_epoch_loss(model, train_loader)))\n",
    "\n",
    "\n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 76.31%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torchvision 0.4.1a0+d94043a\n",
      "numpy       1.16.4\n",
      "torch       1.3.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}