{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic Graph Neural Network with Spectral Graph Convolution on MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementing a very basic graph neural network (GNN) using a spectral graph convolution. \n", "\n", "Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. \n", "\n", "Here, the adjacency matrix of the pixels is basically just determined by their neighborhood pixels. Using a Gaussian filter, we connect pixels based on their Euclidean distance in the grid.\n", "\n", "In the related notebook, [./gnn-basic-1.ipynb](./gnn-basic-1.ipynb), we used this adjacency matrix $A$ to compute the output of a layer as \n", "\n", "$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$\n", "\n", "Here, $A$ is the $N \\times N$ adjacency matrix, and $X$ is the $N \\times C$ feature matrix (a 2D coordinate array, where $N$ is the total number of pixels -- $28 \\times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \\times P$, where $P$ would represent the number of classes if we have only a single hidden layer.\n", "\n", "In this notebook, we modify this code using spectral graph convolution, i.e.,\n", "\n", "$$X^{(l+1)}=V\\left(V^{T} X^{(l)} \\odot V^{T} W_{\\text {spectral }}^{(l)}\\right).$$\n", "\n", "Where $V$ are the eigenvectors of the graph Laplacian $L$, which we can compute from the adjacency matrix $A$. Here, $W_{\\text {spectral }}$ represents the trainable weights (filters).\n", "\n", "- Inspired by and based on Boris Knyazev's tutorial at https://towardsdatascience.com/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-2-be6d71d70f49." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "from scipy.spatial.distance import cdist\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import Subset\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "DEVICE = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.05\n", "NUM_EPOCHS = 50\n", "BATCH_SIZE = 128\n", "IMG_SIZE = 28\n", "\n", "# Architecture\n", "NUM_CLASSES = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "train_indices = torch.arange(0, 59000)\n", "valid_indices = torch.arange(59000, 60000)\n", "\n", "custom_transform = transforms.Compose([transforms.ToTensor()])\n", "\n", "\n", "train_and_valid = datasets.MNIST(root='data', \n", " train=True, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def precompute_adjacency_matrix(img_size):\n", " col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))\n", " \n", " # N = img_size^2\n", " # construct 2D coordinate array (shape N x 2) and normalize\n", " # in range [0, 1]\n", " coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size\n", "\n", " # compute pairwise distance matrix (N x N)\n", " dist = cdist(coord, coord, metric='euclidean')\n", " \n", " # Apply Gaussian filter\n", " sigma = 0.05 * np.pi\n", " A = np.exp(- dist / sigma ** 2)\n", " A[A < 0.01] = 0\n", " A = torch.from_numpy(A).float()\n", " \n", " return A\n", "\n", " \"\"\"\n", " # Normalization as per (Kipf & Welling, ICLR 2017)\n", " D = A.sum(1) # nodes degree (N,)\n", " D_hat = (D + 1e-5) ** (-0.5)\n", " A_hat = D_hat.view(-1, 1) * A * D_hat.view(1, -1) # N,N\n", " \n", " return A_hat\n", " \"\"\"\n", "\n", "\n", "def get_graph_laplacian(A):\n", " # From https://towardsdatascience.com/spectral-graph-convolution-\n", " # explained-and-implemented-step-by-step-2e495b57f801\n", " #\n", " # Computing the graph Laplacian\n", " # A is an adjacency matrix of some graph G\n", " N = A.shape[0] # number of nodes in a graph\n", " D = np.sum(A, 0) # node degrees\n", " D_hat = np.diag((D + 1e-5)**(-0.5)) # normalized node degrees\n", " L = np.identity(N) - np.dot(D_hat, A).dot(D_hat) # Laplacian\n", " return torch.from_numpy(L).float()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "A = precompute_adjacency_matrix(28)\n", "plt.imshow(A, vmin=0., vmax=1.)\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "L = get_graph_laplacian(A.numpy())\n", "plt.imshow(L, vmin=0., vmax=1.)\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "from scipy.sparse.linalg import eigsh\n", " \n", "\n", "class GraphNet(nn.Module):\n", " def __init__(self, img_size=28, num_filters=2, num_classes=10):\n", " super(GraphNet, self).__init__()\n", " \n", " n_rows = img_size**2\n", " self.fc = nn.Linear(n_rows*num_filters, num_classes, bias=False)\n", "\n", " A = precompute_adjacency_matrix(img_size)\n", " L = get_graph_laplacian(A.numpy())\n", " Λ,V = eigsh(L.numpy(), k=20, which='SM') # eigen-decomposition (i.e. find Λ,V)\n", "\n", " V = torch.from_numpy(V)\n", " \n", " # Weight matrix\n", " W_spectral = nn.Parameter(torch.ones((img_size**2, num_filters))).float()\n", " torch.nn.init.kaiming_uniform_(W_spectral)\n", " \n", " self.register_buffer('A', A)\n", " self.register_buffer('L', L)\n", " self.register_buffer('V', V)\n", " self.register_buffer('W_spectral', W_spectral)\n", "\n", " \n", "\n", " def forward(self, x):\n", " \n", " B = x.size(0) # Batch size\n", "\n", " ### Reshape eigenvectors\n", " # from [H*W, 20] to [B, H*W, 20]\n", " V_tensor = self.V.unsqueeze(0)\n", " V_tensor = self.V.expand(B, -1, -1)\n", " # from [H*W, 20] to [B, 20, H*W]\n", " V_tensor_T = self.V.T.unsqueeze(0)\n", " V_tensor_T = self.V.T.expand(B, -1, -1)\n", " \n", " ### Reshape inputs\n", " # [B, C, H, W] => [B, H*W, 1]\n", " x_reshape = x.view(B, -1, 1)\n", " \n", " ### Reshape spectral weights\n", " # to size [128, H*W, F]\n", " W_spectral_tensor = self.W_spectral.unsqueeze(0)\n", " W_spectral_tensor = self.W_spectral.expand(B, -1, -1)\n", " \n", " ### Spectral convolution on graphs\n", " # [B, 20, H*W] . [B, H*W, 1] ==> [B, 20, 1]\n", " X_hat = V_tensor_T.bmm(x_reshape) # 20×1 node features in the \"spectral\" domain\n", " W_hat = V_tensor_T.bmm(W_spectral_tensor) # 20×F filters in the \"spectral\" domain\n", " Y = V_tensor.bmm(X_hat * W_hat) # N×F result of convolution\n", "\n", " ### Fully connected\n", " logits = self.fc(Y.reshape(B, -1))\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)\n", "\n", "model = model.to(DEVICE)\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/461 | Cost: 2.3133\n", "Epoch: 001/050 | Batch 150/461 | Cost: 1.1899\n", "Epoch: 001/050 | Batch 300/461 | Cost: 1.0481\n", "Epoch: 001/050 | Batch 450/461 | Cost: 0.9287\n", "Epoch: 001/050\n", "Train ACC: 73.79 | Validation ACC: 78.10\n", "Time elapsed: 0.07 min\n", "Epoch: 002/050 | Batch 000/461 | Cost: 0.8224\n", "Epoch: 002/050 | Batch 150/461 | Cost: 0.9684\n", "Epoch: 002/050 | Batch 300/461 | Cost: 0.6952\n", "Epoch: 002/050 | Batch 450/461 | Cost: 0.8158\n", "Epoch: 002/050\n", "Train ACC: 77.48 | Validation ACC: 82.20\n", "Time elapsed: 0.14 min\n", "Epoch: 003/050 | Batch 000/461 | Cost: 0.8203\n", "Epoch: 003/050 | Batch 150/461 | Cost: 0.8409\n", "Epoch: 003/050 | Batch 300/461 | Cost: 0.8602\n", "Epoch: 003/050 | Batch 450/461 | Cost: 0.7012\n", "Epoch: 003/050\n", "Train ACC: 78.55 | Validation ACC: 83.40\n", "Time elapsed: 0.21 min\n", "Epoch: 004/050 | Batch 000/461 | Cost: 0.7919\n", "Epoch: 004/050 | Batch 150/461 | Cost: 0.9010\n", "Epoch: 004/050 | Batch 300/461 | Cost: 0.6895\n", "Epoch: 004/050 | Batch 450/461 | Cost: 0.6981\n", "Epoch: 004/050\n", "Train ACC: 79.30 | Validation ACC: 84.10\n", "Time elapsed: 0.28 min\n", "Epoch: 005/050 | Batch 000/461 | Cost: 0.6080\n", "Epoch: 005/050 | Batch 150/461 | Cost: 0.6627\n", "Epoch: 005/050 | Batch 300/461 | Cost: 0.7620\n", "Epoch: 005/050 | Batch 450/461 | Cost: 0.8047\n", "Epoch: 005/050\n", "Train ACC: 79.66 | Validation ACC: 84.50\n", "Time elapsed: 0.35 min\n", "Epoch: 006/050 | Batch 000/461 | Cost: 0.5992\n", "Epoch: 006/050 | Batch 150/461 | Cost: 0.5546\n", "Epoch: 006/050 | Batch 300/461 | Cost: 0.6459\n", "Epoch: 006/050 | Batch 450/461 | Cost: 0.5968\n", "Epoch: 006/050\n", "Train ACC: 79.91 | Validation ACC: 85.10\n", "Time elapsed: 0.42 min\n", "Epoch: 007/050 | Batch 000/461 | Cost: 0.7909\n", "Epoch: 007/050 | Batch 150/461 | Cost: 0.6488\n", "Epoch: 007/050 | Batch 300/461 | Cost: 0.7580\n", "Epoch: 007/050 | Batch 450/461 | Cost: 0.5646\n", "Epoch: 007/050\n", "Train ACC: 80.50 | Validation ACC: 85.00\n", "Time elapsed: 0.48 min\n", "Epoch: 008/050 | Batch 000/461 | Cost: 0.6147\n", "Epoch: 008/050 | Batch 150/461 | Cost: 0.6998\n", "Epoch: 008/050 | Batch 300/461 | Cost: 0.5563\n", "Epoch: 008/050 | Batch 450/461 | Cost: 0.5611\n", "Epoch: 008/050\n", "Train ACC: 80.73 | Validation ACC: 85.60\n", "Time elapsed: 0.56 min\n", "Epoch: 009/050 | Batch 000/461 | Cost: 0.5629\n", "Epoch: 009/050 | Batch 150/461 | Cost: 0.6245\n", "Epoch: 009/050 | Batch 300/461 | Cost: 0.7393\n", "Epoch: 009/050 | Batch 450/461 | Cost: 0.6670\n", "Epoch: 009/050\n", "Train ACC: 81.09 | Validation ACC: 85.70\n", "Time elapsed: 0.62 min\n", "Epoch: 010/050 | Batch 000/461 | Cost: 0.6582\n", "Epoch: 010/050 | Batch 150/461 | Cost: 0.7550\n", "Epoch: 010/050 | Batch 300/461 | Cost: 0.7028\n", "Epoch: 010/050 | Batch 450/461 | Cost: 0.6558\n", "Epoch: 010/050\n", "Train ACC: 81.00 | Validation ACC: 85.70\n", "Time elapsed: 0.69 min\n", "Epoch: 011/050 | Batch 000/461 | Cost: 0.5472\n", "Epoch: 011/050 | Batch 150/461 | Cost: 0.6051\n", "Epoch: 011/050 | Batch 300/461 | Cost: 0.5875\n", "Epoch: 011/050 | Batch 450/461 | Cost: 0.4688\n", "Epoch: 011/050\n", "Train ACC: 81.50 | Validation ACC: 85.90\n", "Time elapsed: 0.76 min\n", "Epoch: 012/050 | Batch 000/461 | Cost: 0.5227\n", "Epoch: 012/050 | Batch 150/461 | Cost: 0.6252\n", "Epoch: 012/050 | Batch 300/461 | Cost: 0.6359\n", "Epoch: 012/050 | Batch 450/461 | Cost: 0.8590\n", "Epoch: 012/050\n", "Train ACC: 81.61 | Validation ACC: 86.50\n", "Time elapsed: 0.83 min\n", "Epoch: 013/050 | Batch 000/461 | Cost: 0.4933\n", "Epoch: 013/050 | Batch 150/461 | Cost: 0.5844\n", "Epoch: 013/050 | Batch 300/461 | Cost: 0.4684\n", "Epoch: 013/050 | Batch 450/461 | Cost: 0.5275\n", "Epoch: 013/050\n", "Train ACC: 81.79 | Validation ACC: 86.50\n", "Time elapsed: 0.90 min\n", "Epoch: 014/050 | Batch 000/461 | Cost: 0.6382\n", "Epoch: 014/050 | Batch 150/461 | Cost: 0.7612\n", "Epoch: 014/050 | Batch 300/461 | Cost: 0.5378\n", "Epoch: 014/050 | Batch 450/461 | Cost: 0.5651\n", "Epoch: 014/050\n", "Train ACC: 81.94 | Validation ACC: 86.50\n", "Time elapsed: 0.97 min\n", "Epoch: 015/050 | Batch 000/461 | Cost: 0.5122\n", "Epoch: 015/050 | Batch 150/461 | Cost: 0.6347\n", "Epoch: 015/050 | Batch 300/461 | Cost: 0.6239\n", "Epoch: 015/050 | Batch 450/461 | Cost: 0.6026\n", "Epoch: 015/050\n", "Train ACC: 82.01 | Validation ACC: 87.00\n", "Time elapsed: 1.03 min\n", "Epoch: 016/050 | Batch 000/461 | Cost: 0.6380\n", "Epoch: 016/050 | Batch 150/461 | Cost: 0.5865\n", "Epoch: 016/050 | Batch 300/461 | Cost: 0.3510\n", "Epoch: 016/050 | Batch 450/461 | Cost: 0.5859\n", "Epoch: 016/050\n", "Train ACC: 82.06 | Validation ACC: 86.50\n", "Time elapsed: 1.10 min\n", "Epoch: 017/050 | Batch 000/461 | Cost: 0.6827\n", "Epoch: 017/050 | Batch 150/461 | Cost: 0.6415\n", "Epoch: 017/050 | Batch 300/461 | Cost: 0.7186\n", "Epoch: 017/050 | Batch 450/461 | Cost: 0.6067\n", "Epoch: 017/050\n", "Train ACC: 82.41 | Validation ACC: 87.70\n", "Time elapsed: 1.17 min\n", "Epoch: 018/050 | Batch 000/461 | Cost: 0.7209\n", "Epoch: 018/050 | Batch 150/461 | Cost: 0.6981\n", "Epoch: 018/050 | Batch 300/461 | Cost: 0.6810\n", "Epoch: 018/050 | Batch 450/461 | Cost: 0.6180\n", "Epoch: 018/050\n", "Train ACC: 82.55 | Validation ACC: 87.50\n", "Time elapsed: 1.24 min\n", "Epoch: 019/050 | Batch 000/461 | Cost: 0.7285\n", "Epoch: 019/050 | Batch 150/461 | Cost: 0.7734\n", "Epoch: 019/050 | Batch 300/461 | Cost: 0.7189\n", "Epoch: 019/050 | Batch 450/461 | Cost: 0.5652\n", "Epoch: 019/050\n", "Train ACC: 82.46 | Validation ACC: 87.30\n", "Time elapsed: 1.31 min\n", "Epoch: 020/050 | Batch 000/461 | Cost: 0.7076\n", "Epoch: 020/050 | Batch 150/461 | Cost: 0.4096\n", "Epoch: 020/050 | Batch 300/461 | Cost: 0.7485\n", "Epoch: 020/050 | Batch 450/461 | Cost: 0.7334\n", "Epoch: 020/050\n", "Train ACC: 82.48 | Validation ACC: 87.30\n", "Time elapsed: 1.38 min\n", "Epoch: 021/050 | Batch 000/461 | Cost: 0.4686\n", "Epoch: 021/050 | Batch 150/461 | Cost: 0.6241\n", "Epoch: 021/050 | Batch 300/461 | Cost: 0.5736\n", "Epoch: 021/050 | Batch 450/461 | Cost: 0.4948\n", "Epoch: 021/050\n", "Train ACC: 82.67 | Validation ACC: 88.00\n", "Time elapsed: 1.45 min\n", "Epoch: 022/050 | Batch 000/461 | Cost: 0.4657\n", "Epoch: 022/050 | Batch 150/461 | Cost: 0.6718\n", "Epoch: 022/050 | Batch 300/461 | Cost: 0.6647\n", "Epoch: 022/050 | Batch 450/461 | Cost: 0.4913\n", "Epoch: 022/050\n", "Train ACC: 82.87 | Validation ACC: 87.90\n", "Time elapsed: 1.52 min\n", "Epoch: 023/050 | Batch 000/461 | Cost: 0.5567\n", "Epoch: 023/050 | Batch 150/461 | Cost: 0.4976\n", "Epoch: 023/050 | Batch 300/461 | Cost: 0.5911\n", "Epoch: 023/050 | Batch 450/461 | Cost: 0.4014\n", "Epoch: 023/050\n", "Train ACC: 82.91 | Validation ACC: 87.80\n", "Time elapsed: 1.59 min\n", "Epoch: 024/050 | Batch 000/461 | Cost: 0.5728\n", "Epoch: 024/050 | Batch 150/461 | Cost: 0.6313\n", "Epoch: 024/050 | Batch 300/461 | Cost: 0.5825\n", "Epoch: 024/050 | Batch 450/461 | Cost: 0.4720\n", "Epoch: 024/050\n", "Train ACC: 83.00 | Validation ACC: 87.90\n", "Time elapsed: 1.66 min\n", "Epoch: 025/050 | Batch 000/461 | Cost: 0.5128\n", "Epoch: 025/050 | Batch 150/461 | Cost: 0.4793\n", "Epoch: 025/050 | Batch 300/461 | Cost: 0.7191\n", "Epoch: 025/050 | Batch 450/461 | Cost: 0.5402\n", "Epoch: 025/050\n", "Train ACC: 83.12 | Validation ACC: 88.30\n", "Time elapsed: 1.72 min\n", "Epoch: 026/050 | Batch 000/461 | Cost: 0.4961\n", "Epoch: 026/050 | Batch 150/461 | Cost: 0.4546\n", "Epoch: 026/050 | Batch 300/461 | Cost: 0.5333\n", "Epoch: 026/050 | Batch 450/461 | Cost: 0.5073\n", "Epoch: 026/050\n", "Train ACC: 82.98 | Validation ACC: 87.90\n", "Time elapsed: 1.79 min\n", "Epoch: 027/050 | Batch 000/461 | Cost: 0.7034\n", "Epoch: 027/050 | Batch 150/461 | Cost: 0.5373\n", "Epoch: 027/050 | Batch 300/461 | Cost: 0.5158\n", "Epoch: 027/050 | Batch 450/461 | Cost: 0.5705\n", "Epoch: 027/050\n", "Train ACC: 83.15 | Validation ACC: 88.00\n", "Time elapsed: 1.86 min\n", "Epoch: 028/050 | Batch 000/461 | Cost: 0.4614\n", "Epoch: 028/050 | Batch 150/461 | Cost: 0.4124\n", "Epoch: 028/050 | Batch 300/461 | Cost: 0.7368\n", "Epoch: 028/050 | Batch 450/461 | Cost: 0.5744\n", "Epoch: 028/050\n", "Train ACC: 82.85 | Validation ACC: 87.60\n", "Time elapsed: 1.93 min\n", "Epoch: 029/050 | Batch 000/461 | Cost: 0.5026\n", "Epoch: 029/050 | Batch 150/461 | Cost: 0.6048\n", "Epoch: 029/050 | Batch 300/461 | Cost: 0.6400\n", "Epoch: 029/050 | Batch 450/461 | Cost: 0.4906\n", "Epoch: 029/050\n", "Train ACC: 83.26 | Validation ACC: 88.10\n", "Time elapsed: 2.00 min\n", "Epoch: 030/050 | Batch 000/461 | Cost: 0.6298\n", "Epoch: 030/050 | Batch 150/461 | Cost: 0.5472\n", "Epoch: 030/050 | Batch 300/461 | Cost: 0.5469\n", "Epoch: 030/050 | Batch 450/461 | Cost: 0.4819\n", "Epoch: 030/050\n", "Train ACC: 83.30 | Validation ACC: 88.70\n", "Time elapsed: 2.07 min\n", "Epoch: 031/050 | Batch 000/461 | Cost: 0.6101\n", "Epoch: 031/050 | Batch 150/461 | Cost: 0.5150\n", "Epoch: 031/050 | Batch 300/461 | Cost: 0.5505\n", "Epoch: 031/050 | Batch 450/461 | Cost: 0.5634\n", "Epoch: 031/050\n", "Train ACC: 83.28 | Validation ACC: 88.60\n", "Time elapsed: 2.13 min\n", "Epoch: 032/050 | Batch 000/461 | Cost: 0.5655\n", "Epoch: 032/050 | Batch 150/461 | Cost: 0.6567\n", "Epoch: 032/050 | Batch 300/461 | Cost: 0.5758\n", "Epoch: 032/050 | Batch 450/461 | Cost: 0.5306\n", "Epoch: 032/050\n", "Train ACC: 83.31 | Validation ACC: 88.20\n", "Time elapsed: 2.20 min\n", "Epoch: 033/050 | Batch 000/461 | Cost: 0.6677\n", "Epoch: 033/050 | Batch 150/461 | Cost: 0.7450\n", "Epoch: 033/050 | Batch 300/461 | Cost: 0.5538\n", "Epoch: 033/050 | Batch 450/461 | Cost: 0.5642\n", "Epoch: 033/050\n", "Train ACC: 83.33 | Validation ACC: 88.40\n", "Time elapsed: 2.27 min\n", "Epoch: 034/050 | Batch 000/461 | Cost: 0.6287\n", "Epoch: 034/050 | Batch 150/461 | Cost: 0.4752\n", "Epoch: 034/050 | Batch 300/461 | Cost: 0.5957\n", "Epoch: 034/050 | Batch 450/461 | Cost: 0.4531\n", "Epoch: 034/050\n", "Train ACC: 83.50 | Validation ACC: 88.70\n", "Time elapsed: 2.34 min\n", "Epoch: 035/050 | Batch 000/461 | Cost: 0.5368\n", "Epoch: 035/050 | Batch 150/461 | Cost: 0.5658\n", "Epoch: 035/050 | Batch 300/461 | Cost: 0.6598\n", "Epoch: 035/050 | Batch 450/461 | Cost: 0.5858\n", "Epoch: 035/050\n", "Train ACC: 83.59 | Validation ACC: 88.50\n", "Time elapsed: 2.41 min\n", "Epoch: 036/050 | Batch 000/461 | Cost: 0.5557\n", "Epoch: 036/050 | Batch 150/461 | Cost: 0.4680\n", "Epoch: 036/050 | Batch 300/461 | Cost: 0.4905\n", "Epoch: 036/050 | Batch 450/461 | Cost: 0.9074\n", "Epoch: 036/050\n", "Train ACC: 83.67 | Validation ACC: 88.50\n", "Time elapsed: 2.48 min\n", "Epoch: 037/050 | Batch 000/461 | Cost: 0.6120\n", "Epoch: 037/050 | Batch 150/461 | Cost: 0.4668\n", "Epoch: 037/050 | Batch 300/461 | Cost: 0.5836\n", "Epoch: 037/050 | Batch 450/461 | Cost: 0.4536\n", "Epoch: 037/050\n", "Train ACC: 83.35 | Validation ACC: 88.80\n", "Time elapsed: 2.55 min\n", "Epoch: 038/050 | Batch 000/461 | Cost: 0.5380\n", "Epoch: 038/050 | Batch 150/461 | Cost: 0.4491\n", "Epoch: 038/050 | Batch 300/461 | Cost: 0.4500\n", "Epoch: 038/050 | Batch 450/461 | Cost: 0.6041\n", "Epoch: 038/050\n", "Train ACC: 83.69 | Validation ACC: 88.80\n", "Time elapsed: 2.61 min\n", "Epoch: 039/050 | Batch 000/461 | Cost: 0.4863\n", "Epoch: 039/050 | Batch 150/461 | Cost: 0.5673\n", "Epoch: 039/050 | Batch 300/461 | Cost: 0.4037\n", "Epoch: 039/050 | Batch 450/461 | Cost: 0.6392\n", "Epoch: 039/050\n", "Train ACC: 83.71 | Validation ACC: 88.70\n", "Time elapsed: 2.68 min\n", "Epoch: 040/050 | Batch 000/461 | Cost: 0.6707\n", "Epoch: 040/050 | Batch 150/461 | Cost: 0.5601\n", "Epoch: 040/050 | Batch 300/461 | Cost: 0.5265\n", "Epoch: 040/050 | Batch 450/461 | Cost: 0.4867\n", "Epoch: 040/050\n", "Train ACC: 83.76 | Validation ACC: 88.90\n", "Time elapsed: 2.75 min\n", "Epoch: 041/050 | Batch 000/461 | Cost: 0.5379\n", "Epoch: 041/050 | Batch 150/461 | Cost: 0.4588\n", "Epoch: 041/050 | Batch 300/461 | Cost: 0.5684\n", "Epoch: 041/050 | Batch 450/461 | Cost: 0.5547\n", "Epoch: 041/050\n", "Train ACC: 83.75 | Validation ACC: 88.60\n", "Time elapsed: 2.82 min\n", "Epoch: 042/050 | Batch 000/461 | Cost: 0.5714\n", "Epoch: 042/050 | Batch 150/461 | Cost: 0.3863\n", "Epoch: 042/050 | Batch 300/461 | Cost: 0.5142\n", "Epoch: 042/050 | Batch 450/461 | Cost: 0.6219\n", "Epoch: 042/050\n", "Train ACC: 83.79 | Validation ACC: 89.20\n", "Time elapsed: 2.89 min\n", "Epoch: 043/050 | Batch 000/461 | Cost: 0.5385\n", "Epoch: 043/050 | Batch 150/461 | Cost: 0.4801\n", "Epoch: 043/050 | Batch 300/461 | Cost: 0.6064\n", "Epoch: 043/050 | Batch 450/461 | Cost: 0.4959\n", "Epoch: 043/050\n", "Train ACC: 83.89 | Validation ACC: 88.80\n", "Time elapsed: 2.96 min\n", "Epoch: 044/050 | Batch 000/461 | Cost: 0.6742\n", "Epoch: 044/050 | Batch 150/461 | Cost: 0.5746\n", "Epoch: 044/050 | Batch 300/461 | Cost: 0.6846\n", "Epoch: 044/050 | Batch 450/461 | Cost: 0.6283\n", "Epoch: 044/050\n", "Train ACC: 83.91 | Validation ACC: 89.00\n", "Time elapsed: 3.03 min\n", "Epoch: 045/050 | Batch 000/461 | Cost: 0.5646\n", "Epoch: 045/050 | Batch 150/461 | Cost: 0.3776\n", "Epoch: 045/050 | Batch 300/461 | Cost: 0.5457\n", "Epoch: 045/050 | Batch 450/461 | Cost: 0.4897\n", "Epoch: 045/050\n", "Train ACC: 83.87 | Validation ACC: 89.10\n", "Time elapsed: 3.10 min\n", "Epoch: 046/050 | Batch 000/461 | Cost: 0.5300\n", "Epoch: 046/050 | Batch 150/461 | Cost: 0.6787\n", "Epoch: 046/050 | Batch 300/461 | Cost: 0.4310\n", "Epoch: 046/050 | Batch 450/461 | Cost: 0.5758\n", "Epoch: 046/050\n", "Train ACC: 84.01 | Validation ACC: 89.10\n", "Time elapsed: 3.17 min\n", "Epoch: 047/050 | Batch 000/461 | Cost: 0.6111\n", "Epoch: 047/050 | Batch 150/461 | Cost: 0.5679\n", "Epoch: 047/050 | Batch 300/461 | Cost: 0.6306\n", "Epoch: 047/050 | Batch 450/461 | Cost: 0.7292\n", "Epoch: 047/050\n", "Train ACC: 84.03 | Validation ACC: 89.20\n", "Time elapsed: 3.24 min\n", "Epoch: 048/050 | Batch 000/461 | Cost: 0.5925\n", "Epoch: 048/050 | Batch 150/461 | Cost: 0.6623\n", "Epoch: 048/050 | Batch 300/461 | Cost: 0.4188\n", "Epoch: 048/050 | Batch 450/461 | Cost: 0.3433\n", "Epoch: 048/050\n", "Train ACC: 83.89 | Validation ACC: 89.10\n", "Time elapsed: 3.31 min\n", "Epoch: 049/050 | Batch 000/461 | Cost: 0.4881\n", "Epoch: 049/050 | Batch 150/461 | Cost: 0.5040\n", "Epoch: 049/050 | Batch 300/461 | Cost: 0.5655\n", "Epoch: 049/050 | Batch 450/461 | Cost: 0.5264\n", "Epoch: 049/050\n", "Train ACC: 83.83 | Validation ACC: 88.60\n", "Time elapsed: 3.38 min\n", "Epoch: 050/050 | Batch 000/461 | Cost: 0.5284\n", "Epoch: 050/050 | Batch 150/461 | Cost: 0.6253\n", "Epoch: 050/050 | Batch 300/461 | Cost: 0.3891\n", "Epoch: 050/050 | Batch 450/461 | Cost: 0.4316\n", "Epoch: 050/050\n", "Train ACC: 83.90 | Validation ACC: 88.70\n", "Time elapsed: 3.45 min\n", "Total Training Time: 3.45 min\n" ] } ], "source": [ "def compute_acc(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for features, targets in data_loader:\n", " features = features.to(device)\n", " targets = targets.to(device)\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "\n", "cost_list = []\n", "train_acc_list, valid_acc_list = [], []\n", "\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " #################################################\n", " ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n", " ################################################\n", " cost_list.append(cost.item())\n", " if not batch_idx % 150:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " \n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " \n", " train_acc = compute_acc(model, train_loader, device=DEVICE)\n", " valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n", " \n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n", " f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n", " \n", " train_acc_list.append(train_acc)\n", " valid_acc_list.append(valid_acc)\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(cost_list, label='Minibatch cost')\n", "plt.plot(np.convolve(cost_list, \n", " np.ones(200,)/200, mode='valid'), \n", " label='Running average')\n", "\n", "plt.ylabel('Cross Entropy')\n", "plt.xlabel('Iteration')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n", "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n", "\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation ACC: 88.70%\n", "Test ACC: 84.55%\n" ] } ], "source": [ "with torch.set_grad_enabled(False):\n", " test_acc = compute_acc(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " valid_acc = compute_acc(model=model,\n", " data_loader=valid_loader,\n", " device=DEVICE)\n", " \n", "\n", "print(f'Validation ACC: {valid_acc:.2f}%')\n", "print(f'Test ACC: {test_acc:.2f}%')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.16.4\n", "torch 1.2.0\n", "matplotlib 3.1.0\n", "torchvision 0.4.0a0+6b959ee\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }