{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1.post2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- All-Convolutional Neural Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple convolutional neural network that uses stride=2 every 2nd convolutional layer, instead of max pooling, to reduce the feature maps. Loosely based on\n",
    "\n",
    "- Springenberg, Jost Tobias, Alexey Dosovitskiy, Thomas Brox, and Martin Riedmiller. \"Striving for simplicity: The all convolutional net.\" arXiv preprint arXiv:1412.6806 (2014)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([256, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([256])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.001\n",
    "num_epochs = 15\n",
    "batch_size = 256\n",
    "\n",
    "# Architecture\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class ConvNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        self.num_classes = num_classes\n",
    "        # calculate same padding:\n",
    "        # (w - k + 2*p)/s + 1 = o\n",
    "        # => p = (s(o-1) - w + k)/2\n",
    "        \n",
    "        # 28x28x1 => 28x28x4\n",
    "        self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(28-1) - 28 + 3) / 2 = 1\n",
    "        # 28x28x4 => 14x14x4\n",
    "        self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)                             \n",
    "        # 14x14x4 => 14x14x8\n",
    "        self.conv_3 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=8,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(14-1) - 14 + 3) / 2 = 1                 \n",
    "        # 14x14x8 => 7x7x8                             \n",
    "        self.conv_4 = torch.nn.Conv2d(in_channels=8,\n",
    "                                      out_channels=8,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)      \n",
    "        \n",
    "        # 7x7x8 => 7x7x16                             \n",
    "        self.conv_5 = torch.nn.Conv2d(in_channels=8,\n",
    "                                      out_channels=16,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(7-1) - 7 + 3) / 2 = 1          \n",
    "        # 7x7x16 => 4x4x16                             \n",
    "        self.conv_6 = torch.nn.Conv2d(in_channels=16,\n",
    "                                      out_channels=16,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)      \n",
    "        \n",
    "        # 4x4x16 => 4x4xnum_classes                             \n",
    "        self.conv_7 = torch.nn.Conv2d(in_channels=16,\n",
    "                                      out_channels=self.num_classes,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1) # (1(7-1) - 7 + 3) / 2 = 1       \n",
    "\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.conv_1(x)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        out = self.conv_2(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_3(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_4(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        out = self.conv_5(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        out = self.conv_6(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        out = self.conv_7(out)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        logits = F.adaptive_avg_pool2d(out, 1)\n",
    "        # drop width\n",
    "        logits.squeeze_(-1)\n",
    "        # drop height\n",
    "        logits.squeeze_(-1)\n",
    "        probas = torch.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = ConvNet(num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/015 | Batch 000/235 | Cost: 2.3051\n",
      "Epoch: 001/015 | Batch 050/235 | Cost: 2.2926\n",
      "Epoch: 001/015 | Batch 100/235 | Cost: 2.0812\n",
      "Epoch: 001/015 | Batch 150/235 | Cost: 1.4435\n",
      "Epoch: 001/015 | Batch 200/235 | Cost: 0.9232\n",
      "Epoch: 001/015 training accuracy: 76.06%\n",
      "Time elapsed: 0.23 min\n",
      "Epoch: 002/015 | Batch 000/235 | Cost: 0.7001\n",
      "Epoch: 002/015 | Batch 050/235 | Cost: 0.5710\n",
      "Epoch: 002/015 | Batch 100/235 | Cost: 0.5925\n",
      "Epoch: 002/015 | Batch 150/235 | Cost: 0.4022\n",
      "Epoch: 002/015 | Batch 200/235 | Cost: 0.4663\n",
      "Epoch: 002/015 training accuracy: 85.68%\n",
      "Time elapsed: 0.45 min\n",
      "Epoch: 003/015 | Batch 000/235 | Cost: 0.4332\n",
      "Epoch: 003/015 | Batch 050/235 | Cost: 0.3523\n",
      "Epoch: 003/015 | Batch 100/235 | Cost: 0.4114\n",
      "Epoch: 003/015 | Batch 150/235 | Cost: 0.4587\n",
      "Epoch: 003/015 | Batch 200/235 | Cost: 0.4517\n",
      "Epoch: 003/015 training accuracy: 89.33%\n",
      "Time elapsed: 0.68 min\n",
      "Epoch: 004/015 | Batch 000/235 | Cost: 0.4083\n",
      "Epoch: 004/015 | Batch 050/235 | Cost: 0.3158\n",
      "Epoch: 004/015 | Batch 100/235 | Cost: 0.2728\n",
      "Epoch: 004/015 | Batch 150/235 | Cost: 0.3023\n",
      "Epoch: 004/015 | Batch 200/235 | Cost: 0.2709\n",
      "Epoch: 004/015 training accuracy: 90.40%\n",
      "Time elapsed: 0.90 min\n",
      "Epoch: 005/015 | Batch 000/235 | Cost: 0.2514\n",
      "Epoch: 005/015 | Batch 050/235 | Cost: 0.3704\n",
      "Epoch: 005/015 | Batch 100/235 | Cost: 0.2972\n",
      "Epoch: 005/015 | Batch 150/235 | Cost: 0.2335\n",
      "Epoch: 005/015 | Batch 200/235 | Cost: 0.3242\n",
      "Epoch: 005/015 training accuracy: 91.36%\n",
      "Time elapsed: 1.13 min\n",
      "Epoch: 006/015 | Batch 000/235 | Cost: 0.3255\n",
      "Epoch: 006/015 | Batch 050/235 | Cost: 0.2985\n",
      "Epoch: 006/015 | Batch 100/235 | Cost: 0.3501\n",
      "Epoch: 006/015 | Batch 150/235 | Cost: 0.2415\n",
      "Epoch: 006/015 | Batch 200/235 | Cost: 0.1978\n",
      "Epoch: 006/015 training accuracy: 92.82%\n",
      "Time elapsed: 1.35 min\n",
      "Epoch: 007/015 | Batch 000/235 | Cost: 0.1925\n",
      "Epoch: 007/015 | Batch 050/235 | Cost: 0.2179\n",
      "Epoch: 007/015 | Batch 100/235 | Cost: 0.3337\n",
      "Epoch: 007/015 | Batch 150/235 | Cost: 0.1856\n",
      "Epoch: 007/015 | Batch 200/235 | Cost: 0.1333\n",
      "Epoch: 007/015 training accuracy: 93.68%\n",
      "Time elapsed: 1.58 min\n",
      "Epoch: 008/015 | Batch 000/235 | Cost: 0.1776\n",
      "Epoch: 008/015 | Batch 050/235 | Cost: 0.2973\n",
      "Epoch: 008/015 | Batch 100/235 | Cost: 0.1685\n",
      "Epoch: 008/015 | Batch 150/235 | Cost: 0.2062\n",
      "Epoch: 008/015 | Batch 200/235 | Cost: 0.2165\n",
      "Epoch: 008/015 training accuracy: 94.42%\n",
      "Time elapsed: 1.80 min\n",
      "Epoch: 009/015 | Batch 000/235 | Cost: 0.2038\n",
      "Epoch: 009/015 | Batch 050/235 | Cost: 0.1301\n",
      "Epoch: 009/015 | Batch 100/235 | Cost: 0.1977\n",
      "Epoch: 009/015 | Batch 150/235 | Cost: 0.2160\n",
      "Epoch: 009/015 | Batch 200/235 | Cost: 0.1772\n",
      "Epoch: 009/015 training accuracy: 94.61%\n",
      "Time elapsed: 2.02 min\n",
      "Epoch: 010/015 | Batch 000/235 | Cost: 0.1709\n",
      "Epoch: 010/015 | Batch 050/235 | Cost: 0.1695\n",
      "Epoch: 010/015 | Batch 100/235 | Cost: 0.2144\n",
      "Epoch: 010/015 | Batch 150/235 | Cost: 0.1548\n",
      "Epoch: 010/015 | Batch 200/235 | Cost: 0.1033\n",
      "Epoch: 010/015 training accuracy: 94.90%\n",
      "Time elapsed: 2.25 min\n",
      "Epoch: 011/015 | Batch 000/235 | Cost: 0.1651\n",
      "Epoch: 011/015 | Batch 050/235 | Cost: 0.1899\n",
      "Epoch: 011/015 | Batch 100/235 | Cost: 0.1727\n",
      "Epoch: 011/015 | Batch 150/235 | Cost: 0.1216\n",
      "Epoch: 011/015 | Batch 200/235 | Cost: 0.1859\n",
      "Epoch: 011/015 training accuracy: 94.82%\n",
      "Time elapsed: 2.47 min\n",
      "Epoch: 012/015 | Batch 000/235 | Cost: 0.2490\n",
      "Epoch: 012/015 | Batch 050/235 | Cost: 0.1022\n",
      "Epoch: 012/015 | Batch 100/235 | Cost: 0.0793\n",
      "Epoch: 012/015 | Batch 150/235 | Cost: 0.2258\n",
      "Epoch: 012/015 | Batch 200/235 | Cost: 0.1356\n",
      "Epoch: 012/015 training accuracy: 95.35%\n",
      "Time elapsed: 2.70 min\n",
      "Epoch: 013/015 | Batch 000/235 | Cost: 0.1512\n",
      "Epoch: 013/015 | Batch 050/235 | Cost: 0.1758\n",
      "Epoch: 013/015 | Batch 100/235 | Cost: 0.1349\n",
      "Epoch: 013/015 | Batch 150/235 | Cost: 0.1838\n",
      "Epoch: 013/015 | Batch 200/235 | Cost: 0.1166\n",
      "Epoch: 013/015 training accuracy: 95.61%\n",
      "Time elapsed: 2.92 min\n",
      "Epoch: 014/015 | Batch 000/235 | Cost: 0.1210\n",
      "Epoch: 014/015 | Batch 050/235 | Cost: 0.1511\n",
      "Epoch: 014/015 | Batch 100/235 | Cost: 0.1331\n",
      "Epoch: 014/015 | Batch 150/235 | Cost: 0.1058\n",
      "Epoch: 014/015 | Batch 200/235 | Cost: 0.1340\n",
      "Epoch: 014/015 training accuracy: 95.53%\n",
      "Time elapsed: 3.15 min\n",
      "Epoch: 015/015 | Batch 000/235 | Cost: 0.2342\n",
      "Epoch: 015/015 | Batch 050/235 | Cost: 0.1371\n",
      "Epoch: 015/015 | Batch 100/235 | Cost: 0.0944\n",
      "Epoch: 015/015 | Batch 150/235 | Cost: 0.1102\n",
      "Epoch: 015/015 | Batch 200/235 | Cost: 0.1259\n",
      "Epoch: 015/015 training accuracy: 96.36%\n",
      "Time elapsed: 3.37 min\n",
      "Total Training Time: 3.37 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for features, targets in data_loader:\n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "    \n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model = model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "    \n",
    "    model = model.eval()\n",
    "    print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "          epoch+1, num_epochs, \n",
    "          compute_accuracy(model, train_loader)))\n",
    "    \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 96.42%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.1.post2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}