{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Multilayer Perceptron with BatchNorm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.1\n",
    "num_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_hidden_1 = 128\n",
    "num_hidden_2 = 256\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "class MultilayerPerceptron(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(MultilayerPerceptron, self).__init__()\n",
    "        \n",
    "        ### 1st hidden layer\n",
    "        self.linear_1 = torch.nn.Linear(num_features, num_hidden_1)\n",
    "        # The following to lones are not necessary, \n",
    "        # but used here to demonstrate how to access the weights\n",
    "        # and use a different weight initialization.\n",
    "        # By default, PyTorch uses Xavier/Glorot initialization, which\n",
    "        # should usually be preferred.\n",
    "        self.linear_1.weight.detach().normal_(0.0, 0.1)\n",
    "        self.linear_1.bias.detach().zero_()\n",
    "        self.linear_1_bn = torch.nn.BatchNorm1d(num_hidden_1)\n",
    "        \n",
    "        ### 2nd hidden layer\n",
    "        self.linear_2 = torch.nn.Linear(num_hidden_1, num_hidden_2)\n",
    "        self.linear_2.weight.detach().normal_(0.0, 0.1)\n",
    "        self.linear_2.bias.detach().zero_()\n",
    "        self.linear_2_bn = torch.nn.BatchNorm1d(num_hidden_2)\n",
    "        \n",
    "        ### Output layer\n",
    "        self.linear_out = torch.nn.Linear(num_hidden_2, num_classes)\n",
    "        self.linear_out.weight.detach().normal_(0.0, 0.1)\n",
    "        self.linear_out.bias.detach().zero_()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.linear_1(x)\n",
    "        # note that batchnorm is in the classic\n",
    "        # sense placed before the activation\n",
    "        out = self.linear_1_bn(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        out = self.linear_2(out)\n",
    "        out = self.linear_2_bn(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        logits = self.linear_out(out)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = MultilayerPerceptron(num_features=num_features,\n",
    "                             num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/938 | Cost: 2.6465\n",
      "Epoch: 001/010 | Batch 050/938 | Cost: 1.0305\n",
      "Epoch: 001/010 | Batch 100/938 | Cost: 0.5404\n",
      "Epoch: 001/010 | Batch 150/938 | Cost: 0.4430\n",
      "Epoch: 001/010 | Batch 200/938 | Cost: 0.3235\n",
      "Epoch: 001/010 | Batch 250/938 | Cost: 0.1927\n",
      "Epoch: 001/010 | Batch 300/938 | Cost: 0.4007\n",
      "Epoch: 001/010 | Batch 350/938 | Cost: 0.3802\n",
      "Epoch: 001/010 | Batch 400/938 | Cost: 0.2528\n",
      "Epoch: 001/010 | Batch 450/938 | Cost: 0.2257\n",
      "Epoch: 001/010 | Batch 500/938 | Cost: 0.1454\n",
      "Epoch: 001/010 | Batch 550/938 | Cost: 0.2160\n",
      "Epoch: 001/010 | Batch 600/938 | Cost: 0.3425\n",
      "Epoch: 001/010 | Batch 650/938 | Cost: 0.2175\n",
      "Epoch: 001/010 | Batch 700/938 | Cost: 0.2307\n",
      "Epoch: 001/010 | Batch 750/938 | Cost: 0.3723\n",
      "Epoch: 001/010 | Batch 800/938 | Cost: 0.2452\n",
      "Epoch: 001/010 | Batch 850/938 | Cost: 0.1285\n",
      "Epoch: 001/010 | Batch 900/938 | Cost: 0.1302\n",
      "Epoch: 001/010 training accuracy: 95.63%\n",
      "Time elapsed: 0.22 min\n",
      "Epoch: 002/010 | Batch 000/938 | Cost: 0.2137\n",
      "Epoch: 002/010 | Batch 050/938 | Cost: 0.1923\n",
      "Epoch: 002/010 | Batch 100/938 | Cost: 0.1739\n",
      "Epoch: 002/010 | Batch 150/938 | Cost: 0.0742\n",
      "Epoch: 002/010 | Batch 200/938 | Cost: 0.2186\n",
      "Epoch: 002/010 | Batch 250/938 | Cost: 0.1424\n",
      "Epoch: 002/010 | Batch 300/938 | Cost: 0.1131\n",
      "Epoch: 002/010 | Batch 350/938 | Cost: 0.0575\n",
      "Epoch: 002/010 | Batch 400/938 | Cost: 0.1232\n",
      "Epoch: 002/010 | Batch 450/938 | Cost: 0.2385\n",
      "Epoch: 002/010 | Batch 500/938 | Cost: 0.1344\n",
      "Epoch: 002/010 | Batch 550/938 | Cost: 0.0950\n",
      "Epoch: 002/010 | Batch 600/938 | Cost: 0.1565\n",
      "Epoch: 002/010 | Batch 650/938 | Cost: 0.1312\n",
      "Epoch: 002/010 | Batch 700/938 | Cost: 0.0859\n",
      "Epoch: 002/010 | Batch 750/938 | Cost: 0.1722\n",
      "Epoch: 002/010 | Batch 800/938 | Cost: 0.0630\n",
      "Epoch: 002/010 | Batch 850/938 | Cost: 0.2606\n",
      "Epoch: 002/010 | Batch 900/938 | Cost: 0.1681\n",
      "Epoch: 002/010 training accuracy: 96.94%\n",
      "Time elapsed: 0.45 min\n",
      "Epoch: 003/010 | Batch 000/938 | Cost: 0.0676\n",
      "Epoch: 003/010 | Batch 050/938 | Cost: 0.1975\n",
      "Epoch: 003/010 | Batch 100/938 | Cost: 0.1241\n",
      "Epoch: 003/010 | Batch 150/938 | Cost: 0.1723\n",
      "Epoch: 003/010 | Batch 200/938 | Cost: 0.2233\n",
      "Epoch: 003/010 | Batch 250/938 | Cost: 0.2249\n",
      "Epoch: 003/010 | Batch 300/938 | Cost: 0.1027\n",
      "Epoch: 003/010 | Batch 350/938 | Cost: 0.0369\n",
      "Epoch: 003/010 | Batch 400/938 | Cost: 0.1460\n",
      "Epoch: 003/010 | Batch 450/938 | Cost: 0.0430\n",
      "Epoch: 003/010 | Batch 500/938 | Cost: 0.0821\n",
      "Epoch: 003/010 | Batch 550/938 | Cost: 0.1188\n",
      "Epoch: 003/010 | Batch 600/938 | Cost: 0.0424\n",
      "Epoch: 003/010 | Batch 650/938 | Cost: 0.2548\n",
      "Epoch: 003/010 | Batch 700/938 | Cost: 0.1219\n",
      "Epoch: 003/010 | Batch 750/938 | Cost: 0.0623\n",
      "Epoch: 003/010 | Batch 800/938 | Cost: 0.0557\n",
      "Epoch: 003/010 | Batch 850/938 | Cost: 0.0999\n",
      "Epoch: 003/010 | Batch 900/938 | Cost: 0.0595\n",
      "Epoch: 003/010 training accuracy: 97.93%\n",
      "Time elapsed: 0.66 min\n",
      "Epoch: 004/010 | Batch 000/938 | Cost: 0.1017\n",
      "Epoch: 004/010 | Batch 050/938 | Cost: 0.0885\n",
      "Epoch: 004/010 | Batch 100/938 | Cost: 0.0252\n",
      "Epoch: 004/010 | Batch 150/938 | Cost: 0.1987\n",
      "Epoch: 004/010 | Batch 200/938 | Cost: 0.0377\n",
      "Epoch: 004/010 | Batch 250/938 | Cost: 0.1986\n",
      "Epoch: 004/010 | Batch 300/938 | Cost: 0.1076\n",
      "Epoch: 004/010 | Batch 350/938 | Cost: 0.0270\n",
      "Epoch: 004/010 | Batch 400/938 | Cost: 0.1977\n",
      "Epoch: 004/010 | Batch 450/938 | Cost: 0.0623\n",
      "Epoch: 004/010 | Batch 500/938 | Cost: 0.1706\n",
      "Epoch: 004/010 | Batch 550/938 | Cost: 0.0296\n",
      "Epoch: 004/010 | Batch 600/938 | Cost: 0.0899\n",
      "Epoch: 004/010 | Batch 650/938 | Cost: 0.0479\n",
      "Epoch: 004/010 | Batch 700/938 | Cost: 0.0615\n",
      "Epoch: 004/010 | Batch 750/938 | Cost: 0.0633\n",
      "Epoch: 004/010 | Batch 800/938 | Cost: 0.0348\n",
      "Epoch: 004/010 | Batch 850/938 | Cost: 0.0710\n",
      "Epoch: 004/010 | Batch 900/938 | Cost: 0.1097\n",
      "Epoch: 004/010 training accuracy: 98.49%\n",
      "Time elapsed: 0.88 min\n",
      "Epoch: 005/010 | Batch 000/938 | Cost: 0.0251\n",
      "Epoch: 005/010 | Batch 050/938 | Cost: 0.0213\n",
      "Epoch: 005/010 | Batch 100/938 | Cost: 0.0694\n",
      "Epoch: 005/010 | Batch 150/938 | Cost: 0.1481\n",
      "Epoch: 005/010 | Batch 200/938 | Cost: 0.1333\n",
      "Epoch: 005/010 | Batch 250/938 | Cost: 0.0117\n",
      "Epoch: 005/010 | Batch 300/938 | Cost: 0.0978\n",
      "Epoch: 005/010 | Batch 350/938 | Cost: 0.0204\n",
      "Epoch: 005/010 | Batch 400/938 | Cost: 0.0517\n",
      "Epoch: 005/010 | Batch 450/938 | Cost: 0.0371\n",
      "Epoch: 005/010 | Batch 500/938 | Cost: 0.0337\n",
      "Epoch: 005/010 | Batch 550/938 | Cost: 0.1566\n",
      "Epoch: 005/010 | Batch 600/938 | Cost: 0.1280\n",
      "Epoch: 005/010 | Batch 650/938 | Cost: 0.1210\n",
      "Epoch: 005/010 | Batch 700/938 | Cost: 0.1570\n",
      "Epoch: 005/010 | Batch 750/938 | Cost: 0.0531\n",
      "Epoch: 005/010 | Batch 800/938 | Cost: 0.0136\n",
      "Epoch: 005/010 | Batch 850/938 | Cost: 0.1199\n",
      "Epoch: 005/010 | Batch 900/938 | Cost: 0.0485\n",
      "Epoch: 005/010 training accuracy: 98.75%\n",
      "Time elapsed: 1.10 min\n",
      "Epoch: 006/010 | Batch 000/938 | Cost: 0.0548\n",
      "Epoch: 006/010 | Batch 050/938 | Cost: 0.0178\n",
      "Epoch: 006/010 | Batch 100/938 | Cost: 0.0137\n",
      "Epoch: 006/010 | Batch 150/938 | Cost: 0.0555\n",
      "Epoch: 006/010 | Batch 200/938 | Cost: 0.1317\n",
      "Epoch: 006/010 | Batch 250/938 | Cost: 0.0326\n",
      "Epoch: 006/010 | Batch 300/938 | Cost: 0.0615\n",
      "Epoch: 006/010 | Batch 350/938 | Cost: 0.0594\n",
      "Epoch: 006/010 | Batch 400/938 | Cost: 0.0780\n",
      "Epoch: 006/010 | Batch 450/938 | Cost: 0.0451\n",
      "Epoch: 006/010 | Batch 500/938 | Cost: 0.1128\n",
      "Epoch: 006/010 | Batch 550/938 | Cost: 0.0465\n",
      "Epoch: 006/010 | Batch 600/938 | Cost: 0.0719\n",
      "Epoch: 006/010 | Batch 650/938 | Cost: 0.0286\n",
      "Epoch: 006/010 | Batch 700/938 | Cost: 0.0323\n",
      "Epoch: 006/010 | Batch 750/938 | Cost: 0.0246\n",
      "Epoch: 006/010 | Batch 800/938 | Cost: 0.0303\n",
      "Epoch: 006/010 | Batch 850/938 | Cost: 0.0532\n",
      "Epoch: 006/010 | Batch 900/938 | Cost: 0.0584\n",
      "Epoch: 006/010 training accuracy: 98.99%\n",
      "Time elapsed: 1.33 min\n",
      "Epoch: 007/010 | Batch 000/938 | Cost: 0.0348\n",
      "Epoch: 007/010 | Batch 050/938 | Cost: 0.0086\n",
      "Epoch: 007/010 | Batch 100/938 | Cost: 0.0448\n",
      "Epoch: 007/010 | Batch 150/938 | Cost: 0.0301\n",
      "Epoch: 007/010 | Batch 200/938 | Cost: 0.0218\n",
      "Epoch: 007/010 | Batch 250/938 | Cost: 0.0705\n",
      "Epoch: 007/010 | Batch 300/938 | Cost: 0.0957\n",
      "Epoch: 007/010 | Batch 350/938 | Cost: 0.0849\n",
      "Epoch: 007/010 | Batch 400/938 | Cost: 0.0368\n",
      "Epoch: 007/010 | Batch 450/938 | Cost: 0.0423\n",
      "Epoch: 007/010 | Batch 500/938 | Cost: 0.0450\n",
      "Epoch: 007/010 | Batch 550/938 | Cost: 0.0101\n",
      "Epoch: 007/010 | Batch 600/938 | Cost: 0.0460\n",
      "Epoch: 007/010 | Batch 650/938 | Cost: 0.0290\n",
      "Epoch: 007/010 | Batch 700/938 | Cost: 0.0351\n",
      "Epoch: 007/010 | Batch 750/938 | Cost: 0.0317\n",
      "Epoch: 007/010 | Batch 800/938 | Cost: 0.0574\n",
      "Epoch: 007/010 | Batch 850/938 | Cost: 0.0758\n",
      "Epoch: 007/010 | Batch 900/938 | Cost: 0.0172\n",
      "Epoch: 007/010 training accuracy: 99.31%\n",
      "Time elapsed: 1.55 min\n",
      "Epoch: 008/010 | Batch 000/938 | Cost: 0.0331\n",
      "Epoch: 008/010 | Batch 050/938 | Cost: 0.0113\n",
      "Epoch: 008/010 | Batch 100/938 | Cost: 0.0890\n",
      "Epoch: 008/010 | Batch 150/938 | Cost: 0.0309\n",
      "Epoch: 008/010 | Batch 200/938 | Cost: 0.0391\n",
      "Epoch: 008/010 | Batch 250/938 | Cost: 0.0567\n",
      "Epoch: 008/010 | Batch 300/938 | Cost: 0.0330\n",
      "Epoch: 008/010 | Batch 350/938 | Cost: 0.0342\n",
      "Epoch: 008/010 | Batch 400/938 | Cost: 0.0904\n",
      "Epoch: 008/010 | Batch 450/938 | Cost: 0.0247\n",
      "Epoch: 008/010 | Batch 500/938 | Cost: 0.0359\n",
      "Epoch: 008/010 | Batch 550/938 | Cost: 0.0544\n",
      "Epoch: 008/010 | Batch 600/938 | Cost: 0.0428\n",
      "Epoch: 008/010 | Batch 650/938 | Cost: 0.0105\n",
      "Epoch: 008/010 | Batch 700/938 | Cost: 0.0986\n",
      "Epoch: 008/010 | Batch 750/938 | Cost: 0.0188\n",
      "Epoch: 008/010 | Batch 800/938 | Cost: 0.0153\n",
      "Epoch: 008/010 | Batch 850/938 | Cost: 0.0095\n",
      "Epoch: 008/010 | Batch 900/938 | Cost: 0.0464\n",
      "Epoch: 008/010 training accuracy: 99.36%\n",
      "Time elapsed: 1.76 min\n",
      "Epoch: 009/010 | Batch 000/938 | Cost: 0.0491\n",
      "Epoch: 009/010 | Batch 050/938 | Cost: 0.0390\n",
      "Epoch: 009/010 | Batch 100/938 | Cost: 0.1674\n",
      "Epoch: 009/010 | Batch 150/938 | Cost: 0.0409\n",
      "Epoch: 009/010 | Batch 200/938 | Cost: 0.0664\n",
      "Epoch: 009/010 | Batch 250/938 | Cost: 0.0775\n",
      "Epoch: 009/010 | Batch 300/938 | Cost: 0.0383\n",
      "Epoch: 009/010 | Batch 350/938 | Cost: 0.0214\n",
      "Epoch: 009/010 | Batch 400/938 | Cost: 0.0217\n",
      "Epoch: 009/010 | Batch 450/938 | Cost: 0.0254\n",
      "Epoch: 009/010 | Batch 500/938 | Cost: 0.0369\n",
      "Epoch: 009/010 | Batch 550/938 | Cost: 0.0154\n",
      "Epoch: 009/010 | Batch 600/938 | Cost: 0.0524\n",
      "Epoch: 009/010 | Batch 650/938 | Cost: 0.0727\n",
      "Epoch: 009/010 | Batch 700/938 | Cost: 0.0718\n",
      "Epoch: 009/010 | Batch 750/938 | Cost: 0.0279\n",
      "Epoch: 009/010 | Batch 800/938 | Cost: 0.0238\n",
      "Epoch: 009/010 | Batch 850/938 | Cost: 0.0236\n",
      "Epoch: 009/010 | Batch 900/938 | Cost: 0.0147\n",
      "Epoch: 009/010 training accuracy: 99.46%\n",
      "Time elapsed: 1.98 min\n",
      "Epoch: 010/010 | Batch 000/938 | Cost: 0.0172\n",
      "Epoch: 010/010 | Batch 050/938 | Cost: 0.0071\n",
      "Epoch: 010/010 | Batch 100/938 | Cost: 0.0308\n",
      "Epoch: 010/010 | Batch 150/938 | Cost: 0.0047\n",
      "Epoch: 010/010 | Batch 200/938 | Cost: 0.0716\n",
      "Epoch: 010/010 | Batch 250/938 | Cost: 0.0162\n",
      "Epoch: 010/010 | Batch 300/938 | Cost: 0.0614\n",
      "Epoch: 010/010 | Batch 350/938 | Cost: 0.0308\n",
      "Epoch: 010/010 | Batch 400/938 | Cost: 0.0571\n",
      "Epoch: 010/010 | Batch 450/938 | Cost: 0.0050\n",
      "Epoch: 010/010 | Batch 500/938 | Cost: 0.0548\n",
      "Epoch: 010/010 | Batch 550/938 | Cost: 0.0269\n",
      "Epoch: 010/010 | Batch 600/938 | Cost: 0.0378\n",
      "Epoch: 010/010 | Batch 650/938 | Cost: 0.0120\n",
      "Epoch: 010/010 | Batch 700/938 | Cost: 0.0298\n",
      "Epoch: 010/010 | Batch 750/938 | Cost: 0.0781\n",
      "Epoch: 010/010 | Batch 800/938 | Cost: 0.0251\n",
      "Epoch: 010/010 | Batch 850/938 | Cost: 0.0693\n",
      "Epoch: 010/010 | Batch 900/938 | Cost: 0.0499\n",
      "Epoch: 010/010 training accuracy: 99.61%\n",
      "Time elapsed: 2.20 min\n",
      "Total Training Time: 2.20 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(net, data_loader):\n",
    "    net.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.view(-1, 28*28).to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits, probas = net(features)\n",
    "            _, predicted_labels = torch.max(probas, 1)\n",
    "            num_examples += targets.size(0)\n",
    "            correct_pred += (predicted_labels == targets).sum()\n",
    "        return correct_pred.float()/num_examples * 100\n",
    "\n",
    "    \n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.view(-1, 28*28).to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "          epoch+1, num_epochs, \n",
    "          compute_accuracy(model, train_loader)))\n",
    "\n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.82%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}