{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.1.post2\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MEu9MiOxj5wk" }, "source": [ "- Runs on CPU (not recommended here) or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# Model Zoo -- Convolutional Neural Network (VGG19 Architecture)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of the VGG-19 architecture on Cifar10. \n", "\n", "\n", "Reference for VGG-19:\n", " \n", "- Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.\n", "\n", "\n", "The following table (taken from Simonyan & Zisserman referenced above) summarizes the VGG19 architecture:\n", "\n", "![](../images/vgg19/vgg19-arch-table.png)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import numpy as np\n", "import time\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PvgJ_0i7j5wt" }, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:0\n", "Files already downloaded and verified\n", "Image batch dimensions: torch.Size([128, 3, 32, 32])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', DEVICE)\n", "\n", "# Hyperparameters\n", "random_seed = 1\n", "learning_rate = 0.001\n", "num_epochs = 20\n", "batch_size = 128\n", "\n", "# Architecture\n", "num_features = 784\n", "num_classes = 10\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.CIFAR10(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class VGG16(torch.nn.Module):\n", "\n", " def __init__(self, num_features, num_classes):\n", " super(VGG16, self).__init__()\n", " \n", " # calculate same padding:\n", " # (w - k + 2*p)/s + 1 = o\n", " # => p = (s(o-1) - w + k)/2\n", " \n", " self.block_1 = nn.Sequential(\n", " nn.Conv2d(in_channels=3,\n", " out_channels=64,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " # (1(32-1)- 32 + 3)/2 = 1\n", " padding=1), \n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=64,\n", " out_channels=64,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2))\n", " )\n", " \n", " self.block_2 = nn.Sequential(\n", " nn.Conv2d(in_channels=64,\n", " out_channels=128,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=128,\n", " out_channels=128,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2))\n", " )\n", " \n", " self.block_3 = nn.Sequential( \n", " nn.Conv2d(in_channels=128,\n", " out_channels=256,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=256,\n", " out_channels=256,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.Conv2d(in_channels=256,\n", " out_channels=256,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=256,\n", " out_channels=256,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2))\n", " )\n", " \n", " \n", " self.block_4 = nn.Sequential( \n", " nn.Conv2d(in_channels=256,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2))\n", " )\n", " \n", " self.block_5 = nn.Sequential(\n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=512,\n", " out_channels=512,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1),\n", " nn.ReLU(), \n", " nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2)) \n", " )\n", " \n", " self.classifier = nn.Sequential(\n", " nn.Linear(512, 4096),\n", " nn.ReLU(True),\n", " nn.Linear(4096, 4096),\n", " nn.ReLU(True),\n", " nn.Linear(4096, num_classes)\n", " )\n", " \n", " \n", " for m in self.modules():\n", " if isinstance(m, torch.nn.Conv2d):\n", " #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", " #m.weight.data.normal_(0, np.sqrt(2. / n))\n", " m.weight.detach().normal_(0, 0.05)\n", " if m.bias is not None:\n", " m.bias.detach().zero_()\n", " elif isinstance(m, torch.nn.Linear):\n", " m.weight.detach().normal_(0, 0.05)\n", " m.bias.detach().detach().zero_()\n", " \n", " \n", " def forward(self, x):\n", "\n", " x = self.block_1(x)\n", " x = self.block_2(x)\n", " x = self.block_3(x)\n", " x = self.block_4(x)\n", " x = self.block_5(x)\n", " logits = self.classifier(x.view(-1, 512))\n", " probas = F.softmax(logits, dim=1)\n", "\n", " return logits, probas\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = VGG16(num_features=num_features,\n", " num_classes=num_classes)\n", "\n", "model = model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/020 | Batch 0000/0391 | Cost: 1061.4152\n", "Epoch: 001/020 | Batch 0050/0391 | Cost: 2.3018\n", "Epoch: 001/020 | Batch 0100/0391 | Cost: 2.0600\n", "Epoch: 001/020 | Batch 0150/0391 | Cost: 1.9973\n", "Epoch: 001/020 | Batch 0200/0391 | Cost: 1.8176\n", "Epoch: 001/020 | Batch 0250/0391 | Cost: 1.8368\n", "Epoch: 001/020 | Batch 0300/0391 | Cost: 1.7213\n", "Epoch: 001/020 | Batch 0350/0391 | Cost: 1.7154\n", "Epoch: 001/020 | Train: 35.478% | Loss: 1.685\n", "Time elapsed: 1.02 min\n", "Epoch: 002/020 | Batch 0000/0391 | Cost: 1.7648\n", "Epoch: 002/020 | Batch 0050/0391 | Cost: 1.7050\n", "Epoch: 002/020 | Batch 0100/0391 | Cost: 1.5464\n", "Epoch: 002/020 | Batch 0150/0391 | Cost: 1.6054\n", "Epoch: 002/020 | Batch 0200/0391 | Cost: 1.4430\n", "Epoch: 002/020 | Batch 0250/0391 | Cost: 1.4253\n", "Epoch: 002/020 | Batch 0300/0391 | Cost: 1.5701\n", "Epoch: 002/020 | Batch 0350/0391 | Cost: 1.4163\n", "Epoch: 002/020 | Train: 44.042% | Loss: 1.531\n", "Time elapsed: 2.07 min\n", "Epoch: 003/020 | Batch 0000/0391 | Cost: 1.5172\n", "Epoch: 003/020 | Batch 0050/0391 | Cost: 1.1992\n", "Epoch: 003/020 | Batch 0100/0391 | Cost: 1.2846\n", "Epoch: 003/020 | Batch 0150/0391 | Cost: 1.4088\n", "Epoch: 003/020 | Batch 0200/0391 | Cost: 1.4853\n", "Epoch: 003/020 | Batch 0250/0391 | Cost: 1.3923\n", "Epoch: 003/020 | Batch 0300/0391 | Cost: 1.3268\n", "Epoch: 003/020 | Batch 0350/0391 | Cost: 1.3162\n", "Epoch: 003/020 | Train: 55.596% | Loss: 1.223\n", "Time elapsed: 3.10 min\n", "Epoch: 004/020 | Batch 0000/0391 | Cost: 1.2210\n", "Epoch: 004/020 | Batch 0050/0391 | Cost: 1.2594\n", "Epoch: 004/020 | Batch 0100/0391 | Cost: 1.2881\n", "Epoch: 004/020 | Batch 0150/0391 | Cost: 1.0182\n", "Epoch: 004/020 | Batch 0200/0391 | Cost: 1.1256\n", "Epoch: 004/020 | Batch 0250/0391 | Cost: 1.1048\n", "Epoch: 004/020 | Batch 0300/0391 | Cost: 1.1812\n", "Epoch: 004/020 | Batch 0350/0391 | Cost: 1.1685\n", "Epoch: 004/020 | Train: 57.594% | Loss: 1.178\n", "Time elapsed: 4.13 min\n", "Epoch: 005/020 | Batch 0000/0391 | Cost: 1.1298\n", "Epoch: 005/020 | Batch 0050/0391 | Cost: 0.9705\n", "Epoch: 005/020 | Batch 0100/0391 | Cost: 0.9255\n", "Epoch: 005/020 | Batch 0150/0391 | Cost: 1.3610\n", "Epoch: 005/020 | Batch 0200/0391 | Cost: 0.9720\n", "Epoch: 005/020 | Batch 0250/0391 | Cost: 1.0088\n", "Epoch: 005/020 | Batch 0300/0391 | Cost: 0.9998\n", "Epoch: 005/020 | Batch 0350/0391 | Cost: 1.1961\n", "Epoch: 005/020 | Train: 63.570% | Loss: 1.003\n", "Time elapsed: 5.17 min\n", "Epoch: 006/020 | Batch 0000/0391 | Cost: 0.8837\n", "Epoch: 006/020 | Batch 0050/0391 | Cost: 0.9184\n", "Epoch: 006/020 | Batch 0100/0391 | Cost: 0.8568\n", "Epoch: 006/020 | Batch 0150/0391 | Cost: 1.0788\n", "Epoch: 006/020 | Batch 0200/0391 | Cost: 1.0365\n", "Epoch: 006/020 | Batch 0250/0391 | Cost: 0.8714\n", "Epoch: 006/020 | Batch 0300/0391 | Cost: 1.0370\n", "Epoch: 006/020 | Batch 0350/0391 | Cost: 1.0536\n", "Epoch: 006/020 | Train: 68.390% | Loss: 0.880\n", "Time elapsed: 6.20 min\n", "Epoch: 007/020 | Batch 0000/0391 | Cost: 1.0297\n", "Epoch: 007/020 | Batch 0050/0391 | Cost: 0.8801\n", "Epoch: 007/020 | Batch 0100/0391 | Cost: 0.9652\n", "Epoch: 007/020 | Batch 0150/0391 | Cost: 1.1417\n", "Epoch: 007/020 | Batch 0200/0391 | Cost: 0.8851\n", "Epoch: 007/020 | Batch 0250/0391 | Cost: 0.9499\n", "Epoch: 007/020 | Batch 0300/0391 | Cost: 0.9416\n", "Epoch: 007/020 | Batch 0350/0391 | Cost: 0.9220\n", "Epoch: 007/020 | Train: 68.740% | Loss: 0.872\n", "Time elapsed: 7.24 min\n", "Epoch: 008/020 | Batch 0000/0391 | Cost: 1.0054\n", "Epoch: 008/020 | Batch 0050/0391 | Cost: 0.8184\n", "Epoch: 008/020 | Batch 0100/0391 | Cost: 0.8955\n", "Epoch: 008/020 | Batch 0150/0391 | Cost: 0.9319\n", "Epoch: 008/020 | Batch 0200/0391 | Cost: 1.0566\n", "Epoch: 008/020 | Batch 0250/0391 | Cost: 1.0591\n", "Epoch: 008/020 | Batch 0300/0391 | Cost: 0.7914\n", "Epoch: 008/020 | Batch 0350/0391 | Cost: 0.9090\n", "Epoch: 008/020 | Train: 72.846% | Loss: 0.770\n", "Time elapsed: 8.27 min\n", "Epoch: 009/020 | Batch 0000/0391 | Cost: 0.6672\n", "Epoch: 009/020 | Batch 0050/0391 | Cost: 0.7192\n", "Epoch: 009/020 | Batch 0100/0391 | Cost: 0.8586\n", "Epoch: 009/020 | Batch 0150/0391 | Cost: 0.7310\n", "Epoch: 009/020 | Batch 0200/0391 | Cost: 0.8406\n", "Epoch: 009/020 | Batch 0250/0391 | Cost: 0.7620\n", "Epoch: 009/020 | Batch 0300/0391 | Cost: 0.6692\n", "Epoch: 009/020 | Batch 0350/0391 | Cost: 0.6407\n", "Epoch: 009/020 | Train: 73.702% | Loss: 0.748\n", "Time elapsed: 9.30 min\n", "Epoch: 010/020 | Batch 0000/0391 | Cost: 0.6539\n", "Epoch: 010/020 | Batch 0050/0391 | Cost: 1.0382\n", "Epoch: 010/020 | Batch 0100/0391 | Cost: 0.5921\n", "Epoch: 010/020 | Batch 0150/0391 | Cost: 0.4933\n", "Epoch: 010/020 | Batch 0200/0391 | Cost: 0.7485\n", "Epoch: 010/020 | Batch 0250/0391 | Cost: 0.6779\n", "Epoch: 010/020 | Batch 0300/0391 | Cost: 0.6787\n", "Epoch: 010/020 | Batch 0350/0391 | Cost: 0.6977\n", "Epoch: 010/020 | Train: 75.708% | Loss: 0.703\n", "Time elapsed: 10.34 min\n", "Epoch: 011/020 | Batch 0000/0391 | Cost: 0.6866\n", "Epoch: 011/020 | Batch 0050/0391 | Cost: 0.7203\n", "Epoch: 011/020 | Batch 0100/0391 | Cost: 0.5730\n", "Epoch: 011/020 | Batch 0150/0391 | Cost: 0.5762\n", "Epoch: 011/020 | Batch 0200/0391 | Cost: 0.6571\n", "Epoch: 011/020 | Batch 0250/0391 | Cost: 0.7582\n", "Epoch: 011/020 | Batch 0300/0391 | Cost: 0.7366\n", "Epoch: 011/020 | Batch 0350/0391 | Cost: 0.6810\n", "Epoch: 011/020 | Train: 79.044% | Loss: 0.606\n", "Time elapsed: 11.37 min\n", "Epoch: 012/020 | Batch 0000/0391 | Cost: 0.5665\n", "Epoch: 012/020 | Batch 0050/0391 | Cost: 0.7081\n", "Epoch: 012/020 | Batch 0100/0391 | Cost: 0.6823\n", "Epoch: 012/020 | Batch 0150/0391 | Cost: 0.8297\n", "Epoch: 012/020 | Batch 0200/0391 | Cost: 0.6470\n", "Epoch: 012/020 | Batch 0250/0391 | Cost: 0.7293\n", "Epoch: 012/020 | Batch 0300/0391 | Cost: 0.9127\n", "Epoch: 012/020 | Batch 0350/0391 | Cost: 0.8419\n", "Epoch: 012/020 | Train: 79.474% | Loss: 0.585\n", "Time elapsed: 12.40 min\n", "Epoch: 013/020 | Batch 0000/0391 | Cost: 0.4087\n", "Epoch: 013/020 | Batch 0050/0391 | Cost: 0.4224\n", "Epoch: 013/020 | Batch 0100/0391 | Cost: 0.4336\n", "Epoch: 013/020 | Batch 0150/0391 | Cost: 0.6586\n", "Epoch: 013/020 | Batch 0200/0391 | Cost: 0.7107\n", "Epoch: 013/020 | Batch 0250/0391 | Cost: 0.7359\n", "Epoch: 013/020 | Batch 0300/0391 | Cost: 0.4860\n", "Epoch: 013/020 | Batch 0350/0391 | Cost: 0.7271\n", "Epoch: 013/020 | Train: 80.746% | Loss: 0.549\n", "Time elapsed: 13.44 min\n", "Epoch: 014/020 | Batch 0000/0391 | Cost: 0.5500\n", "Epoch: 014/020 | Batch 0050/0391 | Cost: 0.5108\n", "Epoch: 014/020 | Batch 0100/0391 | Cost: 0.5186\n", "Epoch: 014/020 | Batch 0150/0391 | Cost: 0.4737\n", "Epoch: 014/020 | Batch 0200/0391 | Cost: 0.7015\n", "Epoch: 014/020 | Batch 0250/0391 | Cost: 0.6069\n", "Epoch: 014/020 | Batch 0300/0391 | Cost: 0.7080\n", "Epoch: 014/020 | Batch 0350/0391 | Cost: 0.6460\n", "Epoch: 014/020 | Train: 81.596% | Loss: 0.553\n", "Time elapsed: 14.47 min\n", "Epoch: 015/020 | Batch 0000/0391 | Cost: 0.5398\n", "Epoch: 015/020 | Batch 0050/0391 | Cost: 0.5269\n", "Epoch: 015/020 | Batch 0100/0391 | Cost: 0.5048\n", "Epoch: 015/020 | Batch 0150/0391 | Cost: 0.5873\n", "Epoch: 015/020 | Batch 0200/0391 | Cost: 0.5320\n", "Epoch: 015/020 | Batch 0250/0391 | Cost: 0.4743\n", "Epoch: 015/020 | Batch 0300/0391 | Cost: 0.6124\n", "Epoch: 015/020 | Batch 0350/0391 | Cost: 0.7204\n", "Epoch: 015/020 | Train: 85.276% | Loss: 0.439\n", "Time elapsed: 15.51 min\n", "Epoch: 016/020 | Batch 0000/0391 | Cost: 0.4387\n", "Epoch: 016/020 | Batch 0050/0391 | Cost: 0.3777\n", "Epoch: 016/020 | Batch 0100/0391 | Cost: 0.3430\n", "Epoch: 016/020 | Batch 0150/0391 | Cost: 0.5901\n", "Epoch: 016/020 | Batch 0200/0391 | Cost: 0.6303\n", "Epoch: 016/020 | Batch 0250/0391 | Cost: 0.4983\n", "Epoch: 016/020 | Batch 0300/0391 | Cost: 0.6507\n", "Epoch: 016/020 | Batch 0350/0391 | Cost: 0.4663\n", "Epoch: 016/020 | Train: 86.440% | Loss: 0.406\n", "Time elapsed: 16.55 min\n", "Epoch: 017/020 | Batch 0000/0391 | Cost: 0.4675\n", "Epoch: 017/020 | Batch 0050/0391 | Cost: 0.6440\n", "Epoch: 017/020 | Batch 0100/0391 | Cost: 0.3536\n", "Epoch: 017/020 | Batch 0150/0391 | Cost: 0.5421\n", "Epoch: 017/020 | Batch 0200/0391 | Cost: 0.4504\n", "Epoch: 017/020 | Batch 0250/0391 | Cost: 0.4169\n", "Epoch: 017/020 | Batch 0300/0391 | Cost: 0.4617\n", "Epoch: 017/020 | Batch 0350/0391 | Cost: 0.4092\n", "Epoch: 017/020 | Train: 84.636% | Loss: 0.459\n", "Time elapsed: 17.59 min\n", "Epoch: 018/020 | Batch 0000/0391 | Cost: 0.4267\n", "Epoch: 018/020 | Batch 0050/0391 | Cost: 0.6478\n", "Epoch: 018/020 | Batch 0100/0391 | Cost: 0.5806\n", "Epoch: 018/020 | Batch 0150/0391 | Cost: 0.5453\n", "Epoch: 018/020 | Batch 0200/0391 | Cost: 0.4984\n", "Epoch: 018/020 | Batch 0250/0391 | Cost: 0.2517\n", "Epoch: 018/020 | Batch 0300/0391 | Cost: 0.5219\n", "Epoch: 018/020 | Batch 0350/0391 | Cost: 0.5217\n", "Epoch: 018/020 | Train: 86.094% | Loss: 0.413\n", "Time elapsed: 18.63 min\n", "Epoch: 019/020 | Batch 0000/0391 | Cost: 0.3849\n", "Epoch: 019/020 | Batch 0050/0391 | Cost: 0.2890\n", "Epoch: 019/020 | Batch 0100/0391 | Cost: 0.5058\n", "Epoch: 019/020 | Batch 0150/0391 | Cost: 0.5718\n", "Epoch: 019/020 | Batch 0200/0391 | Cost: 0.4053\n", "Epoch: 019/020 | Batch 0250/0391 | Cost: 0.5241\n", "Epoch: 019/020 | Batch 0300/0391 | Cost: 0.7110\n", "Epoch: 019/020 | Batch 0350/0391 | Cost: 0.4572\n", "Epoch: 019/020 | Train: 87.586% | Loss: 0.365\n", "Time elapsed: 19.67 min\n", "Epoch: 020/020 | Batch 0000/0391 | Cost: 0.3576\n", "Epoch: 020/020 | Batch 0050/0391 | Cost: 0.3466\n", "Epoch: 020/020 | Batch 0100/0391 | Cost: 0.3427\n", "Epoch: 020/020 | Batch 0150/0391 | Cost: 0.3117\n", "Epoch: 020/020 | Batch 0200/0391 | Cost: 0.4912\n", "Epoch: 020/020 | Batch 0250/0391 | Cost: 0.4481\n", "Epoch: 020/020 | Batch 0300/0391 | Cost: 0.6303\n", "Epoch: 020/020 | Batch 0350/0391 | Cost: 0.4274\n", "Epoch: 020/020 | Train: 88.024% | Loss: 0.361\n", "Time elapsed: 20.71 min\n", "Total Training Time: 20.71 min\n" ] } ], "source": [ "def compute_accuracy(model, data_loader):\n", " model.eval()\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", "\n", "\n", "def compute_epoch_loss(model, data_loader):\n", " model.eval()\n", " curr_loss, num_examples = 0., 0\n", " with torch.no_grad():\n", " for features, targets in data_loader:\n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " logits, probas = model(features)\n", " loss = F.cross_entropy(logits, targets, reduction='sum')\n", " num_examples += targets.size(0)\n", " curr_loss += loss\n", "\n", " curr_loss = curr_loss / num_examples\n", " return curr_loss\n", " \n", " \n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (\n", " epoch+1, num_epochs, \n", " compute_accuracy(model, train_loader),\n", " compute_epoch_loss(model, train_loader)))\n", "\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "paaeEQHQj5xC" }, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 6514, "status": "ok", "timestamp": 1524976895054, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "gzQMWKq5j5xE", "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 74.56%\n" ] } ], "source": [ "with torch.set_grad_enabled(False): # save memory during inference\n", " print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.1.post2\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 2 }