{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic Graph Neural Network with Edge Prediction on MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementing a very basic graph neural network (GNN) using a subnetwork for edge prediction. \n", "\n", "Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. \n", "\n", "In the related notebook, [gnn-basic-1.ipyb], the adjacency matrix of the pixels was basically just determined by the neighborhood pixels. Using a Gaussian filter, pixels were connected based on their Euclidean distance in the grid. In **this notebook**, the edges are predicted via a seperate neural network model \n", "\n", "\n", "```python\n", " self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 1),\n", " nn.Tanh())\n", "```\n", "\n", "\n", "Using the resulting adjacency matrix $A$, we can compute the output of a layer as \n", "\n", "$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$\n", "\n", "Here, $A$ is the $N \\times N$ adjacency matrix, and $X$ is the $N \\times C$ feature matrix (a 2D coordinate array, where $N$ is the total number of pixels -- $28 \\times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \\times P$, where $P$ would represent the number of classes if we have only a single hidden layer.\n", "\n", "\n", "- Inspired by and based on Boris Knyazev's tutorial at https://medium.com/@BorisAKnyazev/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-1-3d9fada3b80d." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "from scipy.spatial.distance import cdist\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import Subset\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.0005\n", "NUM_EPOCHS = 50\n", "BATCH_SIZE = 128\n", "IMG_SIZE = 28\n", "\n", "# Architecture\n", "NUM_CLASSES = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "train_indices = torch.arange(0, 59000)\n", "valid_indices = torch.arange(59000, 60000)\n", "\n", "custom_transform = transforms.Compose([transforms.ToTensor()])\n", "\n", "\n", "train_and_valid = datasets.MNIST(root='data', \n", " train=True, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "def make_coordinate_array(img_size, out_size=4):\n", " \n", " ### Make 2D coordinate array (for MNIST: 784x2)\n", " n_rows = img_size * img_size\n", " col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))\n", " coord = np.stack((col, row), axis=2).reshape(-1, 2)\n", " coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)\n", " coord = torch.from_numpy(coord).float()\n", " \n", " ### Reshape to [N, N, out_size]\n", " coord = torch.cat((coord.unsqueeze(0).repeat(n_rows, 1, int(out_size/2-1)),\n", " coord.unsqueeze(1).repeat(1, n_rows, 1)), dim=2)\n", " \n", " \n", " return coord\n", "\n", " \n", "\n", "class GraphNet(nn.Module):\n", " def __init__(self, img_size=28, coord_features=4, num_classes=10):\n", " super(GraphNet, self).__init__()\n", " \n", " n_rows = img_size**2\n", " self.fc = nn.Linear(n_rows, num_classes, bias=False)\n", "\n", " coord = make_coordinate_array(img_size, coord_features)\n", " self.register_buffer('coord', coord)\n", " \n", " ##########\n", " # Edge Predictor\n", " self.pred_edge_fc = nn.Sequential(nn.Linear(coord_features, 32), # coord -> hidden\n", " nn.ReLU(),\n", " nn.Linear(32, 1), # hidden -> edge\n", " nn.Tanh())\n", " \n", "\n", " \n", "\n", " def forward(self, x):\n", " B = x.size(0)\n", " \n", " ### Predict edges\n", " self.A = self.pred_edge_fc(self.coord).squeeze()\n", "\n", " ### Reshape Adjacency Matrix\n", " # [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW\n", " A_tensor = self.A.unsqueeze(0)\n", " # [1, N, N] Adj tensor -> [B, N, N] tensor\n", " A_tensor = self.A.expand(B, -1, -1)\n", " \n", " ### Reshape inputs\n", " # [B, C, H, W] => [B, H*W, 1]\n", " x_reshape = x.view(B, -1, 1)\n", " \n", " # bmm = batch matrix product to sum the neighbor features\n", " # Input: [B, N, N] x [B, N, 1]\n", " # Output: [B, N]\n", " avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1))\n", "\n", " logits = self.fc(avg_neighbor_features)\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)\n", "\n", "model = model.to(DEVICE)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/461 | Cost: 24.2727\n", "Epoch: 001/050 | Batch 150/461 | Cost: 2.2706\n", "Epoch: 001/050 | Batch 300/461 | Cost: 1.8713\n", "Epoch: 001/050 | Batch 450/461 | Cost: 1.5048\n", "Epoch: 001/050\n", "Train ACC: 50.39 | Validation ACC: 54.80\n", "Time elapsed: 0.25 min\n", "Epoch: 002/050 | Batch 000/461 | Cost: 1.4445\n", "Epoch: 002/050 | Batch 150/461 | Cost: 1.3288\n", "Epoch: 002/050 | Batch 300/461 | Cost: 1.1868\n", "Epoch: 002/050 | Batch 450/461 | Cost: 1.2040\n", "Epoch: 002/050\n", "Train ACC: 67.68 | Validation ACC: 71.40\n", "Time elapsed: 0.49 min\n", "Epoch: 003/050 | Batch 000/461 | Cost: 1.2128\n", "Epoch: 003/050 | Batch 150/461 | Cost: 0.9953\n", "Epoch: 003/050 | Batch 300/461 | Cost: 0.9818\n", "Epoch: 003/050 | Batch 450/461 | Cost: 1.0487\n", "Epoch: 003/050\n", "Train ACC: 68.09 | Validation ACC: 73.40\n", "Time elapsed: 0.74 min\n", "Epoch: 004/050 | Batch 000/461 | Cost: 1.0444\n", "Epoch: 004/050 | Batch 150/461 | Cost: 0.9064\n", "Epoch: 004/050 | Batch 300/461 | Cost: 0.9152\n", "Epoch: 004/050 | Batch 450/461 | Cost: 0.7396\n", "Epoch: 004/050\n", "Train ACC: 76.10 | Validation ACC: 80.20\n", "Time elapsed: 0.98 min\n", "Epoch: 005/050 | Batch 000/461 | Cost: 0.7698\n", "Epoch: 005/050 | Batch 150/461 | Cost: 0.8356\n", "Epoch: 005/050 | Batch 300/461 | Cost: 0.6544\n", "Epoch: 005/050 | Batch 450/461 | Cost: 0.8700\n", "Epoch: 005/050\n", "Train ACC: 80.30 | Validation ACC: 84.10\n", "Time elapsed: 1.22 min\n", "Epoch: 006/050 | Batch 000/461 | Cost: 0.6292\n", "Epoch: 006/050 | Batch 150/461 | Cost: 0.7779\n", "Epoch: 006/050 | Batch 300/461 | Cost: 0.5978\n", "Epoch: 006/050 | Batch 450/461 | Cost: 0.6260\n", "Epoch: 006/050\n", "Train ACC: 82.00 | Validation ACC: 85.60\n", "Time elapsed: 1.46 min\n", "Epoch: 007/050 | Batch 000/461 | Cost: 0.7172\n", "Epoch: 007/050 | Batch 150/461 | Cost: 0.6444\n", "Epoch: 007/050 | Batch 300/461 | Cost: 0.5620\n", "Epoch: 007/050 | Batch 450/461 | Cost: 0.5314\n", "Epoch: 007/050\n", "Train ACC: 82.90 | Validation ACC: 86.20\n", "Time elapsed: 1.71 min\n", "Epoch: 008/050 | Batch 000/461 | Cost: 0.6211\n", "Epoch: 008/050 | Batch 150/461 | Cost: 0.5004\n", "Epoch: 008/050 | Batch 300/461 | Cost: 0.5274\n", "Epoch: 008/050 | Batch 450/461 | Cost: 0.6611\n", "Epoch: 008/050\n", "Train ACC: 84.86 | Validation ACC: 87.90\n", "Time elapsed: 1.95 min\n", "Epoch: 009/050 | Batch 000/461 | Cost: 0.4017\n", "Epoch: 009/050 | Batch 150/461 | Cost: 0.7080\n", "Epoch: 009/050 | Batch 300/461 | Cost: 0.4298\n", "Epoch: 009/050 | Batch 450/461 | Cost: 0.4516\n", "Epoch: 009/050\n", "Train ACC: 85.89 | Validation ACC: 90.20\n", "Time elapsed: 2.19 min\n", "Epoch: 010/050 | Batch 000/461 | Cost: 0.4571\n", "Epoch: 010/050 | Batch 150/461 | Cost: 0.4976\n", "Epoch: 010/050 | Batch 300/461 | Cost: 0.6208\n", "Epoch: 010/050 | Batch 450/461 | Cost: 0.3780\n", "Epoch: 010/050\n", "Train ACC: 85.84 | Validation ACC: 89.40\n", "Time elapsed: 2.43 min\n", "Epoch: 011/050 | Batch 000/461 | Cost: 0.5262\n", "Epoch: 011/050 | Batch 150/461 | Cost: 0.4255\n", "Epoch: 011/050 | Batch 300/461 | Cost: 0.3840\n", "Epoch: 011/050 | Batch 450/461 | Cost: 0.4941\n", "Epoch: 011/050\n", "Train ACC: 85.73 | Validation ACC: 90.40\n", "Time elapsed: 2.68 min\n", "Epoch: 012/050 | Batch 000/461 | Cost: 0.3425\n", "Epoch: 012/050 | Batch 150/461 | Cost: 0.5059\n", "Epoch: 012/050 | Batch 300/461 | Cost: 0.6590\n", "Epoch: 012/050 | Batch 450/461 | Cost: 0.5481\n", "Epoch: 012/050\n", "Train ACC: 87.50 | Validation ACC: 91.30\n", "Time elapsed: 2.92 min\n", "Epoch: 013/050 | Batch 000/461 | Cost: 0.6081\n", "Epoch: 013/050 | Batch 150/461 | Cost: 0.4584\n", "Epoch: 013/050 | Batch 300/461 | Cost: 0.2856\n", "Epoch: 013/050 | Batch 450/461 | Cost: 0.4324\n", "Epoch: 013/050\n", "Train ACC: 87.35 | Validation ACC: 91.20\n", "Time elapsed: 3.17 min\n", "Epoch: 014/050 | Batch 000/461 | Cost: 0.4685\n", "Epoch: 014/050 | Batch 150/461 | Cost: 0.4492\n", "Epoch: 014/050 | Batch 300/461 | Cost: 0.3913\n", "Epoch: 014/050 | Batch 450/461 | Cost: 0.5154\n", "Epoch: 014/050\n", "Train ACC: 86.71 | Validation ACC: 91.20\n", "Time elapsed: 3.41 min\n", "Epoch: 015/050 | Batch 000/461 | Cost: 0.4526\n", "Epoch: 015/050 | Batch 150/461 | Cost: 0.4834\n", "Epoch: 015/050 | Batch 300/461 | Cost: 0.5208\n", "Epoch: 015/050 | Batch 450/461 | Cost: 0.3536\n", "Epoch: 015/050\n", "Train ACC: 85.21 | Validation ACC: 89.50\n", "Time elapsed: 3.66 min\n", "Epoch: 016/050 | Batch 000/461 | Cost: 0.6614\n", "Epoch: 016/050 | Batch 150/461 | Cost: 0.3036\n", "Epoch: 016/050 | Batch 300/461 | Cost: 0.3766\n", "Epoch: 016/050 | Batch 450/461 | Cost: 0.4550\n", "Epoch: 016/050\n", "Train ACC: 86.97 | Validation ACC: 92.10\n", "Time elapsed: 3.91 min\n", "Epoch: 017/050 | Batch 000/461 | Cost: 0.6241\n", "Epoch: 017/050 | Batch 150/461 | Cost: 0.3934\n", "Epoch: 017/050 | Batch 300/461 | Cost: 0.4330\n", "Epoch: 017/050 | Batch 450/461 | Cost: 0.5914\n", "Epoch: 017/050\n", "Train ACC: 88.12 | Validation ACC: 91.60\n", "Time elapsed: 4.15 min\n", "Epoch: 018/050 | Batch 000/461 | Cost: 0.3769\n", "Epoch: 018/050 | Batch 150/461 | Cost: 0.4817\n", "Epoch: 018/050 | Batch 300/461 | Cost: 0.4103\n", "Epoch: 018/050 | Batch 450/461 | Cost: 0.3727\n", "Epoch: 018/050\n", "Train ACC: 86.58 | Validation ACC: 90.90\n", "Time elapsed: 4.40 min\n", "Epoch: 019/050 | Batch 000/461 | Cost: 0.4098\n", "Epoch: 019/050 | Batch 150/461 | Cost: 0.4435\n", "Epoch: 019/050 | Batch 300/461 | Cost: 0.2952\n", "Epoch: 019/050 | Batch 450/461 | Cost: 0.3328\n", "Epoch: 019/050\n", "Train ACC: 88.65 | Validation ACC: 92.00\n", "Time elapsed: 4.64 min\n", "Epoch: 020/050 | Batch 000/461 | Cost: 0.5363\n", "Epoch: 020/050 | Batch 150/461 | Cost: 0.3143\n", "Epoch: 020/050 | Batch 300/461 | Cost: 0.5186\n", "Epoch: 020/050 | Batch 450/461 | Cost: 0.3806\n", "Epoch: 020/050\n", "Train ACC: 88.95 | Validation ACC: 92.70\n", "Time elapsed: 4.89 min\n", "Epoch: 021/050 | Batch 000/461 | Cost: 0.3810\n", "Epoch: 021/050 | Batch 150/461 | Cost: 0.2470\n", "Epoch: 021/050 | Batch 300/461 | Cost: 0.6154\n", "Epoch: 021/050 | Batch 450/461 | Cost: 0.3651\n", "Epoch: 021/050\n", "Train ACC: 88.31 | Validation ACC: 92.40\n", "Time elapsed: 5.13 min\n", "Epoch: 022/050 | Batch 000/461 | Cost: 0.3704\n", "Epoch: 022/050 | Batch 150/461 | Cost: 0.4338\n", "Epoch: 022/050 | Batch 300/461 | Cost: 0.4197\n", "Epoch: 022/050 | Batch 450/461 | Cost: 0.3304\n", "Epoch: 022/050\n", "Train ACC: 88.62 | Validation ACC: 91.90\n", "Time elapsed: 5.31 min\n", "Epoch: 023/050 | Batch 000/461 | Cost: 0.2825\n", "Epoch: 023/050 | Batch 150/461 | Cost: 0.4302\n", "Epoch: 023/050 | Batch 300/461 | Cost: 0.4738\n", "Epoch: 023/050 | Batch 450/461 | Cost: 0.4362\n", "Epoch: 023/050\n", "Train ACC: 89.02 | Validation ACC: 92.80\n", "Time elapsed: 5.44 min\n", "Epoch: 024/050 | Batch 000/461 | Cost: 0.2097\n", "Epoch: 024/050 | Batch 150/461 | Cost: 0.4440\n", "Epoch: 024/050 | Batch 300/461 | Cost: 0.4467\n", "Epoch: 024/050 | Batch 450/461 | Cost: 0.2744\n", "Epoch: 024/050\n", "Train ACC: 88.82 | Validation ACC: 92.40\n", "Time elapsed: 5.57 min\n", "Epoch: 025/050 | Batch 000/461 | Cost: 0.2734\n", "Epoch: 025/050 | Batch 150/461 | Cost: 0.3980\n", "Epoch: 025/050 | Batch 300/461 | Cost: 0.4395\n", "Epoch: 025/050 | Batch 450/461 | Cost: 0.2336\n", "Epoch: 025/050\n", "Train ACC: 89.59 | Validation ACC: 93.90\n", "Time elapsed: 5.70 min\n", "Epoch: 026/050 | Batch 000/461 | Cost: 0.3138\n", "Epoch: 026/050 | Batch 150/461 | Cost: 0.3772\n", "Epoch: 026/050 | Batch 300/461 | Cost: 0.2955\n", "Epoch: 026/050 | Batch 450/461 | Cost: 0.3747\n", "Epoch: 026/050\n", "Train ACC: 88.71 | Validation ACC: 92.70\n", "Time elapsed: 5.82 min\n", "Epoch: 027/050 | Batch 000/461 | Cost: 0.4107\n", "Epoch: 027/050 | Batch 150/461 | Cost: 0.4375\n", "Epoch: 027/050 | Batch 300/461 | Cost: 0.3802\n", "Epoch: 027/050 | Batch 450/461 | Cost: 0.3240\n", "Epoch: 027/050\n", "Train ACC: 87.90 | Validation ACC: 91.60\n", "Time elapsed: 5.95 min\n", "Epoch: 028/050 | Batch 000/461 | Cost: 0.5124\n", "Epoch: 028/050 | Batch 150/461 | Cost: 0.4980\n", "Epoch: 028/050 | Batch 300/461 | Cost: 0.3937\n", "Epoch: 028/050 | Batch 450/461 | Cost: 0.2704\n", "Epoch: 028/050\n", "Train ACC: 89.08 | Validation ACC: 92.30\n", "Time elapsed: 6.08 min\n", "Epoch: 029/050 | Batch 000/461 | Cost: 0.3328\n", "Epoch: 029/050 | Batch 150/461 | Cost: 0.3022\n", "Epoch: 029/050 | Batch 300/461 | Cost: 0.3222\n", "Epoch: 029/050 | Batch 450/461 | Cost: 0.3084\n", "Epoch: 029/050\n", "Train ACC: 89.30 | Validation ACC: 93.90\n", "Time elapsed: 6.21 min\n", "Epoch: 030/050 | Batch 000/461 | Cost: 0.4667\n", "Epoch: 030/050 | Batch 150/461 | Cost: 0.3290\n", "Epoch: 030/050 | Batch 300/461 | Cost: 0.3261\n", "Epoch: 030/050 | Batch 450/461 | Cost: 0.3347\n", "Epoch: 030/050\n", "Train ACC: 89.17 | Validation ACC: 93.60\n", "Time elapsed: 6.33 min\n", "Epoch: 031/050 | Batch 000/461 | Cost: 0.3486\n", "Epoch: 031/050 | Batch 150/461 | Cost: 0.2426\n", "Epoch: 031/050 | Batch 300/461 | Cost: 0.2748\n", "Epoch: 031/050 | Batch 450/461 | Cost: 0.2072\n", "Epoch: 031/050\n", "Train ACC: 89.17 | Validation ACC: 93.20\n", "Time elapsed: 6.46 min\n", "Epoch: 032/050 | Batch 000/461 | Cost: 0.3423\n", "Epoch: 032/050 | Batch 150/461 | Cost: 0.4924\n", "Epoch: 032/050 | Batch 300/461 | Cost: 0.4072\n", "Epoch: 032/050 | Batch 450/461 | Cost: 0.3611\n", "Epoch: 032/050\n", "Train ACC: 89.83 | Validation ACC: 94.30\n", "Time elapsed: 6.59 min\n", "Epoch: 033/050 | Batch 000/461 | Cost: 0.2461\n", "Epoch: 033/050 | Batch 150/461 | Cost: 0.2343\n", "Epoch: 033/050 | Batch 300/461 | Cost: 0.2891\n", "Epoch: 033/050 | Batch 450/461 | Cost: 0.3772\n", "Epoch: 033/050\n", "Train ACC: 88.81 | Validation ACC: 92.40\n", "Time elapsed: 6.72 min\n", "Epoch: 034/050 | Batch 000/461 | Cost: 0.3052\n", "Epoch: 034/050 | Batch 150/461 | Cost: 0.5129\n", "Epoch: 034/050 | Batch 300/461 | Cost: 0.3810\n", "Epoch: 034/050 | Batch 450/461 | Cost: 0.2906\n", "Epoch: 034/050\n", "Train ACC: 89.34 | Validation ACC: 93.10\n", "Time elapsed: 6.85 min\n", "Epoch: 035/050 | Batch 000/461 | Cost: 0.3604\n", "Epoch: 035/050 | Batch 150/461 | Cost: 0.3832\n", "Epoch: 035/050 | Batch 300/461 | Cost: 0.3632\n", "Epoch: 035/050 | Batch 450/461 | Cost: 0.3345\n", "Epoch: 035/050\n", "Train ACC: 89.74 | Validation ACC: 93.10\n", "Time elapsed: 6.98 min\n", "Epoch: 036/050 | Batch 000/461 | Cost: 0.3382\n", "Epoch: 036/050 | Batch 150/461 | Cost: 0.3754\n", "Epoch: 036/050 | Batch 300/461 | Cost: 0.4120\n", "Epoch: 036/050 | Batch 450/461 | Cost: 0.4710\n", "Epoch: 036/050\n", "Train ACC: 89.10 | Validation ACC: 93.90\n", "Time elapsed: 7.10 min\n", "Epoch: 037/050 | Batch 000/461 | Cost: 0.4466\n", "Epoch: 037/050 | Batch 150/461 | Cost: 0.3427\n", "Epoch: 037/050 | Batch 300/461 | Cost: 0.3301\n", "Epoch: 037/050 | Batch 450/461 | Cost: 0.4110\n", "Epoch: 037/050\n", "Train ACC: 89.95 | Validation ACC: 93.90\n", "Time elapsed: 7.23 min\n", "Epoch: 038/050 | Batch 000/461 | Cost: 0.2470\n", "Epoch: 038/050 | Batch 150/461 | Cost: 0.4719\n", "Epoch: 038/050 | Batch 300/461 | Cost: 0.3253\n", "Epoch: 038/050 | Batch 450/461 | Cost: 0.4324\n", "Epoch: 038/050\n", "Train ACC: 89.35 | Validation ACC: 93.50\n", "Time elapsed: 7.36 min\n", "Epoch: 039/050 | Batch 000/461 | Cost: 0.3058\n", "Epoch: 039/050 | Batch 150/461 | Cost: 0.4755\n", "Epoch: 039/050 | Batch 300/461 | Cost: 0.2981\n", "Epoch: 039/050 | Batch 450/461 | Cost: 0.4293\n", "Epoch: 039/050\n", "Train ACC: 89.51 | Validation ACC: 92.90\n", "Time elapsed: 7.48 min\n", "Epoch: 040/050 | Batch 000/461 | Cost: 0.3378\n", "Epoch: 040/050 | Batch 150/461 | Cost: 0.5137\n", "Epoch: 040/050 | Batch 300/461 | Cost: 0.2680\n", "Epoch: 040/050 | Batch 450/461 | Cost: 0.3397\n", "Epoch: 040/050\n", "Train ACC: 90.01 | Validation ACC: 93.70\n", "Time elapsed: 7.61 min\n", "Epoch: 041/050 | Batch 000/461 | Cost: 0.2766\n", "Epoch: 041/050 | Batch 150/461 | Cost: 0.2959\n", "Epoch: 041/050 | Batch 300/461 | Cost: 0.1930\n", "Epoch: 041/050 | Batch 450/461 | Cost: 0.3735\n", "Epoch: 041/050\n", "Train ACC: 89.45 | Validation ACC: 93.60\n", "Time elapsed: 7.74 min\n", "Epoch: 042/050 | Batch 000/461 | Cost: 0.2694\n", "Epoch: 042/050 | Batch 150/461 | Cost: 0.3575\n", "Epoch: 042/050 | Batch 300/461 | Cost: 0.4267\n", "Epoch: 042/050 | Batch 450/461 | Cost: 0.3332\n", "Epoch: 042/050\n", "Train ACC: 89.96 | Validation ACC: 93.30\n", "Time elapsed: 7.86 min\n", "Epoch: 043/050 | Batch 000/461 | Cost: 0.2288\n", "Epoch: 043/050 | Batch 150/461 | Cost: 0.4260\n", "Epoch: 043/050 | Batch 300/461 | Cost: 0.2835\n", "Epoch: 043/050 | Batch 450/461 | Cost: 0.2882\n", "Epoch: 043/050\n", "Train ACC: 89.91 | Validation ACC: 93.40\n", "Time elapsed: 7.99 min\n", "Epoch: 044/050 | Batch 000/461 | Cost: 0.3211\n", "Epoch: 044/050 | Batch 150/461 | Cost: 0.3061\n", "Epoch: 044/050 | Batch 300/461 | Cost: 0.3137\n", "Epoch: 044/050 | Batch 450/461 | Cost: 0.2978\n", "Epoch: 044/050\n", "Train ACC: 89.63 | Validation ACC: 94.30\n", "Time elapsed: 8.12 min\n", "Epoch: 045/050 | Batch 000/461 | Cost: 0.2325\n", "Epoch: 045/050 | Batch 150/461 | Cost: 0.3013\n", "Epoch: 045/050 | Batch 300/461 | Cost: 0.3732\n", "Epoch: 045/050 | Batch 450/461 | Cost: 0.3229\n", "Epoch: 045/050\n", "Train ACC: 90.00 | Validation ACC: 93.80\n", "Time elapsed: 8.25 min\n", "Epoch: 046/050 | Batch 000/461 | Cost: 0.2521\n", "Epoch: 046/050 | Batch 150/461 | Cost: 0.4440\n", "Epoch: 046/050 | Batch 300/461 | Cost: 0.3420\n", "Epoch: 046/050 | Batch 450/461 | Cost: 0.4288\n", "Epoch: 046/050\n", "Train ACC: 89.97 | Validation ACC: 93.40\n", "Time elapsed: 8.38 min\n", "Epoch: 047/050 | Batch 000/461 | Cost: 0.4605\n", "Epoch: 047/050 | Batch 150/461 | Cost: 0.3261\n", "Epoch: 047/050 | Batch 300/461 | Cost: 0.4493\n", "Epoch: 047/050 | Batch 450/461 | Cost: 0.4902\n", "Epoch: 047/050\n", "Train ACC: 89.60 | Validation ACC: 93.70\n", "Time elapsed: 8.50 min\n", "Epoch: 048/050 | Batch 000/461 | Cost: 0.4136\n", "Epoch: 048/050 | Batch 150/461 | Cost: 0.2952\n", "Epoch: 048/050 | Batch 300/461 | Cost: 0.4784\n", "Epoch: 048/050 | Batch 450/461 | Cost: 0.3044\n", "Epoch: 048/050\n", "Train ACC: 90.15 | Validation ACC: 94.60\n", "Time elapsed: 8.63 min\n", "Epoch: 049/050 | Batch 000/461 | Cost: 0.3802\n", "Epoch: 049/050 | Batch 150/461 | Cost: 0.4018\n", "Epoch: 049/050 | Batch 300/461 | Cost: 0.3197\n", "Epoch: 049/050 | Batch 450/461 | Cost: 0.4157\n", "Epoch: 049/050\n", "Train ACC: 89.91 | Validation ACC: 93.70\n", "Time elapsed: 8.76 min\n", "Epoch: 050/050 | Batch 000/461 | Cost: 0.4057\n", "Epoch: 050/050 | Batch 150/461 | Cost: 0.3687\n", "Epoch: 050/050 | Batch 300/461 | Cost: 0.3552\n", "Epoch: 050/050 | Batch 450/461 | Cost: 0.2707\n", "Epoch: 050/050\n", "Train ACC: 89.91 | Validation ACC: 93.00\n", "Time elapsed: 8.88 min\n", "Total Training Time: 8.88 min\n" ] } ], "source": [ "def compute_acc(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for features, targets in data_loader:\n", " features = features.to(device)\n", " targets = targets.to(device)\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "\n", "cost_list = []\n", "train_acc_list, valid_acc_list = [], []\n", "\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " #################################################\n", " ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n", " ################################################\n", " cost_list.append(cost.item())\n", " if not batch_idx % 150:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " \n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " \n", " train_acc = compute_acc(model, train_loader, device=DEVICE)\n", " valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n", " \n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n", " f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n", " \n", " train_acc_list.append(train_acc)\n", " valid_acc_list.append(valid_acc)\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# last adjacency matrix\n", "\n", "plt.imshow(model.A.to('cpu'));" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(cost_list, label='Minibatch cost')\n", "plt.plot(np.convolve(cost_list, \n", " np.ones(200,)/200, mode='valid'), \n", " label='Running average')\n", "\n", "plt.ylabel('Cross Entropy')\n", "plt.xlabel('Iteration')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n", "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n", "\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation ACC: 93.00%\n", "Test ACC: 90.36%\n" ] } ], "source": [ "with torch.set_grad_enabled(False):\n", " test_acc = compute_acc(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " valid_acc = compute_acc(model=model,\n", " data_loader=valid_loader,\n", " device=DEVICE)\n", " \n", "\n", "print(f'Validation ACC: {valid_acc:.2f}%')\n", "print(f'Test ACC: {test_acc:.2f}%')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torchvision 0.4.0a0+6b959ee\n", "matplotlib 3.1.0\n", "torch 1.2.0\n", "numpy 1.16.4\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }