{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Most Basic Graph Neural Network with Gaussian Filter on MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementing a very basic graph neural network (GNN) using a Gaussian filter. \n", "\n", "Here, the 28x28 image of a digit in MNIST represents the graph, where each pixel (i.e., cell in the grid) represents a particular node. The feature of that node is simply the pixel intensity in range [0, 1]. \n", "\n", "Here, the adjacency matrix of the pixels is basically just determined by their neighborhood pixels. Using a Gaussian filter, we connect pixels based on their Euclidean distance in the grid.\n", "\n", "Using this adjacency matrix $A$, we can compute the output of a layer as \n", "\n", "$$X^{(l+1)}=A X^{(l)} W^{(l)}.$$\n", "\n", "Here, $A$ is the $N \\times N$ adjacency matrix, and $X$ is the $N \\times C$ feature matrix (a 2D coordinate array, where $N$ is the total number of pixels -- $28 \\times 28 = 784$ in MNIST). $W$ is the weight matrix of shape $N \\times P$, where $P$ would represent the number of classes if we have only a single hidden layer.\n", "\n", "\n", "\n", "- Inspired by and based on Boris Knyazev's tutorial at https://medium.com/@BorisAKnyazev/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-1-3d9fada3b80d." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "from scipy.spatial.distance import cdist\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import Subset\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "RANDOM_SEED = 1\n", "LEARNING_RATE = 0.05\n", "NUM_EPOCHS = 20\n", "BATCH_SIZE = 128\n", "IMG_SIZE = 28\n", "\n", "# Architecture\n", "NUM_CLASSES = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST Dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "train_indices = torch.arange(0, 59000)\n", "valid_indices = torch.arange(59000, 60000)\n", "\n", "custom_transform = transforms.Compose([transforms.ToTensor()])\n", "\n", "\n", "train_and_valid = datasets.MNIST(root='data', \n", " train=True, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=custom_transform,\n", " download=True)\n", "\n", "train_dataset = Subset(train_and_valid, train_indices)\n", "valid_dataset = Subset(train_and_valid, valid_indices)\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=True)\n", "\n", "valid_loader = DataLoader(dataset=valid_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def precompute_adjacency_matrix(img_size):\n", " col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))\n", " \n", " # N = img_size^2\n", " # construct 2D coordinate array (shape N x 2) and normalize\n", " # in range [0, 1]\n", " coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size\n", "\n", " # compute pairwise distance matrix (N x N)\n", " dist = cdist(coord, coord, metric='euclidean')\n", " \n", " # Apply Gaussian filter\n", " sigma = 0.05 * np.pi\n", " A = np.exp(- dist / sigma ** 2)\n", " A[A < 0.01] = 0\n", " A = torch.from_numpy(A).float()\n", "\n", " # Normalization as per (Kipf & Welling, ICLR 2017)\n", " D = A.sum(1) # nodes degree (N,)\n", " D_hat = (D + 1e-5) ** (-0.5)\n", " A_hat = D_hat.view(-1, 1) * A * D_hat.view(1, -1) # N,N\n", " \n", " return A_hat" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(precompute_adjacency_matrix(28));" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", " \n", "\n", "class GraphNet(nn.Module):\n", " def __init__(self, img_size=28, num_classes=10):\n", " super(GraphNet, self).__init__()\n", " \n", " n_rows = img_size**2\n", " self.fc = nn.Linear(n_rows, num_classes, bias=False)\n", "\n", " A = precompute_adjacency_matrix(img_size)\n", " self.register_buffer('A', A)\n", "\n", " \n", "\n", " def forward(self, x):\n", " \n", " B = x.size(0) # Batch size\n", "\n", " ### Reshape Adjacency Matrix\n", " # [N, N] Adj. matrix -> [1, N, N] Adj tensor where N = HxW\n", " A_tensor = self.A.unsqueeze(0)\n", " # [1, N, N] Adj tensor -> [B, N, N] tensor\n", " A_tensor = self.A.expand(B, -1, -1)\n", " \n", " ### Reshape inputs\n", " # [B, C, H, W] => [B, H*W, 1]\n", " x_reshape = x.view(B, -1, 1)\n", " \n", " # bmm = batch matrix product to sum the neighbor features\n", " # Input: [B, N, N] x [B, N, 1]\n", " # Output: [B, N]\n", " avg_neighbor_features = (torch.bmm(A_tensor, x_reshape).view(B, -1))\n", " \n", " logits = self.fc(avg_neighbor_features)\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(RANDOM_SEED)\n", "model = GraphNet(img_size=IMG_SIZE, num_classes=NUM_CLASSES)\n", "\n", "model = model.to(DEVICE)\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/020 | Batch 000/461 | Cost: 2.2677\n", "Epoch: 001/020 | Batch 150/461 | Cost: 0.8999\n", "Epoch: 001/020 | Batch 300/461 | Cost: 0.6701\n", "Epoch: 001/020 | Batch 450/461 | Cost: 0.4905\n", "Epoch: 001/020\n", "Train ACC: 87.02 | Validation ACC: 92.40\n", "Time elapsed: 0.08 min\n", "Epoch: 002/020 | Batch 000/461 | Cost: 0.5868\n", "Epoch: 002/020 | Batch 150/461 | Cost: 0.4526\n", "Epoch: 002/020 | Batch 300/461 | Cost: 0.4192\n", "Epoch: 002/020 | Batch 450/461 | Cost: 0.3647\n", "Epoch: 002/020\n", "Train ACC: 88.25 | Validation ACC: 92.60\n", "Time elapsed: 0.14 min\n", "Epoch: 003/020 | Batch 000/461 | Cost: 0.4316\n", "Epoch: 003/020 | Batch 150/461 | Cost: 0.4165\n", "Epoch: 003/020 | Batch 300/461 | Cost: 0.4130\n", "Epoch: 003/020 | Batch 450/461 | Cost: 0.3991\n", "Epoch: 003/020\n", "Train ACC: 88.80 | Validation ACC: 93.30\n", "Time elapsed: 0.22 min\n", "Epoch: 004/020 | Batch 000/461 | Cost: 0.3537\n", "Epoch: 004/020 | Batch 150/461 | Cost: 0.3460\n", "Epoch: 004/020 | Batch 300/461 | Cost: 0.4011\n", "Epoch: 004/020 | Batch 450/461 | Cost: 0.4666\n", "Epoch: 004/020\n", "Train ACC: 89.34 | Validation ACC: 93.40\n", "Time elapsed: 0.29 min\n", "Epoch: 005/020 | Batch 000/461 | Cost: 0.4523\n", "Epoch: 005/020 | Batch 150/461 | Cost: 0.4006\n", "Epoch: 005/020 | Batch 300/461 | Cost: 0.4396\n", "Epoch: 005/020 | Batch 450/461 | Cost: 0.4509\n", "Epoch: 005/020\n", "Train ACC: 89.65 | Validation ACC: 93.40\n", "Time elapsed: 0.36 min\n", "Epoch: 006/020 | Batch 000/461 | Cost: 0.3381\n", "Epoch: 006/020 | Batch 150/461 | Cost: 0.3627\n", "Epoch: 006/020 | Batch 300/461 | Cost: 0.2736\n", "Epoch: 006/020 | Batch 450/461 | Cost: 0.3932\n", "Epoch: 006/020\n", "Train ACC: 89.85 | Validation ACC: 93.50\n", "Time elapsed: 0.42 min\n", "Epoch: 007/020 | Batch 000/461 | Cost: 0.4984\n", "Epoch: 007/020 | Batch 150/461 | Cost: 0.3718\n", "Epoch: 007/020 | Batch 300/461 | Cost: 0.2903\n", "Epoch: 007/020 | Batch 450/461 | Cost: 0.4040\n", "Epoch: 007/020\n", "Train ACC: 90.02 | Validation ACC: 93.50\n", "Time elapsed: 0.50 min\n", "Epoch: 008/020 | Batch 000/461 | Cost: 0.5250\n", "Epoch: 008/020 | Batch 150/461 | Cost: 0.3481\n", "Epoch: 008/020 | Batch 300/461 | Cost: 0.3838\n", "Epoch: 008/020 | Batch 450/461 | Cost: 0.4789\n", "Epoch: 008/020\n", "Train ACC: 90.14 | Validation ACC: 93.90\n", "Time elapsed: 0.57 min\n", "Epoch: 009/020 | Batch 000/461 | Cost: 0.3028\n", "Epoch: 009/020 | Batch 150/461 | Cost: 0.3982\n", "Epoch: 009/020 | Batch 300/461 | Cost: 0.4042\n", "Epoch: 009/020 | Batch 450/461 | Cost: 0.5471\n", "Epoch: 009/020\n", "Train ACC: 90.26 | Validation ACC: 93.90\n", "Time elapsed: 0.64 min\n", "Epoch: 010/020 | Batch 000/461 | Cost: 0.2279\n", "Epoch: 010/020 | Batch 150/461 | Cost: 0.2992\n", "Epoch: 010/020 | Batch 300/461 | Cost: 0.4507\n", "Epoch: 010/020 | Batch 450/461 | Cost: 0.2165\n", "Epoch: 010/020\n", "Train ACC: 90.40 | Validation ACC: 93.90\n", "Time elapsed: 0.71 min\n", "Epoch: 011/020 | Batch 000/461 | Cost: 0.5089\n", "Epoch: 011/020 | Batch 150/461 | Cost: 0.2480\n", "Epoch: 011/020 | Batch 300/461 | Cost: 0.3782\n", "Epoch: 011/020 | Batch 450/461 | Cost: 0.3228\n", "Epoch: 011/020\n", "Train ACC: 90.47 | Validation ACC: 93.40\n", "Time elapsed: 0.78 min\n", "Epoch: 012/020 | Batch 000/461 | Cost: 0.2597\n", "Epoch: 012/020 | Batch 150/461 | Cost: 0.2955\n", "Epoch: 012/020 | Batch 300/461 | Cost: 0.2243\n", "Epoch: 012/020 | Batch 450/461 | Cost: 0.2967\n", "Epoch: 012/020\n", "Train ACC: 90.58 | Validation ACC: 93.60\n", "Time elapsed: 0.85 min\n", "Epoch: 013/020 | Batch 000/461 | Cost: 0.3367\n", "Epoch: 013/020 | Batch 150/461 | Cost: 0.3696\n", "Epoch: 013/020 | Batch 300/461 | Cost: 0.2744\n", "Epoch: 013/020 | Batch 450/461 | Cost: 0.4097\n", "Epoch: 013/020\n", "Train ACC: 90.65 | Validation ACC: 93.80\n", "Time elapsed: 0.92 min\n", "Epoch: 014/020 | Batch 000/461 | Cost: 0.2629\n", "Epoch: 014/020 | Batch 150/461 | Cost: 0.3282\n", "Epoch: 014/020 | Batch 300/461 | Cost: 0.2407\n", "Epoch: 014/020 | Batch 450/461 | Cost: 0.2714\n", "Epoch: 014/020\n", "Train ACC: 90.66 | Validation ACC: 93.80\n", "Time elapsed: 0.99 min\n", "Epoch: 015/020 | Batch 000/461 | Cost: 0.2497\n", "Epoch: 015/020 | Batch 150/461 | Cost: 0.3774\n", "Epoch: 015/020 | Batch 300/461 | Cost: 0.3405\n", "Epoch: 015/020 | Batch 450/461 | Cost: 0.4727\n", "Epoch: 015/020\n", "Train ACC: 90.81 | Validation ACC: 93.90\n", "Time elapsed: 1.06 min\n", "Epoch: 016/020 | Batch 000/461 | Cost: 0.4100\n", "Epoch: 016/020 | Batch 150/461 | Cost: 0.3284\n", "Epoch: 016/020 | Batch 300/461 | Cost: 0.3974\n", "Epoch: 016/020 | Batch 450/461 | Cost: 0.2978\n", "Epoch: 016/020\n", "Train ACC: 90.86 | Validation ACC: 93.90\n", "Time elapsed: 1.13 min\n", "Epoch: 017/020 | Batch 000/461 | Cost: 0.2101\n", "Epoch: 017/020 | Batch 150/461 | Cost: 0.3024\n", "Epoch: 017/020 | Batch 300/461 | Cost: 0.2714\n", "Epoch: 017/020 | Batch 450/461 | Cost: 0.2259\n", "Epoch: 017/020\n", "Train ACC: 90.91 | Validation ACC: 93.90\n", "Time elapsed: 1.20 min\n", "Epoch: 018/020 | Batch 000/461 | Cost: 0.3154\n", "Epoch: 018/020 | Batch 150/461 | Cost: 0.2534\n", "Epoch: 018/020 | Batch 300/461 | Cost: 0.3008\n", "Epoch: 018/020 | Batch 450/461 | Cost: 0.2815\n", "Epoch: 018/020\n", "Train ACC: 90.98 | Validation ACC: 93.90\n", "Time elapsed: 1.27 min\n", "Epoch: 019/020 | Batch 000/461 | Cost: 0.2850\n", "Epoch: 019/020 | Batch 150/461 | Cost: 0.2086\n", "Epoch: 019/020 | Batch 300/461 | Cost: 0.4104\n", "Epoch: 019/020 | Batch 450/461 | Cost: 0.2749\n", "Epoch: 019/020\n", "Train ACC: 90.94 | Validation ACC: 94.00\n", "Time elapsed: 1.35 min\n", "Epoch: 020/020 | Batch 000/461 | Cost: 0.4211\n", "Epoch: 020/020 | Batch 150/461 | Cost: 0.2129\n", "Epoch: 020/020 | Batch 300/461 | Cost: 0.2256\n", "Epoch: 020/020 | Batch 450/461 | Cost: 0.5096\n", "Epoch: 020/020\n", "Train ACC: 91.02 | Validation ACC: 94.30\n", "Time elapsed: 1.42 min\n", "Total Training Time: 1.42 min\n" ] } ], "source": [ "def compute_acc(model, data_loader, device):\n", " correct_pred, num_examples = 0, 0\n", " for features, targets in data_loader:\n", " features = features.to(device)\n", " targets = targets.to(device)\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "\n", "start_time = time.time()\n", "\n", "cost_list = []\n", "train_acc_list, valid_acc_list = [], []\n", "\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " \n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(DEVICE)\n", " targets = targets.to(DEVICE)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " #################################################\n", " ### CODE ONLY FOR LOGGING BEYOND THIS POINT\n", " ################################################\n", " cost_list.append(cost.item())\n", " if not batch_idx % 150:\n", " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", " f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n", " f' Cost: {cost:.4f}')\n", "\n", " \n", "\n", " model.eval()\n", " with torch.set_grad_enabled(False): # save memory during inference\n", " \n", " train_acc = compute_acc(model, train_loader, device=DEVICE)\n", " valid_acc = compute_acc(model, valid_loader, device=DEVICE)\n", " \n", " print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d}\\n'\n", " f'Train ACC: {train_acc:.2f} | Validation ACC: {valid_acc:.2f}')\n", " \n", " train_acc_list.append(train_acc)\n", " valid_acc_list.append(valid_acc)\n", " \n", " elapsed = (time.time() - start_time)/60\n", " print(f'Time elapsed: {elapsed:.2f} min')\n", " \n", "elapsed = (time.time() - start_time)/60\n", "print(f'Total Training Time: {elapsed:.2f} min')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(cost_list, label='Minibatch cost')\n", "plt.plot(np.convolve(cost_list, \n", " np.ones(200,)/200, mode='valid'), \n", " label='Running average')\n", "\n", "plt.ylabel('Cross Entropy')\n", "plt.xlabel('Iteration')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(1, NUM_EPOCHS+1), train_acc_list, label='Training')\n", "plt.plot(np.arange(1, NUM_EPOCHS+1), valid_acc_list, label='Validation')\n", "\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation ACC: 94.30%\n", "Test ACC: 91.63%\n" ] } ], "source": [ "with torch.set_grad_enabled(False):\n", " test_acc = compute_acc(model=model,\n", " data_loader=test_loader,\n", " device=DEVICE)\n", " \n", " valid_acc = compute_acc(model=model,\n", " data_loader=valid_loader,\n", " device=DEVICE)\n", " \n", "\n", "print(f'Validation ACC: {valid_acc:.2f}%')\n", "print(f'Test ACC: {test_acc:.2f}%')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torchvision 0.4.0a0+6b959ee\n", "torch 1.2.0\n", "numpy 1.16.4\n", "matplotlib 3.1.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }