{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEBilEjLj5wY" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1524974472601, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "GOzuY8Yvj5wb", "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rH4XmErYj5wm" }, "source": [ "# Model Zoo -- CNN Gender Classifier (ResNet-152 Architecture, CelebA) with Data Parallelism" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Network Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The network in this notebook is an implementation of the ResNet-152 [1] architecture on the CelebA face dataset [2] to train a gender classifier. \n", "\n", "\n", "References\n", " \n", "- [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). ([CVPR Link](https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html))\n", "\n", "- [2] Zhang, K., Tan, L., Li, Z., & Qiao, Y. (2016). Gender and smile classification using deep convolutional neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops (pp. 34-38).\n", "\n", "The ResNet-152 architecture is similar to the ResNet-50 architecture, which is in turn similar to the ResNet-34 architecture shown below (from [1]) except that the ResNet 101 is using a Bootleneck block (compared to ResNet-34) and more layers than ResNet-50 (figure shows a screenshot from [1]):\n", "\n", "\n", "![](../images/resnets/resnet152/resnet152-arch-1.png)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following figure illustrates residual blocks with skip connections such that the input passed via the shortcut matches the dimensions of the main path's output, which allows the network to learn identity functions.\n", "\n", "![](../images/resnets/resnet-ex-1-1.png)\n", "\n", "\n", "The ResNet-34 architecture actually uses residual blocks with modified skip connections such that the input passed via the shortcut matches is resized to dimensions of the main path's output. Such a residual block is illustrated below:\n", "\n", "![](../images/resnets/resnet-ex-1-2.png)\n", "\n", "The ResNet-50/101/151 then uses a bottleneck as shown below:\n", "\n", "![](../images/resnets/resnet-ex-1-3.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more detailed explanation see the other notebook, [resnet-ex-1.ipynb](resnet-ex-1.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MkoGLH_Tj5wn" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "ORj09gnrj5wp" }, "outputs": [], "source": [ "import os\n", "import time\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "executionInfo": { "elapsed": 23936, "status": "ok", "timestamp": 1524974497505, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "NnT0sZIwj5wu", "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637" }, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.001\n", "NUM_EPOCHS = 10\n", "\n", "# Architecture\n", "NUM_FEATURES = 128*128\n", "NUM_CLASSES = 2\n", "BATCH_SIZE = 128\n", "DEVICE = 'cuda:2' # default GPU device\n", "GRAYSCALE = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Downloading the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the ~200,000 CelebA face image dataset is relatively large (~1.3 Gb). The download link provided below was provided by the author on the official CelebA website at http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1) Download and unzip the file `img_align_celeba.zip`, which contains the images in jpeg format.\n", "\n", "2) Download the `list_attr_celeba.txt` file, which contains the class labels\n", "\n", "3) Download the `list_eval_partition.txt` file, which contains training/validation/test partitioning info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing the Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Male
000001.jpg0
000002.jpg0
000003.jpg1
000004.jpg0
000005.jpg0
\n", "
" ], "text/plain": [ " Male\n", "000001.jpg 0\n", "000002.jpg 0\n", "000003.jpg 1\n", "000004.jpg 0\n", "000005.jpg 0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df1 = pd.read_csv('list_attr_celeba.txt', sep=\"\\s+\", skiprows=1, usecols=['Male'])\n", "\n", "# Make 0 (female) & 1 (male) labels instead of -1 & 1\n", "df1.loc[df1['Male'] == -1, 'Male'] = 0\n", "\n", "df1.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Partition
Filename
000001.jpg0
000002.jpg0
000003.jpg0
000004.jpg0
000005.jpg0
\n", "
" ], "text/plain": [ " Partition\n", "Filename \n", "000001.jpg 0\n", "000002.jpg 0\n", "000003.jpg 0\n", "000004.jpg 0\n", "000005.jpg 0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2 = pd.read_csv('list_eval_partition.txt', sep=\"\\s+\", skiprows=0, header=None)\n", "df2.columns = ['Filename', 'Partition']\n", "df2 = df2.set_index('Filename')\n", "\n", "df2.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MalePartition
000001.jpg00
000002.jpg00
000003.jpg10
000004.jpg00
000005.jpg00
\n", "
" ], "text/plain": [ " Male Partition\n", "000001.jpg 0 0\n", "000002.jpg 0 0\n", "000003.jpg 1 0\n", "000004.jpg 0 0\n", "000005.jpg 0 0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df3 = df1.merge(df2, left_index=True, right_index=True)\n", "df3.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MalePartition
000001.jpg00
000002.jpg00
000003.jpg10
000004.jpg00
000005.jpg00
\n", "
" ], "text/plain": [ " Male Partition\n", "000001.jpg 0 0\n", "000002.jpg 0 0\n", "000003.jpg 1 0\n", "000004.jpg 0 0\n", "000005.jpg 0 0" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df3.to_csv('celeba-gender-partitions.csv')\n", "df4 = pd.read_csv('celeba-gender-partitions.csv', index_col=0)\n", "df4.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df4.loc[df4['Partition'] == 0].to_csv('celeba-gender-train.csv')\n", "df4.loc[df4['Partition'] == 1].to_csv('celeba-gender-valid.csv')\n", "df4.loc[df4['Partition'] == 2].to_csv('celeba-gender-test.csv')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(218, 178, 3)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img = Image.open('img_align_celeba/000001.jpg')\n", "print(np.asarray(img, dtype=np.uint8).shape)\n", "plt.imshow(img);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementing a Custom DataLoader Class" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CelebaDataset(Dataset):\n", " \"\"\"Custom Dataset for loading CelebA face images\"\"\"\n", "\n", " def __init__(self, csv_path, img_dir, transform=None):\n", " \n", " df = pd.read_csv(csv_path, index_col=0)\n", " self.img_dir = img_dir\n", " self.csv_path = csv_path\n", " self.img_names = df.index.values\n", " self.y = df['Male'].values\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_names[index]))\n", " \n", " if self.transform is not None:\n", " img = self.transform(img)\n", " \n", " label = self.y[index]\n", " return img, label\n", "\n", " def __len__(self):\n", " return self.y.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Note that transforms.ToTensor()\n", "# already divides pixels by 255. internally\n", "\n", "custom_transform = transforms.Compose([transforms.CenterCrop((178, 178)),\n", " transforms.Resize((128, 128)),\n", " #transforms.Grayscale(), \n", " #transforms.Lambda(lambda x: x/255.),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = CelebaDataset(csv_path='celeba-gender-train.csv',\n", " img_dir='img_align_celeba/',\n", " transform=custom_transform)\n", "\n", "valid_dataset = CelebaDataset(csv_path='celeba-gender-valid.csv',\n", " img_dir='img_align_celeba/',\n", " transform=custom_transform)\n", "\n", "test_dataset = CelebaDataset(csv_path='celeba-gender-test.csv',\n", " img_dir='img_align_celeba/',\n", " transform=custom_transform)\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=4)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=4)\n", "\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Batch index: 0 | Batch size: 128\n", "Epoch: 2 | Batch index: 0 | Batch size: 128\n" ] } ], "source": [ "torch.manual_seed(0)\n", "\n", "for epoch in range(2):\n", "\n", " for batch_idx, (x, y) in enumerate(train_loader):\n", " \n", " print('Epoch:', epoch+1, end='')\n", " print(' | Batch index:', batch_idx, end='')\n", " print(' | Batch size:', y.size()[0])\n", " \n", " x = x.to(DEVICE)\n", " y = y.to(DEVICE)\n", " time.sleep(1)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I6hghKPxj5w0" }, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code cell that implements the ResNet-34 architecture is a derivative of the code provided at https://pytorch.org/docs/0.4.0/_modules/torchvision/models/resnet.html." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", " padding=1, bias=False)\n", "\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None):\n", " super(Bottleneck, self).__init__()\n", " self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n", " padding=1, bias=False)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n", " self.bn3 = nn.BatchNorm2d(planes * 4)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " residual = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " residual = self.downsample(x)\n", "\n", " out += residual\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "\n", "\n", "class ResNet(nn.Module):\n", "\n", " def __init__(self, block, layers, num_classes, grayscale):\n", " self.inplanes = 64\n", " if grayscale:\n", " in_dim = 1\n", " else:\n", " in_dim = 3\n", " super(ResNet, self).__init__()\n", " self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n", " bias=False)\n", " self.bn1 = nn.BatchNorm2d(64)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n", " self.avgpool = nn.AvgPool2d(7, stride=1, padding=2)\n", " self.fc = nn.Linear(2048 * block.expansion, num_classes)\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", " m.weight.data.normal_(0, (2. / n)**.5)\n", " elif isinstance(m, nn.BatchNorm2d):\n", " m.weight.data.fill_(1)\n", " m.bias.data.zero_()\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1):\n", " downsample = None\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " nn.Conv2d(self.inplanes, planes * block.expansion,\n", " kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample))\n", " self.inplanes = planes * block.expansion\n", " for i in range(1, blocks):\n", " layers.append(block(self.inplanes, planes))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = x.view(x.size(0), -1)\n", " logits = self.fc(x)\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas\n", "\n", "\n", "\n", "def resnet152(num_classes, grayscale):\n", " \"\"\"Constructs a ResNet-152 model.\"\"\"\n", " model = ResNet(block=Bottleneck, \n", " layers=[3, 4, 36, 3],\n", " num_classes=NUM_CLASSES,\n", " grayscale=grayscale)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "id": "_lza9t_uj5w1" }, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "model = resnet152(NUM_CLASSES, GRAYSCALE)\n", "model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RAodboScj5w6" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 1547 }, "colab_type": "code", "executionInfo": { "elapsed": 2384585, "status": "ok", "timestamp": 1524976888520, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "Dzh3ROmRj5w7", "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7", "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 0000/1272 | Cost: 0.7148\n", "Epoch: 001/010 | Batch 0050/1272 | Cost: 0.6455\n", "Epoch: 001/010 | Batch 0100/1272 | Cost: 0.4099\n", "Epoch: 001/010 | Batch 0150/1272 | Cost: 0.2189\n", "Epoch: 001/010 | Batch 0200/1272 | Cost: 0.2228\n", "Epoch: 001/010 | Batch 0250/1272 | Cost: 0.2147\n", "Epoch: 001/010 | Batch 0300/1272 | Cost: 0.1621\n", "Epoch: 001/010 | Batch 0350/1272 | Cost: 0.1987\n", "Epoch: 001/010 | Batch 0400/1272 | Cost: 0.1688\n", "Epoch: 001/010 | Batch 0450/1272 | Cost: 0.2529\n", "Epoch: 001/010 | Batch 0500/1272 | Cost: 0.2114\n", "Epoch: 001/010 | Batch 0550/1272 | Cost: 0.1637\n", "Epoch: 001/010 | Batch 0600/1272 | Cost: 0.1147\n", "Epoch: 001/010 | Batch 0650/1272 | Cost: 0.2357\n", "Epoch: 001/010 | Batch 0700/1272 | Cost: 0.1656\n", "Epoch: 001/010 | Batch 0750/1272 | Cost: 0.0716\n", "Epoch: 001/010 | Batch 0800/1272 | Cost: 0.0936\n", "Epoch: 001/010 | Batch 0850/1272 | Cost: 0.2091\n", "Epoch: 001/010 | Batch 0900/1272 | Cost: 0.0778\n", "Epoch: 001/010 | Batch 0950/1272 | Cost: 0.1051\n", "Epoch: 001/010 | Batch 1000/1272 | Cost: 0.2065\n", "Epoch: 001/010 | Batch 1050/1272 | Cost: 0.1207\n", "Epoch: 001/010 | Batch 1100/1272 | Cost: 0.1522\n", "Epoch: 001/010 | Batch 1150/1272 | Cost: 0.1350\n", "Epoch: 001/010 | Batch 1200/1272 | Cost: 0.0743\n", "Epoch: 001/010 | Batch 1250/1272 | Cost: 0.1881\n", "Epoch: 001/010 | Train: 95.689% | Valid: 96.139%\n", "Time elapsed: 17.02 min\n", "Epoch: 002/010 | Batch 0000/1272 | Cost: 0.1336\n", "Epoch: 002/010 | Batch 0050/1272 | Cost: 0.0961\n", "Epoch: 002/010 | Batch 0100/1272 | Cost: 0.0340\n", "Epoch: 002/010 | Batch 0150/1272 | Cost: 0.1562\n", "Epoch: 002/010 | Batch 0200/1272 | Cost: 0.0968\n", "Epoch: 002/010 | Batch 0250/1272 | Cost: 0.0795\n", "Epoch: 002/010 | Batch 0300/1272 | Cost: 0.0635\n", "Epoch: 002/010 | Batch 0350/1272 | Cost: 0.1017\n", "Epoch: 002/010 | Batch 0400/1272 | Cost: 0.0406\n", "Epoch: 002/010 | Batch 0450/1272 | Cost: 0.0972\n", "Epoch: 002/010 | Batch 0500/1272 | Cost: 0.1341\n", "Epoch: 002/010 | Batch 0550/1272 | Cost: 0.1311\n", "Epoch: 002/010 | Batch 0600/1272 | Cost: 0.0875\n", "Epoch: 002/010 | Batch 0650/1272 | Cost: 0.0430\n", "Epoch: 002/010 | Batch 0700/1272 | Cost: 0.0517\n", "Epoch: 002/010 | Batch 0750/1272 | Cost: 0.0735\n", "Epoch: 002/010 | Batch 0800/1272 | Cost: 0.0446\n", "Epoch: 002/010 | Batch 0850/1272 | Cost: 0.1644\n", "Epoch: 002/010 | Batch 0900/1272 | Cost: 0.0575\n", "Epoch: 002/010 | Batch 0950/1272 | Cost: 0.0547\n", "Epoch: 002/010 | Batch 1000/1272 | Cost: 0.1019\n", "Epoch: 002/010 | Batch 1050/1272 | Cost: 0.1229\n", "Epoch: 002/010 | Batch 1100/1272 | Cost: 0.1009\n", "Epoch: 002/010 | Batch 1150/1272 | Cost: 0.1092\n", "Epoch: 002/010 | Batch 1200/1272 | Cost: 0.0293\n", "Epoch: 002/010 | Batch 1250/1272 | Cost: 0.1025\n", "Epoch: 002/010 | Train: 96.899% | Valid: 96.920%\n", "Time elapsed: 36.16 min\n", "Epoch: 003/010 | Batch 0000/1272 | Cost: 0.0454\n", "Epoch: 003/010 | Batch 0050/1272 | Cost: 0.0702\n", "Epoch: 003/010 | Batch 0100/1272 | Cost: 0.0256\n", "Epoch: 003/010 | Batch 0150/1272 | Cost: 0.1387\n", "Epoch: 003/010 | Batch 0200/1272 | Cost: 0.0935\n", "Epoch: 003/010 | Batch 0250/1272 | Cost: 0.1291\n", "Epoch: 003/010 | Batch 0300/1272 | Cost: 0.0718\n", "Epoch: 003/010 | Batch 0350/1272 | Cost: 0.0668\n", "Epoch: 003/010 | Batch 0400/1272 | Cost: 0.0440\n", "Epoch: 003/010 | Batch 0450/1272 | Cost: 0.0551\n", "Epoch: 003/010 | Batch 0500/1272 | Cost: 0.0620\n", "Epoch: 003/010 | Batch 0550/1272 | Cost: 0.0191\n", "Epoch: 003/010 | Batch 0600/1272 | Cost: 0.0869\n", "Epoch: 003/010 | Batch 0650/1272 | Cost: 0.0524\n", "Epoch: 003/010 | Batch 0700/1272 | Cost: 0.0461\n", "Epoch: 003/010 | Batch 0750/1272 | Cost: 0.1172\n", "Epoch: 003/010 | Batch 0800/1272 | Cost: 0.0409\n", "Epoch: 003/010 | Batch 0850/1272 | Cost: 0.0294\n", "Epoch: 003/010 | Batch 0900/1272 | Cost: 0.0899\n", "Epoch: 003/010 | Batch 0950/1272 | Cost: 0.1365\n", "Epoch: 003/010 | Batch 1000/1272 | Cost: 0.0700\n", "Epoch: 003/010 | Batch 1050/1272 | Cost: 0.0687\n", "Epoch: 003/010 | Batch 1100/1272 | Cost: 0.0645\n", "Epoch: 003/010 | Batch 1150/1272 | Cost: 0.0878\n", "Epoch: 003/010 | Batch 1200/1272 | Cost: 0.0473\n", "Epoch: 003/010 | Batch 1250/1272 | Cost: 0.1231\n", "Epoch: 003/010 | Train: 97.895% | Valid: 97.770%\n", "Time elapsed: 51.91 min\n", "Epoch: 004/010 | Batch 0000/1272 | Cost: 0.0516\n", "Epoch: 004/010 | Batch 0050/1272 | Cost: 0.0579\n", "Epoch: 004/010 | Batch 0100/1272 | Cost: 0.0487\n", "Epoch: 004/010 | Batch 0150/1272 | Cost: 0.0394\n", "Epoch: 004/010 | Batch 0200/1272 | Cost: 0.0205\n", "Epoch: 004/010 | Batch 0250/1272 | Cost: 0.0628\n", "Epoch: 004/010 | Batch 0300/1272 | Cost: 0.0522\n", "Epoch: 004/010 | Batch 0350/1272 | Cost: 0.0456\n", "Epoch: 004/010 | Batch 0400/1272 | Cost: 0.0370\n", "Epoch: 004/010 | Batch 0450/1272 | Cost: 0.0460\n", "Epoch: 004/010 | Batch 0500/1272 | Cost: 0.0784\n", "Epoch: 004/010 | Batch 0550/1272 | Cost: 0.0632\n", "Epoch: 004/010 | Batch 0600/1272 | Cost: 0.0721\n", "Epoch: 004/010 | Batch 0650/1272 | Cost: 0.1943\n", "Epoch: 004/010 | Batch 0700/1272 | Cost: 0.0365\n", "Epoch: 004/010 | Batch 0750/1272 | Cost: 0.0437\n", "Epoch: 004/010 | Batch 0800/1272 | Cost: 0.0335\n", "Epoch: 004/010 | Batch 0850/1272 | Cost: 0.0897\n", "Epoch: 004/010 | Batch 0900/1272 | Cost: 0.0661\n", "Epoch: 004/010 | Batch 0950/1272 | Cost: 0.1020\n", "Epoch: 004/010 | Batch 1000/1272 | Cost: 0.0935\n", "Epoch: 004/010 | Batch 1050/1272 | Cost: 0.1341\n", "Epoch: 004/010 | Batch 1100/1272 | Cost: 0.0694\n", "Epoch: 004/010 | Batch 1150/1272 | Cost: 0.0634\n", "Epoch: 004/010 | Batch 1200/1272 | Cost: 0.0721\n", "Epoch: 004/010 | Batch 1250/1272 | Cost: 0.0504\n", "Epoch: 004/010 | Train: 97.629% | Valid: 97.634%\n", "Time elapsed: 67.70 min\n", "Epoch: 005/010 | Batch 0000/1272 | Cost: 0.0560\n", "Epoch: 005/010 | Batch 0050/1272 | Cost: 0.0277\n", "Epoch: 005/010 | Batch 0100/1272 | Cost: 0.0239\n", "Epoch: 005/010 | Batch 0150/1272 | Cost: 0.0721\n", "Epoch: 005/010 | Batch 0200/1272 | Cost: 0.0570\n", "Epoch: 005/010 | Batch 0250/1272 | Cost: 0.0258\n", "Epoch: 005/010 | Batch 0300/1272 | Cost: 0.0349\n", "Epoch: 005/010 | Batch 0350/1272 | Cost: 0.0479\n", "Epoch: 005/010 | Batch 0400/1272 | Cost: 0.0406\n", "Epoch: 005/010 | Batch 0450/1272 | Cost: 0.0580\n", "Epoch: 005/010 | Batch 0500/1272 | Cost: 0.0167\n", "Epoch: 005/010 | Batch 0550/1272 | Cost: 0.0593\n", "Epoch: 005/010 | Batch 0600/1272 | Cost: 0.0273\n", "Epoch: 005/010 | Batch 0650/1272 | Cost: 0.0446\n", "Epoch: 005/010 | Batch 0700/1272 | Cost: 0.0171\n", "Epoch: 005/010 | Batch 0750/1272 | Cost: 0.1026\n", "Epoch: 005/010 | Batch 0800/1272 | Cost: 0.0624\n", "Epoch: 005/010 | Batch 0850/1272 | Cost: 0.0731\n", "Epoch: 005/010 | Batch 0900/1272 | Cost: 0.0480\n", "Epoch: 005/010 | Batch 0950/1272 | Cost: 0.0968\n", "Epoch: 005/010 | Batch 1000/1272 | Cost: 0.0164\n", "Epoch: 005/010 | Batch 1050/1272 | Cost: 0.0946\n", "Epoch: 005/010 | Batch 1100/1272 | Cost: 0.0524\n", "Epoch: 005/010 | Batch 1150/1272 | Cost: 0.0421\n", "Epoch: 005/010 | Batch 1200/1272 | Cost: 0.0779\n", "Epoch: 005/010 | Batch 1250/1272 | Cost: 0.0367\n", "Epoch: 005/010 | Train: 97.482% | Valid: 97.327%\n", "Time elapsed: 83.43 min\n", "Epoch: 006/010 | Batch 0000/1272 | Cost: 0.0753\n", "Epoch: 006/010 | Batch 0050/1272 | Cost: 0.0498\n", "Epoch: 006/010 | Batch 0100/1272 | Cost: 0.0319\n", "Epoch: 006/010 | Batch 0150/1272 | Cost: 0.0550\n", "Epoch: 006/010 | Batch 0200/1272 | Cost: 0.0922\n", "Epoch: 006/010 | Batch 0250/1272 | Cost: 0.0564\n", "Epoch: 006/010 | Batch 0300/1272 | Cost: 0.0505\n", "Epoch: 006/010 | Batch 0350/1272 | Cost: 0.0697\n", "Epoch: 006/010 | Batch 0400/1272 | Cost: 0.0434\n", "Epoch: 006/010 | Batch 0450/1272 | Cost: 0.0854\n", "Epoch: 006/010 | Batch 0500/1272 | Cost: 0.0356\n", "Epoch: 006/010 | Batch 0550/1272 | Cost: 0.0565\n", "Epoch: 006/010 | Batch 0600/1272 | Cost: 0.0969\n", "Epoch: 006/010 | Batch 0650/1272 | Cost: 0.0479\n", "Epoch: 006/010 | Batch 0700/1272 | Cost: 0.0556\n", "Epoch: 006/010 | Batch 0750/1272 | Cost: 0.0409\n", "Epoch: 006/010 | Batch 0800/1272 | Cost: 0.0493\n", "Epoch: 006/010 | Batch 0850/1272 | Cost: 0.0604\n", "Epoch: 006/010 | Batch 0900/1272 | Cost: 0.0386\n", "Epoch: 006/010 | Batch 0950/1272 | Cost: 0.0465\n", "Epoch: 006/010 | Batch 1000/1272 | Cost: 0.0526\n", "Epoch: 006/010 | Batch 1050/1272 | Cost: 0.0192\n", "Epoch: 006/010 | Batch 1100/1272 | Cost: 0.0300\n", "Epoch: 006/010 | Batch 1150/1272 | Cost: 0.0607\n", "Epoch: 006/010 | Batch 1200/1272 | Cost: 0.1048\n", "Epoch: 006/010 | Batch 1250/1272 | Cost: 0.0237\n", "Epoch: 006/010 | Train: 98.128% | Valid: 97.851%\n", "Time elapsed: 99.20 min\n", "Epoch: 007/010 | Batch 0000/1272 | Cost: 0.0240\n", "Epoch: 007/010 | Batch 0050/1272 | Cost: 0.0638\n", "Epoch: 007/010 | Batch 0100/1272 | Cost: 0.0192\n", "Epoch: 007/010 | Batch 0150/1272 | Cost: 0.0800\n", "Epoch: 007/010 | Batch 0200/1272 | Cost: 0.0562\n", "Epoch: 007/010 | Batch 0250/1272 | Cost: 0.0293\n", "Epoch: 007/010 | Batch 0300/1272 | Cost: 0.0558\n", "Epoch: 007/010 | Batch 0350/1272 | Cost: 0.0206\n", "Epoch: 007/010 | Batch 0400/1272 | Cost: 0.0315\n", "Epoch: 007/010 | Batch 0450/1272 | Cost: 0.0339\n", "Epoch: 007/010 | Batch 0500/1272 | Cost: 0.0311\n", "Epoch: 007/010 | Batch 0550/1272 | Cost: 0.0366\n", "Epoch: 007/010 | Batch 0600/1272 | Cost: 0.0638\n", "Epoch: 007/010 | Batch 0650/1272 | Cost: 0.0610\n", "Epoch: 007/010 | Batch 0700/1272 | Cost: 0.0597\n", "Epoch: 007/010 | Batch 0750/1272 | Cost: 0.0489\n", "Epoch: 007/010 | Batch 0800/1272 | Cost: 0.0512\n", "Epoch: 007/010 | Batch 0850/1272 | Cost: 0.0995\n", "Epoch: 007/010 | Batch 0900/1272 | Cost: 0.0364\n", "Epoch: 007/010 | Batch 0950/1272 | Cost: 0.1224\n", "Epoch: 007/010 | Batch 1000/1272 | Cost: 0.0514\n", "Epoch: 007/010 | Batch 1050/1272 | Cost: 0.0663\n", "Epoch: 007/010 | Batch 1100/1272 | Cost: 0.0514\n", "Epoch: 007/010 | Batch 1150/1272 | Cost: 0.0148\n", "Epoch: 007/010 | Batch 1200/1272 | Cost: 0.0304\n", "Epoch: 007/010 | Batch 1250/1272 | Cost: 0.0482\n" ] } ], "source": [ "def compute_accuracy(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for i, (features, targets) in enumerate(data_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n", " %(epoch+1, NUM_EPOCHS, batch_idx, \n", " len(train_loader), cost))\n", "\n", " \n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " print('Epoch: %03d/%03d | Train: %.3f%% | Valid: %.3f%%' % (\n", " epoch+1, NUM_EPOCHS, \n", " compute_accuracy(model, train_loader, device=DEVICE),\n", " compute_accuracy(model, valid_loader, device=DEVICE)))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "paaeEQHQj5xC" }, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 6514, "status": "ok", "timestamp": 1524976895054, "user": { "displayName": "Sebastian Raschka", "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg", "userId": "118404394130788869227" }, "user_tz": 240 }, "id": "gzQMWKq5j5xE", "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9" }, "outputs": [], "source": [ "with torch.set_grad_enabled(False): # save memory during inference\n", " print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for batch_idx, (features, targets) in enumerate(test_loader):\n", "\n", " features = features\n", " targets = targets\n", " break\n", " \n", "plt.imshow(np.transpose(features[0], (1, 2, 0)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.eval()\n", "logits, probas = model(features.to(DEVICE)[0, None])\n", "print('Probability Female %.2f%%' % (probas[0][0]*100))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%watermark -iv" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "default_view": {}, "name": "convnet-vgg16.ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "371px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 2 }