{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gsn45i8AWUxw" }, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "id": "jcO_pGrwWUxx", "outputId": "001a0bdd-7224-459f-a80e-9d2336f9edb1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yaPAGojpWUx2" }, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lYFsqMPGWUx3" }, "source": [ "# Model Zoo -- Cyclical Learning Rate in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Fa9mT_urXvMf" }, "source": [ "This notebook will go over the following topics in the order listed below:\n", "\n", "1. Briefly explain the concept behind the cyclical learning rate\n", "2. Use the \"LR range test\" to choose a good base and max learning rate for the cyclical learning rate\n", "3. Train a simple convolutional neural net on CIFAR-10 using the cyclical learning rate" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HqJXinN5WUx3" }, "source": [ "## Cyclical Learning Rate Concept" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8Q6XUX6GSYD-" }, "source": [ "In his paper [1], Leslie N. Smith introduced the concept of cyclical learning rates, that is, learning rates that periodically alternative between a user-specified minimum and maximum learning rate. \n", "\n", "Varying the learning rate between between specified bounds, as implemented by Smith, is cheaper to compute than the nowadays popular approach using adaptive learning rates. Note that adaptive learning rate can also be combined with the concept of cyclical learning rates.\n", "\n", "The idea behind cyclical learning rates is that while increasing the learning rate can be harmful short term it can be beneficial in the long run. Concretely, the three methods introduced by Smith (and implemented in this notebook) are\n", "\n", "- `triangular`: The base approach, varying between a lower and an upper bound, as illustrated in the figure below\n", "- `triangular2`: Same as triangular, but learning rate difference is cut in half at the end of each cycle. This means the learning rate difference drops after each cycle\n", "-- `exp_range`: The learning rate varies between the minimum and maximum boundaries and each boundary value declines by an exponential factor of $gamma^{iteration}$\n", "\n", "\n", "{insert figure}\n", "\n", "\n", "### References\n", "\n", "\n", "- [1] Smith, Leslie N. “[Cyclical learning rates for training neural networks](https://ieeexplore.ieee.org/abstract/document/7926641/).” Applications of Computer Vision (WACV), 2017 IEEE Winter Conference on. IEEE, 2017." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wx0idSn5SWG1" }, "source": [ "Following the description in the paper, the different cyclical learning rates are very simple to implement, as shown below:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "WmHAbUMlWUx4" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def cyclical_learning_rate(batch_step,\n", " step_size,\n", " base_lr=0.001,\n", " max_lr=0.006,\n", " mode='triangular',\n", " gamma=0.999995):\n", "\n", " cycle = np.floor(1 + batch_step / (2. * step_size))\n", " x = np.abs(batch_step / float(step_size) - 2 * cycle + 1)\n", "\n", " lr_delta = (max_lr - base_lr) * np.maximum(0, (1 - x))\n", " \n", " if mode == 'triangular':\n", " pass\n", " elif mode == 'triangular2':\n", " lr_delta = lr_delta * 1 / (2. ** (cycle - 1))\n", " elif mode == 'exp_range':\n", " lr_delta = lr_delta * (gamma**(batch_step))\n", " else:\n", " raise ValueError('mode must be \"triangular\", \"triangular2\", or \"exp_range\"')\n", " \n", " lr = base_lr + lr_delta\n", " \n", " return lr" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DwyuJm46YXQf" }, "source": [ "To ensure that the learning rate works as intended, let us plot the learning rate variation for a dry run. Note that `batch_step` is a variable that tracks the total number of times a model has been updated. For instance, if we run the training loop over 5 epochs (5 passes over the training set), where each epoch is split into 100 batches, then we have a `batch_step` count of 5 * 100 = 500 at the end of the training. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "CQx-HDcyYWlu" }, "outputs": [], "source": [ "num_epochs = 50\n", "num_train = 50000\n", "batch_size = 100\n", "iter_per_ep = num_train // batch_size" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "87UblsShY7Uy" }, "source": [ "**Triangular**" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "colab_type": "code", "id": "qIHyChx8Y6Lu", "outputId": "ab9a6ae8-db6e-4846-c48c-a7ce99443df8" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "batch_step = -1\n", "collect_lr = []\n", "for e in range(num_epochs):\n", " for i in range(iter_per_ep):\n", " batch_step += 1\n", " cur_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=iter_per_ep*5)\n", " \n", " collect_lr.append(cur_lr)\n", " \n", "plt.scatter(range(len(collect_lr)), collect_lr)\n", "plt.ylim([0.0, 0.01])\n", "plt.xlim([0, num_epochs*iter_per_ep + 5000])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YkKd-HdtbWqv" }, "source": [ "As we can see above, with a batchsize of 100 and for a training set of 50,000 training example, we have 50,000=500 iterations per epoch. With a cycle length of 5*iterations_per_epoch=25,000, the learning rate reaches the base_lr every 5 epochs, which is equal to 25,000 batch updates. The stepsize is defined as cycle/2, i.e., 25,000/2 = 12,500 batch updates." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bMErqRevY-hO" }, "source": [ "**Triangular2**" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pFBm9DnTcmiM" }, "source": [ "The `triangular2` learning rate is similar to the `triangular` learning rate but cuts the max. learning rate in halve after each cycle." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "colab_type": "code", "id": "pATJZ3BbY9sZ", "outputId": "192f3447-7072-445d-d8dd-d48d5c19c177" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "collect_lr = []\n", "batch_step = -1\n", "for e in range(num_epochs):\n", " for i in range(iter_per_ep):\n", " batch_step += 1\n", " cur_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=iter_per_ep*4,\n", " mode='triangular2')\n", " \n", " collect_lr.append(cur_lr)\n", " \n", "plt.scatter(range(len(collect_lr)), collect_lr)\n", "plt.ylim([0.0, 0.01])\n", "plt.xlim([0, num_epochs*iter_per_ep + 5000])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ogc0gqPTZAjz" }, "source": [ "**Exp_range**" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "470maAxLc6UC" }, "source": [ "The `exp_range` option adds an additional hyperparameter, `gamma` to decay the learning rate exponentially." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "colab_type": "code", "id": "SBm5Q7x9Zt54", "outputId": "8f5f1662-ace9-43e8-f97b-4bed6fecb590" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAD8CAYAAABZ/vJZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHB9JREFUeJzt3X/MZNV93/H3p7tm7WCDYbOJ3IVo17A1WWrX4EdrKJVVmcQsduSlElEW1TVKkZFSaE3dqFpk146RUUPVGjcKONoEWoxRdilxyaq2IajYqhLZCw8GA8tm44cfNQsU1gIDsWQ2S779Y846s8PM3Dt3Zu69587nJT16Zu6vOWfunfO959xzz1VEYGZmNqm/13QCzMwsTw4gZmZWiQOImZlV4gBiZmaVOICYmVklDiBmZlZJqQAiaaukA5JWJO0YMn+NpN1p/l5JG9L0tZK+JemvJf3+wDrvk/RIWuf3JGkWGTIzs3oUBhBJq4AbgAuBzcAlkjYPLHYZ8FJEnA5cD1yXpv8U+A/Abw/Z9JeBy4FN6W9rlQyYmVkzytRAtgArEfFERBwGdgHbBpbZBtySXt8BnC9JEfGTiPhzeoHkZyS9AzghIr4TvTsZvwJcNE1GzMysXqtLLLMeeLrv/UHg/aOWiYgjkl4G1gI/GrPNgwPbXD9sQUmX06upcPzxx7/vjDPOKJFkMzMDeOCBB34UEevmse0yAWTYtYnB8U/KLFNp+YjYCewEWFpaiuXl5TGbNTOzfpL+77y2XaYJ6yBwat/7U4BnRy0jaTVwIvBiwTZPKdimmZm1WJkAcj+wSdJGSccB24E9A8vsAS5Nry8G7o0xozRGxHPAq5LOSb2vPg786cSpNzOzxhQ2YaVrGlcCdwOrgJsjYp+ka4DliNgD3ATcKmmFXs1j+9H1JT0FnAAcJ+ki4EMR8RjwW8B/B94CfDP9mZlZJpTTcO6+BmJmNhlJD0TE0jy27TvRzcysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKikVQCRtlXRA0oqkHUPmr5G0O83fK2lD37yr0/QDki7om/5vJe2T9KikP5b05llkyMzM6lEYQCStAm4ALgQ2A5dI2jyw2GXASxFxOnA9cF1adzOwHTgT2ArcKGmVpPXAvwGWIuIfAqvScmZmlokyNZAtwEpEPBERh4FdwLaBZbYBt6TXdwDnS1KavisiXouIJ4GVtD2A1cBbJK0Gfg54drqsmJlZncoEkPXA033vD6ZpQ5eJiCPAy8DaUetGxDPAfwZ+CDwHvBwRfzbswyVdLmlZ0vKhQ4dKJNfMzOpQJoBoyLQouczQ6ZJOolc72Qj8feB4SR8b9uERsTMiliJiad26dSWSa2ZmdSgTQA4Cp/a9P4U3Njf9bJnUJHUi8OKYdX8FeDIiDkXE3wBfA/5xlQyYmVkzygSQ+4FNkjZKOo7exe49A8vsAS5Nry8G7o2ISNO3p15aG4FNwH30mq7OkfRz6VrJ+cD+6bNjZmZ1WV20QEQckXQlcDe93lI3R8Q+SdcAyxGxB7gJuFXSCr2ax/a07j5JtwOPAUeAKyLidWCvpDuA76XpDwI7Z589MzObF/UqCnlYWlqK5eXlppNhZpYNSQ9ExNI8tu070c3MrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6ukVACRtFXSAUkrknYMmb9G0u40f6+kDX3zrk7TD0i6oG/62yXdIekvJe2XdO4sMmRmZvUoDCCSVgE3ABcCm4FLJG0eWOwy4KWIOB24HrgurbsZ2A6cCWwFbkzbA/ivwF0RcQbwj4D902fHzMzqUqYGsgVYiYgnIuIwsAvYNrDMNuCW9PoO4HxJStN3RcRrEfEksAJskXQC8AHgJoCIOBwRP54+O2ZmVpcyAWQ98HTf+4Np2tBlIuII8DKwdsy67wQOAf9N0oOS/kjS8cM+XNLlkpYlLR86dKhEcs3MrA5lAoiGTIuSy4yavho4G/hyRJwF/AR4w7UVgIjYGRFLEbG0bt26Esk1M7M6rC6xzEHg1L73pwDPjljmoKTVwInAi2PWPQgcjIi9afodjAgg87Rhx9ffMO2p3/1I3ckwM8tSmRrI/cAmSRslHUfvoviegWX2AJem1xcD90ZEpOnbUy+tjcAm4L6I+H/A05LeldY5H3hsyrxMZFjwGDfdzMyOVVgDiYgjkq4E7gZWATdHxD5J1wDLEbGH3sXwWyWt0Kt5bE/r7pN0O73gcAS4IiJeT5v+18BtKSg9AfzmjPM20mfufKSujzIz6yz1Kgp5WFpaiuXl5am3U1TLcDOWmXWFpAciYmke2/ad6EO4GcvMrJgDiJmZVeIAYmZmlSxcAHHzlJnZbJS5D2QhfebOR/jCRe9uOhnHGBX8fNHfzJqwcDWQsr763R82nYRjjKs5uVZlZk1wADEzs0oWKoD4BkIzs9lZqADStmYpM7OcLVQAmdT7r72n6SQA5a5x+DqImdXNAWSM51893HQSzMxay914ba48ZL5Zdy1MDcRNPPXzkPlm3bYwASRXp19dvrC988Fn5piS2XKPOLP8OYAUmKQAn4cjE4y2f9Xuh+aXkBlzjziz/DmAFJikALe/42Yqs+5zADEzs0oWIoDkejbclvtQzMyGcTfeFqtyH8qGHV/PpptsW9Pqrsdm5SxEDaRIUeGQaw2mKTl/X+56bFaeA4hZ4iBhNpnOB5CibrhumjAzq6bzASTXbrjTnA3/8z/8zgxTMplf/eK3J1o+p7P+nNJqVofOB5CyTlizqukkzMxfPP5iY5/9gxd+0thnm1m93AsrefjzWwsfG+vmru7KuXYxLO3nnXYyt33i3AZSY4vENRCzCbTt3pxRge8vHn+x0aZMWwydDiBFZ5VtrVHkmu4zPv2NSuvldPaf0zNimmzKtMXQ6QDSZW28d+Wnr+fZYyGnAGbWJg4gfT52zi+Nne8mAWuTtjWn2eJxAOnzhYvePXa+mwQMYGNLaixlmtNcu7J56mwvrKL7ETb9wvH1JGRCRT/4X3zbcTWlZDLTFlQ59XLLs6GuXTzeWDd0tgZSdD/CPZ/6p/UkZMb2fvpXf/a6KJhUvag9D20tHHyGXj+PN9YdnQ0gi6A/mAyT60VtKzZJYZvT44MnHcnAmuUAMqDoTLlNZ/XWnJzOltv0+OCi780jGeTFAWRC8zyrLwpOmtsnT6dsYbq6IAN1F8rv+dxdtX6eWdeUCiCStko6IGlF0o4h89dI2p3m75W0oW/e1Wn6AUkXDKy3StKDkv7XtBnpl+uNeEXB6cmWprvI0e975T+2K/2vvPZ600moJKfaj3VbYQCRtAq4AbgQ2AxcImnzwGKXAS9FxOnA9cB1ad3NwHbgTGArcGPa3lGfBPZPm4lF1sYbCq192nAclE1DG9Jq5ZSpgWwBViLiiYg4DOwCtg0ssw24Jb2+AzhfktL0XRHxWkQ8Cayk7SHpFOAjwB9Nnw2z+rmgs0VXJoCsB57ue38wTRu6TEQcAV4G1has+yXg3wN/O+7DJV0uaVnS8qFDh0okd3pNnNXn1FOm36y/i7oK5TKf08amzqIHpJnVqUwAGXbpc7CxftQyQ6dL+jXghYh4oOjDI2JnRCxFxNK6deuKU1ug6EJuU4p6yrSxMCtjMN255KOt6ZzmAWlNDsUz6YmBa3d5KBNADgKn9r0/BXh21DKSVgMnAi+OWfc84KOSnqLXJPZBSV+tkP43KDrw2nYhdxa+9BvvHTvfP8b5ufPBZ5pOQmkeisdmrUwAuR/YJGmjpOPoXRTfM7DMHuDS9Ppi4N6IiDR9e+qltRHYBNwXEVdHxCkRsSFt796I+NgM8rOQLjprsEXRxpmk0C8a8uaq3Q9NmxyzbBUGkHRN40rgbno9pm6PiH2SrpH00bTYTcBaSSvAp4Adad19wO3AY8BdwBURkUXfyTevamlbV4vMq2Yz7xpTUaHfv+fbNORNrjXJqunONb+LpNR9IBHxjYj4BxFxWkRcm6Z9NiL2pNc/jYhfj4jTI2JLRDzRt+61ab13RcQ3h2z72xHxa7PK0Kz85bUfHjt/lgd3rvetFBmV7rbnJ/f7bcZxoWyz1Kk70YvuLG57wTWNos4BHmNofto0vE2Xj3Frn04FkFzvLJ6Fos4BHmOop8oZeFGhXMeglbnWHKZNt4ebabdOBRCrT1HBMG136aYKzFzP4CdJd0610UU+KcyBA8gYRd1jZ/FDzPUBUkWKakS5FtRd0KbaqI+DvDmAjFHUPbaOH2LRMz+sHeZZY5r07vO2FMq5NrtZeZ0JIF3tyTQJD6w43jT5b/L4Kbr7vOsdzhf9uG2zzgQQq0/RD3pW99BsrLngyPUko0q34zYUyke/71y/d3MAaVROFzMnUXQPzVFFBYcfyFtd04VyGwKUzZ8DSIF5NgsVXUNpuhDokjoGEpxHoemCuMffQzstRABp6wi88/Cxc35p7PxF/SEWDSR4wppVY+dDOwP6NGlq8gbIXEdptmN1IoAs4gi8o3zhonfPdft1dzuu6zrIw5/fWsvn1KnoxGleN0D6mSWLoxMBxNpj0m7HXbsOMsvrWtPWFps6cZrmmSXjLGrtuc0cQEqYRxOYux3XZ5YFT9Hw7nXepFd0o2tbjTq2c71pdpE5gJRQdCbXtjOjonjX1d5fVU0SrNs0vPssngMz62P3/dfeU3ld3zSbn+wDSFFhmOtZ2jSK7guoepbc1LArbQvQdZimIO5Xd032+VcPz3X7i3gstFn2AaSoMPTT+upT9Qyya811syjk5l0QN6Vr+3rRZR9AclRUwJx32sk1paT75nHG2oZCcJZpmNV39Jk7H5l6G234bq08B5CSiu6vmOV1hds+ce7MtmXjLUKBVXThf1a++t0f1vI57ibcHg4gJRXdX9GmIbJh9nfQNz3s/CK1fc86r2258D+rYD2vbsI2uawDiLvCtse0PWjmsa9m0aRS1TwDXteP667nr0uyDiA5qmNMJuspalKZptbUtUJu2oBXdw3RXdHbwQGkZkVjMnWtYGqzNt53MK+CuOnjataf37Ym40XlADKB3B7YVHQPTNn0Fi1X10Xatn2/ReZR26zru25a0wHPylnddALmZYEG4B3porPWc9Xuh+b+ObO6SPvU735kZkGijkEYT1izildee33k/KLaZhXzvCC+YcfXKxXcuQX2o97zubvesP9OWLOqkwNrzku2NZCig7bKU9qsO4o66sziDHfWBc2dDz4z0+0NauqsvqgLfFXTBK5hwQPgldde5z2fu2uaZC2UbANIjtxrzMapo7bYhKqPGJjn72FczXHcPDuWA8iEZvW877oUjSRcdFNW0/d/DMqtuWSW6a3jBGPSh0zltj9sthxAJlT0vO+2/aCKRhKe9qasWfdkmkUhWec+6FqtcdYPmWrjDaZ1PaRsEXQygPi5AjZOGwv9uoJe3Xlv4w2mRSGyzOONrSfLAFL0Y2tj//6ipiIHPRtnkYNe3dwLq7wsA0iOipqKmgx6o+5XaNv1j6NyK7ic3vmaJL255a3tHEAqKDobbNtooUXprXq/wryC3jRn20UFxDweTzxt7aDuYTnqqs3M6nMW5ebJHDmAzIFHC22vok4FTchxWI46L0TXOZpwG5sK2yy7AFJ0tuYDwNpomkfUNnFMF9Xk2naOVKZpys1Xs1cqgEjaKumApBVJO4bMXyNpd5q/V9KGvnlXp+kHJF2Qpp0q6VuS9kvaJ+mTZROc49lajjcQFqW56fthhqVvVs8Rr6KomaVtj6idd7PQrI/peTQ92vQKA4ikVcANwIXAZuASSZsHFrsMeCkiTgeuB65L624GtgNnAluBG9P2jgD/LiJ+GTgHuGLINm2GZj0QZNH9MNOqUgAVFdJFg0tOo2ozS1NnxdM0CzUxlHodTY9tPJFruzI1kC3ASkQ8ERGHgV3AtoFltgG3pNd3AOdLUpq+KyJei4gngRVgS0Q8FxHfA4iIV4H9wPrps1Of3Ebmtd7gkrlpslAbVaNrayvAuN+cf4/zUSaArAee7nt/kDcW9j9bJiKOAC8Da8usm5q7zgL2DvtwSZdLWpa0fOjQoRLJNWun3Aqxqs1uPpNfHGUCyLDWx8FraKOWGbuupLcCfwJcFRGvDPvwiNgZEUsRsfT84fH3HbTxwO3i9Y+62qOLrrP0p7MNhfOk+7LpUV/beOyNM8/05vZdtEWZAHIQOLXv/SnAs6OWkbQaOBF4cdy6kt5EL3jcFhFfK5PYtvX8yM15p508dn7ZQriurrCzvM7SxgKiaNTXNl44bkOgHmdY+tqe5pyVCSD3A5skbZR0HL2L4nsGltkDXJpeXwzcGxGRpm9PvbQ2ApuA+9L1kZuA/RHxxVlkpAlFBXKTvYKGue0T5zadBJtAG+5ZmbTwbWOgtvkpDCDpmsaVwN30LnbfHhH7JF0j6aNpsZuAtZJWgE8BO9K6+4DbgceAu4ArIuJ14DzgXwAflPRQ+ptvt545KCqQ29Z105rXtrPh3Ar8eaQ3t++gTUo90jYivgF8Y2DaZ/te/xT49RHrXgtcOzDtz/FTZ1v5BRQVcHWnueixsW0skLt6U1suae5/NG8uac5Vdneij5LjWUQTj92d9gbAutM8i5FRczw22pTmsoXwvB5da+3VmQBi5cz7BkAr1rYHGs1qVOWqj66d1CyDa5sCdY4cQKbkA9AGFR0TbetNWGZU5dyagjbu+Hp2ac6RA4i13jTjNrWxK2yRHJ+I1/TYaIPaFqS7qhMBpG0Hbxk51lyaSvM04za1oSvspHJ8Il7dTaOzOBZz/A22TScCiNv1J+Mfjg3yMWFVdCKANM3PM7dBLpBtETiAzECTzzNfFFUK5By7leYYeJpK8zSfm+P33EYOIA3wNZt61NWttCty3MfWrOwDyLyfrDYPbbhm48Ji/nI8NhdBfqdv7ZV9AJmmh47ZPE16bObYfbfpE5Eqn9/ECBBdlX0AaYumf0iLYJLvOMf90Ybuu0UjTJv1cwCxQjkWxlZN14f8d/PVbDmA1KxNhXGb0tJVXf6O25K3SdLh5qvZyjqAfOk33tt0Esxmoi2Fsdkksg4gF521vukkHCPHi6C5KVPQujCeTle/vxzHRWu7rANI27ThIuisdbUwqVMXC662nSyVOU5zHBet7RxAatS2Hx3kebd2booKrhxvLO3iyZJNzgGkRm380flu7ea14cbSQV2reeYXovOQbQBp6wE+Kl1tTS/kl+Zx6cotzW1NL+SX5lHpEu59NS+KyOfRK0tLS7G8vNx0MszMsiHpgYhYmse2s62BmJlZsxxAzMysEgcQMzOrxAHEzMwqcQAxM7NKHEDMzKwSBxAzM6vEAcTMzCpxADEzs0ocQMzMrBIHEDMzq8QBxMzMKnEAMTOzShxAzMysklIBRNJWSQckrUjaMWT+Gkm70/y9kjb0zbs6TT8g6YKy2zQzs3YrDCCSVgE3ABcCm4FLJG0eWOwy4KWIOB24HrgurbsZ2A6cCWwFbpS0quQ2zcysxcrUQLYAKxHxREQcBnYB2waW2Qbckl7fAZwvSWn6roh4LSKeBFbS9sps08zMWmx1iWXWA0/3vT8IvH/UMhFxRNLLwNo0/bsD665Pr4u2CYCky4HL09vXJD1aIs05+nngR00nYo6cv7w5f/l617w2XCaADHse/eBzcEctM2r6sJrP0GfrRsROYCeApOV5PZqxaV3OGzh/uXP+8iVpbs8BL9OEdRA4te/9KcCzo5aRtBo4EXhxzLpltmlmZi1WJoDcD2yStFHScfQuiu8ZWGYPcGl6fTFwb0REmr499dLaCGwC7iu5TTMza7HCJqx0TeNK4G5gFXBzROyTdA2wHBF7gJuAWyWt0Kt5bE/r7pN0O/AYcAS4IiJeBxi2zRLp3TlxDvPR5byB85c75y9fc8ubehUFMzOzyfhOdDMzq8QBxMzMKskigOQ87ImkpyQ9Iumho93pJJ0s6R5JP0j/T0rTJen3Uj4flnR233YuTcv/QNKloz6vhvzcLOmF/vtxZpkfSe9L39dKWndYV/A68/Y7kp5J++8hSR/umzfRMD2p08jelOfdqQNJbSSdKulbkvZL2ifpk2l6V/bfqPxlvw8lvVnSfZK+n/L2+XHpUV3DS0VEq//oXWR/HHgncBzwfWBz0+maIP1PAT8/MO0/ATvS6x3Aden1h4Fv0rt/5hxgb5p+MvBE+n9Sen1SQ/n5AHA28Og88kOvl965aZ1vAhc2nLffAX57yLKb07G4BtiYjtFV445X4HZge3r9B8Bv1bzv3gGcnV6/DfirlI+u7L9R+ct+H6bv863p9ZuAvWmfDE0P8K+AP0ivtwO7q+Z53F8ONZAuDnvSP/TLLcBFfdO/Ej3fBd4u6R3ABcA9EfFiRLwE3ENvbLHaRcT/odfTrt9M8pPmnRAR34ne0f6Vvm3N3Yi8jTLRMD3pTPyD9Ib6gWO/p1pExHMR8b30+lVgP72RIbqy/0blb5Rs9mHaB3+d3r4p/cWY9NQyvFQOAWTYUCrjDoq2CeDPJD2g3rAsAL8YEc9B76AHfiFNH5XXtn8Hs8rP+vR6cHrTrkxNODcfbd5h8rytBX4cEUcGpjciNWmcRe9MtnP7byB/0IF9qN5AtA8BL9AL2o+PSc8xw0sB/cNLzayMySGAlBlKpc3Oi4iz6Y08fIWkD4xZdtIhYdpu0vy0MZ9fBk4D3gs8B/yXND3bvEl6K/AnwFUR8cq4RYdMa30eh+SvE/swIl6PiPfSG7ljC/DLY9JTS95yCCBZD3sSEc+m/y8A/5Pejn8+VfdJ/19Ii+c69Mus8nMwvR6c3piIeD79cP8W+EN6+w8mz9uP6DUBrR6YXitJb6JXuN4WEV9Lkzuz/4blr2v7MCJ+DHyb3jWQUempZXipHAJItsOeSDpe0tuOvgY+BDzKsUO/XAr8aXq9B/h46v1yDvByalK4G/iQpJNS9ftDaVpbzCQ/ad6rks5J7bUf79tWI44WrMk/o7f/YMJhetI1gW/RG+oHjv2eapG+05uA/RHxxb5Zndh/o/LXhX0oaZ2kt6fXbwF+hd41nlHpqWd4qXn1GpjlH73eIH9Fr83v002nZ4J0v5Neb4bvA/uOpp1eW+T/Bn6Q/p8cf9fT4oaUz0eApb5t/Ut6F7xWgN9sME9/TK8Z4G/onbVcNsv8AEv0fuCPA79PGi2hwbzdmtL+cPpBvaNv+U+ndB6gr7fRqOM1HQ/3pTz/D2BNzfvun9BrlngYeCj9fbhD+29U/rLfh8B7gAdTHh4FPjsuPcCb0/uVNP+dVfM87s9DmZiZWSU5NGGZmVkLOYCYmVklDiBmZlaJA4iZmVXiAGJmZpU4gJiZWSUOIGZmVsn/B3SxwYJVysYhAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "collect_lr = []\n", "batch_step = -1\n", "for e in range(num_epochs):\n", " for i in range(iter_per_ep):\n", " batch_step += 1\n", " cur_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=iter_per_ep*4,\n", " mode='exp_range',\n", " gamma=0.99998)\n", " \n", " collect_lr.append(cur_lr)\n", " \n", "plt.scatter(range(len(collect_lr)), collect_lr)\n", "plt.ylim([0.0, 0.01])\n", "plt.xlim([0, num_epochs*iter_per_ep + 5000])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8HhNsfbmWUx7" }, "source": [ "## Torch Imports" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "ZfG-3PemWUx7" }, "outputs": [], "source": [ "import time\n", "import torch\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data.sampler import SubsetRandomSampler\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7NYCdJSHWUx-" }, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "id": "r0x1fTW0WUx_", "outputId": "7bf96884-0d96-4ac4-c77d-583bd0b6b651" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Image batch dimensions: torch.Size([128, 3, 32, 32])\n", "Image label dimensions: torch.Size([128])\n", "Number of training examples: 49000\n", "Number of validation instances: 1000\n", "Number of test instances: 10000\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "random_seed = 1\n", "batch_size = 128\n", "\n", "# Architecture\n", "num_classes = 10\n", "\n", "\n", "##########################\n", "### CIFAR-10 DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "\n", "## Create a validation dataset\n", "np.random.seed(random_seed)\n", "idx = np.arange(50000) # the size of CIFAR10-train\n", "np.random.shuffle(idx)\n", "val_idx, train_idx = idx[:1000], idx[1000:]\n", "train_sampler = SubsetRandomSampler(train_idx)\n", "val_sampler = SubsetRandomSampler(val_idx)\n", "\n", "\n", "train_dataset = datasets.CIFAR10(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " # shuffle=True, # Subsetsampler already shuffles\n", " sampler=train_sampler)\n", "\n", "val_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " # shuffle=True,\n", " sampler=val_sampler)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n", " \n", "cnt = 0\n", "for images, labels in train_loader: \n", " cnt += images.shape[0]\n", "print('Number of training examples:', cnt)\n", "\n", "cnt = 0\n", "for images, labels in val_loader: \n", " cnt += images.shape[0]\n", "print('Number of validation instances:', cnt)\n", "\n", "cnt = 0\n", "for images, labels in test_loader: \n", " cnt += images.shape[0]\n", "print('Number of test instances:', cnt)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "agDsuvseWUyC" }, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "uhrX9-e0afuK" }, "source": [ "Note that this is a very simple convolutional network in this notebook, which is not geared to reach best performance on CIFAR-10 but rather to test the implementation of the cyclical learning rate concept." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "luAg0-ncWUyD" }, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class ConvNet(torch.nn.Module):\n", "\n", " def __init__(self, num_classes):\n", " super(ConvNet, self).__init__()\n", " \n", " # calculate same padding:\n", " # (w - k + 2*p)/s + 1 = o\n", " # => p = (s(o-1) - w + k)/2\n", " \n", " # 32x32x3 => 32x32x6\n", " self.conv_1 = torch.nn.Conv2d(in_channels=3,\n", " out_channels=6,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1) # (1(32-1) - 32 + 3) / 2) = 1\n", " # 32x32x4 => 16x16x6\n", " self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0) # (2(16-1) - 32 + 2) = 0 \n", " \n", " \n", " # 16x16x6 => 16x16x12\n", " self.conv_2 = torch.nn.Conv2d(in_channels=6,\n", " out_channels=12,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1) # (1(16-1) - 16 + 3) / 2 = 1 \n", " # 16x16x12 => 8x8x12 \n", " self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0) # (2(8-1) - 16 + 2) = 0\n", " \n", " \n", " # 8x8x12 => 8x8x18\n", " self.conv_3 = torch.nn.Conv2d(in_channels=12,\n", " out_channels=18,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1) # (1(8-1) - 8 + 3) / 2 = 1 \n", " # 8x8x18 => 4x4x18 \n", " self.pool_3 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0) # (2(4-1) - 8 + 2) = 0\n", " \n", " \n", " # 4x4x18 => 4x4x24\n", " self.conv_4 = torch.nn.Conv2d(in_channels=18,\n", " out_channels=24,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1) \n", " # 4x4x24 => 2x2x24 \n", " self.pool_4 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " \n", " # 2x2x24 => 2x2x30\n", " self.conv_5 = torch.nn.Conv2d(in_channels=24,\n", " out_channels=30,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " padding=1) \n", " # 2x2x30 => 1x1x30 \n", " self.pool_5 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " self.linear_1 = torch.nn.Linear(1*1*30, num_classes)\n", "\n", " \n", " def forward(self, x):\n", " out = self.conv_1(x)\n", " out = F.relu(out)\n", " out = self.pool_1(out)\n", "\n", " out = self.conv_2(out)\n", " out = F.relu(out)\n", " out = self.pool_2(out)\n", " \n", " out = self.conv_3(out)\n", " out = F.relu(out)\n", " out = self.pool_3(out)\n", " \n", " out = self.conv_4(out)\n", " out = F.relu(out)\n", " out = self.pool_4(out)\n", " \n", " out = self.conv_5(out)\n", " out = F.relu(out)\n", " out = self.pool_5(out)\n", " \n", " logits = self.linear_1(out.view(-1, 1*1*30))\n", " probas = F.softmax(logits, dim=1)\n", " return logits, probas" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YwLP-0qkWUyF" }, "source": [ "## LR Range Test" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Z_PvKZmYZ12N" }, "source": [ "The LR range test is a simple heuristic that is also described in Smith's paper. Essentially, it's a quick-and-dirty approach to find good values for the `base_lr` and `max_lr` (hyperparameters of the cyclical learning rate).\n", "\n", "It works as follows:\n", "\n", "We run the training for 5-10 epochs and increase the learning rate linearly up to an upper bound. We select the cut-off where the (train or validation) accuracy starts improving as the base_lr for the cyclical learning rate. The max_lr for the cyclical learning rate is determined in a similar manner, by choosing the cut-off value where the accuracy improvements stop, decrease, or widely fluctuate.\n", "\n", "Note that we can use the `cyclical_learning_rate` function to compute the learning rates for the increasing interval by setting `step_size=num_epochs*iter_per_ep`:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "colab_type": "code", "id": "Giqvby2abkA8", "outputId": "28a54819-ca46-4f90-e3aa-058a01017778" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAD8CAYAAABZ/vJZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFhtJREFUeJzt3X+wX3V95/Hny5smKAUUFhkMZBIlLRNrK3qHH7XTcUqRQDuGP5ghVFa2Zc1MF7Zat9NJlt3pytRp2emg64i2qdBSigaWOjZjVJYW/WM7GLmprhAgywVcuWAFF4oss4LJvveP70G+fLk395tzv/fH9/t9Pmbu5JzP+ZzzPefkhBfvc77nc1NVSJJ0pF6z3DsgSRpOBogkqRUDRJLUigEiSWrFAJEktWKASJJa6StAkmxOciDJdJLtsyxfk+TWZvneJOub9hOSfDXJ/0nyyZ513pnk3madTyTJIA5IkrQ05g2QJBPA9cAFwCbg0iSberpdATxTVacBHwOubdp/BPxH4Pdm2fSngW3AxuZnc5sDkCQtj34qkDOB6ap6pKpeBHYBW3r6bAFuaqZvB85Nkqp6vqr+O50g+YkkJwPHVtXd1XmT8a+AixZyIJKkpbWqjz5rgce65meAs+bqU1UHkzwLnAD84DDbnOnZ5trZOibZRqdS4eijj37n6aef3scuS5IA9u3b94OqOnExtt1PgMz2bKJ3/JN++rTqX1U7gZ0Ak5OTNTU1dZjNSpK6Jflfi7Xtfm5hzQCnds2fAjwxV58kq4DjgKfn2eYp82xTkrSC9RMg9wAbk2xIshrYCuzu6bMbuLyZvhi4qw4zSmNVfQ94LsnZzbev3g/87RHvvSRp2cx7C6t5pnEVcAcwAdxYVfuTXANMVdVu4Abg5iTTdCqPrS+tn+Q7wLHA6iQXAe+pqvuB3wb+Engt8OXmR5I0JDJMw7n7DESSjkySfVU1uRjb9k10SVIrBogkqRUDRJLUigEiSWrFAJEktWKASJJaMUAkSa0YIJKkVgwQSVIrBogkqRUDRJLUigEiSWrFAJEktWKASJJaMUAkSa0YIJKkVgwQSVIrBogkqRUDRJLUigEiSWrFAJEktWKASJJaMUAkSa0YIJKkVgwQSVIrBogkqRUDRJLUigEiSWrFAJEktWKASJJaMUAkSa0YIJKkVgwQSVIrBogkqRUDRJLUSl8BkmRzkgNJppNsn2X5miS3Nsv3JlnftWxH034gyfld7b+bZH+S+5J8LslRgzggSdLSmDdAkkwA1wMXAJuAS5Ns6ul2BfBMVZ0GfAy4tll3E7AVeCuwGfhUkokka4HfASar6ueAiaafJGlI9FOBnAlMV9UjVfUisAvY0tNnC3BTM307cG6SNO27quqFqnoUmG62B7AKeG2SVcDrgCcWdiiSpKXUT4CsBR7rmp9p2mbtU1UHgWeBE+Zat6oeB/4E+C7wPeDZqvpvs314km1JppJMPfXUU33sriRpKfQTIJmlrfrsM2t7kjfQqU42AG8Cjk5y2WwfXlU7q2qyqiZPPPHEPnZXkrQU+gmQGeDUrvlTePXtpp/0aW5JHQc8fZh1fxV4tKqeqqofA58HfrHNAUiSlkc/AXIPsDHJhiSr6Tzs3t3TZzdweTN9MXBXVVXTvrX5ltYGYCPwDTq3rs5O8rrmWcm5wAMLPxxJ0lJZNV+HqjqY5CrgDjrflrqxqvYnuQaYqqrdwA3AzUmm6VQeW5t19ye5DbgfOAhcWVWHgL1Jbgf+sWn/JrBz8IcnSVos6RQKw2FycrKmpqaWezckaWgk2VdVk4uxbd9ElyS1YoBIkloxQCRJrRggkqRWDBBJUisGiCSpFQNEktSKASJJasUAkSS1YoBIkloxQCRJrRggkqRWDBBJUisGiCSpFQNEktSKASJJasUAkSS1YoBIkloxQCRJrRggkqRWDBBJUisGiCSpFQNEktSKASJJasUAkSS1YoBIkloxQCRJrRggkqRWDBBJUisGiCSpFQNEktSKASJJasUAkSS1YoBIklpZ1U+nJJuB/wJMAJ+pqj/uWb4G+CvgncD/Bi6pqu80y3YAVwCHgN+pqjua9tcDnwF+Dijgt6rq7gEck7RkTr/6S/zoUP1k/qiJ8OBHL1zGPZKWzrwVSJIJ4HrgAmATcGmSTT3drgCeqarTgI8B1zbrbgK2Am8FNgOfarYHnUD6SlWdDvwC8MDCD0daGhu272H99j2vCA+AHx0qTr/6S8u0V9LS6ucW1pnAdFU9UlUvAruALT19tgA3NdO3A+cmSdO+q6peqKpHgWngzCTHAr8M3ABQVS9W1T8v/HCkxbd++x7qMMt7Q0UaVf3cwloLPNY1PwOcNVefqjqY5FnghKb96z3rrgX+L/AU8BdJfgHYB3ywqp7v/fAk24BtAOvWretjd6XFsX77nuXeBWlF6acCySxtvf+LNVefudpXAe8APl1VZwDPA9tn+/Cq2llVk1U1eeKJJ/axu9LgGR7Sq/VTgcwAp3bNnwI8MUefmSSrgOOApw+z7gwwU1V7m/bbmSNApOVkcEhz66cCuQfYmGRDktV0Horv7umzG7i8mb4YuKuqqmnfmmRNkg3ARuAbVfVPwGNJfrZZ51zg/gUeizRQbcPjqInZCm9p9MxbgTTPNK4C7qDzNd4bq2p/kmuAqaraTedh+M1JpulUHlubdfcnuY1OOBwErqyqQ82m/y1wSxNKjwC/OeBjk1pZSNXh13g1TtIpFIbD5ORkTU1NLfduaES978/v5h8efrr1+t/5418b4N5Ig5FkX1VNLsa2+3qRUBp1C6k6Nr7xaO788LsHtzPSkDBANNbO+uidfP+5F1uvb9WhcWaAaGwtpOq47Ox1/OFFbxvg3kjDxwDR2Dltxx4OLuDRn1WH1GGAaKwspOowOKRXMkA0Fhb6QqDhIb2aAaKRZ9UhLQ4DRCPLqkNaXAaIRpJVh7T4DBCNlIUEx6rA9B8ZHlK/DBCNhC9883E+dOu3Wq9v1SEdOQNEQ28hVcexayb49kc2D3BvpPFhgGhonXfd13joyVf9Esu+WXVIC2OAaCg5+KG0/AwQDZWf/4Ov8MMXDs3fcQ5WHdLgGCAaGgupOj5+ydu56Iy1A9wbSQaIVjwHP5RWJgNEK5ovBEorlwGiFclhSKSVzwDRimPVIQ0HA0QrhlWHNFwMEK0IVh3S8DFAtKwc/FAaXgaIlsV/+MK9/PXXv9t6fasOafkZIFpyC6k6TjpmNXuvPm+AeyOpLQNES8bBD6XRYoBoSSyk6njXW47nlg+cM8C9kTQIBogW1elXf4kfHWo/DolVh7RyGSBaNA5+KI02A0QDt2H7HhYw9qFVhzQkDBANlC8ESuPDANFAOAyJNH4MEC2YVYc0ngwQtWbVIY03A0StWHVIek0/nZJsTnIgyXSS7bMsX5Pk1mb53iTru5btaNoPJDm/Z72JJN9M8sWFHoiWxvrte1qHx1ETMTykETJvBZJkArgeOA+YAe5Jsruq7u/qdgXwTFWdlmQrcC1wSZJNwFbgrcCbgL9L8jNVdahZ74PAA8CxAzsiLQoHP5TUq58K5ExguqoeqaoXgV3Alp4+W4CbmunbgXOTpGnfVVUvVNWjwHSzPZKcAvwa8JmFH4YW0/rte1qHx0nHrDY8pBHVzzOQtcBjXfMzwFlz9amqg0meBU5o2r/es+5Lrxd/HPh94JjDfXiSbcA2gHXr1vWxuxqUsz56J99/7sXW6xsc0mjrpwLJLG29LxrP1WfW9iS/DjxZVfvm+/Cq2llVk1U1eeKJJ86/txqI9dv3tA6Py85eZ3hIY6CfCmQGOLVr/hTgiTn6zCRZBRwHPH2Ydd8LvDfJhcBRwLFJ/rqqLmt1FBoYBz+U1K9+AuQeYGOSDcDjdB6K/0ZPn93A5cDdwMXAXVVVSXYDn01yHZ2H6BuBb1TV3cAOgCTvBn7P8Fh+fjVX0pGYN0CaZxpXAXcAE8CNVbU/yTXAVFXtBm4Abk4yTafy2Nqsuz/JbcD9wEHgyq5vYGmF8IVASW2kaiHjpi6tycnJmpqaWu7dGClWHdJoS7KvqiYXY9u+iT6mrDokLZQBMoasOiQNggEyRhYSHAEeNTwkdTFAxoRVh6RBM0BG3EKC46iJ8OBHLxzg3kgaJQbIiHLwQ0mLzQAZQQupOja+8Wju/PC7B7czkkaWATJCzrvuazz05POt17fqkHQkDJARsZCq411vOZ5bPnDOAPdG0jgwQIacgx9KWi4GyBDzq7mSlpMBMoQ2bN/zql/IciQMD0mDYIAMGasOSSuFATIkHPxQ0kpjgAwBqw5JK5EBsoI5+KGklcwAWaGsOiStdAbICuPgh5KGhQGyQrzvz+/mHx5+uvX6Vh2SlpoBsgIspOo46ZjV7L36vAHujST1xwBZRmd99E6+/9yLrde36pC0nAyQZbKQquOys9fxhxe9bYB7I0lHzgBZYg5+KGlUGCBLyK/mSholBsgScBgSSaPIAFlkVh2SRpUBskisOiSNOgNkEVh1SBoHBsgAOfihpHFigAzAF775OB+69Vut17fqkDSMDJAFWkjVceyaCb79kc0D3BtJWjoGSEsOfihp3BkgLSyk6tj4xqO588PvHtzOSNIyMUCOwM//wVf44QuHWq9v1SFplLymn05JNic5kGQ6yfZZlq9JcmuzfG+S9V3LdjTtB5Kc37SdmuSrSR5Isj/JBwd1QItl/fY9rcPjsrPXGR6SRs68FUiSCeB64DxgBrgnye6qur+r2xXAM1V1WpKtwLXAJUk2AVuBtwJvAv4uyc8AB4F/V1X/mOQYYF+SO3u2uSKctmMPB9uPfWhwSBpZ/dzCOhOYrqpHAJLsArYA3f+x3wL8p2b6duCTSdK076qqF4BHk0wDZ1bV3cD3AKrquSQPAGt7trnsfCFQkubWT4CsBR7rmp8BzpqrT1UdTPIscELT/vWeddd2r9jc7joD2DvbhyfZBmwDWLduXR+7u3AOQyJJ8+snQDJLW+9Nnbn6HHbdJD8N/A3woar64WwfXlU7gZ0Ak5OTC7iZ1B+rDknqTz8BMgOc2jV/CvDEHH1mkqwCjgOePty6SX6KTnjcUlWfb7X3A2TVIUlHpp8AuQfYmGQD8Didh+K/0dNnN3A5cDdwMXBXVVWS3cBnk1xH5yH6RuAbzfORG4AHquq6wRxKe1YdknTk5g2Q5pnGVcAdwARwY1XtT3INMFVVu+mEwc3NQ/Kn6YQMTb/b6DwcPwhcWVWHkvwS8C+Be5O8NIjUv6+qLw36AA9nIcGxKjD9R4aHpPGVqkV/rDAwk5OTNTU1teDtOPihpHGRZF9VTS7GtsfuTXQHP5SkwRibADnvuq/x0JPPt17fqkOSXmksAmQhVce73nI8t3zgnAHujSSNhpEOEAc/lKTFM7IBspCq4+OXvJ2Lzlg7f0dJGmMjFyAbtu951WvyR8KqQ5L6M1IB4guBkrR0RiJAHIZEkpbe0AeIVYckLY+hDRCrDklaXkMZIFYdkrT8+vqd6CvFvY8/2zo8jpqI4SFJAzSUFciRMjgkafCGqgI5Uicds9rwkKRFMrIViMEhSYtr5CqQd73leMNDkpbASFUgBockLZ2RCBCDQ5KW3tDfwjI8JGl5DG0FYnBI0vIaqgB529rjmDI4JGlFGPpbWJKk5WGASJJaMUAkSa0YIJKkVgwQSVIrBogkqRUDRJLUigEiSWrFAJEktWKASJJaMUAkSa0YIJKkVgwQSVIrBogkqZW+AiTJ5iQHkkwn2T7L8jVJbm2W702yvmvZjqb9QJLz+92mJGllmzdAkkwA1wMXAJuAS5Ns6ul2BfBMVZ0GfAy4tll3E7AVeCuwGfhUkok+tylJWsH6qUDOBKar6pGqehHYBWzp6bMFuKmZvh04N0ma9l1V9UJVPQpMN9vrZ5uSpBWsn99IuBZ4rGt+Bjhrrj5VdTDJs8AJTfvXe9Zd20zPt00AkmwDtjWzLyS5r499Hgf/AvjBcu/ECuB5eJnn4mWei5f97GJtuJ8AySxt1Wefudpnq3x6t9lprNoJ7ARIMlVVk3Pv6vjwXHR4Hl7muXiZ5+JlSaYWa9v93MKaAU7tmj8FeGKuPklWAccBTx9m3X62KUlawfoJkHuAjUk2JFlN56H47p4+u4HLm+mLgbuqqpr2rc23tDYAG4Fv9LlNSdIKNu8trOaZxlXAHcAEcGNV7U9yDTBVVbuBG4Cbk0zTqTy2NuvuT3IbcD9wELiyqg4BzLbNPvZ35xEf4ejyXHR4Hl7muXiZ5+Jli3Yu0ikUJEk6Mr6JLklqxQCRJLUyFAEyDsOeJDk1yVeTPJBkf5IPNu3HJ7kzyUPNn29o2pPkE805+XaSd3Rt6/Km/0NJLp/rM1eyZsSCbyb5YjO/oRkm56Fm2JzVTfsRD6MzbJK8PsntSR5sro9zxvi6+N3m38d9ST6X5KhxuTaS3Jjkye534QZ5HSR5Z5J7m3U+kWS21zBeqapW9A+dh+wPA28GVgP/A9i03Pu1CMd5MvCOZvoY4H/SGeblPwPbm/btwLXN9IXAl+m8a3M2sLdpPx54pPnzDc30G5b7+Fqcjw8DnwW+2MzfBmxtpv8U+O1m+t8Af9pMbwVubaY3NdfKGmBDcw1NLPdxtTwXNwH/upleDbx+HK8LOi8hPwq8tuua+Ffjcm0Avwy8A7ivq21g1wGdb8ie06zzZeCCefdpuU9KHyftHOCOrvkdwI7l3q8lOO6/Bc4DDgAnN20nAwea6T8DLu3qf6BZfinwZ13tr+g3DD903gv6e+BXgC82F/QPgFW91wSdb/Kd00yvavql9zrp7jdMP8CxzX8009M+jtfFSyNeHN/8XX8ROH+crg1gfU+ADOQ6aJY92NX+in5z/QzDLazZhlJZO0ffkdCU2mcAe4GTqup7AM2fb2y6zXVeRuF8fRz4feD/NfMnAP9cVQeb+e5jesUwOkD3MDrDfh6gU3k/BfxFc0vvM0mOZgyvi6p6HPgT4LvA9+j8Xe9jfK8NGNx1sLaZ7m0/rGEIkH6GUhkZSX4a+BvgQ1X1w8N1naXtcMPHDIUkvw48WVX7uptn6VrzLBvq89BlFZ3bFp+uqjOA5+ncqpjLyJ6P5v7+Fjq3nd4EHE1nRO9e43JtHM6RHnurczIMATI2w54k+Sk64XFLVX2+af5+kpOb5ScDTzbtozpMzLuA9yb5Dp1Rmn+FTkXy+nSGyYFXHtORDqMzbGaAmara28zfTidQxu26APhV4NGqeqqqfgx8HvhFxvfagMFdBzPNdG/7YQ1DgIzFsCfNNx5uAB6oquu6FnUPE3M5nWcjL7W/v/m2xdnAs00JewfwniRvaP6P7T1N21Coqh1VdUpVrafzd31XVb0P+CqdYXLg1efhSIbRGSpV9U/AY0leGlH1XDojO4zVddH4LnB2ktc1/15eOhdjeW00BnIdNMueS3J2c27f37WtuS33Q6E+HxxdSOdbSQ8DVy/3/izSMf4SnZLx28C3mp8L6dyz/XvgoebP45v+ofNLuR4G7gUmu7b1W3R+98o08JvLfWwLOCfv5uVvYb2Zzj/yaeC/Amua9qOa+elm+Zu71r+6OT8H6OMbJSv1B3g7MNVcG1+g8+2ZsbwugI8ADwL3ATfT+SbVWFwbwOfoPPv5MZ2K4YpBXgfAZHNeHwY+Sc8XN2b7cSgTSVIrw3ALS5K0AhkgkqRWDBBJUisGiCSpFQNEktSKASJJasUAkSS18v8BZZHMTD1Xiy4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs = 10\n", "\n", "batch_step = -1\n", "collect_lr = []\n", "for e in range(num_epochs):\n", " for i in range(iter_per_ep):\n", " batch_step += 1\n", " cur_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=num_epochs*iter_per_ep)\n", " \n", " collect_lr.append(cur_lr)\n", " \n", "plt.scatter(range(len(collect_lr)), collect_lr)\n", "plt.ylim([0.0, 0.01])\n", "plt.xlim([0, num_epochs*iter_per_ep + 5000])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VUEUn70QWUyG" }, "source": [ "**Utility Functions**" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": {}, "colab_type": "code", "id": "cQ46iEStWUyG" }, "outputs": [], "source": [ "def compute_accuracy(model, data_loader):\n", " correct_pred, num_examples = 0, 0\n", " for features, targets in data_loader:\n", " features = features.to(device)\n", " targets = targets.to(device)\n", " logits, probas = model(features)\n", " _, predicted_labels = torch.max(probas, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "QL7DF29vWUyK" }, "source": [ "**Train Model/Run LR Range Test**" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 357 }, "colab_type": "code", "id": "WI-NY9u4WUyL", "outputId": "f323a1d5-0c7b-442d-8b7c-131c821b1cab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total batch # 0/3830 Curr. Batch Cost: 2.31266\n", "Total batch # 200/3830 Curr. Batch Cost: 2.30711\n", "Total batch # 400/3830 Curr. Batch Cost: 2.30392\n", "Total batch # 600/3830 Curr. Batch Cost: 2.30356\n", "Total batch # 800/3830 Curr. Batch Cost: 2.30203\n", "Total batch # 1000/3830 Curr. Batch Cost: 2.30223\n", "Total batch # 1200/3830 Curr. Batch Cost: 2.30101\n", "Total batch # 1400/3830 Curr. Batch Cost: 2.30159\n", "Total batch # 1600/3830 Curr. Batch Cost: 2.25974\n", "Total batch # 1800/3830 Curr. Batch Cost: 2.02467\n", "Total batch # 2000/3830 Curr. Batch Cost: 2.01952\n", "Total batch # 2200/3830 Curr. Batch Cost: 1.90831\n", "Total batch # 2400/3830 Curr. Batch Cost: 1.56817\n", "Total batch # 2600/3830 Curr. Batch Cost: 1.71451\n", "Total batch # 2800/3830 Curr. Batch Cost: 2.13523\n", "Total batch # 3000/3830 Curr. Batch Cost: 1.62590\n", "Total batch # 3200/3830 Curr. Batch Cost: 1.42501\n", "Total batch # 3400/3830 Curr. Batch Cost: 1.62436\n", "Total batch # 3600/3830 Curr. Batch Cost: 1.55984\n", "Total batch # 3800/3830 Curr. Batch Cost: 1.48068\n" ] } ], "source": [ "#################################\n", "### Setting for this run\n", "#################################\n", "\n", "num_epochs = 10\n", "iter_per_ep = len(train_loader)\n", "base_lr = 0.01\n", "max_lr = 0.2\n", "\n", "#################################\n", "### Init Model\n", "#################################\n", "\n", "torch.manual_seed(random_seed)\n", "model = ConvNet(num_classes=num_classes)\n", "model = model.to(device)\n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "cost_fn = torch.nn.CrossEntropyLoss() \n", "optimizer = torch.optim.SGD(model.parameters(), lr=base_lr) \n", "\n", "########################################################################\n", "# Collect the data to be evaluated via the LR Range Test\n", "collect = {'lr': [], 'cost': [], 'train_batch_acc': [], 'val_acc': []}\n", "########################################################################\n", "\n", "\n", "batch_step = -1\n", "cur_lr = base_lr\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " batch_step += 1\n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = cost_fn(logits, targets)\n", " optimizer.zero_grad()\n", "\n", " cost.backward()\n", "\n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " #############################################\n", " # Logging\n", " if not batch_step % 200:\n", " print('Total batch # %5d/%d' % (batch_step, \n", " iter_per_ep*num_epochs), \n", " end='')\n", " print(' Curr. Batch Cost: %.5f' % cost) \n", "\n", " #############################################\n", " # Collect stats \n", " model = model.eval()\n", " train_acc = compute_accuracy(model, [[features, targets]])\n", " val_acc = compute_accuracy(model, val_loader)\n", " collect['lr'].append(cur_lr)\n", " collect['train_batch_acc'].append(train_acc)\n", " collect['val_acc'].append(val_acc)\n", " collect['cost'].append(cost)\n", " model = model.train()\n", " #############################################\n", " # update learning rate\n", " cur_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=num_epochs*iter_per_ep,\n", " base_lr=base_lr,\n", " max_lr=max_lr)\n", " for g in optimizer.param_groups:\n", " g['lr'] = cur_lr\n", " ############################################\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 541 }, "colab_type": "code", "id": "_hOJ7z1-WUyN", "outputId": "148987b0-093b-4f22-a3b4-acd4cb9bddee" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(collect['lr'], collect['train_batch_acc'], label='train_batch_acc')\n", "plt.plot(collect['lr'], collect['val_acc'], label='val_acc')\n", "plt.xlabel('Learning Rate')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()\n", "\n", "\n", "plt.plot(collect['lr'], collect['cost'])\n", "plt.xlabel('Learning Rate')\n", "plt.ylabel('Current Batch Cost')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "d07DrlrjOHzC" }, "source": [ "Looking at the graphs above, in particular the validation accuracy, it's not immediately obvious to find the 2 points:\n", "\n", "1. Where the accuracy starts increasing\n", "2. Where the accuracy starts dropping or ceases to improve\n", "\n", "However, point 1) may be at 0.08-0.09, and point 2) may be at 0.175 or even 0.2 (or even beyond that, if we would keep increasing the learning rate).\n", "\n", "\n", "Also note that this heuristic is less \"clean\" as starting the epoch from scratch with each incremental learning rate change, which adds addtional noise to the interpretation (including questions like \"by how much did the cost drop/accuracy improve just because of going downhill on the cost surface and the gradients becoming smaller?\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KcZ2vv8foEcP" }, "source": [ "## Train with Cyclical Learning Rate (`triangular`)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "FvnNfSIfPiw4" }, "source": [ "Below, the triangular (default) cyclical learning rate training procedure is run with a `base_lr=0.09` and `max_lr=0.175`. Based on the LR Range Tests graphs above, a `max_lr` >= 0.2 may even be reasonable. However, in practice (based on my experience and some trial runs with these settings), such large learning rates would increase convergence problems using a vanilla SGD optimizer (as it is done here). " ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 2567 }, "colab_type": "code", "id": "W_MyHO0HoIz8", "outputId": "ace8f950-c37b-49a7-98ff-7849185d6892" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 | Train/Valid Acc: 10.12/11.50\n", "Epoch 2 | Train/Valid Acc: 11.56/11.80\n", "Epoch 3 | Train/Valid Acc: 25.08/23.20\n", "Epoch 4 | Train/Valid Acc: 29.49/30.20\n", "Epoch 5 | Train/Valid Acc: 39.37/38.70\n", "Epoch 6 | Train/Valid Acc: 41.30/39.30\n", "Epoch 7 | Train/Valid Acc: 39.99/36.40\n", "Epoch 8 | Train/Valid Acc: 43.76/41.30\n", "Epoch 9 | Train/Valid Acc: 49.54/47.30\n", "Epoch 10 | Train/Valid Acc: 50.19/48.70\n", "Epoch 11 | Train/Valid Acc: 52.60/48.90\n", "Epoch 12 | Train/Valid Acc: 53.14/52.00\n", "Epoch 13 | Train/Valid Acc: 55.03/52.20\n", "Epoch 14 | Train/Valid Acc: 52.50/51.50\n", "Epoch 15 | Train/Valid Acc: 56.78/54.20\n", "Epoch 16 | Train/Valid Acc: 57.11/54.60\n", "Epoch 17 | Train/Valid Acc: 53.47/51.40\n", "Epoch 18 | Train/Valid Acc: 58.84/54.80\n", "Epoch 19 | Train/Valid Acc: 60.68/57.80\n", "Epoch 20 | Train/Valid Acc: 59.67/55.20\n", "Epoch 21 | Train/Valid Acc: 60.42/55.40\n", "Epoch 22 | Train/Valid Acc: 62.05/57.10\n", "Epoch 23 | Train/Valid Acc: 58.89/53.80\n", "Epoch 24 | Train/Valid Acc: 61.83/58.90\n", "Epoch 25 | Train/Valid Acc: 64.20/57.50\n", "Epoch 26 | Train/Valid Acc: 62.74/56.80\n", "Epoch 27 | Train/Valid Acc: 60.05/56.20\n", "Epoch 28 | Train/Valid Acc: 61.31/56.20\n", "Epoch 29 | Train/Valid Acc: 65.26/59.30\n", "Epoch 30 | Train/Valid Acc: 64.57/59.40\n", "Epoch 31 | Train/Valid Acc: 59.70/55.20\n", "Epoch 32 | Train/Valid Acc: 60.00/57.10\n", "Epoch 33 | Train/Valid Acc: 64.72/59.20\n", "Epoch 34 | Train/Valid Acc: 60.15/54.80\n", "Epoch 35 | Train/Valid Acc: 60.73/55.60\n", "Epoch 36 | Train/Valid Acc: 64.78/59.90\n", "Epoch 37 | Train/Valid Acc: 64.26/60.10\n", "Epoch 38 | Train/Valid Acc: 65.21/59.50\n", "Epoch 39 | Train/Valid Acc: 64.77/56.70\n", "Epoch 40 | Train/Valid Acc: 63.09/55.20\n", "Epoch 41 | Train/Valid Acc: 66.14/60.90\n", "Epoch 42 | Train/Valid Acc: 67.17/62.50\n", "Epoch 43 | Train/Valid Acc: 60.02/54.30\n", "Epoch 44 | Train/Valid Acc: 64.91/58.80\n", "Epoch 45 | Train/Valid Acc: 67.85/60.30\n", "Epoch 46 | Train/Valid Acc: 64.30/58.10\n", "Epoch 47 | Train/Valid Acc: 64.16/58.30\n", "Epoch 48 | Train/Valid Acc: 68.18/62.10\n", "Epoch 49 | Train/Valid Acc: 62.60/58.40\n", "Epoch 50 | Train/Valid Acc: 66.92/60.50\n", "Epoch 51 | Train/Valid Acc: 63.54/57.20\n", "Epoch 52 | Train/Valid Acc: 68.29/61.20\n", "Epoch 53 | Train/Valid Acc: 67.00/61.30\n", "Epoch 54 | Train/Valid Acc: 66.01/60.70\n", "Epoch 55 | Train/Valid Acc: 66.48/61.50\n", "Epoch 56 | Train/Valid Acc: 61.47/56.10\n", "Epoch 57 | Train/Valid Acc: 66.07/60.30\n", "Epoch 58 | Train/Valid Acc: 67.73/59.40\n", "Epoch 59 | Train/Valid Acc: 62.46/58.90\n", "Epoch 60 | Train/Valid Acc: 66.94/61.40\n", "Epoch 61 | Train/Valid Acc: 69.32/64.10\n", "Epoch 62 | Train/Valid Acc: 57.43/54.50\n", "Epoch 63 | Train/Valid Acc: 67.81/60.60\n", "Epoch 64 | Train/Valid Acc: 64.76/59.30\n", "Epoch 65 | Train/Valid Acc: 69.20/62.60\n", "Epoch 66 | Train/Valid Acc: 66.62/61.40\n", "Epoch 67 | Train/Valid Acc: 64.40/60.00\n", "Epoch 68 | Train/Valid Acc: 66.40/62.20\n", "Epoch 69 | Train/Valid Acc: 68.49/62.80\n", "Epoch 70 | Train/Valid Acc: 66.81/61.30\n", "Epoch 71 | Train/Valid Acc: 67.85/62.10\n", "Epoch 72 | Train/Valid Acc: 68.71/62.10\n", "Epoch 73 | Train/Valid Acc: 66.94/61.90\n", "Epoch 74 | Train/Valid Acc: 69.00/62.30\n", "Epoch 75 | Train/Valid Acc: 65.22/61.90\n", "Epoch 76 | Train/Valid Acc: 65.86/60.70\n", "Epoch 77 | Train/Valid Acc: 70.45/62.80\n", "Epoch 78 | Train/Valid Acc: 63.32/57.10\n", "Epoch 79 | Train/Valid Acc: 68.13/59.80\n", "Epoch 80 | Train/Valid Acc: 69.84/64.40\n", "Epoch 81 | Train/Valid Acc: 69.26/63.70\n", "Epoch 82 | Train/Valid Acc: 66.01/61.60\n", "Epoch 83 | Train/Valid Acc: 70.93/65.30\n", "Epoch 84 | Train/Valid Acc: 69.66/62.10\n", "Epoch 85 | Train/Valid Acc: 65.53/61.50\n", "Epoch 86 | Train/Valid Acc: 67.92/60.20\n", "Epoch 87 | Train/Valid Acc: 67.67/63.10\n", "Epoch 88 | Train/Valid Acc: 64.33/59.40\n", "Epoch 89 | Train/Valid Acc: 66.37/58.60\n", "Epoch 90 | Train/Valid Acc: 63.32/56.20\n", "Epoch 91 | Train/Valid Acc: 67.35/61.30\n", "Epoch 92 | Train/Valid Acc: 69.12/62.00\n", "Epoch 93 | Train/Valid Acc: 69.93/62.90\n", "Epoch 94 | Train/Valid Acc: 66.52/60.70\n", "Epoch 95 | Train/Valid Acc: 69.41/61.80\n", "Epoch 96 | Train/Valid Acc: 67.85/62.50\n", "Epoch 97 | Train/Valid Acc: 70.32/63.60\n", "Epoch 98 | Train/Valid Acc: 69.32/62.90\n", "Epoch 99 | Train/Valid Acc: 68.90/60.30\n", "Epoch 100 | Train/Valid Acc: 69.61/61.80\n", "Epoch 101 | Train/Valid Acc: 67.63/62.10\n", "Epoch 102 | Train/Valid Acc: 68.18/61.60\n", "Epoch 103 | Train/Valid Acc: 71.05/62.40\n", "Epoch 104 | Train/Valid Acc: 71.27/63.30\n", "Epoch 105 | Train/Valid Acc: 67.66/62.30\n", "Epoch 106 | Train/Valid Acc: 70.37/61.60\n", "Epoch 107 | Train/Valid Acc: 65.84/63.10\n", "Epoch 108 | Train/Valid Acc: 72.44/63.80\n", "Epoch 109 | Train/Valid Acc: 71.44/62.50\n", "Epoch 110 | Train/Valid Acc: 69.01/61.90\n", "Epoch 111 | Train/Valid Acc: 69.38/60.80\n", "Epoch 112 | Train/Valid Acc: 70.99/65.00\n", "Epoch 113 | Train/Valid Acc: 67.42/59.40\n", "Epoch 114 | Train/Valid Acc: 68.88/61.40\n", "Epoch 115 | Train/Valid Acc: 70.59/61.90\n", "Epoch 116 | Train/Valid Acc: 64.71/58.50\n", "Epoch 117 | Train/Valid Acc: 67.19/61.20\n", "Epoch 118 | Train/Valid Acc: 68.88/61.70\n", "Epoch 119 | Train/Valid Acc: 69.34/62.70\n", "Epoch 120 | Train/Valid Acc: 66.37/62.10\n", "Epoch 121 | Train/Valid Acc: 66.52/59.90\n", "Epoch 122 | Train/Valid Acc: 69.42/61.30\n", "Epoch 123 | Train/Valid Acc: 51.75/47.70\n", "Epoch 124 | Train/Valid Acc: 70.67/62.90\n", "Epoch 125 | Train/Valid Acc: 71.86/63.10\n", "Epoch 126 | Train/Valid Acc: 71.21/63.20\n", "Epoch 127 | Train/Valid Acc: 72.21/63.60\n", "Epoch 128 | Train/Valid Acc: 69.04/62.30\n", "Epoch 129 | Train/Valid Acc: 67.66/59.60\n", "Epoch 130 | Train/Valid Acc: 69.09/61.50\n", "Epoch 131 | Train/Valid Acc: 64.01/57.00\n", "Epoch 132 | Train/Valid Acc: 69.79/61.50\n", "Epoch 133 | Train/Valid Acc: 66.73/60.80\n", "Epoch 134 | Train/Valid Acc: 65.47/57.60\n", "Epoch 135 | Train/Valid Acc: 68.09/59.90\n", "Epoch 136 | Train/Valid Acc: 64.38/58.50\n", "Epoch 137 | Train/Valid Acc: 70.52/61.70\n", "Epoch 138 | Train/Valid Acc: 68.28/61.60\n", "Epoch 139 | Train/Valid Acc: 67.66/60.60\n", "Epoch 140 | Train/Valid Acc: 70.20/62.20\n", "Epoch 141 | Train/Valid Acc: 71.66/63.10\n", "Epoch 142 | Train/Valid Acc: 64.78/57.40\n", "Epoch 143 | Train/Valid Acc: 63.49/56.70\n", "Epoch 144 | Train/Valid Acc: 72.68/63.30\n", "Epoch 145 | Train/Valid Acc: 70.93/62.50\n", "Epoch 146 | Train/Valid Acc: 70.51/61.90\n", "Epoch 147 | Train/Valid Acc: 72.00/61.60\n", "Epoch 148 | Train/Valid Acc: 69.59/60.70\n", "Epoch 149 | Train/Valid Acc: 71.24/60.70\n", "Epoch 150 | Train/Valid Acc: 69.91/59.60\n" ] } ], "source": [ "#################################\n", "### Setting for this run\n", "#################################\n", "\n", "num_epochs = 150\n", "iter_per_ep = len(train_loader.sampler.indices) // train_loader.batch_size\n", "base_lr = 0.09\n", "max_lr = 0.175\n", "\n", "#################################\n", "### Init Model\n", "#################################\n", "\n", "torch.manual_seed(random_seed)\n", "model = ConvNet(num_classes=num_classes)\n", "model = model.to(device)\n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "cost_fn = torch.nn.CrossEntropyLoss() \n", "optimizer = torch.optim.SGD(model.parameters(), lr=base_lr) \n", "\n", "########################################################################\n", "# Collect the data to be evaluated via the LR Range Test\n", "collect = {'epoch': [], 'cost': [], 'train_acc': [], 'val_acc': []}\n", "########################################################################\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " epoch_avg_cost = 0.\n", " model = model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = cost_fn(logits, targets)\n", " optimizer.zero_grad()\n", "\n", " cost.backward()\n", "\n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", "\n", " epoch_avg_cost += cost\n", " \n", " #############################################\n", " # Logging\n", " if not batch_step % 600:\n", " print('Batch %5d/%d' % (batch_step, iter_per_ep*num_epochs),\n", " end='')\n", " print(' Cost: %.5f' % cost) \n", "\n", "\n", " #############################################\n", " # Collect stats \n", " model = model.eval()\n", " train_acc = compute_accuracy(model, train_loader)\n", " val_acc = compute_accuracy(model, val_loader)\n", " epoch_avg_cost /= batch_idx+1\n", " collect['epoch'].append(epoch+1)\n", " collect['val_acc'].append(val_acc)\n", " collect['train_acc'].append(train_acc)\n", " collect['cost'].append(epoch_avg_cost / iter_per_ep)\n", " \n", " ################################################\n", " # Logging\n", " print('Epoch %3d' % (epoch+1), end='')\n", " print(' | Train/Valid Acc: %.2f/%.2f' % (train_acc, val_acc))\n", " \n", " \n", " #############################################\n", " # update learning rate\n", " base_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=num_epochs*iter_per_ep,\n", " base_lr=base_lr,\n", " max_lr=max_lr)\n", " for g in optimizer.param_groups:\n", " g['lr'] = base_lr\n", " ############################################\n", " \n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 541 }, "colab_type": "code", "id": "P_3X9h9CrQwD", "outputId": "011badf6-a190-4880-8377-eaa5ded3f199" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(collect['epoch'], collect['train_acc'], label='train_acc')\n", "plt.plot(collect['epoch'], collect['val_acc'], label='val_acc')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()\n", "\n", "\n", "plt.plot(collect['epoch'], collect['cost'])\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Avg. Cost Per Epoch')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "-28cbD9hKp56", "outputId": "93d97e13-00a0-4f1f-b98b-66ca33f276f2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 61.45%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "kTtRXkn5daq7" }, "source": [ "## Train with Cyclical Learning Rate (`triangular2`)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 3091 }, "colab_type": "code", "id": "zS-lrQx3QYAH", "outputId": "3640920f-82ac-4489-b52e-15616bb442fd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 | Train/Valid Acc: 10.10/11.50\n", "Epoch 2 | Train/Valid Acc: 11.65/11.70\n", "Epoch 3 | Train/Valid Acc: 25.36/23.20\n", "Epoch 4 | Train/Valid Acc: 23.71/25.20\n", "Epoch 5 | Train/Valid Acc: 39.76/39.30\n", "Epoch 6 | Train/Valid Acc: 41.19/38.80\n", "Epoch 7 | Train/Valid Acc: 40.50/37.70\n", "Epoch 8 | Train/Valid Acc: 44.18/41.50\n", "Epoch 9 | Train/Valid Acc: 49.93/47.00\n", "Epoch 10 | Train/Valid Acc: 51.51/49.00\n", "Epoch 11 | Train/Valid Acc: 52.48/49.50\n", "Epoch 12 | Train/Valid Acc: 51.53/50.80\n", "Epoch 13 | Train/Valid Acc: 54.98/52.80\n", "Epoch 14 | Train/Valid Acc: 47.98/46.50\n", "Epoch 15 | Train/Valid Acc: 57.34/54.80\n", "Epoch 16 | Train/Valid Acc: 57.44/54.20\n", "Epoch 17 | Train/Valid Acc: 55.66/53.80\n", "Epoch 18 | Train/Valid Acc: 59.47/55.10\n", "Epoch 19 | Train/Valid Acc: 58.14/55.20\n", "Epoch 20 | Train/Valid Acc: 58.70/55.30\n", "Epoch 21 | Train/Valid Acc: 58.28/54.00\n", "Epoch 22 | Train/Valid Acc: 61.68/58.00\n", "Epoch 23 | Train/Valid Acc: 59.49/56.80\n", "Epoch 24 | Train/Valid Acc: 61.95/56.70\n", "Epoch 25 | Train/Valid Acc: 58.73/54.00\n", "Epoch 26 | Train/Valid Acc: 58.40/53.50\n", "Epoch 27 | Train/Valid Acc: 59.80/56.90\n", "Epoch 28 | Train/Valid Acc: 61.06/56.60\n", "Epoch 29 | Train/Valid Acc: 59.48/55.10\n", "Epoch 30 | Train/Valid Acc: 63.96/59.20\n", "Epoch 31 | Train/Valid Acc: 61.43/56.90\n", "Epoch 32 | Train/Valid Acc: 64.08/59.30\n", "Epoch 33 | Train/Valid Acc: 64.98/59.70\n", "Epoch 34 | Train/Valid Acc: 56.01/51.10\n", "Epoch 35 | Train/Valid Acc: 64.41/59.80\n", "Epoch 36 | Train/Valid Acc: 64.44/60.50\n", "Epoch 37 | Train/Valid Acc: 64.06/59.20\n", "Epoch 38 | Train/Valid Acc: 63.14/58.60\n", "Epoch 39 | Train/Valid Acc: 63.49/61.10\n", "Epoch 40 | Train/Valid Acc: 65.99/60.90\n", "Epoch 41 | Train/Valid Acc: 64.38/57.70\n", "Epoch 42 | Train/Valid Acc: 63.68/59.10\n", "Epoch 43 | Train/Valid Acc: 63.98/58.10\n", "Epoch 44 | Train/Valid Acc: 65.07/59.50\n", "Epoch 45 | Train/Valid Acc: 64.54/60.70\n", "Epoch 46 | Train/Valid Acc: 66.89/62.10\n", "Epoch 47 | Train/Valid Acc: 63.51/59.10\n", "Epoch 48 | Train/Valid Acc: 66.54/61.20\n", "Epoch 49 | Train/Valid Acc: 66.38/60.70\n", "Epoch 50 | Train/Valid Acc: 66.54/60.20\n", "Epoch 51 | Train/Valid Acc: 65.31/58.40\n", "Epoch 52 | Train/Valid Acc: 64.59/60.60\n", "Epoch 53 | Train/Valid Acc: 66.46/60.70\n", "Epoch 54 | Train/Valid Acc: 61.76/57.50\n", "Epoch 55 | Train/Valid Acc: 67.43/62.70\n", "Epoch 56 | Train/Valid Acc: 66.12/61.10\n", "Epoch 57 | Train/Valid Acc: 66.77/61.10\n", "Epoch 58 | Train/Valid Acc: 67.22/60.30\n", "Epoch 59 | Train/Valid Acc: 66.94/62.30\n", "Epoch 60 | Train/Valid Acc: 68.32/61.60\n", "Epoch 61 | Train/Valid Acc: 67.54/62.10\n", "Epoch 62 | Train/Valid Acc: 65.06/61.10\n", "Epoch 63 | Train/Valid Acc: 63.50/59.20\n", "Epoch 64 | Train/Valid Acc: 66.59/60.80\n", "Epoch 65 | Train/Valid Acc: 69.35/60.90\n", "Epoch 66 | Train/Valid Acc: 67.44/62.80\n", "Epoch 67 | Train/Valid Acc: 67.28/62.50\n", "Epoch 68 | Train/Valid Acc: 68.13/62.00\n", "Epoch 69 | Train/Valid Acc: 65.02/60.00\n", "Epoch 70 | Train/Valid Acc: 68.35/63.00\n", "Epoch 71 | Train/Valid Acc: 61.92/57.70\n", "Epoch 72 | Train/Valid Acc: 68.02/62.10\n", "Epoch 73 | Train/Valid Acc: 67.62/61.50\n", "Epoch 74 | Train/Valid Acc: 68.39/61.50\n", "Epoch 75 | Train/Valid Acc: 66.11/60.70\n", "Epoch 76 | Train/Valid Acc: 60.31/56.90\n", "Epoch 77 | Train/Valid Acc: 67.06/61.60\n", "Epoch 78 | Train/Valid Acc: 66.69/61.00\n", "Epoch 79 | Train/Valid Acc: 68.88/62.70\n", "Epoch 80 | Train/Valid Acc: 53.01/48.80\n", "Epoch 81 | Train/Valid Acc: 70.58/63.00\n", "Epoch 82 | Train/Valid Acc: 66.57/60.00\n", "Epoch 83 | Train/Valid Acc: 62.87/57.30\n", "Epoch 84 | Train/Valid Acc: 69.49/61.50\n", "Epoch 85 | Train/Valid Acc: 66.03/60.40\n", "Epoch 86 | Train/Valid Acc: 68.34/63.10\n", "Epoch 87 | Train/Valid Acc: 69.02/60.90\n", "Epoch 88 | Train/Valid Acc: 65.63/60.30\n", "Epoch 89 | Train/Valid Acc: 62.16/56.80\n", "Epoch 90 | Train/Valid Acc: 58.92/56.50\n", "Epoch 91 | Train/Valid Acc: 70.52/63.90\n", "Epoch 92 | Train/Valid Acc: 69.29/62.90\n", "Epoch 93 | Train/Valid Acc: 69.67/62.70\n", "Epoch 94 | Train/Valid Acc: 69.38/62.00\n", "Epoch 95 | Train/Valid Acc: 68.55/62.00\n", "Epoch 96 | Train/Valid Acc: 69.87/63.00\n", "Epoch 97 | Train/Valid Acc: 67.04/60.20\n", "Epoch 98 | Train/Valid Acc: 64.95/60.00\n", "Epoch 99 | Train/Valid Acc: 67.18/61.30\n", "Epoch 100 | Train/Valid Acc: 69.53/60.40\n", "Epoch 101 | Train/Valid Acc: 68.25/62.00\n", "Epoch 102 | Train/Valid Acc: 66.14/60.40\n", "Epoch 103 | Train/Valid Acc: 70.77/63.10\n", "Epoch 104 | Train/Valid Acc: 66.72/58.20\n", "Epoch 105 | Train/Valid Acc: 67.28/61.10\n", "Epoch 106 | Train/Valid Acc: 69.26/61.80\n", "Epoch 107 | Train/Valid Acc: 70.49/63.00\n", "Epoch 108 | Train/Valid Acc: 68.44/61.50\n", "Epoch 109 | Train/Valid Acc: 69.62/62.20\n", "Epoch 110 | Train/Valid Acc: 65.51/57.40\n", "Epoch 111 | Train/Valid Acc: 67.62/60.60\n", "Epoch 112 | Train/Valid Acc: 68.98/62.60\n", "Epoch 113 | Train/Valid Acc: 67.35/61.70\n", "Epoch 114 | Train/Valid Acc: 63.91/57.20\n", "Epoch 115 | Train/Valid Acc: 69.48/59.50\n", "Epoch 116 | Train/Valid Acc: 67.54/60.90\n", "Epoch 117 | Train/Valid Acc: 64.29/58.20\n", "Epoch 118 | Train/Valid Acc: 68.95/61.50\n", "Epoch 119 | Train/Valid Acc: 69.82/60.40\n", "Epoch 120 | Train/Valid Acc: 68.28/60.70\n", "Epoch 121 | Train/Valid Acc: 67.40/58.80\n", "Epoch 122 | Train/Valid Acc: 68.32/61.40\n", "Epoch 123 | Train/Valid Acc: 71.35/61.70\n", "Epoch 124 | Train/Valid Acc: 69.96/60.80\n", "Epoch 125 | Train/Valid Acc: 69.93/61.90\n", "Epoch 126 | Train/Valid Acc: 70.48/61.20\n", "Epoch 127 | Train/Valid Acc: 65.93/58.40\n", "Epoch 128 | Train/Valid Acc: 66.86/61.10\n", "Epoch 129 | Train/Valid Acc: 69.40/60.50\n", "Epoch 130 | Train/Valid Acc: 71.33/61.00\n", "Epoch 131 | Train/Valid Acc: 70.79/61.50\n", "Epoch 132 | Train/Valid Acc: 67.92/60.80\n", "Epoch 133 | Train/Valid Acc: 68.64/61.50\n", "Epoch 134 | Train/Valid Acc: 65.79/59.10\n", "Epoch 135 | Train/Valid Acc: 69.58/62.90\n", "Epoch 136 | Train/Valid Acc: 69.36/62.00\n", "Epoch 137 | Train/Valid Acc: 65.36/61.50\n", "Epoch 138 | Train/Valid Acc: 67.90/60.50\n", "Epoch 139 | Train/Valid Acc: 66.31/58.10\n", "Epoch 140 | Train/Valid Acc: 71.86/63.60\n", "Epoch 141 | Train/Valid Acc: 63.20/58.50\n", "Epoch 142 | Train/Valid Acc: 68.61/59.60\n", "Epoch 143 | Train/Valid Acc: 68.63/60.70\n", "Epoch 144 | Train/Valid Acc: 69.86/61.70\n", "Epoch 145 | Train/Valid Acc: 65.65/60.60\n", "Epoch 146 | Train/Valid Acc: 69.74/61.10\n", "Epoch 147 | Train/Valid Acc: 68.24/59.10\n", "Epoch 148 | Train/Valid Acc: 66.41/58.70\n", "Epoch 149 | Train/Valid Acc: 63.01/57.50\n", "Epoch 150 | Train/Valid Acc: 70.22/62.70\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#################################\n", "### Setting for this run\n", "#################################\n", "\n", "num_epochs = 150\n", "iter_per_ep = len(train_loader.sampler.indices) // train_loader.batch_size\n", "base_lr = 0.09\n", "max_lr = 0.175\n", "\n", "#################################\n", "### Init Model\n", "#################################\n", "\n", "torch.manual_seed(random_seed)\n", "model = ConvNet(num_classes=num_classes)\n", "model = model.to(device)\n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "cost_fn = torch.nn.CrossEntropyLoss() \n", "optimizer = torch.optim.SGD(model.parameters(), lr=base_lr) \n", "\n", "########################################################################\n", "# Collect the data to be evaluated via the LR Range Test\n", "collect = {'epoch': [], 'cost': [], 'train_acc': [], 'val_acc': []}\n", "########################################################################\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " epoch_avg_cost = 0.\n", " model = model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = cost_fn(logits, targets)\n", " optimizer.zero_grad()\n", "\n", " cost.backward()\n", "\n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", "\n", " epoch_avg_cost += cost\n", " \n", " #############################################\n", " # Logging\n", " if not batch_step % 600:\n", " print('Batch %5d/%d' % (batch_step, iter_per_ep*num_epochs),\n", " end='')\n", " print(' Cost: %.5f' % cost) \n", "\n", "\n", " #############################################\n", " # Collect stats \n", " model = model.eval()\n", " train_acc = compute_accuracy(model, train_loader)\n", " val_acc = compute_accuracy(model, val_loader)\n", " epoch_avg_cost /= batch_idx+1\n", " collect['epoch'].append(epoch+1)\n", " collect['val_acc'].append(val_acc)\n", " collect['train_acc'].append(train_acc)\n", " collect['cost'].append(epoch_avg_cost / iter_per_ep)\n", " \n", " ################################################\n", " # Logging\n", " print('Epoch %3d' % (epoch+1), end='')\n", " print(' | Train/Valid Acc: %.2f/%.2f' % (train_acc, val_acc))\n", " \n", " \n", " #############################################\n", " # update learning rate\n", " base_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=num_epochs*iter_per_ep,\n", " base_lr=base_lr,\n", " max_lr=max_lr,\n", " mode='triangular2')\n", " for g in optimizer.param_groups:\n", " g['lr'] = base_lr\n", " ############################################\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) \n", " \n", "\n", "plt.plot(collect['epoch'], collect['train_acc'], label='train_acc')\n", "plt.plot(collect['epoch'], collect['val_acc'], label='val_acc')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()\n", "\n", "\n", "plt.plot(collect['epoch'], collect['cost'])\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Avg. Cost Per Epoch')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "FAIyUHvzZUK_", "outputId": "830250cd-62eb-48e7-fdea-f3855d5d9baf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 61.69%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "903jZNgGdcjV" }, "source": [ "## Train with Cyclical Learning Rate (`exp_range`)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 3091 }, "colab_type": "code", "id": "Y6_0-ygOdbYd", "outputId": "f72b602e-a3d5-466f-a6e5-2be69f537a13" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 | Train/Valid Acc: 10.17/11.60\n", "Epoch 2 | Train/Valid Acc: 11.49/11.70\n", "Epoch 3 | Train/Valid Acc: 25.78/23.30\n", "Epoch 4 | Train/Valid Acc: 26.36/27.90\n", "Epoch 5 | Train/Valid Acc: 39.37/39.00\n", "Epoch 6 | Train/Valid Acc: 42.89/40.80\n", "Epoch 7 | Train/Valid Acc: 39.37/36.00\n", "Epoch 8 | Train/Valid Acc: 45.47/44.20\n", "Epoch 9 | Train/Valid Acc: 50.02/48.40\n", "Epoch 10 | Train/Valid Acc: 50.26/47.70\n", "Epoch 11 | Train/Valid Acc: 51.07/50.40\n", "Epoch 12 | Train/Valid Acc: 52.54/51.00\n", "Epoch 13 | Train/Valid Acc: 56.30/54.70\n", "Epoch 14 | Train/Valid Acc: 53.49/51.60\n", "Epoch 15 | Train/Valid Acc: 57.27/53.60\n", "Epoch 16 | Train/Valid Acc: 54.77/52.10\n", "Epoch 17 | Train/Valid Acc: 54.57/50.60\n", "Epoch 18 | Train/Valid Acc: 56.87/52.60\n", "Epoch 19 | Train/Valid Acc: 60.48/58.10\n", "Epoch 20 | Train/Valid Acc: 57.57/55.80\n", "Epoch 21 | Train/Valid Acc: 59.53/55.70\n", "Epoch 22 | Train/Valid Acc: 61.76/58.70\n", "Epoch 23 | Train/Valid Acc: 56.61/54.20\n", "Epoch 24 | Train/Valid Acc: 60.27/58.40\n", "Epoch 25 | Train/Valid Acc: 62.83/58.20\n", "Epoch 26 | Train/Valid Acc: 63.13/58.80\n", "Epoch 27 | Train/Valid Acc: 51.37/47.80\n", "Epoch 28 | Train/Valid Acc: 63.18/59.70\n", "Epoch 29 | Train/Valid Acc: 63.26/58.50\n", "Epoch 30 | Train/Valid Acc: 59.02/57.40\n", "Epoch 31 | Train/Valid Acc: 62.49/58.00\n", "Epoch 32 | Train/Valid Acc: 64.12/57.00\n", "Epoch 33 | Train/Valid Acc: 64.96/59.50\n", "Epoch 34 | Train/Valid Acc: 59.77/55.10\n", "Epoch 35 | Train/Valid Acc: 64.57/59.50\n", "Epoch 36 | Train/Valid Acc: 61.37/58.30\n", "Epoch 37 | Train/Valid Acc: 65.78/61.20\n", "Epoch 38 | Train/Valid Acc: 64.57/59.20\n", "Epoch 39 | Train/Valid Acc: 65.71/60.10\n", "Epoch 40 | Train/Valid Acc: 64.29/60.10\n", "Epoch 41 | Train/Valid Acc: 63.83/58.10\n", "Epoch 42 | Train/Valid Acc: 67.49/59.90\n", "Epoch 43 | Train/Valid Acc: 58.37/55.30\n", "Epoch 44 | Train/Valid Acc: 66.92/62.40\n", "Epoch 45 | Train/Valid Acc: 66.80/62.10\n", "Epoch 46 | Train/Valid Acc: 64.27/56.40\n", "Epoch 47 | Train/Valid Acc: 61.46/57.00\n", "Epoch 48 | Train/Valid Acc: 67.49/60.00\n", "Epoch 49 | Train/Valid Acc: 66.78/60.60\n", "Epoch 50 | Train/Valid Acc: 67.60/61.40\n", "Epoch 51 | Train/Valid Acc: 61.56/56.70\n", "Epoch 52 | Train/Valid Acc: 68.42/60.00\n", "Epoch 53 | Train/Valid Acc: 65.98/60.40\n", "Epoch 54 | Train/Valid Acc: 64.53/59.60\n", "Epoch 55 | Train/Valid Acc: 67.92/61.50\n", "Epoch 56 | Train/Valid Acc: 66.49/59.70\n", "Epoch 57 | Train/Valid Acc: 60.37/54.90\n", "Epoch 58 | Train/Valid Acc: 66.04/60.20\n", "Epoch 59 | Train/Valid Acc: 64.52/58.00\n", "Epoch 60 | Train/Valid Acc: 69.13/62.10\n", "Epoch 61 | Train/Valid Acc: 65.04/59.70\n", "Epoch 62 | Train/Valid Acc: 68.39/60.10\n", "Epoch 63 | Train/Valid Acc: 64.84/58.80\n", "Epoch 64 | Train/Valid Acc: 68.32/60.50\n", "Epoch 65 | Train/Valid Acc: 68.29/62.70\n", "Epoch 66 | Train/Valid Acc: 67.53/60.50\n", "Epoch 67 | Train/Valid Acc: 68.81/63.40\n", "Epoch 68 | Train/Valid Acc: 69.23/61.50\n", "Epoch 69 | Train/Valid Acc: 66.75/61.90\n", "Epoch 70 | Train/Valid Acc: 64.38/57.70\n", "Epoch 71 | Train/Valid Acc: 68.63/62.30\n", "Epoch 72 | Train/Valid Acc: 68.43/62.80\n", "Epoch 73 | Train/Valid Acc: 70.29/63.00\n", "Epoch 74 | Train/Valid Acc: 67.50/60.50\n", "Epoch 75 | Train/Valid Acc: 67.02/61.50\n", "Epoch 76 | Train/Valid Acc: 65.49/58.10\n", "Epoch 77 | Train/Valid Acc: 70.99/62.70\n", "Epoch 78 | Train/Valid Acc: 67.99/61.40\n", "Epoch 79 | Train/Valid Acc: 70.69/63.90\n", "Epoch 80 | Train/Valid Acc: 67.63/62.20\n", "Epoch 81 | Train/Valid Acc: 71.34/62.70\n", "Epoch 82 | Train/Valid Acc: 68.85/63.10\n", "Epoch 83 | Train/Valid Acc: 69.36/60.50\n", "Epoch 84 | Train/Valid Acc: 69.17/61.40\n", "Epoch 85 | Train/Valid Acc: 68.96/60.90\n", "Epoch 86 | Train/Valid Acc: 67.52/61.60\n", "Epoch 87 | Train/Valid Acc: 67.50/60.30\n", "Epoch 88 | Train/Valid Acc: 64.41/59.60\n", "Epoch 89 | Train/Valid Acc: 67.42/60.40\n", "Epoch 90 | Train/Valid Acc: 68.84/63.70\n", "Epoch 91 | Train/Valid Acc: 69.34/62.00\n", "Epoch 92 | Train/Valid Acc: 70.38/63.10\n", "Epoch 93 | Train/Valid Acc: 70.51/63.40\n", "Epoch 94 | Train/Valid Acc: 67.36/59.90\n", "Epoch 95 | Train/Valid Acc: 70.43/61.50\n", "Epoch 96 | Train/Valid Acc: 71.22/62.80\n", "Epoch 97 | Train/Valid Acc: 66.62/60.40\n", "Epoch 98 | Train/Valid Acc: 67.72/60.20\n", "Epoch 99 | Train/Valid Acc: 69.91/62.30\n", "Epoch 100 | Train/Valid Acc: 67.40/60.50\n", "Epoch 101 | Train/Valid Acc: 68.86/61.90\n", "Epoch 102 | Train/Valid Acc: 66.22/61.00\n", "Epoch 103 | Train/Valid Acc: 63.31/56.20\n", "Epoch 104 | Train/Valid Acc: 66.99/60.30\n", "Epoch 105 | Train/Valid Acc: 68.20/63.50\n", "Epoch 106 | Train/Valid Acc: 62.79/58.40\n", "Epoch 107 | Train/Valid Acc: 70.71/61.60\n", "Epoch 108 | Train/Valid Acc: 71.60/61.20\n", "Epoch 109 | Train/Valid Acc: 71.00/64.50\n", "Epoch 110 | Train/Valid Acc: 67.55/61.00\n", "Epoch 111 | Train/Valid Acc: 68.52/61.40\n", "Epoch 112 | Train/Valid Acc: 65.78/58.70\n", "Epoch 113 | Train/Valid Acc: 65.15/57.90\n", "Epoch 114 | Train/Valid Acc: 70.42/61.70\n", "Epoch 115 | Train/Valid Acc: 70.57/61.50\n", "Epoch 116 | Train/Valid Acc: 71.08/62.00\n", "Epoch 117 | Train/Valid Acc: 69.09/62.20\n", "Epoch 118 | Train/Valid Acc: 70.03/63.00\n", "Epoch 119 | Train/Valid Acc: 69.64/62.30\n", "Epoch 120 | Train/Valid Acc: 70.66/64.50\n", "Epoch 121 | Train/Valid Acc: 70.62/63.70\n", "Epoch 122 | Train/Valid Acc: 69.05/61.90\n", "Epoch 123 | Train/Valid Acc: 70.63/60.50\n", "Epoch 124 | Train/Valid Acc: 70.81/61.90\n", "Epoch 125 | Train/Valid Acc: 67.99/59.90\n", "Epoch 126 | Train/Valid Acc: 68.54/61.10\n", "Epoch 127 | Train/Valid Acc: 70.59/62.40\n", "Epoch 128 | Train/Valid Acc: 70.12/61.40\n", "Epoch 129 | Train/Valid Acc: 68.87/59.50\n", "Epoch 130 | Train/Valid Acc: 66.19/60.00\n", "Epoch 131 | Train/Valid Acc: 70.92/62.20\n", "Epoch 132 | Train/Valid Acc: 68.28/59.90\n", "Epoch 133 | Train/Valid Acc: 67.72/62.00\n", "Epoch 134 | Train/Valid Acc: 73.66/64.90\n", "Epoch 135 | Train/Valid Acc: 70.24/59.40\n", "Epoch 136 | Train/Valid Acc: 70.25/60.40\n", "Epoch 137 | Train/Valid Acc: 70.21/62.20\n", "Epoch 138 | Train/Valid Acc: 69.07/60.10\n", "Epoch 139 | Train/Valid Acc: 68.56/60.40\n", "Epoch 140 | Train/Valid Acc: 67.09/61.40\n", "Epoch 141 | Train/Valid Acc: 70.68/60.10\n", "Epoch 142 | Train/Valid Acc: 65.23/59.60\n", "Epoch 143 | Train/Valid Acc: 67.44/60.00\n", "Epoch 144 | Train/Valid Acc: 71.34/64.20\n", "Epoch 145 | Train/Valid Acc: 69.61/61.50\n", "Epoch 146 | Train/Valid Acc: 72.49/62.40\n", "Epoch 147 | Train/Valid Acc: 68.43/59.50\n", "Epoch 148 | Train/Valid Acc: 54.68/47.80\n", "Epoch 149 | Train/Valid Acc: 67.08/61.50\n", "Epoch 150 | Train/Valid Acc: 68.17/59.50\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#################################\n", "### Setting for this run\n", "#################################\n", "\n", "num_epochs = 150\n", "iter_per_ep = len(train_loader.sampler.indices) // train_loader.batch_size\n", "base_lr = 0.09\n", "max_lr = 0.175\n", "\n", "#################################\n", "### Init Model\n", "#################################\n", "\n", "torch.manual_seed(random_seed)\n", "model = ConvNet(num_classes=num_classes)\n", "model = model.to(device)\n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "cost_fn = torch.nn.CrossEntropyLoss() \n", "optimizer = torch.optim.SGD(model.parameters(), lr=base_lr) \n", "\n", "########################################################################\n", "# Collect the data to be evaluated via the LR Range Test\n", "collect = {'epoch': [], 'cost': [], 'train_acc': [], 'val_acc': []}\n", "########################################################################\n", "\n", "\n", "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " epoch_avg_cost = 0.\n", " model = model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = cost_fn(logits, targets)\n", " optimizer.zero_grad()\n", "\n", " cost.backward()\n", "\n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", "\n", " epoch_avg_cost += cost\n", " \n", " #############################################\n", " # Logging\n", " if not batch_step % 600:\n", " print('Batch %5d/%d' % (batch_step, iter_per_ep*num_epochs),\n", " end='')\n", " print(' Cost: %.5f' % cost) \n", "\n", "\n", " #############################################\n", " # Collect stats \n", " model = model.eval()\n", " train_acc = compute_accuracy(model, train_loader)\n", " val_acc = compute_accuracy(model, val_loader)\n", " epoch_avg_cost /= batch_idx+1\n", " collect['epoch'].append(epoch+1)\n", " collect['val_acc'].append(val_acc)\n", " collect['train_acc'].append(train_acc)\n", " collect['cost'].append(epoch_avg_cost / iter_per_ep)\n", " \n", " ################################################\n", " # Logging\n", " print('Epoch %3d' % (epoch+1), end='')\n", " print(' | Train/Valid Acc: %.2f/%.2f' % (train_acc, val_acc))\n", " \n", " \n", " #############################################\n", " # update learning rate\n", " base_lr = cyclical_learning_rate(batch_step=batch_step,\n", " step_size=num_epochs*iter_per_ep,\n", " base_lr=base_lr,\n", " max_lr=max_lr,\n", " mode='exp_range')\n", " for g in optimizer.param_groups:\n", " g['lr'] = base_lr\n", " ############################################\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))\n", "\n", " \n", "plt.plot(collect['epoch'], collect['train_acc'], label='train_acc')\n", "plt.plot(collect['epoch'], collect['val_acc'], label='val_acc')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.show()\n", "\n", "\n", "plt.plot(collect['epoch'], collect['cost'])\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Avg. Cost Per Epoch')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "CyUYaENcZWIS", "outputId": "d660813a-7fbc-4740-d586-7d586ee8d2e4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 59.81%\n" ] } ], "source": [ "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": {}, "colab_type": "code", "id": "REbfufQSZXSA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch 1.0.0\n", "matplotlib 3.0.2\n", "torchvision 0.2.1\n", "numpy 1.15.4\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Copy of recent-learning_rate_cyclic.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }