{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.6.1\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Increase the Batch Size (AlexNet CIFAR-10 Classifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a notebook experimenting with increasing the batch size during training, which is inspired by the paper\n",
    "\n",
    "- Smith, S. L., Kindermans, P. J., Ying, C., & Le, Q. V. (2017). Don't decay the learning rate, increase the batch size. arXiv preprint arXiv:1711.00489.\n",
    "\n",
    "To summarize the main points of the paper:\n",
    "\n",
    "- Stochastic gradient descent adds noise to the optimization problem; during the early training epochs, this noise helps with exploring the loss landscape, and in general, it helps with escaping sharp minima which are known to be bad for generalization.\n",
    "\n",
    "- However, during the course of the training process, one wants to decay the learning rate gradually (like simulated annealing) for fine-tuning, i.e., to help with convergence\n",
    "\n",
    "- Due to the relationship between learning rate, batch size, and momentum, one can also just increase the batch size instead of decreasing the learning rate to reduce the noise. This way, more training examples can be used in each update and fewer steps (parameter updates) overall may be required to converge.\n",
    "\n",
    "The relationship between learning rate and batch size is as follows:\n",
    "\n",
    "\n",
    "$$g=\\epsilon\\left(\\frac{N}{B}-1\\right),$$\n",
    "\n",
    "where $\\epsilon$ is the learning rate, $B$ is the batch size, and $N$ is the number of training examples\n",
    "\n",
    "Or, with added momentum term, this becomes:\n",
    "\n",
    "$$\\begin{aligned} g &=\\frac{\\epsilon}{1-m}\\left(\\frac{N}{B}-1\\right) \\\\ & \\approx \\frac{\\epsilon N}{B(1-m)} \\end{aligned}.$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, the CIFAR-10 dataset is used for training a classic AlexNet network [1] for classification:\n",
    "    \n",
    "- [1] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. \"[Imagenet classification with deep convolutional neural networks.](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)\" In Advances in Neural Information Processing Systems, pp. 1097-1105. 2012.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.dataset import Subset\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.0001\n",
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 40\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_indices = torch.arange(0, 48000)\n",
    "valid_indices = torch.arange(48000, 50000)\n",
    "\n",
    "\n",
    "train_transform = transforms.Compose([transforms.Resize((70, 70)),\n",
    "                                      transforms.RandomCrop((64, 64)),\n",
    "                                      transforms.ToTensor()])\n",
    "\n",
    "test_transform = transforms.Compose([transforms.Resize((70, 70)),\n",
    "                                     transforms.CenterCrop((64, 64)),\n",
    "                                     transforms.ToTensor()])\n",
    "\n",
    "train_and_valid = datasets.CIFAR10(root='data', \n",
    "                                   train=True, \n",
    "                                   transform=train_transform,\n",
    "                                   download=True)\n",
    "\n",
    "train_dataset = Subset(train_and_valid, train_indices)\n",
    "valid_dataset = Subset(train_and_valid, valid_indices)\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=test_transform,\n",
    "                                download=False)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          shuffle=True)\n",
    "\n",
    "valid_loader = DataLoader(dataset=valid_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          shuffle=False)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=4,\n",
    "                         shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Set:\n",
      "\n",
      "Image batch dimensions: torch.Size([256, 3, 64, 64])\n",
      "Image label dimensions: torch.Size([256])\n",
      "\n",
      "Validation Set:\n",
      "Image batch dimensions: torch.Size([256, 3, 64, 64])\n",
      "Image label dimensions: torch.Size([256])\n",
      "\n",
      "Testing Set:\n",
      "Image batch dimensions: torch.Size([256, 3, 64, 64])\n",
      "Image label dimensions: torch.Size([256])\n"
     ]
    }
   ],
   "source": [
    "# Checking the dataset\n",
    "print('Training Set:\\n')\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    break\n",
    "    \n",
    "# Checking the dataset\n",
    "print('\\nValidation Set:')\n",
    "for images, labels in valid_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    break\n",
    "\n",
    "# Checking the dataset\n",
    "print('\\nTesting Set:')\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "class AlexNet(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        )\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(256 * 6 * 6, 4096),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(4096, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), 256 * 6 * 6)\n",
    "        logits = self.classifier(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_acc(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    model.eval()\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        assert predicted_labels.size() == targets.size()\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "# Training 1: Constant Batch Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = AlexNet(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/040 | Batch 000/188 | Cost: 2.3029\n",
      "Epoch: 001/040 | Batch 150/188 | Cost: 1.7122\n",
      "Epoch: 001/040\n",
      "Train ACC: 32.64 | Validation ACC: 31.45\n",
      "Time elapsed: 0.21 min\n",
      "Epoch: 002/040 | Batch 000/188 | Cost: 1.7477\n",
      "Epoch: 002/040 | Batch 150/188 | Cost: 1.6831\n",
      "Epoch: 002/040\n",
      "Train ACC: 44.09 | Validation ACC: 43.40\n",
      "Time elapsed: 0.42 min\n",
      "Epoch: 003/040 | Batch 000/188 | Cost: 1.5064\n",
      "Epoch: 003/040 | Batch 150/188 | Cost: 1.4504\n",
      "Epoch: 003/040\n",
      "Train ACC: 51.37 | Validation ACC: 51.00\n",
      "Time elapsed: 0.63 min\n",
      "Epoch: 004/040 | Batch 000/188 | Cost: 1.4089\n",
      "Epoch: 004/040 | Batch 150/188 | Cost: 1.2423\n",
      "Epoch: 004/040\n",
      "Train ACC: 58.56 | Validation ACC: 57.75\n",
      "Time elapsed: 0.84 min\n",
      "Epoch: 005/040 | Batch 000/188 | Cost: 1.0506\n",
      "Epoch: 005/040 | Batch 150/188 | Cost: 1.1601\n",
      "Epoch: 005/040\n",
      "Train ACC: 59.14 | Validation ACC: 58.55\n",
      "Time elapsed: 1.06 min\n",
      "Epoch: 006/040 | Batch 000/188 | Cost: 1.0774\n",
      "Epoch: 006/040 | Batch 150/188 | Cost: 1.1084\n",
      "Epoch: 006/040\n",
      "Train ACC: 62.25 | Validation ACC: 60.95\n",
      "Time elapsed: 1.27 min\n",
      "Epoch: 007/040 | Batch 000/188 | Cost: 1.0387\n",
      "Epoch: 007/040 | Batch 150/188 | Cost: 1.0570\n",
      "Epoch: 007/040\n",
      "Train ACC: 65.41 | Validation ACC: 63.20\n",
      "Time elapsed: 1.49 min\n",
      "Epoch: 008/040 | Batch 000/188 | Cost: 1.0650\n",
      "Epoch: 008/040 | Batch 150/188 | Cost: 0.9280\n",
      "Epoch: 008/040\n",
      "Train ACC: 64.37 | Validation ACC: 63.95\n",
      "Time elapsed: 1.70 min\n",
      "Epoch: 009/040 | Batch 000/188 | Cost: 1.0195\n",
      "Epoch: 009/040 | Batch 150/188 | Cost: 0.7793\n",
      "Epoch: 009/040\n",
      "Train ACC: 69.71 | Validation ACC: 67.30\n",
      "Time elapsed: 1.91 min\n",
      "Epoch: 010/040 | Batch 000/188 | Cost: 0.7986\n",
      "Epoch: 010/040 | Batch 150/188 | Cost: 0.7988\n",
      "Epoch: 010/040\n",
      "Train ACC: 69.41 | Validation ACC: 65.45\n",
      "Time elapsed: 2.12 min\n",
      "Epoch: 011/040 | Batch 000/188 | Cost: 0.8688\n",
      "Epoch: 011/040 | Batch 150/188 | Cost: 0.7943\n",
      "Epoch: 011/040\n",
      "Train ACC: 70.95 | Validation ACC: 67.35\n",
      "Time elapsed: 2.34 min\n",
      "Epoch: 012/040 | Batch 000/188 | Cost: 0.7696\n",
      "Epoch: 012/040 | Batch 150/188 | Cost: 0.8943\n",
      "Epoch: 012/040\n",
      "Train ACC: 75.26 | Validation ACC: 67.95\n",
      "Time elapsed: 2.55 min\n",
      "Epoch: 013/040 | Batch 000/188 | Cost: 0.6622\n",
      "Epoch: 013/040 | Batch 150/188 | Cost: 0.7226\n",
      "Epoch: 013/040\n",
      "Train ACC: 77.99 | Validation ACC: 72.45\n",
      "Time elapsed: 2.76 min\n",
      "Epoch: 014/040 | Batch 000/188 | Cost: 0.6180\n",
      "Epoch: 014/040 | Batch 150/188 | Cost: 0.6502\n",
      "Epoch: 014/040\n",
      "Train ACC: 77.82 | Validation ACC: 70.85\n",
      "Time elapsed: 2.97 min\n",
      "Epoch: 015/040 | Batch 000/188 | Cost: 0.6359\n",
      "Epoch: 015/040 | Batch 150/188 | Cost: 0.8206\n",
      "Epoch: 015/040\n",
      "Train ACC: 79.41 | Validation ACC: 71.35\n",
      "Time elapsed: 3.18 min\n",
      "Epoch: 016/040 | Batch 000/188 | Cost: 0.6694\n",
      "Epoch: 016/040 | Batch 150/188 | Cost: 0.5700\n",
      "Epoch: 016/040\n",
      "Train ACC: 79.59 | Validation ACC: 70.75\n",
      "Time elapsed: 3.39 min\n",
      "Epoch: 017/040 | Batch 000/188 | Cost: 0.6395\n",
      "Epoch: 017/040 | Batch 150/188 | Cost: 0.5564\n",
      "Epoch: 017/040\n",
      "Train ACC: 82.24 | Validation ACC: 72.75\n",
      "Time elapsed: 3.61 min\n",
      "Epoch: 018/040 | Batch 000/188 | Cost: 0.5724\n",
      "Epoch: 018/040 | Batch 150/188 | Cost: 0.4650\n",
      "Epoch: 018/040\n",
      "Train ACC: 83.02 | Validation ACC: 71.55\n",
      "Time elapsed: 3.82 min\n",
      "Epoch: 019/040 | Batch 000/188 | Cost: 0.4790\n",
      "Epoch: 019/040 | Batch 150/188 | Cost: 0.4548\n",
      "Epoch: 019/040\n",
      "Train ACC: 84.87 | Validation ACC: 73.35\n",
      "Time elapsed: 4.03 min\n",
      "Epoch: 020/040 | Batch 000/188 | Cost: 0.4254\n",
      "Epoch: 020/040 | Batch 150/188 | Cost: 0.4183\n",
      "Epoch: 020/040\n",
      "Train ACC: 85.73 | Validation ACC: 72.55\n",
      "Time elapsed: 4.25 min\n",
      "Epoch: 021/040 | Batch 000/188 | Cost: 0.5254\n",
      "Epoch: 021/040 | Batch 150/188 | Cost: 0.4328\n",
      "Epoch: 021/040\n",
      "Train ACC: 85.22 | Validation ACC: 72.25\n",
      "Time elapsed: 4.46 min\n",
      "Epoch: 022/040 | Batch 000/188 | Cost: 0.4798\n",
      "Epoch: 022/040 | Batch 150/188 | Cost: 0.4075\n",
      "Epoch: 022/040\n",
      "Train ACC: 88.92 | Validation ACC: 73.90\n",
      "Time elapsed: 4.68 min\n",
      "Epoch: 023/040 | Batch 000/188 | Cost: 0.2946\n",
      "Epoch: 023/040 | Batch 150/188 | Cost: 0.3808\n",
      "Epoch: 023/040\n",
      "Train ACC: 89.33 | Validation ACC: 73.80\n",
      "Time elapsed: 4.88 min\n",
      "Epoch: 024/040 | Batch 000/188 | Cost: 0.2511\n",
      "Epoch: 024/040 | Batch 150/188 | Cost: 0.3758\n",
      "Epoch: 024/040\n",
      "Train ACC: 89.94 | Validation ACC: 74.20\n",
      "Time elapsed: 5.10 min\n",
      "Epoch: 025/040 | Batch 000/188 | Cost: 0.2348\n",
      "Epoch: 025/040 | Batch 150/188 | Cost: 0.4043\n",
      "Epoch: 025/040\n",
      "Train ACC: 90.37 | Validation ACC: 74.10\n",
      "Time elapsed: 5.31 min\n",
      "Epoch: 026/040 | Batch 000/188 | Cost: 0.2663\n",
      "Epoch: 026/040 | Batch 150/188 | Cost: 0.2651\n",
      "Epoch: 026/040\n",
      "Train ACC: 91.69 | Validation ACC: 72.55\n",
      "Time elapsed: 5.52 min\n",
      "Epoch: 027/040 | Batch 000/188 | Cost: 0.2907\n",
      "Epoch: 027/040 | Batch 150/188 | Cost: 0.2981\n",
      "Epoch: 027/040\n",
      "Train ACC: 92.33 | Validation ACC: 73.10\n",
      "Time elapsed: 5.74 min\n",
      "Epoch: 028/040 | Batch 000/188 | Cost: 0.2318\n",
      "Epoch: 028/040 | Batch 150/188 | Cost: 0.2904\n",
      "Epoch: 028/040\n",
      "Train ACC: 91.91 | Validation ACC: 72.10\n",
      "Time elapsed: 5.95 min\n",
      "Epoch: 029/040 | Batch 000/188 | Cost: 0.1949\n",
      "Epoch: 029/040 | Batch 150/188 | Cost: 0.1721\n",
      "Epoch: 029/040\n",
      "Train ACC: 93.64 | Validation ACC: 73.15\n",
      "Time elapsed: 6.16 min\n",
      "Epoch: 030/040 | Batch 000/188 | Cost: 0.1504\n",
      "Epoch: 030/040 | Batch 150/188 | Cost: 0.2986\n",
      "Epoch: 030/040\n",
      "Train ACC: 94.12 | Validation ACC: 73.50\n",
      "Time elapsed: 6.37 min\n",
      "Epoch: 031/040 | Batch 000/188 | Cost: 0.1666\n",
      "Epoch: 031/040 | Batch 150/188 | Cost: 0.1380\n",
      "Epoch: 031/040\n",
      "Train ACC: 92.82 | Validation ACC: 72.90\n",
      "Time elapsed: 6.59 min\n",
      "Epoch: 032/040 | Batch 000/188 | Cost: 0.2123\n",
      "Epoch: 032/040 | Batch 150/188 | Cost: 0.2601\n",
      "Epoch: 032/040\n",
      "Train ACC: 94.51 | Validation ACC: 72.80\n",
      "Time elapsed: 6.80 min\n",
      "Epoch: 033/040 | Batch 000/188 | Cost: 0.1769\n",
      "Epoch: 033/040 | Batch 150/188 | Cost: 0.1912\n",
      "Epoch: 033/040\n",
      "Train ACC: 94.81 | Validation ACC: 72.15\n",
      "Time elapsed: 7.01 min\n",
      "Epoch: 034/040 | Batch 000/188 | Cost: 0.2098\n",
      "Epoch: 034/040 | Batch 150/188 | Cost: 0.2454\n",
      "Epoch: 034/040\n",
      "Train ACC: 95.87 | Validation ACC: 73.25\n",
      "Time elapsed: 7.22 min\n",
      "Epoch: 035/040 | Batch 000/188 | Cost: 0.1446\n",
      "Epoch: 035/040 | Batch 150/188 | Cost: 0.1103\n",
      "Epoch: 035/040\n",
      "Train ACC: 94.59 | Validation ACC: 72.05\n",
      "Time elapsed: 7.43 min\n",
      "Epoch: 036/040 | Batch 000/188 | Cost: 0.1118\n",
      "Epoch: 036/040 | Batch 150/188 | Cost: 0.1148\n",
      "Epoch: 036/040\n",
      "Train ACC: 96.36 | Validation ACC: 74.30\n",
      "Time elapsed: 7.65 min\n",
      "Epoch: 037/040 | Batch 000/188 | Cost: 0.1138\n",
      "Epoch: 037/040 | Batch 150/188 | Cost: 0.2091\n",
      "Epoch: 037/040\n",
      "Train ACC: 95.63 | Validation ACC: 73.35\n",
      "Time elapsed: 7.85 min\n",
      "Epoch: 038/040 | Batch 000/188 | Cost: 0.1720\n",
      "Epoch: 038/040 | Batch 150/188 | Cost: 0.0837\n",
      "Epoch: 038/040\n",
      "Train ACC: 95.77 | Validation ACC: 74.10\n",
      "Time elapsed: 8.07 min\n",
      "Epoch: 039/040 | Batch 000/188 | Cost: 0.1058\n",
      "Epoch: 039/040 | Batch 150/188 | Cost: 0.0731\n",
      "Epoch: 039/040\n",
      "Train ACC: 97.03 | Validation ACC: 73.55\n",
      "Time elapsed: 8.28 min\n",
      "Epoch: 040/040 | Batch 000/188 | Cost: 0.1014\n",
      "Epoch: 040/040 | Batch 150/188 | Cost: 0.1611\n",
      "Epoch: 040/040\n",
      "Train ACC: 96.68 | Validation ACC: 72.30\n",
      "Time elapsed: 8.49 min\n",
      "Total Training Time: 8.49 min\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "cost_list = []\n",
    "train_acc_list, valid_acc_list = [], []\n",
    "\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        #################################################\n",
    "        ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n",
    "        ################################################\n",
    "        cost_list.append(cost.item())\n",
    "        if not batch_idx % 150:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f}')\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        \n",
    "        train_acc = compute_acc(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n",
    "        \n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n",
    "              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n",
    "        \n",
    "        train_acc_list.append(train_acc)\n",
    "        valid_acc_list.append(valid_acc)\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cost_list, label='Minibatch cost')\n",
    "plt.plot(np.convolve(cost_list, \n",
    "                     np.ones(200,)/200, mode='valid'), \n",
    "         label='Running average')\n",
    "\n",
    "plt.ylabel('Cross Entropy')\n",
    "plt.xlabel('Iteration')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n",
    "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n",
    "\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation ACC: 71.60%\n",
      "Test ACC: 72.37%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False):\n",
    "    test_acc = compute_acc(model=model,\n",
    "                           data_loader=test_loader,\n",
    "                           device=DEVICE)\n",
    "    \n",
    "    valid_acc = compute_acc(model=model,\n",
    "                            data_loader=valid_loader,\n",
    "                            device=DEVICE)\n",
    "    \n",
    "\n",
    "print(f'Validation ACC: {valid_acc:.2f}%')\n",
    "print(f'Test ACC: {test_acc:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training 2: Increasing Batch Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = AlexNet(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_sizes = np.arange(256, 5121, 512)\n",
    "batch_size_index = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/040 | Batch 000/188 | Cost: 2.3029 | Batchsize: 256\n",
      "Epoch: 001/040 | Batch 150/188 | Cost: 1.7115 | Batchsize: 256\n",
      "Epoch: 001/040\n",
      "Train ACC: 33.36 | Validation ACC: 33.00\n",
      "Time elapsed: 0.21 min\n",
      "Epoch: 002/040 | Batch 000/188 | Cost: 1.7286 | Batchsize: 256\n",
      "Epoch: 002/040 | Batch 150/188 | Cost: 1.6143 | Batchsize: 256\n",
      "Epoch: 002/040\n",
      "Train ACC: 44.84 | Validation ACC: 45.55\n",
      "Time elapsed: 0.42 min\n",
      "Epoch: 003/040 | Batch 000/188 | Cost: 1.5018 | Batchsize: 256\n",
      "Epoch: 003/040 | Batch 150/188 | Cost: 1.4893 | Batchsize: 256\n",
      "Epoch: 003/040\n",
      "Train ACC: 50.73 | Validation ACC: 50.55\n",
      "Time elapsed: 0.64 min\n",
      "Epoch: 004/040 | Batch 000/188 | Cost: 1.4247 | Batchsize: 256\n",
      "Epoch: 004/040 | Batch 150/188 | Cost: 1.2653 | Batchsize: 256\n",
      "Epoch: 004/040\n",
      "Train ACC: 56.70 | Validation ACC: 57.35\n",
      "Time elapsed: 0.85 min\n",
      "Epoch: 005/040 | Batch 000/188 | Cost: 1.0885 | Batchsize: 256\n",
      "Epoch: 005/040 | Batch 150/188 | Cost: 1.1472 | Batchsize: 256\n",
      "Epoch: 005/040\n",
      "Train ACC: 60.30 | Validation ACC: 58.45\n",
      "Time elapsed: 1.06 min\n",
      "Epoch: 006/040 | Batch 000/188 | Cost: 1.0394 | Batchsize: 256\n",
      "Epoch: 006/040 | Batch 150/188 | Cost: 1.0907 | Batchsize: 256\n",
      "Epoch: 006/040\n",
      "Train ACC: 62.62 | Validation ACC: 60.85\n",
      "Time elapsed: 1.27 min\n",
      "Epoch: 007/040 | Batch 000/188 | Cost: 1.0348 | Batchsize: 256\n",
      "Epoch: 007/040 | Batch 150/188 | Cost: 1.0401 | Batchsize: 256\n",
      "Epoch: 007/040\n",
      "Train ACC: 66.19 | Validation ACC: 65.60\n",
      "Time elapsed: 1.48 min\n",
      "Epoch: 008/040 | Batch 000/188 | Cost: 1.0627 | Batchsize: 256\n",
      "Epoch: 008/040 | Batch 150/188 | Cost: 0.9297 | Batchsize: 256\n",
      "Epoch: 008/040\n",
      "Train ACC: 64.11 | Validation ACC: 63.90\n",
      "Time elapsed: 1.70 min\n",
      "Epoch: 009/040 | Batch 000/188 | Cost: 1.0361 | Batchsize: 256\n",
      "Epoch: 009/040 | Batch 150/188 | Cost: 0.8127 | Batchsize: 256\n",
      "Epoch: 009/040\n",
      "Train ACC: 69.89 | Validation ACC: 65.45\n",
      "Time elapsed: 1.91 min\n",
      "Epoch: 010/040 | Batch 000/188 | Cost: 0.7913 | Batchsize: 256\n",
      "Epoch: 010/040 | Batch 150/188 | Cost: 0.7620 | Batchsize: 256\n",
      "Epoch: 010/040\n",
      "Train ACC: 69.22 | Validation ACC: 66.50\n",
      "Time elapsed: 2.12 min\n",
      "Epoch: 011/040 | Batch 000/188 | Cost: 0.8304 | Batchsize: 256\n",
      "Epoch: 011/040 | Batch 150/188 | Cost: 0.8406 | Batchsize: 256\n",
      "Epoch: 011/040\n",
      "Train ACC: 71.92 | Validation ACC: 68.50\n",
      "Time elapsed: 2.33 min\n",
      "Epoch: 012/040 | Batch 000/188 | Cost: 0.6939 | Batchsize: 256\n",
      "Epoch: 012/040 | Batch 150/188 | Cost: 0.9586 | Batchsize: 256\n",
      "Epoch: 012/040\n",
      "Train ACC: 73.86 | Validation ACC: 67.45\n",
      "Time elapsed: 2.54 min\n",
      "Epoch: 013/040 | Batch 000/188 | Cost: 0.7050 | Batchsize: 256\n",
      "Epoch: 013/040 | Batch 150/188 | Cost: 0.6281 | Batchsize: 256\n",
      "Epoch: 013/040\n",
      "Train ACC: 77.54 | Validation ACC: 70.90\n",
      "Time elapsed: 2.76 min\n",
      "Epoch: 014/040 | Batch 000/188 | Cost: 0.6453 | Batchsize: 256\n",
      "Epoch: 014/040 | Batch 150/188 | Cost: 0.6312 | Batchsize: 256\n",
      "Epoch: 014/040\n",
      "Train ACC: 76.89 | Validation ACC: 69.80\n",
      "Time elapsed: 2.97 min\n",
      "Epoch: 015/040 | Batch 000/188 | Cost: 0.6457 | Batchsize: 256\n",
      "Epoch: 015/040 | Batch 150/188 | Cost: 0.7908 | Batchsize: 256\n",
      "Epoch: 015/040\n",
      "Train ACC: 78.62 | Validation ACC: 71.50\n",
      "Time elapsed: 3.18 min\n",
      "Epoch: 016/040 | Batch 000/188 | Cost: 0.7273 | Batchsize: 256\n",
      "Epoch: 016/040 | Batch 150/188 | Cost: 0.5583 | Batchsize: 256\n",
      "Epoch: 016/040\n",
      "Train ACC: 80.89 | Validation ACC: 70.75\n",
      "Time elapsed: 3.39 min\n",
      "Epoch: 017/040 | Batch 000/188 | Cost: 0.5611 | Batchsize: 256\n",
      "Epoch: 017/040 | Batch 150/188 | Cost: 0.5131 | Batchsize: 256\n",
      "Epoch: 017/040\n",
      "Train ACC: 83.01 | Validation ACC: 71.25\n",
      "Time elapsed: 3.60 min\n",
      "Epoch: 018/040 | Batch 000/188 | Cost: 0.5365 | Batchsize: 256\n",
      "Epoch: 018/040 | Batch 150/188 | Cost: 0.4436 | Batchsize: 256\n",
      "Epoch: 018/040\n",
      "Train ACC: 81.85 | Validation ACC: 71.55\n",
      "Time elapsed: 3.81 min\n",
      "Epoch: 019/040 | Batch 000/188 | Cost: 0.4803 | Batchsize: 256\n",
      "Epoch: 019/040 | Batch 150/188 | Cost: 0.4372 | Batchsize: 256\n",
      "Epoch: 019/040\n",
      "Train ACC: 84.76 | Validation ACC: 73.60\n",
      "Time elapsed: 4.03 min\n",
      "Epoch: 020/040 | Batch 000/188 | Cost: 0.4905 | Batchsize: 256\n",
      "Epoch: 020/040 | Batch 150/188 | Cost: 0.4021 | Batchsize: 256\n",
      "Epoch: 020/040\n",
      "Train ACC: 85.10 | Validation ACC: 71.25\n",
      "Time elapsed: 4.24 min\n",
      "Epoch: 021/040 | Batch 000/188 | Cost: 0.4978 | Batchsize: 256\n",
      "Epoch: 021/040 | Batch 150/188 | Cost: 0.4828 | Batchsize: 256\n",
      "Epoch: 021/040\n",
      "Train ACC: 87.19 | Validation ACC: 72.75\n",
      "Time elapsed: 4.45 min\n",
      "Epoch: 022/040 | Batch 000/188 | Cost: 0.3978 | Batchsize: 256\n",
      "Epoch: 022/040 | Batch 150/188 | Cost: 0.4588 | Batchsize: 256\n",
      "Epoch: 022/040\n",
      "Train ACC: 87.93 | Validation ACC: 72.20\n",
      "Time elapsed: 4.66 min\n",
      "Epoch: 023/040 | Batch 000/188 | Cost: 0.3476 | Batchsize: 256\n",
      "Epoch: 023/040 | Batch 150/188 | Cost: 0.3774 | Batchsize: 256\n",
      "Epoch: 023/040\n",
      "Train ACC: 90.10 | Validation ACC: 72.35\n",
      "Time elapsed: 4.87 min\n",
      "Epoch: 024/040 | Batch 000/188 | Cost: 0.3039 | Batchsize: 256\n",
      "Epoch: 024/040 | Batch 150/188 | Cost: 0.4589 | Batchsize: 256\n",
      "Epoch: 024/040\n",
      "Train ACC: 89.20 | Validation ACC: 72.00\n",
      "Time elapsed: 5.09 min\n",
      "Epoch: 025/040 | Batch 000/188 | Cost: 0.2648 | Batchsize: 768\n",
      "Epoch: 025/040 | Batch 150/188 | Cost: 0.3186 | Batchsize: 768\n",
      "Epoch: 025/040\n",
      "Train ACC: 91.24 | Validation ACC: 72.55\n",
      "Time elapsed: 5.30 min\n",
      "Epoch: 026/040 | Batch 000/188 | Cost: 0.2093 | Batchsize: 768\n",
      "Epoch: 026/040 | Batch 150/188 | Cost: 0.3252 | Batchsize: 768\n",
      "Epoch: 026/040\n",
      "Train ACC: 90.77 | Validation ACC: 71.80\n",
      "Time elapsed: 5.51 min\n",
      "Epoch: 027/040 | Batch 000/188 | Cost: 0.3375 | Batchsize: 768\n",
      "Epoch: 027/040 | Batch 150/188 | Cost: 0.2307 | Batchsize: 768\n",
      "Epoch: 027/040\n",
      "Train ACC: 92.61 | Validation ACC: 73.15\n",
      "Time elapsed: 5.72 min\n",
      "Epoch: 028/040 | Batch 000/188 | Cost: 0.2307 | Batchsize: 768\n",
      "Epoch: 028/040 | Batch 150/188 | Cost: 0.2596 | Batchsize: 768\n",
      "Epoch: 028/040\n",
      "Train ACC: 90.78 | Validation ACC: 70.25\n",
      "Time elapsed: 5.94 min\n",
      "Epoch: 029/040 | Batch 000/063 | Cost: 0.2773 | Batchsize: 1280\n",
      "Epoch: 029/040\n",
      "Train ACC: 96.33 | Validation ACC: 75.60\n",
      "Time elapsed: 6.11 min\n",
      "Epoch: 030/040 | Batch 000/063 | Cost: 0.0958 | Batchsize: 1280\n",
      "Epoch: 030/040\n",
      "Train ACC: 96.87 | Validation ACC: 74.95\n",
      "Time elapsed: 6.28 min\n",
      "Epoch: 031/040 | Batch 000/063 | Cost: 0.1020 | Batchsize: 1280\n",
      "Epoch: 031/040\n",
      "Train ACC: 97.30 | Validation ACC: 74.40\n",
      "Time elapsed: 6.44 min\n",
      "Epoch: 032/040 | Batch 000/063 | Cost: 0.0750 | Batchsize: 1280\n",
      "Epoch: 032/040\n",
      "Train ACC: 97.54 | Validation ACC: 75.00\n",
      "Time elapsed: 6.61 min\n",
      "Epoch: 033/040 | Batch 000/038 | Cost: 0.0687 | Batchsize: 1792\n",
      "Epoch: 033/040\n",
      "Train ACC: 98.05 | Validation ACC: 76.20\n",
      "Time elapsed: 6.79 min\n",
      "Epoch: 034/040 | Batch 000/038 | Cost: 0.0607 | Batchsize: 1792\n",
      "Epoch: 034/040\n",
      "Train ACC: 98.19 | Validation ACC: 75.25\n",
      "Time elapsed: 6.96 min\n",
      "Epoch: 035/040 | Batch 000/038 | Cost: 0.0577 | Batchsize: 1792\n",
      "Epoch: 035/040\n",
      "Train ACC: 98.34 | Validation ACC: 75.00\n",
      "Time elapsed: 7.13 min\n",
      "Epoch: 036/040 | Batch 000/038 | Cost: 0.0546 | Batchsize: 1792\n",
      "Epoch: 036/040\n",
      "Train ACC: 98.30 | Validation ACC: 75.35\n",
      "Time elapsed: 7.30 min\n",
      "Epoch: 037/040 | Batch 000/027 | Cost: 0.0610 | Batchsize: 2304\n",
      "Epoch: 037/040\n",
      "Train ACC: 98.56 | Validation ACC: 75.15\n",
      "Time elapsed: 7.47 min\n",
      "Epoch: 038/040 | Batch 000/027 | Cost: 0.0544 | Batchsize: 2304\n",
      "Epoch: 038/040\n",
      "Train ACC: 98.78 | Validation ACC: 75.30\n",
      "Time elapsed: 7.64 min\n",
      "Epoch: 039/040 | Batch 000/027 | Cost: 0.0431 | Batchsize: 2304\n",
      "Epoch: 039/040\n",
      "Train ACC: 98.84 | Validation ACC: 76.75\n",
      "Time elapsed: 7.81 min\n",
      "Epoch: 040/040 | Batch 000/027 | Cost: 0.0455 | Batchsize: 2304\n",
      "Epoch: 040/040\n",
      "Train ACC: 98.84 | Validation ACC: 74.80\n",
      "Time elapsed: 7.98 min\n",
      "Total Training Time: 7.98 min\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "cost_list = []\n",
    "train_acc_list, valid_acc_list = [], []\n",
    "\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    ### INCREASE BATCH SIZE\n",
    "    if epoch > (NUM_EPOCHS//2) and not epoch % (NUM_EPOCHS//len(batch_sizes)):\n",
    "        train_loader = DataLoader(dataset=train_dataset, \n",
    "                                  batch_size=int(batch_sizes[batch_size_index]),\n",
    "                                  num_workers=4,\n",
    "                                  shuffle=True)\n",
    "\n",
    "        batch_size_index += 1\n",
    "    \n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        #################################################\n",
    "        ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n",
    "        ################################################\n",
    "        cost_list.append(cost.item())\n",
    "        if not batch_idx % 150:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f} | Batchsize: {batch_sizes[batch_size_index]}')\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        \n",
    "        train_acc = compute_acc(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n",
    "        \n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n",
    "              f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n",
    "        \n",
    "        train_acc_list.append(train_acc)\n",
    "        valid_acc_list.append(valid_acc)\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cost_list, label='Minibatch cost')\n",
    "plt.plot(np.convolve(cost_list, \n",
    "                     np.ones(200,)/200, mode='valid'), \n",
    "         label='Running average')\n",
    "\n",
    "plt.ylabel('Cross Entropy')\n",
    "plt.xlabel('Iteration')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3dd3hUZdr48e+dThqBhIROqFJiaAGpimLBCiqK2EBR7PXnruyur6776i67r20tq4IIuCrFgmBDXYQVpIcSunQICakkpNfn98cZQoA0ksmcSXJ/rmuumTlz5pw7Rzn3PM95nvuIMQallFIKwMPuAJRSSrkPTQpKKaXKaFJQSilVRpOCUkqpMpoUlFJKlfGyO4C6CAsLM5GRkXaHoZRSDUpsbGyqMaZVRZ816KQQGRnJxo0b7Q5DKaUaFBE5XNln2n2klFKqjCYFpZRSZeotKYjIhyKSLCLbyy1rKSI/ichex3MLx3IRkTdFZJ+IxInIgPqKSymlVOXq85rCHOBt4KNyy6YBy4wx00VkmuP9s8DVQHfH4yLgXcfzeSsqKiI+Pp78/Pw6hK7K8/Pzo3379nh7e9sdilKqntVbUjDG/CIikWctHguMcryeC6zASgpjgY+MVYhprYiEiEgbY0zi+e43Pj6eoKAgIiMjEZHahq8cjDGkpaURHx9P586d7Q5HKVXPXH1NIeLUid7xHO5Y3g44Wm69eMeyc4jIVBHZKCIbU1JSzvk8Pz+f0NBQTQhOIiKEhoZqy0upJsJdLjRXdAavsHyrMWaGMSbGGBPTqlWFw2w1ITiZHk+lmg5Xz1NIOtUtJCJtgGTH8nigQ7n12gMJLo5NKaXqTX5RCWk5haRlF5CWXUhWQTFFxaUUl5ZSVGIoLimluNSUvS41YBy/jY0p9yvZcbuD0b0i6NshxOlxujopLAEmAdMdz4vLLX9UROZjXWDOrM31BHeQlpbG6NGjATh+/Dienp6catGsX78eHx+fardxzz33MG3aNC644IJK13nnnXcICQnhjjvucE7gSqnzVlJqSMsuIOlkAclZ+Wc8p2QVkJZjJYC07AJyCkuctl8RCA/2a1hJQUTmYV1UDhOReOAFrGSwUESmAEeAWxyrfwdcA+wDcoF76iuu+hYaGsqWLVsA+POf/0xgYCDPPPPMGesYYzDG4OFRce/d7Nmzq93PI488UvdglWrEsguKOZ6ZR1Z+MdkFxWTnF5PleM4usB5Z+cVk5ReVvc52vM9yfC6Ap4fg6SF4eXiUvfb0EIwxpOcUUlpBR3dogA+tgnwJC/SlY0d/QgN8CQ30ISzQp+x1kJ8XXh4eeHkKPp4eeHlar70dyzwd3banem9d1Y1bn6OPJlby0egK1jVAoz7L7du3j3HjxjFixAjWrVvHN998w4svvsimTZvIy8tjwoQJPP/88wCMGDGCt99+m6ioKMLCwnjwwQf5/vvv8ff3Z/HixYSHh/Pcc88RFhbGk08+yYgRIxgxYgQ///wzmZmZzJ49m2HDhpGTk8Pdd9/Nvn376N27N3v37uWDDz6gX79+Nh8NpZzrZH4R249lsuPYSbYdy2T7sUwOpuVQ1Y0lfbw8CPbzIsjPmyA/LwJ9vQgL8yfIz5tAX+u9CBSXGkpLDcWlhhLHo9iRCcICfQgP9iM8yJfwIF8igv0IC/TFx8tdLteevwZd+6g6L369g50JJ526zd5tg3nh+j61+u7OnTuZPXs27733HgDTp0+nZcuWFBcXc+mllzJ+/Hh69+59xncyMzO55JJLmD59Ok8//TQffvgh06ZNO2fbxhjWr1/PkiVL+Mtf/sLSpUt56623aN26NV988QVbt25lwACdE6gaHmMMmXlFpGQVkJJdQGp2IamO10fTc9l+LJNDabll67dp7kdUu+aM69+OTqH+BPt5E+g46Z96BPh6NegTd31q1EnB3XTt2pVBgwaVvZ83bx6zZs2iuLiYhIQEdu7ceU5SaNasGVdffTUAAwcOZOXKlRVu+6abbipb59ChQwCsWrWKZ599FoC+ffvSp0/tkplSrnAip5DfkrLYm5zNvuRsfkvK4mBqDqnZBRSVnPuT38tDaN3cj6i2zRk/sD1R7ZoT1a45YYG+NkTfeDTqpFDbX/T1JSAgoOz13r17+ec//8n69esJCQnhzjvvrHAuQPkL056enhQXF1e4bV9f33PWMVW1nZWySUmpYX9KNnHxmWyLz+C3pGz2JmeRml1Ytk6AjyfdIoIY2jW0rEsmLNCHVoG+ZX31zZt54+Ghw6WdrVEnBXd28uRJgoKCCA4OJjExkR9++IExY8Y4dR8jRoxg4cKFjBw5km3btrFz506nbl+p6hhjOJKey9b4TOKOZhB3LJMdxzLLRuL4+3jSIyKIy3qG0z08iO4RgXSPCKJtcz+dH2MTTQo2GTBgAL179yYqKoouXbowfPhwp+/jscce4+677yY6OpoBAwYQFRVF8+bNnb4fpSpSVFLK+HdXszU+E7Au7PZuE8z4ge2Jbh9CdPvmdGkViKf+2ncr0pC7GGJiYszZN9nZtWsXvXr1siki91JcXExxcTF+fn7s3buXK6+8kr179+Lldf6/BfS4qvO1aHM8Ty3YyhOju3NF7wh6RATpxV03ISKxxpiYij7TlkIjlp2dzejRoykuLsYYw/vvv1+rhKDU+TLG8N6KA/SICOSJ0d21778B0TNEIxYSEkJsbKzdYagmaPmeZPYkZfHarX01ITQw2pZTSjnduyv20y6kGdf3bWt3KOo8aVJQSjnVhkPpbDh0gvtHdsbbU08xDY3+F1NKOdV7K/bTwt+bWwd1qH5l5XY0KSilnGbP8SyW7U5m8rDO+PvoJcuGSJOCk40aNYoffvjhjGVvvPEGDz/8cKXfCQwMBCAhIYHx48dXut2zh9+e7Y033iA393QNmGuuuYaMjIyahq5Unb3/3/34+3hy99BOdoeiakmTgpNNnDiR+fPnn7Fs/vz5TJxYWdHY09q2bcvnn39e632fnRS+++47QkKcX29dqYrEn8hl8dYEJg7uSIuA6u8botyTJgUnGz9+PN988w0FBQUAHDp0iISEBPr168fo0aMZMGAAF154IYsXLz7nu4cOHSIqKgqAvLw8brvtNqKjo5kwYQJ5eXll6z300EPExMTQp08fXnjhBQDefPNNEhISuPTSS7n00ksBiIyMJDU1FYDXXnuNqKgooqKieOONN8r216tXL+6//3769OnDlVdeecZ+lDofH6w8iIfAfSM72x2KqoPG3en3/TQ4vs2522x9IVw9vdKPQ0NDGTx4MEuXLmXs2LHMnz+fCRMm0KxZMxYtWkRwcDCpqakMGTKEG264odL6Lu+++y7+/v7ExcURFxd3Rtnrl19+mZYtW1JSUsLo0aOJi4vj8ccf57XXXmP58uWEhYWdsa3Y2Fhmz57NunXrMMZw0UUXcckll9CiRQv27t3LvHnzmDlzJrfeeitffPEFd955p3OOlWrQSkoNM1ceYFt8Jn+8thftQppVum5adgHzNxxhXL92tGle+XrK/WlLoR6U70I61XVkjOGPf/wj0dHRXH755Rw7doykpKRKt/HLL7+UnZyjo6OJjo4u+2zhwoUMGDCA/v37s2PHjmoL3a1atYobb7yRgIAAAgMDuemmm8pKcHfu3Lnspjvly26rpi3+RC4TZ65l+ve7+WHHcca88QuLtxyrdP25qw+RX1TKA5d0cWGUqj407pZCFb/o69O4ceN4+umny+6qNmDAAObMmUNKSgqxsbF4e3sTGRlZYans8ipqRRw8eJBXXnmFDRs20KJFCyZPnlztdqqqb3Wq5DZYZbe1+0gt3nKM577ajjHw6i19iYlswVMLtvDE/C0s353Mi2OjaN7Mu2z9nIJi5q45zJW9I+gWHmRj5MoZtKVQDwIDAxk1ahT33ntv2QXmzMxMwsPD8fb2Zvny5Rw+fLjKbVx88cV88sknAGzfvp24uDjAKrkdEBBA8+bNSUpK4vvvvy/7TlBQEFlZWRVu66uvviI3N5ecnBwWLVrEyJEjnfXnqkYiM6+IJ+Zv5on5W+gREcT3T4zk5oHt6RQawMIHhvL0FT34Oi6Ra/65krUH0sq+N2/9ETLzinhwVFcbo1fO0rhbCjaaOHEiN910U1k30h133MH1119PTEwM/fr1o2fPnlV+/6GHHuKee+4hOjqafv36MXjwYMC6g1r//v3p06fPOSW3p06dytVXX02bNm1Yvnx52fIBAwYwefLksm3cd9999O/fX7uKmoDEzDzeWb6PYD9verYJpmfrIDqHBZwz03jdgTSeXriV4yfzefqKHjw8qite5dbx8vTg8dHdGdk9jKcWbGHizLU8cHFXHrusGx+sPMiQLi0Z0LGFq/88VQ+0dLaqET2uDc/q/ak89ulmsgqKy248D+Dj6UG38EB6tgmiZ+sgUrMLmbnyAB1b+vPGhH70r+bknlNQzEvf7mTe+qOEBfqSml3A3HsHc0mPVq74s5QTuF3pbBF5ArgfEGCmMeYNEWkJLAAigUPArcaYE3bEp1RDZoxhxi8H+PvS3XRpFciCB4bSsaU/+1Oy2X38JLuPZ7E7MYtf96Xy5Sbr4vGEmA48f31vAnyrPyUE+Hrxt5uiufSCcKZ9uY2+7Ztzcfewar+nGgaXJwURicJKCIOBQmCpiHzrWLbMGDNdRKYB04BnXR2fUu7mRE4hOYXFtG/hX+26WflF/O6zOJbuOM61F7bh7+OjCXSc6Hu1CaZXm+Bztp2ZV0RkWEBFm6vSlX1aM7xbGIaKB0WohsmOlkIvYK0xJhdARP4L3AiMBUY51pkLrKCWScEYo/+TOlFD7mJs6FbuTeHhTzaRlV9M3w4hXB/dhuui29K6ud856+5NyuKBj2M5nJbLc9f2YsqIztX+O2gR4FOn2cc1aVmohsWO0UfbgYtFJFRE/IFrgA5AhDEmEcDxHF7Rl0VkqohsFJGNKSkp53zu5+dHWlqansicxBhDWloafn7nnoRU/fpozSEmz95A2+bN+N1VF1BcUspL3+5i6PRlTHh/DR+vPUxatjVz/pu4BMa+8ysn84r45L6LuG9kF/1hpGrFlgvNIjIFeATIBnYCecA9xpiQcuucMMZUecWrogvNRUVFxMfHVzt2X9Wcn58f7du3x9vbu/qVVZ0VlZTy4tc7+HjtES7vFc4bt/Uv6wLan5LNN1sTWbL1GPtTcvD0EKLaBrM1PpMBHUP41x0DK2xFKFVeVReabR99JCJ/BeKBJ4BRxphEEWkDrDDGXFDVdytKCko1ZJm5RTz8aSy/7kvjgUu68PureuJZwe0sjTHsPp7F11sT+Hl3MsO7hfHsmJ74eOnUI1U9dxx9FG6MSRaRjsBNwFCgMzAJmO54PrdinFKN2IGUbO6bu5GjJ3L5v/HR3BJT+U1qRKTswvHvx1Q950Wp82HXVaIvRCQUKAIeMcacEJHpwEJH19IR4BabYlPK5VbtTeXhT2Lx8vTg0/uHMCiypd0hqSbKlqRgjDmnxoIxJg0YbUM4Stkip6CY1fvT+Hl3Mgs3HqVbq0A+mBRDh5bVDz1Vqr7oeDKlXMQYw/6UbFbsSWHFnhTWH0ynsKSUAB9PxvVrx4tj+5RdUFbKLvp/oFL1LCEjj3dX7Gf5nmTiT1hVaHtEBDJ5eCSjerQiJrKlXiBWbkOTglL1aO2BNB75ZBM5hcWM7N6Kh0Z1ZdQF4VXesEYpO2lSUKoeGGOYs/oQL327i8hQfxY+OJSurQLtDkupamlSUMrJ8otK+NOi7XyxKZ7Le0Xw+oS+BPnpxD/VMGhSUMqJEjLyePDjWOLiM3ny8u48fll3PCqYfKaUu9KkoJSTnLp+UFBcysy7Y7iid4TdISl13jQpKFUHRSWlJGbk8+PO40z/fjcdQ/2ZcVcM3cL1+oFqmDQpKFUDR9Nz2RqfwZH0XI6m53LE8UjIyKfEcUez0T3Def22fgTr9QPVgGlSUKoaa/anMenD9RSWlAIQGuBDh5b+9O/QgrF9/enY0p/IsABiOrXQ6weqwdOkoFQVdiWeZOpHG+kYat2/ODIsQGcdq0ZN/+9WqhLHMvKYPHs9Ab5ezL13sE44U02CJgWlKpCRW8ikD9eTW1jCZw8O1YSgmgwtuKLUWfKLSrhv7kaOpOUy464YerYOrv5LSjUSmhRUo5dXWMLTC7Zwyf8t54OVBziZX1TpuiWlhifmbyb2yAlem9CXoV1DXRipUvbTpKAatYSMPG55fzWLthwjyM+Ll77dxbC//cz/frOTo+m5Z6xrjOHPS3bww44k/ufa3lwX3damqJWyj15TUI3WxkPpPPjxJvKLSvjg7hhG94pgW3wms1YdYO7qQ8z+9SBXR7VhysjODOjYgn+t2M+/1x7mgYu7cO+IznaHr5QtxBhjdwy1FhMTYzZu3Gh3GMoNzV9/hP9ZvJ12Ic34YFIM3cKDzvg8MTOPuasP8+m6w5zML6Zn6yB2H89iXL+2vHZrP51voBo1EYk1xsRU+JkmBdWYFJWU8vK3u5iz+hAju4fx9sQBNPevfIZxTkExn8fGM2f1ISJD/Xn/rhi94Y1q9KpKCtp9pBqNEzmFPPLpJlbvT+O+EZ2ZdnVPvDyrPsEH+HoxaVgkk4ZFuiZIpdycJgXVKPyWlMWUuRtIyizglVv6Mn5ge7tDUqpBsqWdLCJPicgOEdkuIvNExE9EOovIOhHZKyILRMTHjthUw7N6Xyo3v7ua/KJS5j8wRBOCUnXg8qQgIu2Ax4EYY0wU4AncBvwdeN0Y0x04AUxxdWyq4flyUzyTZq+ndbAfix4exoCOLewOSakGza4ral5AMxHxAvyBROAy4HPH53OBcTbFphoAYwxvLtvL0wu3EtOpJZ8/NIz2LfztDkupBs/l1xSMMcdE5BXgCJAH/AjEAhnGmGLHavFAu4q+LyJTgakAHTt2rP+AldspKinlj19u47PYeG7q347pN0friCGlnMSO7qMWwFigM9AWCACurmDVCsfKGmNmGGNijDExrVq1qr9AlVs6mV/EPbM38FlsPE+M7s6rt/bVhKCUE9kx+uhy4KAxJgVARL4EhgEhIuLlaC20BxJsiE25sYSMPO6ds4F9ydn83/hobonpYHdISjU6dvzEOgIMERF/ERFgNLATWA6Md6wzCVhsQ2zKTf26L5Ub//Urx07kMeeewZoQlKondlxTWCcinwObgGJgMzAD+BaYLyIvOZbNcnVsyv1kFxTz1+928em6I3QJC2DuvYO1lLVS9ciWyWvGmBeAF85afAAYbEM4yk2t2pvKs1/EkZCZx9SLu/D0FT3w8/a0OyylGjWd0axc5vtticSfyGNIl1B6tw3Gs5Kic1n5Rfz1u93MW3+ELq0C+PzBYQzspPMPlHIFTQrKJTYdOcGj8zZTUmoNKgvy9WJw55YM6RJ6RpJYuTeFaV9sI1FbB0rZQpOCqnc5BcU8tWALrYP9mHvvYHYkZLL2QDrrDqSxbHcyYCWJHq2DiD18gi6tAvhMWweqKctJA99A8PJ1+a41Kah695evd3IkPZf59w+hW3gg3cIDGdvPmpuYfDKftQfTWXsgjS1HMnjwkq48eXl3bR2opikvA1a+CuvegxaRcMtciOjt0hA0Kah6tXT7cRZsPMrDo7pyUZdz73ccHuzHDX3bckNfvfWlcqKSIti5GApzoNf14N/S7oiqVlIEGz+EFdMh7wRE3QSHVsHMy+Ca/4P+d4K45sZPmhRUvUk6mc+0L+OIahfMk5f3sDsc1RQUZMGmj2DNv+BkvLXsu2egx1UQfRt0vxK83KgAszGw5zv46XlI2wedL4YrX4I2fSE7Gb64D5Y8aiWIa1+1upTqmSYFVS9KSw3PfLaV/KIS3pjQX0tRNEXGwI/PQXE+RI6ATiMgsJ5K02QnW10uGz6A/ExrX9e9DkERsHUBbFsIu76GZi0h6mboexu0G+iyX98VStgMPzwHh1dBWA+YuMBKXqdiCgyHuxbBL6/Air9BwiaXdCfp7ThVvfhw1UH+8s1O/ndcFHcN6WR3OMoOB/4LH90AHt5QWmQta9XTShDOShJp+2H1W7DlUygptLqKhj8B7c+602RJMez/GeLmw+5vrUQV2g06DLFOvkGtrefA1qff+wRA/kk4cRBOHLIe6adeH4SsJPALBv9Q69GsxenX/qHg7QcF2VBw0tpOwUmrJZOfCfkZcHybtd6oP8DAyeBZ+W1jOfiL1WrIP+mU7iS9R7NyqT3Hs7j+7VWM6BbGrEkxiJ2/xpR95lxndYk8uhFSdsOhlVY3yOE1UJRjrdOqJ7QfZP1qbx8DrXqBZxUdGFlJcHQtHFlnPR/bBJ4+0G8iDH0MwrpVH1d+JuxcAts+s+LLToLS4nPX8/Kzkkd5zVpaF4BbdoagNtZJPjcNctOt57x067UpKfclAd9gK4GUPQdBm34w7FHwa159zHC6O+ngf62usDp0J2lSUC5TUFzC2Ld/JSWrgKVPXkyrINcPqVNu4PAamD0GrvobDH34zM9KiiBxqyNJ/ArHNloXVwG8mkHbflaSaDfAOgEnxsHRdXBkrfULHcDT11qnyyUw8B6rm6i2Skutk3l2EmQdP/2cmwYBrU4ngRaRNTuBl5ZCQSYU5Vknf59A53VTlZac7k66/M8w4slabUaTgnKZl77ZyQerDjJrUgyje9XhH6pq2D6+GRK2wJPbwKeamx8ZY53sj22CY7HWI3Hrmb/S/cOg4xDocJH13KavLWP43cbRDVbyrKrLqQpVJQW90KycZvnuZD5YdZA7h3TUhNCUHYuFff+xfslWlxDA+hXdsov1uNBRKLmkCJJ3Wn34rS+0PtNuyNM6DKq3TWtSUHVWWmp475f9vPrjb/SICORP17h2sk29y02HZX+BbqOtC5mqar+8Cn4hMOi+2m/D09tqDbTp67y4VI3oOEFVJ2nZBdwzZwP/WLqHMVGt+fyhYTTzaUSzkY9vgxmjIHY2LLgL1r1vd0SudeKwdVG3po5vhz3fwpCHrf501eBoS0HV2roDaTw+fzMncot4aVwUd1zUsXGNNIpbCEset4YaTvoa1r4H3//euhB52f/UvDujtBSyEqzRKh4NJGGeTIBf/s+aCFZaArfMhj43Vv+9la+ATxBcNLX+Y1T1QpOCOm8lpYZ/Ld/H6//5jU6hAXw4eRB92tZwWF1DUFJkTbpa9x50Gg63zLHGrnccBt8+bdWmyU6C6/5Z9fBJgPiN8P2z1ggbn6ByI2scj+bt6udvKC2F9ANW/37KLoiIgq6XVV/uIScVVr1uTQIrLYEBkyBpB3w51brY23lk5d9N+Q12fAUjnrISqWqQqk0KIvIo8Ikx5oQL4lFuLiWrgKcWbGHVvlTG9mvLyzdeSKBvI/ptkZUEn02GI6thyCNwxYunR3h4esH1/4TACPjlH9YJdPzsii+mnkyA/7xoTZYKjIDRz1vLjsXCmndOT+YKamMlh5BOUFh+olPWma+9/awhkS0ioUXnM4dJBraGnJTTI3eOxVqzX/MzHcEIYEA8oP1g6H6FVe6h9YWnWzt5GbDmbVj7LhTlQt+JcMnvre3npsOHY2D+7XDP99A6quJjt/JV8G4GQx9xzn8LZYtqh6Q6bo95G9btMz8EfjBuMo5Vh6S61pajGdz/0UZO5hXx4g19mDCoQ+PqLjq6HhbebZ0gx759eiRMRdbPhO9+Z028un3B6V/gRfnWyXXla9aJf+ijMPLpM/vXi/IhafuZJ/GTCdY6vo6JTWdMdAq2CrudmlmbGQ+m9PT2ys8YFk+rDEK7gdAuxnoO7WYN8dz3E+z90SqvAFYy6X65lZjWz7CSSJ8bYdQfodVZtaoyjsKsKwEDU36CkLPukZ1+AN6KgSEPwVUv1+boKxeq8zwFsf7lXwncA8QAC4FZxpj9zgz0fGlScJ3txzK5feZamvt7M/PuGNffJ7koD/Z8b9WIiehTt+GJxlgnwJPHrBNsZrw1s3X9TGjeHiZ8XPmv4fJ2fAVf3m/9cr/rS+vk/uNzkHEEel5nFTZr2bn2cVamuBAyj54ut3DisNUaaR8DraOrHwaanWwNGd37I+z72Zpo1WMMXPonaBNd+feSdsCHV1slIO5demZX1JLHYet8eDLO+ly5NadMXhORvlhJYQywHBgC/GSM+b2zAj1fmhRc47ekLCa8v4Zm3p4sfHAo7VvUYOy5Mx1eA4sfgXTHb5Cgtqe7QLpcUvkol5JiSN1jzYg9HgepvzmSwDEozDpzXQ8v68Q49u3z6w8/uNLqVikugJICCO8NY6ZbcTUEJcVW11Nwm5qtf2gV/PtGaNsf7l5sdRdlHIU3+1v1e659pV7DVc5Rp6QgIo8Dk4BU4APgK2NMkYh4AHuNMV2dHXBNaVKofwdSsrn1/bV4CCx8YCiRYQG139iWT61yBv1ur9mJtyAblr1o/YIP6WCdbPNOWL9w9y+3+tw9vKHTMCtBtOlrnfiPx1mJIHnn6VmxXs2sLpHmHazWQPP2ENzO8b6d9Uu7tiODjm+DpX+APuNgwOTqLz43dDu+sq67XHAN3PoR/PAH2DgbHt98breSckt1TQp/weoqOlzBZ72MMbvOM5gLgAXlFnUBngc+ciyPBA4Bt1Z3cVuTQv06mp7Lre+vobC4lAUPDKFbeB3Gne/9CT5x9NF7NYPoW2Hw1Mq7afb/DEuesLpJLnrAGgJavvhXSZFVD2fvj9a2k3ee/swvxOoGaR1tJYrW0Va/emM/WbvSuhnw/e8garxVkrrvBLjhLbujUjVU16QwBNhhjMlyvA8CehtjzmNGS6Xb9gSOARcBjwDpxpjpIjINaGGMebaq72tSqD+JmXnc+v4aTuYVM+/+IfRuW4drCJnx8N5ICG5rjd7ZNBfiPoPiPGuY50VTrT54T2/rIu+Pf4LNH0Nod6s7p+OQ6veRcdSqxBnWA0I6akkEV/jPi7DqNWtU02OxVikK1SDUNSlsBgacGnHk6DbaaIwZ4ITArgReMMYMF5E9wChjTKKItAFWGGMuqOr7mhTOjzGGWasO4uvlwdCuYXRtFVDh6KGUrAImvL+G5KwCPrnvIvp2CKn9TkuKYM611kXKqf89Xdo4Nx22fGJ1DWUctkbARN0M2z63+riHP4VIxgYAABldSURBVA6XTLOGYir3ZIzVvecdAJf8zu5o1Hmoa0E8KT8E1RhTKiLOaoffBsxzvI4wxiQ69pEoIuEVBiMyFZgK0LFjRyeF0TSsPZDOS9+e7u2LCPZlWNcwhnYNZVjXUNq38Cc9p5A7P1hHYmY+H00ZXLeEAFbNoKPr4OZZZ9a6928Jwx6zyiHs/ckaErnmbWuS1e3zrQuZyr2JWEXvVKNSk5bCl8AK4F3HooeBS40x4+q0YxEfIAHoY4xJEpEMY0xIuc9PGGOqvBqpLYXzM2XOBjYfzWDe/UOIPXyC1ftTWbM/jbScQgA6tvTHQyAhM585kwcxrFtY3Xa4ZynMm2DVu7/+jerXz023xuRr379S9aquLYUHgTeB5wADLMPxS72OrgY2GWOSHO+TRKRNue6jZCfsQznsS85i2e5knhjdnQtaB3FB6yBuv6gjxhh+S8pm9f5UVu9P41BqDu/fNbDuCSHjKCx6wJo1O2Z6zb5TXQkGpVS9qzYpGGOSsbp5nG0ip7uOAJZgDX2d7nheXA/7bLJmrTqIj5cHdw09837JIlKWJO4ZXs1Eq5w06xaIkSOtmbaVKS6Ez+9xFFKbq9cFlGpAalL7yA+YAvQByv51G2Pure1ORcQfuAJ4oNzi6cBCEZkCHAFuqe321ZlSswv4YtMxbh7QnrDAWt6t6uh6WDjJqvbpHWCVgIi5p+K+/2UvQvwGq5BcqG3TWJRStVCT+yn8G2gNXAX8F2gPZFX5jWoYY3KNMaHGmMxyy9KMMaONMd0dz+l12Yc67d9rDlNYXMqUEbUouWCMVSRt9tXg5WOd6KNutMpKzxhlPTZ9ZNXmAdj9rXXBeND9NSu1rJRyKzUakmqM6S8iccaYaBHxxiqKd5lrQqycXmiuXn5RCcOm/0z/DiHMmnyet/AryIIlj8GORdbs1XHvQjPHWIC8DCsxxM62Jo75Bluth+1fWLWApvzYtO+hq5Qbq+uFZkf5RTJEJAo4jjXrWDUAX2yKJz2nkPtGnufEouRd1p3G0vfD5S/C8CfOnBDWLMSadDb4fmvI6cbZsPkT8PKzWhOaEJRqkGqSFGaISAus0UdLgEDgf+o1KuUUpaWGWSsPcmG75gzpch4je+IWwtdPgE8g3L2k6huriFgzjjsOgTF/swrD1bS4mlLK7VSZFByzl086ahD9glWnSDUQP+9O5kBqDv+8rV/N7nuQkwrLX4aNH1rlJ26ZfX5lkHVIqVINXpVJwTF7+VGs+ycoN7DneBZH0nO5ondEtevOXHmAts39uObCKn65GwOHf7USwa6voaTQmmk8+oXTdxxTSjUZNek++klEnsGqYJpzaqGODnK91OwC7py1jpSsAm6/qCMvXN8bX6+Kyz3HxWew7mA6z13bC2/PCgaZ5abD1nkQO8cqN+3XHGKmWDXxw3vW69+hlHJfNUkKp+YjlL/xqkG7klyqtNTwzGdbiclfw20djvHoutHsTjzJe3cOJDz43MlhM1ceJMjXiwmDzqpvH7/RKkK3Y5F1U5j2g6xRRb3HVX/HLqVUo1eTGc31cD9Bdb5mrz7Epj0H2RA4A9+Uk2wIWcZTiZO57q083r1zIAM7nS4TdSwjj++2JXLv8EiC/BxdQCcTrFtFbv8CfIKg/53W5LPWF9r0Fyml3FFNZjTfXdFyY8xHzg9HVWT7sUymf7+LtyKW4ZOZBePexW/127yb/w9+Kr2Yh2bcwZM3DOP2i6yqsbNXHQRg8vDOVsmJtf+C//4DSoutctTDHjvzhjVKKeVQk+6j8jOe/IDRwCasO6WpepZTUMzj8zbT2/8kV2V/hfSdaN3OMmo8rHqNy395hYt84/jD4rvZFn8Lz1x1AfM3HOW66Da0S1sD//49pO2FHldbQ0br40bySqlGoybdR4+Vfy8izbFKXygXeGHJDg6m5bC+1w/IYeCyP1kfePnAqGlIr+sJWvwI7yS8yQ9bVjNxxwMEF+Txl4LZ8O/vrdnFty+EHlfZ+ncopRqG2hSuzwW6OzsQda7FW47xeWw8/3tRKa22LrJmFTdvf+ZKEX2QKf+Bte9w+bKXGVLyJH7NivE96gGXPmd1FWmVUqVUDdXkmsLXWKONwCqg1xudt1DvjqTl8qdF2xnYqQV3ZP3DKisx4qmKV/b0guFP4HnBtTT7+v8hzZrDmJesexUrpdR5qElL4ZVyr4uBw8aY+HqKRwFFJaU8Nn8zIvDe0Aw8vloOV/3tdDG6yoR1w+cevQ2FUqr2apIUjgCJxph8ABFpJiKRxphD9RpZE/baT7+x9WgG70zsS6vVt0FIJxg0xe6wlFJNQE3up/AZUFrufYljmaoHP+w4znv/3c9tgzpwrVkJSdtg9PNadVQp5RI1SQpexpjCU28cr33qL6Sma976Izz0cSzR7Zrz/JjO8PNL1p3N+txkd2hKqSaiJkkhRURuOPVGRMYCqfUXUtNjjOH1n37jD19u4+Ierfj0/iH4b54FJ+Phir+AR03+MymlVN3V5JrCg8AnIvK24308UOEsZ3X+iktKee6r7czfcJTxA9vzt5suxLsgA1a+Bt2vgs4X2x2iUqoJqcnktf3AEBEJxLp9Z53uz6xOyy0s5rFPN7NsdzKPXtqN/3dlD+u+B7+8AoVZcPmf7Q5RKdXEVNsvISJ/FZEQY0y2MSZLRFqIyEuuCK4xS88p5PaZ61i+J5n/HRfFM1ddYCWE9AOwfgb0uwMietsdplKqialJZ/XVxpiMU28cd2G7pi47FZEQEflcRHaLyC4RGSoiLUXkJxHZ63huUf2WGqaj6bnc/O5qdiWe5N07B3LXkE7WB2n74aOx1n2OL/2jvUEqpZqkmiQFTxEpGw8pIs2Auo6P/Cew1BjTE+gL7AKmAcuMMd2BZY73jc6RtFxu/Ndq0nMK+eS+i7iqj+N2l8e3wYdjoCAbJi2G4Lb2BqqUapJqcqH5Y2CZiMx2vL8HmFvbHYpIMHAxMBnKhrgWOkY1jXKsNhdYATxb2/24o1MzlQuLS/jioWF0jwiyPji8Gj69zSpnPflbaNXD3kCVUk1WTS40/0NE4oDLAQGWAp3qsM8uQAowW0T6ArHAE0CEMSbRsc9EEQmv6MsiMhWYCtCxY8Oq7XNqpvK/7hhwOiHsWQqfTYLmHeCuRRDSoeqNKKVUParpAPjjWLOab8a6n8KuOuzTCxgAvGuM6Y913+cadxUZY2YYY2KMMTGtWrWqQxiutXpfatlM5WsubGMt3LoA5t8O4b3g3qWaEJRStqu0pSAiPYDbgIlAGrAAa0jqpXXcZzwQb4xZ53j/OVZSSBKRNo5WQhsguY77cRvpOYU8tXALXcICeP56x4iite/B0mchciRMnAe+QfYGqZRSVN1S2I3VKrjeGDPCGPMWVt2jOjHGHAeOisgFjkWjgZ3AEmCSY9kkoFGU+zTG8PvP4ziRU8SbE/vj7yXw88tWQuh5HdzxuSYEpZTbqOqaws1YLYXlIrIUmI91TcEZHsOaJe0DHMC6eO0BLBSRKViVWW9x0r5s9fHaw/xnVxIvXdmWPgfnwMIPIOMI9L8LrnvDuheCUkq5iUrPSMaYRcAiEQkAxgFPAREi8i6wyBjzY213aozZAsRU8NHo2m7THe05nsXn337P3LAVXLx6BRTnW91FV74EvW4AcVaOVUop56jJ6KMc4BOsX/YtsX7BTwNqnRQavZIiCrd/RdGSV1nstQuT74/0nQiD74eIPnZHp5RSlTqvvgtjTDrwvuOhKlKYC++PxCdtH4GlEeyL+SPdrpgKzRrtBG2lVCOiHdrOtmsJpO3jd0VTaT5kEs9dH2V3REopVWOaFJysOPbfJNKaneHX8+XVvewORymlzovevcWZ0g/gdWQV84su5uWbovH18rQ7IqWUOi+aFJzIbP6EUoTf2lxHvw4hdoejlFLnTbuPnKW0hIKNH7O2JJqxFw+yOxqllKoVbSk4y4Hl+OUd5yffKxhzqhy2Uko1MNpScJKTa+ZQYgLpOGw8Xp6aa5VSDZOevZwhNx3/A0v5xoxgwpCudkejlFK1pknBCXJi5+NlisjoOYEQfx+7w1FKqVrT7iMnyF03h4OlkYwZfYXdoSilVJ1oS6GOiuK30Cp7D7Etrz19NzWllGqgNCnU0dGfZ1BgvOhy6WS7Q1FKqTrTpFAXxQWEHVzCr95DGH5hd7ujUUqpOtOkUAcHfv2MYJNFSd878PDQeyMopRo+TQp1kLduLomEMuzym+0ORSmlnEKTQi0lxe+jV84G9rW5noBmvnaHo5RSTqFJoZb2/DATDzF0veIBu0NRSimn0aRQC3kFRXQ+uog9fn1p26W33eEopZTTaFKohV+XLaEDSXgNvNvuUJRSyqlsmdEsIoeALKAEKDbGxIhIS2ABEAkcAm41xpywI76qGGMo3fxvcmlGl0sm2h2OUko5lZ0thUuNMf2MMTGO99OAZcaY7sAyx3u3s25LHJcUriSx0w2IT4Dd4SillFO5U/fRWGCu4/VcYJyNsVQq/z9/RQTa3/Anu0NRSimnsyspGOBHEYkVkamOZRHGmEQAx3N4RV8UkakislFENqakpLgoXMuhPVsZkf0jO9uOxze0k0v3rZRSrmBXldThxpgEEQkHfhKR3TX9ojFmBjADICYmxtRXgBXJ/O5FwvGm49jnXLlbpZRyGVtaCsaYBMdzMrAIGAwkiUgbAMdzsh2xVSbz4Gb6Zi5jffh4WkZ0sDscpZSqFy5PCiISICJBp14DVwLbgSXAJMdqk4DFro6tKmlfP89J04wO17nl9W+llHIKO7qPIoBFInJq/58aY5aKyAZgoYhMAY4At9gQW4UKD6+jS/ovfB4ymfGdOtodjlJK1RuXJwVjzAGgbwXL04DRro6nJk58/QJeJojWVz5pdyhKKVWv3GlIqlsyB38hInUNn/ndwvDekXaHo5RS9Urv0VwVY8j67s/kmha0HPUQji4vpZRqtLSlUJW9PxGcEsssj/HcENPV7miUUqreaUuhMqWlFPz0IkmmFf5DJuPn7Wl3REopVe+0pVCZXUvwTdnOWyXjuWNYN7ujUUopl9CWQkVKSyj5+WUOmnaUXngL4cF+dkeklFIuoS2FisQtxDPtN14tGs+9I7WVoJRqOrSlUIHSuAUclXZkdLqKPm2b2x2OUkq5jLYUzmYMxUc3sbqoB1O0laCUamI0KZwt/QA+RZkcadaTy3pWWL1bKaUaLU0KZyk9tgkA346D8PDQyWpKqaZFk8JZTu5fR77xpl2P/naHopRSLqcXms9SdCSW/SaS/p2160gp1fRoS6G8kmKaZ+xkj0c3uoQF2B2NUkq5nCaF8lL34GPyyQ7rq9cTlFJNkiaFcnIPrgcgoPMgmyNRSil76DWFck7sW0exaUaXC865B5BSSjUJ2lIoxzNxM9tNF/p2bGF3KEopZQtNCqcUFxCWs48E/174+2gDSinVNGlScChOiMOLYkra6PwEpVTTpUnBIXn3GgBa9hhicyRKKWUf25KCiHiKyGYR+cbxvrOIrBORvSKyQER8XBlP7qENpJhgevfs7crdKqWUW7GzpfAEsKvc+78DrxtjugMngCmuDCYgNY49Ht1pG9LMlbtVSim3YktSEJH2wLXAB473AlwGfO5YZS4wzmUBFWQRUXiYzBYXYoWilFJNk10thTeA3wOljvehQIYxptjxPh5oV9EXRWSqiGwUkY0pKSlOCSZt3wY8MHh1HOiU7SmlVEPl8qQgItcBycaY2PKLK1jVVPR9Y8wMY0yMMSamVatWTokpefdqANr2Hu6U7SmlVENlx4D84cANInIN4AcEY7UcQkTEy9FaaA8kuCqg0vhNxJswLujS2VW7VEopt+TyloIx5g/GmPbGmEjgNuBnY8wdwHJgvGO1ScBiV8UUmrmDo3498fHSEbpKqabNnc6CzwJPi8g+rGsMs1yx0/yMJFqXHic/vJ8rdqeUUm7N1noOxpgVwArH6wPAYFfHcGT7r/QAgrq6fNdKKeV23KmlYIvMfesoNUKXC/Uis1JKNfmk4JO8haMe7WgZGmZ3KEopZbsmnRRMaSntcneTEtzH7lCUUsotNOmkEH94H2FkQFutjKqUUtDUk8KOXwFo1XOozZEopZR7aNJJoeDIRoqMJx16XWR3KEop5RaadFJonr6NeJ8uePhoZVSllIImnBRO5hXQtWgv2aEX2h2KUkq5jSabFPbs2Eqw5OLXaZDdoSillNtoskkhdY91+822fYbZHIlSSrmPJpsUJHEz+fgS0C7K7lCUUsptNMmkUFJqiMjaSVJAD/C0tfyTUkq5lSaZFH5LPEFPDlIUoZVRlVKqvCaZFPbv2EgzKSSk2xC7Q1FKKbfSJJNCz5K9AIT20KSglFLlNckO9W6dIyHjWiS0q92hKKWUW2mSSYGe11oPpZRSZ2iS3UdKKaUqpklBKaVUGU0KSimlymhSUEopVcblSUFE/ERkvYhsFZEdIvKiY3lnEVknIntFZIGI+Lg6NqWUaursaCkUAJcZY/oC/YAxIjIE+DvwujGmO3ACmGJDbEop1aS5PCkYS7bjrbfjYYDLgM8dy+cC41wdm1JKNXW2XFMQEU8R2QIkAz8B+4EMY0yxY5V4oF0l350qIhtFZGNKSoprAlZKqSbClslrxpgSoJ+IhACLgF4VrVbJd2cAMwBEJEVEDleymzAg1Qnh1hd3jk9jqx2NrXY0ttqpS2ydKvvA1hnNxpgMEVkBDAFCRMTL0VpoDyTU4PutKvtMRDYaY2KcFqyTuXN8GlvtaGy1o7HVTn3FZsfoo1aOFgIi0gy4HNgFLAfGO1abBCx2dWxKKdXU2dFSaAPMFRFPrKS00BjzjYjsBOaLyEvAZmCWDbEppVST5vKkYIyJA/pXsPwAMNiJu5rhxG3VB3eOT2OrHY2tdjS22qmX2MSYCq/nKqWUaoK0zIVSSqkymhSUUkqVaZRJQUTGiMgeEdknItPsjqc8ETkkIttEZIuIbLQ5lg9FJFlEtpdb1lJEfnLUoPpJRFq4UWx/FpFjjmO3RUSusSm2DiKyXER2Oep3PeFYbvuxqyI224+dO9c9qyK2OSJysNxx6+fq2MrF6Ckim0XkG8f7+jluxphG9QA8sWZIdwF8gK1Ab7vjKhffISDM7jgcsVwMDAC2l1v2D2Ca4/U04O9uFNufgWfc4Li1AQY4XgcBvwG93eHYVRGb7ccOECDQ8dobWIc1R2khcJtj+XvAQ24U2xxgvN3/zzniehr4FPjG8b5ejltjbCkMBvYZYw4YYwqB+cBYm2NyS8aYX4D0sxaPxao9BTbWoKokNrdgjEk0xmxyvM7CmmfTDjc4dlXEZjtjccu6Z1XE5hZEpD1wLfCB471QT8etMSaFdsDRcu8rraNkEwP8KCKxIjLV7mAqEGGMSQTrBAOE2xzP2R4VkThH95ItXVvliUgk1hDrdbjZsTsrNnCDY1eXumeujs0Yc+q4vew4bq+LiK8dsQFvAL8HSh3vQ6mn49YYk4JUsMxtMj4w3BgzALgaeERELrY7oAbkXaArVsn1ROBVO4MRkUDgC+BJY8xJO2M5WwWxucWxM8aUGGP6YZWyGcx51D2rb2fHJiJRwB+AnsAgoCXwrKvjEpHrgGRjTGz5xRWs6pTj1hiTQjzQodz7GtVRchVjTILjORmrGKAzJ+w5Q5KItAFwPCfbHE8ZY0yS4x9uKTATG4+diHhjnXQ/McZ86VjsFseuotjc6dg54skAVlCu7pnjI9v/vZaLbYyjO84YYwqA2dhz3IYDN4jIIazu8MuwWg71ctwaY1LYAHR3XJn3AW4DltgcEwAiEiAiQadeA1cC26v+lsstwao9BW5Wg+rUCdfhRmw6do7+3FnALmPMa+U+sv3YVRabOxw7d657Vklsu8slecHqs3f5cTPG/MEY094YE4l1PvvZGHMH9XXc7L6iXh8P4BqsURf7gT/ZHU+5uLpgjYbaCuywOzZgHlZXQhFWC2sKVl/lMmCv47mlG8X2b2AbEId1Am5jU2wjsJrqccAWx+Madzh2VcRm+7EDorHqmsVhnVyfdyzvAqwH9gGfAb5uFNvPjuO2HfgYxwglux7AKE6PPqqX46ZlLpRSSpVpjN1HSimlakmTglJKqTKaFJRSSpXRpKCUUqqMJgWllFJlNCkoVQURKSlXIXOLOLHqrohElq8Cq5Q7sOMezUo1JHnGKn2gVJOgLQWlakGs+2L83VGDf72IdHMs7yQiyxwF1JaJSEfH8ggRWeSo179VRIY5NuUpIjMdNfx/dMymVco2mhSUqlqzs7qPJpT77KQxZjDwNlYtGhyvPzLGRAOfAG86lr8J/NcY0xfrPhE7HMu7A+8YY/oAGcDN9fz3KFUlndGsVBVEJNsYE1jB8kPAZcaYA44CdMeNMaEikopVQqLIsTzRGBMmIilAe2MVVju1jUisEs3dHe+fBbyNMS/V/1+mVMW0paBU7ZlKXle2TkUKyr0uQa/zKZtpUlCq9iaUe17jeL0aq5IlwB3AKsfrZcBDUHYzl2BXBanU+dBfJUpVrZnjblynLDXGnBqW6isi67B+XE10LHsc+FBEfgekAPc4lj8BzBCRKVgtgoewqsAq5Vb0moJSteC4phBjjEm1OxalnEm7j5RSSpXRloJSSqky2lJQSilVRpOCUkqpMpoUlFJKldGkoJRSqowmBaWUUmX+P27nAQuwl9IYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n",
    "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n",
    "\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation ACC: 75.25%\n",
      "Test ACC: 73.93%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False):\n",
    "    test_acc = compute_acc(model=model,\n",
    "                           data_loader=test_loader,\n",
    "                           device=DEVICE)\n",
    "    \n",
    "    valid_acc = compute_acc(model=model,\n",
    "                            data_loader=valid_loader,\n",
    "                            device=DEVICE)\n",
    "    \n",
    "\n",
    "print(f'Validation ACC: {valid_acc:.2f}%')\n",
    "print(f'Test ACC: {test_acc:.2f}%')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}