{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Clipping" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Certain types of deep neural networks, especially, simple ones without any other type regularization and a relatively large number of layers, can suffer from exploding gradient problems. The exploding gradient problem is a scenario where large loss gradients accumulate during backpropagation, which will eventually result in very large weight updates during training. As a consequence, the updates will be very unstable and fluctuate a lot, which often causes severe problems during training. This is also a particular problem for unbounded activation functions such as ReLU.\n", "\n", "One common, classic technique for avoiding exploding gradient problems is the so-called gradient clipping approach. Here, we simply set gradient values above or below a certain threshold to a user-specified min or max value. In PyTorch, there are several ways for performing gradient clipping. \n", "\n", "**1 - Basic Clipping**\n", "\n", "The simplest approach to gradient clipping in PyTorch is by using the [`torch.nn.utils.clip_grad_value_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_value_) function. For example, if we have instantiated a PyTorch model from a model class based on `torch.nn.Module` (as usual), we can add the following line of code in order to clip the gradients to [-1, 1] range:\n", "\n", "```python\n", "torch.nn.utils.clip_grad_value_(parameters=model.parameters(), \n", " clip_value=1.)\n", "\n", "```\n", "\n", "However, notice that via this approach, we can only specify a single clip value, which will be used for both the upper and lower bound such that gradients will be clipped to the range [-`clip_value`, `clip_value`].\n", "\n", "\n", "**2 - Custom Lower and Upper Bounds**\n", "\n", "If we want to clip the gradients to an unsymmetric interval around zero, say [-0.1, 1.0], we can take a different approach by defining a backwards hook:\n", "\n", "```python\n", "for param in model.parameters():\n", " param.register_hook(lambda gradient: torch.clamp(gradient, -0.1, 1.0))\n", "```\n", "\n", "This backward hook only needs to be defined once after instantiating the model. Then, each time after calling the `backward` method, it will clip the gradients before running the `model.step()` method.\n", "\n", "**3 - Norm-clipping**\n", "\n", "Lastly, there's a third clipping option, [`torch.nn.utils.clip_grad_norm_`](https://pytorch.org/docs/stable/nn.html?highlight=clip#torch.nn.utils.clip_grad_norm_), which clips the gradients using a vector norm as follows:\n", "\n", "\n", "> `torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)`\n", "\n", ">Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "import torch.nn.functional as F\n", "import torch\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([64, 1, 28, 28])\n", "Image label dimensions: torch.Size([64])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "random_seed = 1\n", "learning_rate = 0.01\n", "num_epochs = 10\n", "batch_size = 64\n", "\n", "# Architecture\n", "num_features = 784\n", "num_hidden_1 = 256\n", "num_hidden_2 = 128\n", "num_hidden_3 = 64\n", "num_hidden_4 = 32\n", "num_classes = 10\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def compute_accuracy(net, data_loader):\n", " net.eval()\n", " correct_pred, num_examples = 0, 0\n", " with torch.no_grad():\n", " for features, targets in data_loader:\n", " features = features.view(-1, 28*28).to(device)\n", " targets = targets.to(device)\n", " logits, probas = net(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class MultilayerPerceptron(torch.nn.Module):\n", "\n", " def __init__(self, num_features, num_classes):\n", " super(MultilayerPerceptron, self).__init__()\n", " \n", " ### 1st hidden layer\n", " self.linear_1 = torch.nn.Linear(num_features, num_hidden_1)\n", "\n", " ### 2nd hidden layer\n", " self.linear_2 = torch.nn.Linear(num_hidden_1, num_hidden_2)\n", "\n", " ### 3rd hidden layer\n", " self.linear_3 = torch.nn.Linear(num_hidden_2, num_hidden_3)\n", " \n", " ### 4th hidden layer\n", " self.linear_4 = torch.nn.Linear(num_hidden_3, num_hidden_4)\n", " \n", " \n", " ### Output layer\n", " self.linear_out = torch.nn.Linear(num_hidden_4, num_classes)\n", "\n", " \n", " def forward(self, x):\n", " out = self.linear_1(x)\n", " out = F.relu(out)\n", " out = self.linear_2(out)\n", " out = F.relu(out)\n", " out = self.linear_3(out)\n", " out = F.relu(out)\n", " out = self.linear_4(out)\n", " out = F.relu(out)\n", " logits = self.linear_out(out)\n", " probas = F.log_softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1 - Basic Clipping" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n", "Epoch: 001/010 | Batch 050/938 | Cost: 0.6427\n", "Epoch: 001/010 | Batch 100/938 | Cost: 0.3220\n", "Epoch: 001/010 | Batch 150/938 | Cost: 0.3492\n", "Epoch: 001/010 | Batch 200/938 | Cost: 0.4505\n", "Epoch: 001/010 | Batch 250/938 | Cost: 0.1510\n", "Epoch: 001/010 | Batch 300/938 | Cost: 0.2062\n", "Epoch: 001/010 | Batch 350/938 | Cost: 0.1287\n", "Epoch: 001/010 | Batch 400/938 | Cost: 0.1714\n", "Epoch: 001/010 | Batch 450/938 | Cost: 0.3522\n", "Epoch: 001/010 | Batch 500/938 | Cost: 0.4268\n", "Epoch: 001/010 | Batch 550/938 | Cost: 0.0133\n", "Epoch: 001/010 | Batch 600/938 | Cost: 0.1868\n", "Epoch: 001/010 | Batch 650/938 | Cost: 0.2312\n", "Epoch: 001/010 | Batch 700/938 | Cost: 0.1471\n", "Epoch: 001/010 | Batch 750/938 | Cost: 0.1321\n", "Epoch: 001/010 | Batch 800/938 | Cost: 0.2776\n", "Epoch: 001/010 | Batch 850/938 | Cost: 0.2223\n", "Epoch: 001/010 | Batch 900/938 | Cost: 0.1812\n", "Epoch: 001/010 training accuracy: 94.72%\n", "Time elapsed: 0.25 min\n", "Epoch: 002/010 | Batch 000/938 | Cost: 0.2080\n", "Epoch: 002/010 | Batch 050/938 | Cost: 0.2177\n", "Epoch: 002/010 | Batch 100/938 | Cost: 0.1090\n", "Epoch: 002/010 | Batch 150/938 | Cost: 0.1225\n", "Epoch: 002/010 | Batch 200/938 | Cost: 0.2514\n", "Epoch: 002/010 | Batch 250/938 | Cost: 0.1093\n", "Epoch: 002/010 | Batch 300/938 | Cost: 0.0626\n", "Epoch: 002/010 | Batch 350/938 | Cost: 0.1242\n", "Epoch: 002/010 | Batch 400/938 | Cost: 0.0168\n", "Epoch: 002/010 | Batch 450/938 | Cost: 0.2678\n", "Epoch: 002/010 | Batch 500/938 | Cost: 0.1761\n", "Epoch: 002/010 | Batch 550/938 | Cost: 0.2607\n", "Epoch: 002/010 | Batch 600/938 | Cost: 0.1324\n", "Epoch: 002/010 | Batch 650/938 | Cost: 0.2334\n", "Epoch: 002/010 | Batch 700/938 | Cost: 0.1510\n", "Epoch: 002/010 | Batch 750/938 | Cost: 0.1456\n", "Epoch: 002/010 | Batch 800/938 | Cost: 0.2882\n", "Epoch: 002/010 | Batch 850/938 | Cost: 0.1485\n", "Epoch: 002/010 | Batch 900/938 | Cost: 0.2007\n", "Epoch: 002/010 training accuracy: 96.83%\n", "Time elapsed: 0.49 min\n", "Epoch: 003/010 | Batch 000/938 | Cost: 0.0550\n", "Epoch: 003/010 | Batch 050/938 | Cost: 0.0555\n", "Epoch: 003/010 | Batch 100/938 | Cost: 0.1040\n", "Epoch: 003/010 | Batch 150/938 | Cost: 0.2290\n", "Epoch: 003/010 | Batch 200/938 | Cost: 0.0506\n", "Epoch: 003/010 | Batch 250/938 | Cost: 0.1028\n", "Epoch: 003/010 | Batch 300/938 | Cost: 0.0381\n", "Epoch: 003/010 | Batch 350/938 | Cost: 0.1593\n", "Epoch: 003/010 | Batch 400/938 | Cost: 0.0637\n", "Epoch: 003/010 | Batch 450/938 | Cost: 0.0127\n", "Epoch: 003/010 | Batch 500/938 | Cost: 0.4391\n", "Epoch: 003/010 | Batch 550/938 | Cost: 0.0110\n", "Epoch: 003/010 | Batch 600/938 | Cost: 0.1959\n", "Epoch: 003/010 | Batch 650/938 | Cost: 0.1020\n", "Epoch: 003/010 | Batch 700/938 | Cost: 0.0206\n", "Epoch: 003/010 | Batch 750/938 | Cost: 0.2747\n", "Epoch: 003/010 | Batch 800/938 | Cost: 0.1192\n", "Epoch: 003/010 | Batch 850/938 | Cost: 0.0115\n", "Epoch: 003/010 | Batch 900/938 | Cost: 0.2476\n", "Epoch: 003/010 training accuracy: 97.65%\n", "Time elapsed: 0.74 min\n", "Epoch: 004/010 | Batch 000/938 | Cost: 0.0875\n", "Epoch: 004/010 | Batch 050/938 | Cost: 0.0335\n", "Epoch: 004/010 | Batch 100/938 | Cost: 0.0530\n", "Epoch: 004/010 | Batch 150/938 | Cost: 0.4291\n", "Epoch: 004/010 | Batch 200/938 | Cost: 0.0634\n", "Epoch: 004/010 | Batch 250/938 | Cost: 0.0437\n", "Epoch: 004/010 | Batch 300/938 | Cost: 0.0547\n", "Epoch: 004/010 | Batch 350/938 | Cost: 0.1602\n", "Epoch: 004/010 | Batch 400/938 | Cost: 0.1071\n", "Epoch: 004/010 | Batch 450/938 | Cost: 0.0351\n", "Epoch: 004/010 | Batch 500/938 | Cost: 0.0712\n", "Epoch: 004/010 | Batch 550/938 | Cost: 0.1261\n", "Epoch: 004/010 | Batch 600/938 | Cost: 0.1212\n", "Epoch: 004/010 | Batch 650/938 | Cost: 0.0802\n", "Epoch: 004/010 | Batch 700/938 | Cost: 0.0844\n", "Epoch: 004/010 | Batch 750/938 | Cost: 0.1496\n", "Epoch: 004/010 | Batch 800/938 | Cost: 0.1543\n", "Epoch: 004/010 | Batch 850/938 | Cost: 0.0182\n", "Epoch: 004/010 | Batch 900/938 | Cost: 0.0433\n", "Epoch: 004/010 training accuracy: 97.08%\n", "Time elapsed: 0.98 min\n", "Epoch: 005/010 | Batch 000/938 | Cost: 0.1570\n", "Epoch: 005/010 | Batch 050/938 | Cost: 0.0291\n", "Epoch: 005/010 | Batch 100/938 | Cost: 0.0363\n", "Epoch: 005/010 | Batch 150/938 | Cost: 0.0320\n", "Epoch: 005/010 | Batch 200/938 | Cost: 0.0322\n", "Epoch: 005/010 | Batch 250/938 | Cost: 0.0720\n", "Epoch: 005/010 | Batch 300/938 | Cost: 0.0497\n", "Epoch: 005/010 | Batch 350/938 | Cost: 0.1058\n", "Epoch: 005/010 | Batch 400/938 | Cost: 0.2139\n", "Epoch: 005/010 | Batch 450/938 | Cost: 0.0602\n", "Epoch: 005/010 | Batch 500/938 | Cost: 0.0689\n", "Epoch: 005/010 | Batch 550/938 | Cost: 0.1355\n", "Epoch: 005/010 | Batch 600/938 | Cost: 0.1659\n", "Epoch: 005/010 | Batch 650/938 | Cost: 0.1504\n", "Epoch: 005/010 | Batch 700/938 | Cost: 0.0403\n", "Epoch: 005/010 | Batch 750/938 | Cost: 0.3422\n", "Epoch: 005/010 | Batch 800/938 | Cost: 0.3299\n", "Epoch: 005/010 | Batch 850/938 | Cost: 0.2327\n", "Epoch: 005/010 | Batch 900/938 | Cost: 0.0171\n", "Epoch: 005/010 training accuracy: 97.51%\n", "Time elapsed: 1.23 min\n", "Epoch: 006/010 | Batch 000/938 | Cost: 0.0548\n", "Epoch: 006/010 | Batch 050/938 | Cost: 0.2781\n", "Epoch: 006/010 | Batch 100/938 | Cost: 0.0657\n", "Epoch: 006/010 | Batch 150/938 | Cost: 0.0444\n", "Epoch: 006/010 | Batch 200/938 | Cost: 0.0057\n", "Epoch: 006/010 | Batch 250/938 | Cost: 0.1058\n", "Epoch: 006/010 | Batch 300/938 | Cost: 0.1610\n", "Epoch: 006/010 | Batch 350/938 | Cost: 0.0353\n", "Epoch: 006/010 | Batch 400/938 | Cost: 0.2474\n", "Epoch: 006/010 | Batch 450/938 | Cost: 0.1038\n", "Epoch: 006/010 | Batch 500/938 | Cost: 0.2918\n", "Epoch: 006/010 | Batch 550/938 | Cost: 0.1360\n", "Epoch: 006/010 | Batch 600/938 | Cost: 0.1977\n", "Epoch: 006/010 | Batch 650/938 | Cost: 0.0314\n", "Epoch: 006/010 | Batch 700/938 | Cost: 0.0968\n", "Epoch: 006/010 | Batch 750/938 | Cost: 0.2215\n", "Epoch: 006/010 | Batch 800/938 | Cost: 0.0328\n", "Epoch: 006/010 | Batch 850/938 | Cost: 0.2423\n", "Epoch: 006/010 | Batch 900/938 | Cost: 0.1192\n", "Epoch: 006/010 training accuracy: 97.47%\n", "Time elapsed: 1.48 min\n", "Epoch: 007/010 | Batch 000/938 | Cost: 0.0126\n", "Epoch: 007/010 | Batch 050/938 | Cost: 0.0735\n", "Epoch: 007/010 | Batch 100/938 | Cost: 0.2426\n", "Epoch: 007/010 | Batch 150/938 | Cost: 0.0736\n", "Epoch: 007/010 | Batch 200/938 | Cost: 0.1387\n", "Epoch: 007/010 | Batch 250/938 | Cost: 0.2173\n", "Epoch: 007/010 | Batch 300/938 | Cost: 0.0127\n", "Epoch: 007/010 | Batch 350/938 | Cost: 0.1131\n", "Epoch: 007/010 | Batch 400/938 | Cost: 0.2219\n", "Epoch: 007/010 | Batch 450/938 | Cost: 0.0127\n", "Epoch: 007/010 | Batch 500/938 | Cost: 0.0905\n", "Epoch: 007/010 | Batch 550/938 | Cost: 0.2466\n", "Epoch: 007/010 | Batch 600/938 | Cost: 0.0065\n", "Epoch: 007/010 | Batch 650/938 | Cost: 0.1477\n", "Epoch: 007/010 | Batch 700/938 | Cost: 0.0183\n", "Epoch: 007/010 | Batch 750/938 | Cost: 0.0534\n", "Epoch: 007/010 | Batch 800/938 | Cost: 0.1139\n", "Epoch: 007/010 | Batch 850/938 | Cost: 0.1177\n", "Epoch: 007/010 | Batch 900/938 | Cost: 0.0662\n", "Epoch: 007/010 training accuracy: 97.74%\n", "Time elapsed: 1.72 min\n", "Epoch: 008/010 | Batch 000/938 | Cost: 0.0276\n", "Epoch: 008/010 | Batch 050/938 | Cost: 0.1275\n", "Epoch: 008/010 | Batch 100/938 | Cost: 0.2151\n", "Epoch: 008/010 | Batch 150/938 | Cost: 0.0204\n", "Epoch: 008/010 | Batch 200/938 | Cost: 0.2154\n", "Epoch: 008/010 | Batch 250/938 | Cost: 0.0271\n", "Epoch: 008/010 | Batch 300/938 | Cost: 0.0523\n", "Epoch: 008/010 | Batch 350/938 | Cost: 0.1604\n", "Epoch: 008/010 | Batch 400/938 | Cost: 0.0888\n", "Epoch: 008/010 | Batch 450/938 | Cost: 0.0045\n", "Epoch: 008/010 | Batch 500/938 | Cost: 0.0288\n", "Epoch: 008/010 | Batch 550/938 | Cost: 0.1140\n", "Epoch: 008/010 | Batch 600/938 | Cost: 0.0849\n", "Epoch: 008/010 | Batch 650/938 | Cost: 0.0216\n", "Epoch: 008/010 | Batch 700/938 | Cost: 0.0294\n", "Epoch: 008/010 | Batch 750/938 | Cost: 0.0995\n", "Epoch: 008/010 | Batch 800/938 | Cost: 0.1159\n", "Epoch: 008/010 | Batch 850/938 | Cost: 0.1599\n", "Epoch: 008/010 | Batch 900/938 | Cost: 0.1317\n", "Epoch: 008/010 training accuracy: 98.29%\n", "Time elapsed: 1.97 min\n", "Epoch: 009/010 | Batch 000/938 | Cost: 0.1071\n", "Epoch: 009/010 | Batch 050/938 | Cost: 0.0580\n", "Epoch: 009/010 | Batch 100/938 | Cost: 0.1777\n", "Epoch: 009/010 | Batch 150/938 | Cost: 0.2850\n", "Epoch: 009/010 | Batch 200/938 | Cost: 0.1229\n", "Epoch: 009/010 | Batch 250/938 | Cost: 0.0672\n", "Epoch: 009/010 | Batch 300/938 | Cost: 0.2009\n", "Epoch: 009/010 | Batch 350/938 | Cost: 0.0110\n", "Epoch: 009/010 | Batch 400/938 | Cost: 0.2604\n", "Epoch: 009/010 | Batch 450/938 | Cost: 0.0801\n", "Epoch: 009/010 | Batch 500/938 | Cost: 0.0092\n", "Epoch: 009/010 | Batch 550/938 | Cost: 0.1360\n", "Epoch: 009/010 | Batch 600/938 | Cost: 0.0664\n", "Epoch: 009/010 | Batch 650/938 | Cost: 0.0886\n", "Epoch: 009/010 | Batch 700/938 | Cost: 0.0630\n", "Epoch: 009/010 | Batch 750/938 | Cost: 0.0784\n", "Epoch: 009/010 | Batch 800/938 | Cost: 0.1736\n", "Epoch: 009/010 | Batch 850/938 | Cost: 0.0855\n", "Epoch: 009/010 | Batch 900/938 | Cost: 0.2815\n", "Epoch: 009/010 training accuracy: 97.74%\n", "Time elapsed: 2.21 min\n", "Epoch: 010/010 | Batch 000/938 | Cost: 0.0024\n", "Epoch: 010/010 | Batch 050/938 | Cost: 0.0497\n", "Epoch: 010/010 | Batch 100/938 | Cost: 0.0888\n", "Epoch: 010/010 | Batch 150/938 | Cost: 0.1719\n", "Epoch: 010/010 | Batch 200/938 | Cost: 0.1729\n", "Epoch: 010/010 | Batch 250/938 | Cost: 0.0543\n", "Epoch: 010/010 | Batch 300/938 | Cost: 0.3770\n", "Epoch: 010/010 | Batch 350/938 | Cost: 0.0270\n", "Epoch: 010/010 | Batch 400/938 | Cost: 0.1400\n", "Epoch: 010/010 | Batch 450/938 | Cost: 0.0526\n", "Epoch: 010/010 | Batch 500/938 | Cost: 0.1984\n", "Epoch: 010/010 | Batch 550/938 | Cost: 0.1677\n", "Epoch: 010/010 | Batch 600/938 | Cost: 0.0550\n", "Epoch: 010/010 | Batch 650/938 | Cost: 0.0294\n", "Epoch: 010/010 | Batch 700/938 | Cost: 0.0465\n", "Epoch: 010/010 | Batch 750/938 | Cost: 0.1103\n", "Epoch: 010/010 | Batch 800/938 | Cost: 0.0272\n", "Epoch: 010/010 | Batch 850/938 | Cost: 0.1376\n", "Epoch: 010/010 | Batch 900/938 | Cost: 0.0279\n", "Epoch: 010/010 training accuracy: 98.09%\n", "Time elapsed: 2.46 min\n", "Total Training Time: 2.46 min\n" ] } ], "source": [ "torch.manual_seed(random_seed)\n", "model = MultilayerPerceptron(num_features=num_features,\n", " num_classes=num_classes)\n", "\n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) \n", "\n", "###################################################################\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.view(-1, 28*28).to(device)\n", " targets = targets.to(device)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " \n", " #########################################################\n", " #########################################################\n", " ### GRADIENT CLIPPING\n", " torch.nn.utils.clip_grad_value_(model.parameters(), 1.)\n", " #########################################################\n", " #########################################################\n", " \n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", "\n", " with torch.set_grad_enabled(False):\n", " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", " epoch+1, num_epochs, \n", " compute_accuracy(model, train_loader)))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 96.80%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2 - Custom Lower and Upper Bounds" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n", "Epoch: 001/010 | Batch 050/938 | Cost: 0.5977\n", "Epoch: 001/010 | Batch 100/938 | Cost: 0.4369\n", "Epoch: 001/010 | Batch 150/938 | Cost: 0.3053\n", "Epoch: 001/010 | Batch 200/938 | Cost: 0.3661\n", "Epoch: 001/010 | Batch 250/938 | Cost: 0.1908\n", "Epoch: 001/010 | Batch 300/938 | Cost: 0.2845\n", "Epoch: 001/010 | Batch 350/938 | Cost: 0.1928\n", "Epoch: 001/010 | Batch 400/938 | Cost: 0.2715\n", "Epoch: 001/010 | Batch 450/938 | Cost: 0.2338\n", "Epoch: 001/010 | Batch 500/938 | Cost: 0.3923\n", "Epoch: 001/010 | Batch 550/938 | Cost: 0.0973\n", "Epoch: 001/010 | Batch 600/938 | Cost: 0.3142\n", "Epoch: 001/010 | Batch 650/938 | Cost: 0.5024\n", "Epoch: 001/010 | Batch 700/938 | Cost: 0.1549\n", "Epoch: 001/010 | Batch 750/938 | Cost: 0.1906\n", "Epoch: 001/010 | Batch 800/938 | Cost: 0.3325\n", "Epoch: 001/010 | Batch 850/938 | Cost: 0.2060\n", "Epoch: 001/010 | Batch 900/938 | Cost: 0.1301\n", "Epoch: 001/010 training accuracy: 94.76%\n", "Time elapsed: 0.24 min\n", "Epoch: 002/010 | Batch 000/938 | Cost: 0.2553\n", "Epoch: 002/010 | Batch 050/938 | Cost: 0.1858\n", "Epoch: 002/010 | Batch 100/938 | Cost: 0.2514\n", "Epoch: 002/010 | Batch 150/938 | Cost: 0.1413\n", "Epoch: 002/010 | Batch 200/938 | Cost: 0.3071\n", "Epoch: 002/010 | Batch 250/938 | Cost: 0.6133\n", "Epoch: 002/010 | Batch 300/938 | Cost: 0.1657\n", "Epoch: 002/010 | Batch 350/938 | Cost: 0.0828\n", "Epoch: 002/010 | Batch 400/938 | Cost: 0.0733\n", "Epoch: 002/010 | Batch 450/938 | Cost: 0.3012\n", "Epoch: 002/010 | Batch 500/938 | Cost: 0.1857\n", "Epoch: 002/010 | Batch 550/938 | Cost: 0.3618\n", "Epoch: 002/010 | Batch 600/938 | Cost: 0.0777\n", "Epoch: 002/010 | Batch 650/938 | Cost: 0.2648\n", "Epoch: 002/010 | Batch 700/938 | Cost: 0.0242\n", "Epoch: 002/010 | Batch 750/938 | Cost: 0.1050\n", "Epoch: 002/010 | Batch 800/938 | Cost: 0.2148\n", "Epoch: 002/010 | Batch 850/938 | Cost: 0.0817\n", "Epoch: 002/010 | Batch 900/938 | Cost: 0.1354\n", "Epoch: 002/010 training accuracy: 97.04%\n", "Time elapsed: 0.49 min\n", "Epoch: 003/010 | Batch 000/938 | Cost: 0.1346\n", "Epoch: 003/010 | Batch 050/938 | Cost: 0.0825\n", "Epoch: 003/010 | Batch 100/938 | Cost: 0.0771\n", "Epoch: 003/010 | Batch 150/938 | Cost: 0.2360\n", "Epoch: 003/010 | Batch 200/938 | Cost: 0.0730\n", "Epoch: 003/010 | Batch 250/938 | Cost: 0.1499\n", "Epoch: 003/010 | Batch 300/938 | Cost: 0.0410\n", "Epoch: 003/010 | Batch 350/938 | Cost: 0.2091\n", "Epoch: 003/010 | Batch 400/938 | Cost: 0.0738\n", "Epoch: 003/010 | Batch 450/938 | Cost: 0.0889\n", "Epoch: 003/010 | Batch 500/938 | Cost: 0.3630\n", "Epoch: 003/010 | Batch 550/938 | Cost: 0.0312\n", "Epoch: 003/010 | Batch 600/938 | Cost: 0.0782\n", "Epoch: 003/010 | Batch 650/938 | Cost: 0.1753\n", "Epoch: 003/010 | Batch 700/938 | Cost: 0.0286\n", "Epoch: 003/010 | Batch 750/938 | Cost: 0.2166\n", "Epoch: 003/010 | Batch 800/938 | Cost: 0.0627\n", "Epoch: 003/010 | Batch 850/938 | Cost: 0.0204\n", "Epoch: 003/010 | Batch 900/938 | Cost: 0.2867\n", "Epoch: 003/010 training accuracy: 96.72%\n", "Time elapsed: 0.73 min\n", "Epoch: 004/010 | Batch 000/938 | Cost: 0.0207\n", "Epoch: 004/010 | Batch 050/938 | Cost: 0.0499\n", "Epoch: 004/010 | Batch 100/938 | Cost: 0.1858\n", "Epoch: 004/010 | Batch 150/938 | Cost: 0.2015\n", "Epoch: 004/010 | Batch 200/938 | Cost: 0.0285\n", "Epoch: 004/010 | Batch 250/938 | Cost: 0.0029\n", "Epoch: 004/010 | Batch 300/938 | Cost: 0.1746\n", "Epoch: 004/010 | Batch 350/938 | Cost: 0.3149\n", "Epoch: 004/010 | Batch 400/938 | Cost: 0.1773\n", "Epoch: 004/010 | Batch 450/938 | Cost: 0.1013\n", "Epoch: 004/010 | Batch 500/938 | Cost: 0.1665\n", "Epoch: 004/010 | Batch 550/938 | Cost: 0.1540\n", "Epoch: 004/010 | Batch 600/938 | Cost: 0.1822\n", "Epoch: 004/010 | Batch 650/938 | Cost: 0.1506\n", "Epoch: 004/010 | Batch 700/938 | Cost: 0.0224\n", "Epoch: 004/010 | Batch 750/938 | Cost: 0.1400\n", "Epoch: 004/010 | Batch 800/938 | Cost: 0.2262\n", "Epoch: 004/010 | Batch 850/938 | Cost: 0.0679\n", "Epoch: 004/010 | Batch 900/938 | Cost: 0.0020\n", "Epoch: 004/010 training accuracy: 97.63%\n", "Time elapsed: 0.98 min\n", "Epoch: 005/010 | Batch 000/938 | Cost: 0.0508\n", "Epoch: 005/010 | Batch 050/938 | Cost: 0.0585\n", "Epoch: 005/010 | Batch 100/938 | Cost: 0.1441\n", "Epoch: 005/010 | Batch 150/938 | Cost: 0.0862\n", "Epoch: 005/010 | Batch 200/938 | Cost: 0.0284\n", "Epoch: 005/010 | Batch 250/938 | Cost: 0.0977\n", "Epoch: 005/010 | Batch 300/938 | Cost: 0.0565\n", "Epoch: 005/010 | Batch 350/938 | Cost: 0.0272\n", "Epoch: 005/010 | Batch 400/938 | Cost: 0.2603\n", "Epoch: 005/010 | Batch 450/938 | Cost: 0.1202\n", "Epoch: 005/010 | Batch 500/938 | Cost: 0.0612\n", "Epoch: 005/010 | Batch 550/938 | Cost: 0.0833\n", "Epoch: 005/010 | Batch 600/938 | Cost: 0.1666\n", "Epoch: 005/010 | Batch 650/938 | Cost: 0.2642\n", "Epoch: 005/010 | Batch 700/938 | Cost: 0.1884\n", "Epoch: 005/010 | Batch 750/938 | Cost: 0.1608\n", "Epoch: 005/010 | Batch 800/938 | Cost: 0.1029\n", "Epoch: 005/010 | Batch 850/938 | Cost: 0.1178\n", "Epoch: 005/010 | Batch 900/938 | Cost: 0.0709\n", "Epoch: 005/010 training accuracy: 97.58%\n", "Time elapsed: 1.23 min\n", "Epoch: 006/010 | Batch 000/938 | Cost: 0.0642\n", "Epoch: 006/010 | Batch 050/938 | Cost: 0.3518\n", "Epoch: 006/010 | Batch 100/938 | Cost: 0.1134\n", "Epoch: 006/010 | Batch 150/938 | Cost: 0.0821\n", "Epoch: 006/010 | Batch 200/938 | Cost: 0.0645\n", "Epoch: 006/010 | Batch 250/938 | Cost: 0.0486\n", "Epoch: 006/010 | Batch 300/938 | Cost: 0.0972\n", "Epoch: 006/010 | Batch 350/938 | Cost: 0.2861\n", "Epoch: 006/010 | Batch 400/938 | Cost: 0.1126\n", "Epoch: 006/010 | Batch 450/938 | Cost: 0.1479\n", "Epoch: 006/010 | Batch 500/938 | Cost: 0.2181\n", "Epoch: 006/010 | Batch 550/938 | Cost: 0.0674\n", "Epoch: 006/010 | Batch 600/938 | Cost: 0.0705\n", "Epoch: 006/010 | Batch 650/938 | Cost: 0.1032\n", "Epoch: 006/010 | Batch 700/938 | Cost: 0.1529\n", "Epoch: 006/010 | Batch 750/938 | Cost: 0.2484\n", "Epoch: 006/010 | Batch 800/938 | Cost: 0.0432\n", "Epoch: 006/010 | Batch 850/938 | Cost: 0.0821\n", "Epoch: 006/010 | Batch 900/938 | Cost: 0.1152\n", "Epoch: 006/010 training accuracy: 97.09%\n", "Time elapsed: 1.47 min\n", "Epoch: 007/010 | Batch 000/938 | Cost: 0.0418\n", "Epoch: 007/010 | Batch 050/938 | Cost: 0.0527\n", "Epoch: 007/010 | Batch 100/938 | Cost: 0.3778\n", "Epoch: 007/010 | Batch 150/938 | Cost: 0.1742\n", "Epoch: 007/010 | Batch 200/938 | Cost: 0.0725\n", "Epoch: 007/010 | Batch 250/938 | Cost: 0.1187\n", "Epoch: 007/010 | Batch 300/938 | Cost: 0.0980\n", "Epoch: 007/010 | Batch 350/938 | Cost: 0.0077\n", "Epoch: 007/010 | Batch 400/938 | Cost: 0.1274\n", "Epoch: 007/010 | Batch 450/938 | Cost: 0.1387\n", "Epoch: 007/010 | Batch 500/938 | Cost: 0.1959\n", "Epoch: 007/010 | Batch 550/938 | Cost: 0.0874\n", "Epoch: 007/010 | Batch 600/938 | Cost: 0.2559\n", "Epoch: 007/010 | Batch 650/938 | Cost: 0.1413\n", "Epoch: 007/010 | Batch 700/938 | Cost: 0.1285\n", "Epoch: 007/010 | Batch 750/938 | Cost: 0.1931\n", "Epoch: 007/010 | Batch 800/938 | Cost: 0.1151\n", "Epoch: 007/010 | Batch 850/938 | Cost: 0.1889\n", "Epoch: 007/010 | Batch 900/938 | Cost: 0.5518\n", "Epoch: 007/010 training accuracy: 86.62%\n", "Time elapsed: 1.72 min\n", "Epoch: 008/010 | Batch 000/938 | Cost: 0.3283\n", "Epoch: 008/010 | Batch 050/938 | Cost: 0.1818\n", "Epoch: 008/010 | Batch 100/938 | Cost: 0.1827\n", "Epoch: 008/010 | Batch 150/938 | Cost: 0.0844\n", "Epoch: 008/010 | Batch 200/938 | Cost: 0.4017\n", "Epoch: 008/010 | Batch 250/938 | Cost: 0.0129\n", "Epoch: 008/010 | Batch 300/938 | Cost: 0.0155\n", "Epoch: 008/010 | Batch 350/938 | Cost: 0.1844\n", "Epoch: 008/010 | Batch 400/938 | Cost: 0.1146\n", "Epoch: 008/010 | Batch 450/938 | Cost: 0.0566\n", "Epoch: 008/010 | Batch 500/938 | Cost: 0.0895\n", "Epoch: 008/010 | Batch 550/938 | Cost: 0.1851\n", "Epoch: 008/010 | Batch 600/938 | Cost: 0.1134\n", "Epoch: 008/010 | Batch 650/938 | Cost: 0.0838\n", "Epoch: 008/010 | Batch 700/938 | Cost: 0.1157\n", "Epoch: 008/010 | Batch 750/938 | Cost: 0.2275\n", "Epoch: 008/010 | Batch 800/938 | Cost: 0.5753\n", "Epoch: 008/010 | Batch 850/938 | Cost: 0.8735\n", "Epoch: 008/010 | Batch 900/938 | Cost: 0.7114\n", "Epoch: 008/010 training accuracy: 85.51%\n", "Time elapsed: 1.97 min\n", "Epoch: 009/010 | Batch 000/938 | Cost: 0.4851\n", "Epoch: 009/010 | Batch 050/938 | Cost: 0.4595\n", "Epoch: 009/010 | Batch 100/938 | Cost: 0.1939\n", "Epoch: 009/010 | Batch 150/938 | Cost: 0.1813\n", "Epoch: 009/010 | Batch 200/938 | Cost: 0.4969\n", "Epoch: 009/010 | Batch 250/938 | Cost: 0.4874\n", "Epoch: 009/010 | Batch 300/938 | Cost: 0.1605\n", "Epoch: 009/010 | Batch 350/938 | Cost: 0.0899\n", "Epoch: 009/010 | Batch 400/938 | Cost: 0.3318\n", "Epoch: 009/010 | Batch 450/938 | Cost: 0.0524\n", "Epoch: 009/010 | Batch 500/938 | Cost: 0.0215\n", "Epoch: 009/010 | Batch 550/938 | Cost: 0.0997\n", "Epoch: 009/010 | Batch 600/938 | Cost: 0.0541\n", "Epoch: 009/010 | Batch 650/938 | Cost: 0.3480\n", "Epoch: 009/010 | Batch 700/938 | Cost: 0.0736\n", "Epoch: 009/010 | Batch 750/938 | Cost: 0.1682\n", "Epoch: 009/010 | Batch 800/938 | Cost: 0.2877\n", "Epoch: 009/010 | Batch 850/938 | Cost: 0.0539\n", "Epoch: 009/010 | Batch 900/938 | Cost: 0.2708\n", "Epoch: 009/010 training accuracy: 95.67%\n", "Time elapsed: 2.21 min\n", "Epoch: 010/010 | Batch 000/938 | Cost: 0.0531\n", "Epoch: 010/010 | Batch 050/938 | Cost: 0.0453\n", "Epoch: 010/010 | Batch 100/938 | Cost: 1.8852\n", "Epoch: 010/010 | Batch 150/938 | Cost: 0.1455\n", "Epoch: 010/010 | Batch 200/938 | Cost: 0.2089\n", "Epoch: 010/010 | Batch 250/938 | Cost: 0.0155\n", "Epoch: 010/010 | Batch 300/938 | Cost: 0.9183\n", "Epoch: 010/010 | Batch 350/938 | Cost: 0.2231\n", "Epoch: 010/010 | Batch 400/938 | Cost: 0.3704\n", "Epoch: 010/010 | Batch 450/938 | Cost: 0.1086\n", "Epoch: 010/010 | Batch 500/938 | Cost: 0.3775\n", "Epoch: 010/010 | Batch 550/938 | Cost: 0.4196\n", "Epoch: 010/010 | Batch 600/938 | Cost: 0.2836\n", "Epoch: 010/010 | Batch 650/938 | Cost: 0.1170\n", "Epoch: 010/010 | Batch 700/938 | Cost: 0.2631\n", "Epoch: 010/010 | Batch 750/938 | Cost: 0.1400\n", "Epoch: 010/010 | Batch 800/938 | Cost: 0.1048\n", "Epoch: 010/010 | Batch 850/938 | Cost: 0.7937\n", "Epoch: 010/010 | Batch 900/938 | Cost: 0.2107\n", "Epoch: 010/010 training accuracy: 87.98%\n", "Time elapsed: 2.46 min\n", "Total Training Time: 2.46 min\n" ] } ], "source": [ "torch.manual_seed(random_seed)\n", "model = MultilayerPerceptron(num_features=num_features,\n", " num_classes=num_classes)\n", "\n", "#########################################################\n", "#########################################################\n", "### GRADIENT CLIPPING\n", "for p in model.parameters():\n", " p.register_hook(lambda grad: torch.clamp(grad, -0.1, 1.0))\n", "#########################################################\n", "#########################################################\n", " \n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) \n", "\n", "###################################################################\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.view(-1, 28*28).to(device)\n", " targets = targets.to(device)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", "\n", " with torch.set_grad_enabled(False):\n", " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", " epoch+1, num_epochs, \n", " compute_accuracy(model, train_loader)))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 86.94%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3 - Norm-clipping" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 000/938 | Cost: 2.3054\n", "Epoch: 001/010 | Batch 050/938 | Cost: 0.5121\n", "Epoch: 001/010 | Batch 100/938 | Cost: 0.3424\n", "Epoch: 001/010 | Batch 150/938 | Cost: 0.2765\n", "Epoch: 001/010 | Batch 200/938 | Cost: 0.5126\n", "Epoch: 001/010 | Batch 250/938 | Cost: 0.1481\n", "Epoch: 001/010 | Batch 300/938 | Cost: 0.2240\n", "Epoch: 001/010 | Batch 350/938 | Cost: 0.1948\n", "Epoch: 001/010 | Batch 400/938 | Cost: 0.0655\n", "Epoch: 001/010 | Batch 450/938 | Cost: 0.1893\n", "Epoch: 001/010 | Batch 500/938 | Cost: 0.4133\n", "Epoch: 001/010 | Batch 550/938 | Cost: 0.0375\n", "Epoch: 001/010 | Batch 600/938 | Cost: 0.2691\n", "Epoch: 001/010 | Batch 650/938 | Cost: 0.3342\n", "Epoch: 001/010 | Batch 700/938 | Cost: 0.1662\n", "Epoch: 001/010 | Batch 750/938 | Cost: 0.0702\n", "Epoch: 001/010 | Batch 800/938 | Cost: 0.4246\n", "Epoch: 001/010 | Batch 850/938 | Cost: 0.2282\n", "Epoch: 001/010 | Batch 900/938 | Cost: 0.0459\n", "Epoch: 001/010 training accuracy: 94.99%\n", "Time elapsed: 0.25 min\n", "Epoch: 002/010 | Batch 000/938 | Cost: 0.2188\n", "Epoch: 002/010 | Batch 050/938 | Cost: 0.3042\n", "Epoch: 002/010 | Batch 100/938 | Cost: 0.1391\n", "Epoch: 002/010 | Batch 150/938 | Cost: 0.1453\n", "Epoch: 002/010 | Batch 200/938 | Cost: 0.3031\n", "Epoch: 002/010 | Batch 250/938 | Cost: 0.1398\n", "Epoch: 002/010 | Batch 300/938 | Cost: 0.0868\n", "Epoch: 002/010 | Batch 350/938 | Cost: 0.1679\n", "Epoch: 002/010 | Batch 400/938 | Cost: 0.0480\n", "Epoch: 002/010 | Batch 450/938 | Cost: 0.2823\n", "Epoch: 002/010 | Batch 500/938 | Cost: 0.2307\n", "Epoch: 002/010 | Batch 550/938 | Cost: 0.1610\n", "Epoch: 002/010 | Batch 600/938 | Cost: 0.0972\n", "Epoch: 002/010 | Batch 650/938 | Cost: 0.3210\n", "Epoch: 002/010 | Batch 700/938 | Cost: 0.0697\n", "Epoch: 002/010 | Batch 750/938 | Cost: 0.0879\n", "Epoch: 002/010 | Batch 800/938 | Cost: 0.2113\n", "Epoch: 002/010 | Batch 850/938 | Cost: 0.2496\n", "Epoch: 002/010 | Batch 900/938 | Cost: 0.2453\n", "Epoch: 002/010 training accuracy: 96.15%\n", "Time elapsed: 0.49 min\n", "Epoch: 003/010 | Batch 000/938 | Cost: 0.1779\n", "Epoch: 003/010 | Batch 050/938 | Cost: 0.0618\n", "Epoch: 003/010 | Batch 100/938 | Cost: 0.0570\n", "Epoch: 003/010 | Batch 150/938 | Cost: 0.2510\n", "Epoch: 003/010 | Batch 200/938 | Cost: 0.1193\n", "Epoch: 003/010 | Batch 250/938 | Cost: 0.2530\n", "Epoch: 003/010 | Batch 300/938 | Cost: 0.1220\n", "Epoch: 003/010 | Batch 350/938 | Cost: 0.2401\n", "Epoch: 003/010 | Batch 400/938 | Cost: 0.0520\n", "Epoch: 003/010 | Batch 450/938 | Cost: 0.0262\n", "Epoch: 003/010 | Batch 500/938 | Cost: 0.2961\n", "Epoch: 003/010 | Batch 550/938 | Cost: 0.0030\n", "Epoch: 003/010 | Batch 600/938 | Cost: 0.1998\n", "Epoch: 003/010 | Batch 650/938 | Cost: 0.1968\n", "Epoch: 003/010 | Batch 700/938 | Cost: 0.0499\n", "Epoch: 003/010 | Batch 750/938 | Cost: 0.1742\n", "Epoch: 003/010 | Batch 800/938 | Cost: 0.1034\n", "Epoch: 003/010 | Batch 850/938 | Cost: 0.0437\n", "Epoch: 003/010 | Batch 900/938 | Cost: 0.1414\n", "Epoch: 003/010 training accuracy: 97.30%\n", "Time elapsed: 0.74 min\n", "Epoch: 004/010 | Batch 000/938 | Cost: 0.1098\n", "Epoch: 004/010 | Batch 050/938 | Cost: 0.0060\n", "Epoch: 004/010 | Batch 100/938 | Cost: 0.3551\n", "Epoch: 004/010 | Batch 150/938 | Cost: 0.3143\n", "Epoch: 004/010 | Batch 200/938 | Cost: 0.0527\n", "Epoch: 004/010 | Batch 250/938 | Cost: 0.0204\n", "Epoch: 004/010 | Batch 300/938 | Cost: 0.0289\n", "Epoch: 004/010 | Batch 350/938 | Cost: 0.2386\n", "Epoch: 004/010 | Batch 400/938 | Cost: 0.0694\n", "Epoch: 004/010 | Batch 450/938 | Cost: 0.1200\n", "Epoch: 004/010 | Batch 500/938 | Cost: 0.0797\n", "Epoch: 004/010 | Batch 550/938 | Cost: 0.0891\n", "Epoch: 004/010 | Batch 600/938 | Cost: 0.3322\n", "Epoch: 004/010 | Batch 650/938 | Cost: 0.1640\n", "Epoch: 004/010 | Batch 700/938 | Cost: 0.1170\n", "Epoch: 004/010 | Batch 750/938 | Cost: 0.2028\n", "Epoch: 004/010 | Batch 800/938 | Cost: 0.2188\n", "Epoch: 004/010 | Batch 850/938 | Cost: 0.0575\n", "Epoch: 004/010 | Batch 900/938 | Cost: 0.0180\n", "Epoch: 004/010 training accuracy: 96.86%\n", "Time elapsed: 0.98 min\n", "Epoch: 005/010 | Batch 000/938 | Cost: 0.0779\n", "Epoch: 005/010 | Batch 050/938 | Cost: 0.1183\n", "Epoch: 005/010 | Batch 100/938 | Cost: 0.1184\n", "Epoch: 005/010 | Batch 150/938 | Cost: 0.0815\n", "Epoch: 005/010 | Batch 200/938 | Cost: 0.0691\n", "Epoch: 005/010 | Batch 250/938 | Cost: 0.0784\n", "Epoch: 005/010 | Batch 300/938 | Cost: 0.1464\n", "Epoch: 005/010 | Batch 350/938 | Cost: 0.1488\n", "Epoch: 005/010 | Batch 400/938 | Cost: 0.2636\n", "Epoch: 005/010 | Batch 450/938 | Cost: 0.0839\n", "Epoch: 005/010 | Batch 500/938 | Cost: 0.1343\n", "Epoch: 005/010 | Batch 550/938 | Cost: 0.0514\n", "Epoch: 005/010 | Batch 600/938 | Cost: 0.1802\n", "Epoch: 005/010 | Batch 650/938 | Cost: 0.0681\n", "Epoch: 005/010 | Batch 700/938 | Cost: 0.0986\n", "Epoch: 005/010 | Batch 750/938 | Cost: 0.0930\n", "Epoch: 005/010 | Batch 800/938 | Cost: 0.1829\n", "Epoch: 005/010 | Batch 850/938 | Cost: 0.1694\n", "Epoch: 005/010 | Batch 900/938 | Cost: 0.0440\n", "Epoch: 005/010 training accuracy: 97.22%\n", "Time elapsed: 1.22 min\n", "Epoch: 006/010 | Batch 000/938 | Cost: 0.0142\n", "Epoch: 006/010 | Batch 050/938 | Cost: 0.3528\n", "Epoch: 006/010 | Batch 100/938 | Cost: 0.0710\n", "Epoch: 006/010 | Batch 150/938 | Cost: 0.0553\n", "Epoch: 006/010 | Batch 200/938 | Cost: 0.0084\n", "Epoch: 006/010 | Batch 250/938 | Cost: 0.1178\n", "Epoch: 006/010 | Batch 300/938 | Cost: 0.1271\n", "Epoch: 006/010 | Batch 350/938 | Cost: 0.0404\n", "Epoch: 006/010 | Batch 400/938 | Cost: 0.1435\n", "Epoch: 006/010 | Batch 450/938 | Cost: 0.1568\n", "Epoch: 006/010 | Batch 500/938 | Cost: 0.2100\n", "Epoch: 006/010 | Batch 550/938 | Cost: 0.0019\n", "Epoch: 006/010 | Batch 600/938 | Cost: 0.1721\n", "Epoch: 006/010 | Batch 650/938 | Cost: 0.0943\n", "Epoch: 006/010 | Batch 700/938 | Cost: 0.0913\n", "Epoch: 006/010 | Batch 750/938 | Cost: 0.1211\n", "Epoch: 006/010 | Batch 800/938 | Cost: 0.0890\n", "Epoch: 006/010 | Batch 850/938 | Cost: 0.0390\n", "Epoch: 006/010 | Batch 900/938 | Cost: 0.0521\n", "Epoch: 006/010 training accuracy: 97.79%\n", "Time elapsed: 1.47 min\n", "Epoch: 007/010 | Batch 000/938 | Cost: 0.0059\n", "Epoch: 007/010 | Batch 050/938 | Cost: 0.0371\n", "Epoch: 007/010 | Batch 100/938 | Cost: 0.2702\n", "Epoch: 007/010 | Batch 150/938 | Cost: 0.1142\n", "Epoch: 007/010 | Batch 200/938 | Cost: 0.0900\n", "Epoch: 007/010 | Batch 250/938 | Cost: 0.1922\n", "Epoch: 007/010 | Batch 300/938 | Cost: 0.0062\n", "Epoch: 007/010 | Batch 350/938 | Cost: 0.0435\n", "Epoch: 007/010 | Batch 400/938 | Cost: 0.0503\n", "Epoch: 007/010 | Batch 450/938 | Cost: 0.1411\n", "Epoch: 007/010 | Batch 500/938 | Cost: 0.1547\n", "Epoch: 007/010 | Batch 550/938 | Cost: 0.1858\n", "Epoch: 007/010 | Batch 600/938 | Cost: 0.0108\n", "Epoch: 007/010 | Batch 650/938 | Cost: 0.0569\n", "Epoch: 007/010 | Batch 700/938 | Cost: 0.0254\n", "Epoch: 007/010 | Batch 750/938 | Cost: 0.0635\n", "Epoch: 007/010 | Batch 800/938 | Cost: 0.2539\n", "Epoch: 007/010 | Batch 850/938 | Cost: 0.1338\n", "Epoch: 007/010 | Batch 900/938 | Cost: 0.3336\n", "Epoch: 007/010 training accuracy: 98.25%\n", "Time elapsed: 1.71 min\n", "Epoch: 008/010 | Batch 000/938 | Cost: 0.0215\n", "Epoch: 008/010 | Batch 050/938 | Cost: 0.2800\n", "Epoch: 008/010 | Batch 100/938 | Cost: 0.2627\n", "Epoch: 008/010 | Batch 150/938 | Cost: 0.0538\n", "Epoch: 008/010 | Batch 200/938 | Cost: 0.2164\n", "Epoch: 008/010 | Batch 250/938 | Cost: 0.0025\n", "Epoch: 008/010 | Batch 300/938 | Cost: 0.0021\n", "Epoch: 008/010 | Batch 350/938 | Cost: 0.1489\n", "Epoch: 008/010 | Batch 400/938 | Cost: 0.0997\n", "Epoch: 008/010 | Batch 450/938 | Cost: 0.0055\n", "Epoch: 008/010 | Batch 500/938 | Cost: 0.0181\n", "Epoch: 008/010 | Batch 550/938 | Cost: 0.1672\n", "Epoch: 008/010 | Batch 600/938 | Cost: 0.0538\n", "Epoch: 008/010 | Batch 650/938 | Cost: 0.0842\n", "Epoch: 008/010 | Batch 700/938 | Cost: 0.0941\n", "Epoch: 008/010 | Batch 750/938 | Cost: 0.0171\n", "Epoch: 008/010 | Batch 800/938 | Cost: 0.0638\n", "Epoch: 008/010 | Batch 850/938 | Cost: 0.2507\n", "Epoch: 008/010 | Batch 900/938 | Cost: 0.0568\n", "Epoch: 008/010 training accuracy: 98.31%\n", "Time elapsed: 1.96 min\n", "Epoch: 009/010 | Batch 000/938 | Cost: 0.0844\n", "Epoch: 009/010 | Batch 050/938 | Cost: 0.1087\n", "Epoch: 009/010 | Batch 100/938 | Cost: 0.0584\n", "Epoch: 009/010 | Batch 150/938 | Cost: 0.0544\n", "Epoch: 009/010 | Batch 200/938 | Cost: 0.0352\n", "Epoch: 009/010 | Batch 250/938 | Cost: 0.0189\n", "Epoch: 009/010 | Batch 300/938 | Cost: 0.0356\n", "Epoch: 009/010 | Batch 350/938 | Cost: 0.1357\n", "Epoch: 009/010 | Batch 400/938 | Cost: 0.2133\n", "Epoch: 009/010 | Batch 450/938 | Cost: 0.0081\n", "Epoch: 009/010 | Batch 500/938 | Cost: 0.0710\n", "Epoch: 009/010 | Batch 550/938 | Cost: 0.0652\n", "Epoch: 009/010 | Batch 600/938 | Cost: 0.0136\n", "Epoch: 009/010 | Batch 650/938 | Cost: 0.0772\n", "Epoch: 009/010 | Batch 700/938 | Cost: 0.0744\n", "Epoch: 009/010 | Batch 750/938 | Cost: 0.0388\n", "Epoch: 009/010 | Batch 800/938 | Cost: 0.0208\n", "Epoch: 009/010 | Batch 850/938 | Cost: 0.0114\n", "Epoch: 009/010 | Batch 900/938 | Cost: 0.0706\n", "Epoch: 009/010 training accuracy: 97.76%\n", "Time elapsed: 2.20 min\n", "Epoch: 010/010 | Batch 000/938 | Cost: 0.0773\n", "Epoch: 010/010 | Batch 050/938 | Cost: 0.0362\n", "Epoch: 010/010 | Batch 100/938 | Cost: 0.0406\n", "Epoch: 010/010 | Batch 150/938 | Cost: 0.0900\n", "Epoch: 010/010 | Batch 200/938 | Cost: 0.3629\n", "Epoch: 010/010 | Batch 250/938 | Cost: 0.0016\n", "Epoch: 010/010 | Batch 300/938 | Cost: 0.0314\n", "Epoch: 010/010 | Batch 350/938 | Cost: 0.0677\n", "Epoch: 010/010 | Batch 400/938 | Cost: 0.0821\n", "Epoch: 010/010 | Batch 450/938 | Cost: 0.0717\n", "Epoch: 010/010 | Batch 500/938 | Cost: 0.2704\n", "Epoch: 010/010 | Batch 550/938 | Cost: 0.1784\n", "Epoch: 010/010 | Batch 600/938 | Cost: 0.0899\n", "Epoch: 010/010 | Batch 650/938 | Cost: 0.0578\n", "Epoch: 010/010 | Batch 700/938 | Cost: 0.1572\n", "Epoch: 010/010 | Batch 750/938 | Cost: 0.0106\n", "Epoch: 010/010 | Batch 800/938 | Cost: 0.0714\n", "Epoch: 010/010 | Batch 850/938 | Cost: 0.0125\n", "Epoch: 010/010 | Batch 900/938 | Cost: 0.0235\n", "Epoch: 010/010 training accuracy: 98.38%\n", "Time elapsed: 2.45 min\n", "Total Training Time: 2.45 min\n" ] } ], "source": [ "torch.manual_seed(random_seed)\n", "model = MultilayerPerceptron(num_features=num_features,\n", " num_classes=num_classes)\n", "\n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) \n", "\n", "###################################################################\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.view(-1, 28*28).to(device)\n", " targets = targets.to(device)\n", " \n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = F.cross_entropy(logits, targets)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " \n", " #########################################################\n", " #########################################################\n", " ### GRADIENT CLIPPING\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2)\n", " #########################################################\n", " #########################################################\n", " \n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", "\n", " with torch.set_grad_enabled(False):\n", " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", " epoch+1, num_epochs, \n", " compute_accuracy(model, train_loader)))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 96.89%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.16.4\n", "torch 1.2.0\n", "torchvision 0.4.0a0+6b959ee\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }