{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.6.1\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# DenseNet-121 CIFAR-10 Image Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The network in this notebook is an implementation of the DenseNet-121 [1] architecture on the MNIST digits dataset (http://yann.lecun.com/exdb/mnist/) to train a handwritten digit classifier.  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following figure illustrates the main concept of DenseNet: within each \"dense\" block, each layer is connected with each previous layer -- the feature maps are concatenated.\n",
    "\n",
    "\n",
    "![](../images/densenet/densenet-fig-2.jpg)\n",
    "\n",
    "Note that this is somewhat related yet very different to ResNets. ResNets have skip connections approx. between every other layer (but don't connect all layers with each other). Also, ResNets skip connections work via addition\n",
    "\n",
    "$$\\mathbf{x}_{\\ell}=H_{\\ell}\\left(\\mathbf{X}_{\\ell-1}\\right)+\\mathbf{X}_{\\ell-1}$$,\n",
    "\n",
    "whereas $H_{\\ell}(\\cdot)$ can be a composite function  of operations such as Batch Normalization (BN), rectified linear units (ReLU), Pooling, or Convolution (Conv).\n",
    "\n",
    "In DenseNets, all the previous feature maps $\\mathbf{X}_{0}, \\dots, \\mathbf{X}_{\\ell}-1$ of a feature map $\\mathbf{X}_{\\ell}$ are concatenated:\n",
    "\n",
    "$$\\mathbf{x}_{\\ell}=H_{\\ell}\\left(\\left[\\mathbf{x}_{0}, \\mathbf{x}_{1}, \\ldots, \\mathbf{x}_{\\ell-1}\\right]\\right).$$\n",
    "\n",
    "Furthermore, in this particular notebook, we are considering the DenseNet-121, which is depicted below:\n",
    "\n",
    "\n",
    "\n",
    "![](../images/densenet/densenet-tab-1-dnet121.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**\n",
    "    \n",
    "- [1] Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4700-4708), http://openaccess.thecvf.com/content_cvpr_2017/html/Huang_Densely_Connected_Convolutional_CVPR_2017_paper.html\n",
    "\n",
    "- [2] http://yann.lecun.com/exdb/mnist/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.dataset import Subset\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.001\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 20\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:0\"\n",
    "GRAYSCALE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CIFAR-10 Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_indices = torch.arange(0, 48000)\n",
    "valid_indices = torch.arange(48000, 50000)\n",
    "\n",
    "\n",
    "train_and_valid = datasets.CIFAR10(root='data', \n",
    "                                   train=True, \n",
    "                                   transform=transforms.ToTensor(),\n",
    "                                   download=True)\n",
    "\n",
    "train_dataset = Subset(train_and_valid, train_indices)\n",
    "valid_dataset = Subset(train_and_valid, valid_indices)\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor(),\n",
    "                                download=False)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          shuffle=True)\n",
    "\n",
    "valid_loader = DataLoader(dataset=valid_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          shuffle=False)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=4,\n",
    "                         shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(DEVICE)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "for epoch in range(2):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([3, 0, 1, 3, 3, 5, 0, 4, 9, 4])\n",
      "tensor([1, 0, 4, 1, 8, 2, 0, 3, 5, 3])\n"
     ]
    }
   ],
   "source": [
    "# Check that shuffling works properly\n",
    "# i.e., label indices should be in random order.\n",
    "# Also, the label order should be different in the second\n",
    "# epoch.\n",
    "\n",
    "for images, labels in train_loader:  \n",
    "    pass\n",
    "print(labels[:10])\n",
    "\n",
    "for images, labels in train_loader:  \n",
    "    pass\n",
    "print(labels[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5, 0, 3, 6, 8, 7, 9, 5, 6, 6])\n",
      "tensor([7, 5, 8, 0, 8, 2, 7, 0, 3, 5])\n"
     ]
    }
   ],
   "source": [
    "# Check that validation set and test sets are diverse\n",
    "# i.e., that they contain all classes\n",
    "\n",
    "for images, labels in valid_loader:  \n",
    "    pass\n",
    "print(labels[:10])\n",
    "\n",
    "for images, labels in test_loader:  \n",
    "    pass\n",
    "print(labels[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "# The following code cell that implements the DenseNet-121 architecture \n",
    "# is a derivative of the code provided at \n",
    "# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py\n",
    "\n",
    "import re\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "\n",
    "\n",
    "\n",
    "def _bn_function_factory(norm, relu, conv):\n",
    "    def bn_function(*inputs):\n",
    "        concated_features = torch.cat(inputs, 1)\n",
    "        bottleneck_output = conv(relu(norm(concated_features)))\n",
    "        return bottleneck_output\n",
    "\n",
    "    return bn_function\n",
    "\n",
    "\n",
    "class _DenseLayer(nn.Sequential):\n",
    "    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):\n",
    "        super(_DenseLayer, self).__init__()\n",
    "        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),\n",
    "        self.add_module('relu1', nn.ReLU(inplace=True)),\n",
    "        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *\n",
    "                                           growth_rate, kernel_size=1, stride=1,\n",
    "                                           bias=False)),\n",
    "        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),\n",
    "        self.add_module('relu2', nn.ReLU(inplace=True)),\n",
    "        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,\n",
    "                                           kernel_size=3, stride=1, padding=1,\n",
    "                                           bias=False)),\n",
    "        self.drop_rate = drop_rate\n",
    "        self.memory_efficient = memory_efficient\n",
    "\n",
    "    def forward(self, *prev_features):\n",
    "        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)\n",
    "        if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):\n",
    "            bottleneck_output = cp.checkpoint(bn_function, *prev_features)\n",
    "        else:\n",
    "            bottleneck_output = bn_function(*prev_features)\n",
    "        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))\n",
    "        if self.drop_rate > 0:\n",
    "            new_features = F.dropout(new_features, p=self.drop_rate,\n",
    "                                     training=self.training)\n",
    "        return new_features\n",
    "\n",
    "\n",
    "class _DenseBlock(nn.Module):\n",
    "    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):\n",
    "        super(_DenseBlock, self).__init__()\n",
    "        for i in range(num_layers):\n",
    "            layer = _DenseLayer(\n",
    "                num_input_features + i * growth_rate,\n",
    "                growth_rate=growth_rate,\n",
    "                bn_size=bn_size,\n",
    "                drop_rate=drop_rate,\n",
    "                memory_efficient=memory_efficient,\n",
    "            )\n",
    "            self.add_module('denselayer%d' % (i + 1), layer)\n",
    "\n",
    "    def forward(self, init_features):\n",
    "        features = [init_features]\n",
    "        for name, layer in self.named_children():\n",
    "            new_features = layer(*features)\n",
    "            features.append(new_features)\n",
    "        return torch.cat(features, 1)\n",
    "\n",
    "\n",
    "class _Transition(nn.Sequential):\n",
    "    def __init__(self, num_input_features, num_output_features):\n",
    "        super(_Transition, self).__init__()\n",
    "        self.add_module('norm', nn.BatchNorm2d(num_input_features))\n",
    "        self.add_module('relu', nn.ReLU(inplace=True))\n",
    "        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,\n",
    "                                          kernel_size=1, stride=1, bias=False))\n",
    "        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "\n",
    "class DenseNet121(nn.Module):\n",
    "    r\"\"\"Densenet-BC model class, based on\n",
    "    `\"Densely Connected Convolutional Networks\" <https://arxiv.org/pdf/1608.06993.pdf>`_\n",
    "\n",
    "    Args:\n",
    "        growth_rate (int) - how many filters to add each layer (`k` in paper)\n",
    "        block_config (list of 4 ints) - how many layers in each pooling block\n",
    "        num_init_featuremaps (int) - the number of filters to learn in the first convolution layer\n",
    "        bn_size (int) - multiplicative factor for number of bottle neck layers\n",
    "          (i.e. bn_size * k features in the bottleneck layer)\n",
    "        drop_rate (float) - dropout rate after each dense layer\n",
    "        num_classes (int) - number of classification classes\n",
    "        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,\n",
    "          but slower. Default: *False*. See `\"paper\" <https://arxiv.org/pdf/1707.06990.pdf>`_\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),\n",
    "                 num_init_featuremaps=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False,\n",
    "                 grayscale=False):\n",
    "\n",
    "        super(DenseNet121, self).__init__()\n",
    "\n",
    "        # First convolution\n",
    "        if grayscale:\n",
    "            in_channels=1\n",
    "        else:\n",
    "            in_channels=3\n",
    "        \n",
    "        self.features = nn.Sequential(OrderedDict([\n",
    "            ('conv0', nn.Conv2d(in_channels=in_channels, out_channels=num_init_featuremaps,\n",
    "                                kernel_size=7, stride=2,\n",
    "                                padding=3, bias=False)), # bias is redundant when using batchnorm\n",
    "            ('norm0', nn.BatchNorm2d(num_features=num_init_featuremaps)),\n",
    "            ('relu0', nn.ReLU(inplace=True)),\n",
    "            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),\n",
    "        ]))\n",
    "\n",
    "        # Each denseblock\n",
    "        num_features = num_init_featuremaps\n",
    "        for i, num_layers in enumerate(block_config):\n",
    "            block = _DenseBlock(\n",
    "                num_layers=num_layers,\n",
    "                num_input_features=num_features,\n",
    "                bn_size=bn_size,\n",
    "                growth_rate=growth_rate,\n",
    "                drop_rate=drop_rate,\n",
    "                memory_efficient=memory_efficient\n",
    "            )\n",
    "            self.features.add_module('denseblock%d' % (i + 1), block)\n",
    "            num_features = num_features + num_layers * growth_rate\n",
    "            if i != len(block_config) - 1:\n",
    "                trans = _Transition(num_input_features=num_features,\n",
    "                                    num_output_features=num_features // 2)\n",
    "                self.features.add_module('transition%d' % (i + 1), trans)\n",
    "                num_features = num_features // 2\n",
    "\n",
    "        # Final batch norm\n",
    "        self.features.add_module('norm5', nn.BatchNorm2d(num_features))\n",
    "\n",
    "        # Linear layer\n",
    "        self.classifier = nn.Linear(num_features, num_classes)\n",
    "\n",
    "        # Official init from torch repo.\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight)\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.features(x)\n",
    "        out = F.relu(features, inplace=True)\n",
    "        out = F.adaptive_avg_pool2d(out, (1, 1))\n",
    "        out = torch.flatten(out, 1)\n",
    "        logits = self.classifier(out)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = DenseNet121(num_classes=NUM_CLASSES, grayscale=GRAYSCALE)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_acc(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    model.eval()\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        assert predicted_labels.size() == targets.size()\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/020 | Batch 000/375 | Cost: 2.4002\n",
      "Epoch: 001/020 | Batch 150/375 | Cost: 1.3511\n",
      "Epoch: 001/020 | Batch 300/375 | Cost: 1.3326\n",
      "Epoch: 001/020\n",
      "Train ACC: 61.52 | Validation ACC: 60.25\n",
      "Time elapsed: 0.99 min\n",
      "Epoch: 002/020 | Batch 000/375 | Cost: 1.1005\n",
      "Epoch: 002/020 | Batch 150/375 | Cost: 1.0096\n",
      "Epoch: 002/020 | Batch 300/375 | Cost: 1.2628\n",
      "Epoch: 002/020\n",
      "Train ACC: 66.19 | Validation ACC: 63.55\n",
      "Time elapsed: 1.90 min\n",
      "Epoch: 003/020 | Batch 000/375 | Cost: 0.6471\n",
      "Epoch: 003/020 | Batch 150/375 | Cost: 0.8107\n",
      "Epoch: 003/020 | Batch 300/375 | Cost: 0.8848\n",
      "Epoch: 003/020\n",
      "Train ACC: 73.32 | Validation ACC: 67.40\n",
      "Time elapsed: 2.79 min\n",
      "Epoch: 004/020 | Batch 000/375 | Cost: 0.6951\n",
      "Epoch: 004/020 | Batch 150/375 | Cost: 0.6852\n",
      "Epoch: 004/020 | Batch 300/375 | Cost: 0.8775\n",
      "Epoch: 004/020\n",
      "Train ACC: 76.38 | Validation ACC: 67.90\n",
      "Time elapsed: 3.65 min\n",
      "Epoch: 005/020 | Batch 000/375 | Cost: 0.4929\n",
      "Epoch: 005/020 | Batch 150/375 | Cost: 0.5942\n",
      "Epoch: 005/020 | Batch 300/375 | Cost: 0.6016\n",
      "Epoch: 005/020\n",
      "Train ACC: 76.70 | Validation ACC: 67.95\n",
      "Time elapsed: 4.53 min\n",
      "Epoch: 006/020 | Batch 000/375 | Cost: 0.5888\n",
      "Epoch: 006/020 | Batch 150/375 | Cost: 0.5500\n",
      "Epoch: 006/020 | Batch 300/375 | Cost: 0.4939\n",
      "Epoch: 006/020\n",
      "Train ACC: 84.21 | Validation ACC: 74.85\n",
      "Time elapsed: 5.51 min\n",
      "Epoch: 007/020 | Batch 000/375 | Cost: 0.3982\n",
      "Epoch: 007/020 | Batch 150/375 | Cost: 0.4272\n",
      "Epoch: 007/020 | Batch 300/375 | Cost: 0.3908\n",
      "Epoch: 007/020\n",
      "Train ACC: 87.29 | Validation ACC: 74.65\n",
      "Time elapsed: 6.43 min\n",
      "Epoch: 008/020 | Batch 000/375 | Cost: 0.4332\n",
      "Epoch: 008/020 | Batch 150/375 | Cost: 0.2335\n",
      "Epoch: 008/020 | Batch 300/375 | Cost: 0.3678\n",
      "Epoch: 008/020\n",
      "Train ACC: 84.10 | Validation ACC: 71.00\n",
      "Time elapsed: 7.35 min\n",
      "Epoch: 009/020 | Batch 000/375 | Cost: 0.2343\n",
      "Epoch: 009/020 | Batch 150/375 | Cost: 0.2704\n",
      "Epoch: 009/020 | Batch 300/375 | Cost: 0.3429\n",
      "Epoch: 009/020\n",
      "Train ACC: 88.51 | Validation ACC: 74.35\n",
      "Time elapsed: 8.26 min\n",
      "Epoch: 010/020 | Batch 000/375 | Cost: 0.1757\n",
      "Epoch: 010/020 | Batch 150/375 | Cost: 0.1748\n",
      "Epoch: 010/020 | Batch 300/375 | Cost: 0.4755\n",
      "Epoch: 010/020\n",
      "Train ACC: 92.77 | Validation ACC: 74.20\n",
      "Time elapsed: 9.27 min\n",
      "Epoch: 011/020 | Batch 000/375 | Cost: 0.2347\n",
      "Epoch: 011/020 | Batch 150/375 | Cost: 0.1618\n",
      "Epoch: 011/020 | Batch 300/375 | Cost: 0.3075\n",
      "Epoch: 011/020\n",
      "Train ACC: 90.78 | Validation ACC: 74.10\n",
      "Time elapsed: 10.24 min\n",
      "Epoch: 012/020 | Batch 000/375 | Cost: 0.1233\n",
      "Epoch: 012/020 | Batch 150/375 | Cost: 0.1385\n",
      "Epoch: 012/020 | Batch 300/375 | Cost: 0.2406\n",
      "Epoch: 012/020\n",
      "Train ACC: 93.26 | Validation ACC: 76.00\n",
      "Time elapsed: 11.15 min\n",
      "Epoch: 013/020 | Batch 000/375 | Cost: 0.0757\n",
      "Epoch: 013/020 | Batch 150/375 | Cost: 0.0658\n",
      "Epoch: 013/020 | Batch 300/375 | Cost: 0.2070\n",
      "Epoch: 013/020\n",
      "Train ACC: 91.64 | Validation ACC: 75.80\n",
      "Time elapsed: 12.13 min\n",
      "Epoch: 014/020 | Batch 000/375 | Cost: 0.1285\n",
      "Epoch: 014/020 | Batch 150/375 | Cost: 0.1468\n",
      "Epoch: 014/020 | Batch 300/375 | Cost: 0.1070\n",
      "Epoch: 014/020\n",
      "Train ACC: 95.18 | Validation ACC: 77.80\n",
      "Time elapsed: 13.10 min\n",
      "Epoch: 015/020 | Batch 000/375 | Cost: 0.1204\n",
      "Epoch: 015/020 | Batch 150/375 | Cost: 0.0461\n",
      "Epoch: 015/020 | Batch 300/375 | Cost: 0.1005\n",
      "Epoch: 015/020\n",
      "Train ACC: 93.26 | Validation ACC: 74.80\n",
      "Time elapsed: 14.10 min\n",
      "Epoch: 016/020 | Batch 000/375 | Cost: 0.0591\n",
      "Epoch: 016/020 | Batch 150/375 | Cost: 0.1024\n",
      "Epoch: 016/020 | Batch 300/375 | Cost: 0.0679\n",
      "Epoch: 016/020\n",
      "Train ACC: 95.66 | Validation ACC: 76.80\n",
      "Time elapsed: 15.02 min\n",
      "Epoch: 017/020 | Batch 000/375 | Cost: 0.0359\n",
      "Epoch: 017/020 | Batch 150/375 | Cost: 0.0615\n",
      "Epoch: 017/020 | Batch 300/375 | Cost: 0.0725\n",
      "Epoch: 017/020\n",
      "Train ACC: 92.06 | Validation ACC: 75.85\n",
      "Time elapsed: 16.01 min\n",
      "Epoch: 018/020 | Batch 000/375 | Cost: 0.0831\n",
      "Epoch: 018/020 | Batch 150/375 | Cost: 0.1023\n",
      "Epoch: 018/020 | Batch 300/375 | Cost: 0.0666\n",
      "Epoch: 018/020\n",
      "Train ACC: 94.32 | Validation ACC: 75.10\n",
      "Time elapsed: 17.01 min\n",
      "Epoch: 019/020 | Batch 000/375 | Cost: 0.0499\n",
      "Epoch: 019/020 | Batch 150/375 | Cost: 0.1068\n",
      "Epoch: 019/020 | Batch 300/375 | Cost: 0.0657\n",
      "Epoch: 019/020\n",
      "Train ACC: 96.89 | Validation ACC: 77.75\n",
      "Time elapsed: 17.99 min\n",
      "Epoch: 020/020 | Batch 000/375 | Cost: 0.0751\n",
      "Epoch: 020/020 | Batch 150/375 | Cost: 0.0423\n",
      "Epoch: 020/020 | Batch 300/375 | Cost: 0.0398\n",
      "Epoch: 020/020\n",
      "Train ACC: 94.66 | Validation ACC: 76.70\n",
      "Time elapsed: 18.90 min\n",
      "Total Training Time: 18.90 min\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "cost_list = []\n",
    "train_acc_list, valid_acc_list = [], []\n",
    "\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        #################################################\n",
    "        ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n",
    "        ################################################\n",
    "        cost_list.append(cost.item())\n",
    "        if not batch_idx % 150:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f}')\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        \n",
    "        train_acc = compute_acc(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n",
    "        \n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n",
    "              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n",
    "        \n",
    "        train_acc_list.append(train_acc)\n",
    "        valid_acc_list.append(valid_acc)\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cost_list, label='Minibatch cost')\n",
    "plt.plot(np.convolve(cost_list, \n",
    "                     np.ones(200,)/200, mode='valid'), \n",
    "         label='Running average')\n",
    "\n",
    "plt.ylabel('Cross Entropy')\n",
    "plt.xlabel('Iteration')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n",
    "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n",
    "\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation ACC: 76.70%\n",
      "Test ACC: 74.97%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False):\n",
    "    test_acc = compute_acc(model=model,\n",
    "                           data_loader=test_loader,\n",
    "                           device=DEVICE)\n",
    "    \n",
    "    valid_acc = compute_acc(model=model,\n",
    "                            data_loader=valid_loader,\n",
    "                            device=DEVICE)\n",
    "    \n",
    "\n",
    "print(f'Validation ACC: {valid_acc:.2f}%')\n",
    "print(f'Test ACC: {test_acc:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       1.1.0\n",
      "matplotlib  3.1.0\n",
      "pandas      0.24.2\n",
      "torchvision 0.3.0\n",
      "numpy       1.16.4\n",
      "re          2.2.1\n",
      "PIL.Image   6.0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}