{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tuning Learning Rates with *ktrain*\n", "\n", "Neural networks have many hyperparameters that need to be set before training begins. While, in practice, many hyperparameters have fairly reasonable defaults (e.g., ReLu activation, Xavier initialization, a kernel size of 3 in Convolutional Neural Networks), some do not and should be tuned. One of these is the learning rate, which governs the degree to which weights are adjusted during training. Even after arriving at a good initial learning rate, it has been shown that varying the learning rate during training is effective in helping to minimize loss and improve generalization. *ktrain* provides a number of built-in methods to make it easy to tune and adjust learning rates to more effectively minimize loss during training.\n", "\n", "To demonstrate these capabilities, we will begin by loading some text data into NumPy arrays and defining a simple text classification model." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; \n", "import ktrain" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# load and prepare data as you normally would in Keras\n", "from tensorflow.keras.preprocessing import sequence\n", "from tensorflow.keras.datasets import imdb\n", "NUM_WORDS = 20000\n", "MAXLEN = 400\n", "def load_data():\n", " (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=NUM_WORDS)\n", " x_train = sequence.pad_sequences(x_train, maxlen=MAXLEN)\n", " x_test = sequence.pad_sequences(x_test, maxlen=MAXLEN)\n", " return (x_train, y_train), (x_test, y_test)\n", "(x_train, y_train), (x_test, y_test) = load_data()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# build a model as you normally would in Keras\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D\n", "def get_model():\n", " model = Sequential()\n", " model.add(Embedding(NUM_WORDS, 50, input_length=MAXLEN))\n", " model.add(GlobalAveragePooling1D())\n", " model.add(Dense(1, activation='sigmoid'))\n", " model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", " return model\n", "model = get_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use *ktrain*, one simply wraps the model and the data in a Learner object using the ```get_learner``` function. This Learner object will be used to help tune and train our network." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=(x_train, y_train), val_data = (x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The wrapped model and data are both directly accessible. For instance, the model can be saved and loaded like normal in Keras (e.g,. ```learner.model.save('my_model.h5')```). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Learning Rate Finder\n", "\n", "The Learner object can be used to find the best learning rate for your model. First, we use ```lr_find``` to track the loss as the learning rate is increased and then use ```lr_plot``` to identify the maximal learning rate associated with a falling loss (both methods were adapted from the [fastai library](https://github.com/fastai/fastai)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/5\n", "25000/25000 [==============================] - 5s 197us/step - loss: 0.6932 - acc: 0.4902\n", "Epoch 2/5\n", "25000/25000 [==============================] - 5s 183us/step - loss: 0.6926 - acc: 0.5456\n", "Epoch 3/5\n", "25000/25000 [==============================] - 5s 189us/step - loss: 0.5968 - acc: 0.7288\n", "Epoch 4/5\n", "25000/25000 [==============================] - 5s 182us/step - loss: 0.3167 - acc: 0.8695\n", "Epoch 5/5\n", " 8032/25000 [========>.....................] - ETA: 3s - loss: 0.5864 - acc: 0.8430\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEOCAYAAABmVAtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XmYXGWZ9/HvXdV7p5cs3dlDJ5AAIexhRxZlR0EFEUYdEBVRUQfRV1BHFtHRYXBGWVQEFVlERGUCQQEFJmyBJBAISUhIQhISsnX2dCe9VD3vH+dUdXWneu9Tp6rr97mups9Wde4uKnXXs5tzDhEREYBI2AGIiEj2UFIQEZEkJQUREUlSUhARkSQlBRERSVJSEBGRJCUFERFJUlIQEZEkJQUREUlSUhARkaSCsAPorREjRri6urqwwxARySnz5s2rd87VdHddziWFuro65s6dG3YYIiI5xcxW9eS6QKuPzOwsM1tiZsvM7No05//bzOb7P0vNbFuQ8YiISNcCKymYWRS4AzgdWAPMMbMZzrlFiWucc1enXP9V4PCg4hERke4FWVI4GljmnFvhnGsGHgLO7+L6S4A/BBiPiIh0I8ikMBZ4L2V/jX9sL2a2DzAReCbAeEREpBvZ0iX1YuAR51ws3Ukzu8LM5prZ3E2bNmU4NBGR/BFkUlgLjE/ZH+cfS+diuqg6cs7d5Zyb7pybXlPTbY8qERHpoyC7pM4BJpvZRLxkcDHwLx0vMrMDgKHAywHGQv2uJtZs3U0sHqe51RH3lyG15H/AMMy83da4w4DCgggFESNi3kWJxUsTy5jGHcSdIxZ3RMyIO+9xZt5zJZ4+sZ24WftzlnIm9Vz7azs719n1gdw35TFd3ZeuzvnP09l9sfTXQ9vrbmYURLyfaMSS50WkfwJLCs65VjO7CngSiAK/cc4tNLObgLnOuRn+pRcDD7mAF4t+ZN4afvy3t4O8hYSoMOonB4yCqFEUjVAQNQqjEYqiEQqjEQoLvP3CiLddFI1QXBClpDBCSWGU4oK238X+7/LiAqpLC6kqLaSytJDqMm97SHGBEpEMShbwZ/GAmz59uuvL4LWV9Q0s37SLgmjE+wAxS/nWDy6x57zSQEHEO98Si9Mac23n6fCNFiiIRIiYV2qImPd4l3w6l7KduJ9r20+5jtTHpTym3fUp50j73N3ft+N1Pb0vXT535+dIE39nf1PqfdPFmijNxZ2jNe5oaXXE4nFa4l5pzSWOx+K0tPq/446W1jgtsTjNMf93a5yWmKOpNcaelnjy956WGE2tcbpTGDWqSosYWVlM3fByyoujFBVEGFpWxJjqUsZWlxKNGBUlBew/qoLigmi3zykSJDOb55yb3t11OTeiua/qRpRTN6I87DAkBzjnaGqN09Qap7G5lW2NLWzf3cK2xhZ27G5h2+5mtjS0sH13M+u272Hh+9tpaI7RGouzY08rsXj7pBYxGD+sjH1rhjB+aClH7DOU4yYNp7ayJKS/UKRzeZMURHrKzCgpjFJSGKWqtJDRVaU9fmxrLM6GnU2s3bqbWNyxpaGZJRt2snzTLpZv3MUrKzZz78vebAP71pRz8pRaLjhyLPvVDlFpQrJC3lQfiWSDWNyx6P0dvLyinuff8X4AyoqiXHjkOP71uDr2qx0ScpQyGPW0+khJQSREa7ft5h+LNvD0og3MXrGZ1rjj7GmjuPmj0xg+pDjs8GQQUVIQyTEbd+7hgdmruf3ZZUQjxulTR3L9R6ZSW6G2B+m/niaFbBnRLJL3aitKuPr0KTx21YlcctR4nl60gaN/+E9eWl4fdmiSR5QURLLM1DGV3Hj+NP70xeMA+OJ985izckvIUUm+UFIQyVKHjq9mxlUn0BpzfPJXL/PfTy/dawyHyEBTUhDJYoeMq+af15zMqfvX8rN/vsOV989j++6WsMOSQUxJQSTLjaku5df/Op3vnHMATy7cwDUPz6cl1v2oa5G+UFIQyQGRiHHFSfvy3XMO5B+LN3Lns8vDDkkGKSUFkRzy+Q9M5Oxpo/jVrOVsbWgOOxwZhJQURHKImfFvp02hsTnGb158N+xwZBBSUhDJMfuPquDcQ0ZzzwvvsnOPGp1lYCkpiOSgy46vo7E5xv2zV4cdigwySgoiOejICUM5aEwltz61hC1qW5ABpKQgkoMiEeMnFxxCa9zxxIJ1YYcjg4iSgkiOOmhMJQeOruSBV1ZrpLMMGCUFkRxlZnz62AksXreD11ZvCzscGSSUFERy2EcPG8uQ4gJ1T5UBo6QgksPKiwu48MhxzHxzHcs37Qo7HAnItsZmzvqfWcx8M/j2IyUFkRz3qWMmAPDs2xtDjkSC0tAc4+31O2loag38XkoKIjkusabzzTMX09gc/IeGZN7u5hgAJUXRwO+lpCCS48yMjx8+FoDLfjsn5GgkCHtavKRQWqikICI9cOtFhwKwsr4h5EhkIK3bvpvv/nUBO/w1NEoKg//IVlIQGQTMjJvOP4iNO5tYvbkx7HBkgNw4YxEPvLKaJxeuB6BEJQUR6anj9x0BwIvL60OORAZKxP+E3uk3MBdELPh7Bn4HEcmIfWvKGVZexOurt4YdigwQMy8JxOLeiPWCiKqPRKSHzIyDx1bx5prtYYciAyRRLkisyx1VSUFEeuPQ8dW8vX4nNz++KOxQZAC8t8VrH3puySZASUFEeumy4+sAuPuFdzVJ3iBw5D7DABhZWQwoKYhILw0rL+KGj0wFYNG6HSFHI/2VmNOqvLgAUEOziPTBiZO9XkizlqoX0mCxYpM3/kQlBRHptf1qKxhWXsTqLRrIlsvi8b2r/5QURKRPJo4oT367lNwUS9MmpOojEemTSSPKWb5plxqbc1gsTUkhoqQgIn0xvW4o9buaWbJhZ9ihSB/FVVIQkYFyzMThAMzXMp05a1CWFMzsLDNbYmbLzOzaTq65yMwWmdlCM3swyHhE8sWEYWUAXPuXBSFHIn0Vj3u/U0sHOV1SMLMocAdwNjAVuMTMpna4ZjJwHXCCc+4g4N+Cikckn6R+o3xHVUg5KdHQPGJIcfJYSUFuz5J6NLDMObfCOdcMPASc3+GaLwB3OOe2AjjntJ6gyAD536+cAMAzWqYzJ7X6RYXUJThzvfpoLPBeyv4a/1iqKcAUM3vRzGab2VkBxiOSVw4dX83EEeXMW6VZU3PRnc8uB9qmzc6UgozeLf39JwOnAOOAWWZ2sHOuXeuYmV0BXAEwYcKETMcokrP2H1nB4vWa7iIXzV6xOZT7BllSWAuMT9kf5x9LtQaY4Zxrcc69CyzFSxLtOOfucs5Nd85Nr6mpCSxgkcHmgNEVrNrcmFzjV3JHWENMgkwKc4DJZjbRzIqAi4EZHa55FK+UgJmNwKtOWhFgTCJ5pW54OdA2BbPkjkRDc2kGluBMFVhScM61AlcBTwKLgYedcwvN7CYzO8+/7Elgs5ktAp4FvuWcC6fMJDII7TPc65r6l9c7FtIl2yXmPvrQgbUZvW+gbQrOuSeAJzoc+37KtgO+4f+IyAA7cHQlAOu37wk5EumtxIjmxCR4x0wclpH7ht3QLCIBKimMMmlEOU2talPINfW7mpPbC244g+IMjFEATXMhMugdMLqCF96pZ3ezEkMu2eV3Rf3WmftTUVJIUUFmPq6VFEQGuTMPGsWOPa28/p7GK+SicUPLMno/JQWRQe6ICUMB9UCSnlFSEBnkRleVAHD/7NUhRyK5QElBZJAriHr/zBes3R5yJNJbdcMzW3UESgoieeETR46jtqK4+wsla0QjxrmHjM74fZUURPLAhGFlbNzZ1G7GTclesbgjFncURTM7mhmUFETywkFjvUFsb65RFVIuaG71ps3OVDfUVEoKInng4LHVALytGVNzgpKCiARqxJAiSgojrNm6O+xQpAeaYt5AQyUFEQmEmRGPwz0vvBt2KNIDiZJCcVRJQUQC0hzzPmg27tDkeNmuJeZNhqeSgogE5pYLDwFg8fqdIUci3XlxWT0ADc2Z7y2mpCCSJ06fOhKAt9epsTnbNfnVR4eOq874vZUURPJEdVkRw8qLeFMjm7NeRYm3qkF1WWHG762kIJJHtjQ0M/PNdWGHId1o9dsUCiJqUxCRAE3zB7FJdovFveqjxKprmaSkIJJHTj9wFIBWYstyrf76zIVRJQURCdCoKm9SvI07mkKORLryyootgEoKIhKwkZXe2gobNFYha/1twTr+vnA94K2xnWlKCiJ5ZHRVKQDrlRSy1pceeC25XaCSgogEaZRfUli/XUkhF5gpKYhIgCpLCygrirJ2mybGk/SUFETyiJmxb80Qlm3cFXYokqWUFETyzOSRQ3hng5KCpKekIJJnJtdWsH7HHnbsaQk7FEljbLXXGeCeS6eHcn8lBZE8M7l2CIBKC1nqQwfW+r9HhnJ/JQWRPDNlZAUAF/ziJVr8NRYke8SdY3h5UWj3V1IQyTNjh5Ymt59YoMnxsk0sHk5X1AQlBZE8E40Y3zxjCgDz39sWcjTSkXOOEFbhTFJSEMlDV31wMqcdOJLfvrhSk+NlmbhzRFRSEJFMmzrGm0Z79ebGkCORVLE4SgoiknnHTBwGwOaG5pAjkVTOOULMCUoKIvkqsdTjJb+eHXIkkiruXChTZicoKYjkqQNHedVHzqG5kLJI3Kn6SERCEIkYh46rAuCbD78RcjSSEFP1kYiE5feXHwPA7hb1QMoWbjD3PjKzs8xsiZktM7Nr05y/zMw2mdl8/+fzQcYjIu1VlRVyVN3QUNYClvTicYgOxqRgZlHgDuBsYCpwiZlNTXPpH51zh/k/dwcVj4ikN3lkBUs37MI5F3YogtfQPFirj44GljnnVjjnmoGHgPMDvJ+I9MGU2iFs393Cxp1NYYciDO7Ba2OB91L21/jHOrrAzN40s0fMbHyA8YhIGokJ8pZu2BlyJAJ+76M8nubiMaDOOXcI8DRwb7qLzOwKM5trZnM3bdqU0QBFBrspo7ykcONji0KORABaYnEKQ5z8qEd3NrOvm1mlee4xs9fM7IxuHrYWSP3mP84/luSc2+ycS5RZ7waOTPdEzrm7nHPTnXPTa2pqehKyiPTQiCHFAFqiM0s0tcYpLsjypABc7pzbAZwBDAU+A/y4m8fMASab2UQzKwIuBmakXmBmo1N2zwMW9zAeEZFByUsK0dDu39OkkGj1OAe4zzm3MOVYWs65VuAq4Em8D/uHnXMLzewmMzvPv+xrZrbQzN4AvgZc1ts/QET6LzGV9h6NVwhdU0ss1JJCQQ+vm2dmTwETgevMrALodskm59wTwBMdjn0/Zfs64LqehysiQUhUIW1uaE6uESzhaG6NU5QDSeFzwGHACudco5kNAz4bXFgikklj/ESwZkujkkLIcqX66DhgiXNum5l9GvgesD24sEQkk+qGlwPwzUc0B1LYmlrjFBdmf0PzL4BGMzsUuAZYDvw+sKhEJKPGVJcA8N4WzZYatqbWcNsUenrnVueNgT8fuN05dwdQEVxYIpJJBdEIB/jjFWJxTXcRpqaQ2xR6euedZnYdXlfUmWYWAQqDC0tEMu3saV4PcY1sDldrLE5hiEOae3rnTwJNeOMV1uMNRLslsKhEJOPKi73GzVueXBJyJPnLOedPc5Hlcx/5ieABoMrMPgzscc6pTUFkEPnMcfsAMGFYWciR5K9EzV3WT51tZhcBrwKfAC4CXjGzC4MMTEQyK9EN8ncvrQw3kDyWaM8pCHF9i56OU/gucJRzbiOAmdUA/wAeCSowEQmPcw4Lc1L/PJVICrkwdXYkkRB8m3vxWBHJEV88eRIA9buaQ44kP8X8hY5CnCS1xx/sfzezJ/3lMy8DZtJh+goRyX3HThoOwMrNDSFHkp9ypqTgnPsWcBdwiP9zl3Pu20EGJiKZN9Ef2byyXkkhDPF4oqSQ/W0KOOf+DPw5wFhEJGS1ld7EeOu37wk5kvzUVn2UpSUFM9tpZjvS/Ow0sx2ZClJEMqO00OuBdOvTS0OOJD/Fs6D6qMuSgnNOU1mI5BH1OApX1pcURCT/fOmUfQHNgRSG1piSgohkmQL/A+mFZfUhR5J/4omSQrb3PhKR/HHK/rUAPPbG+yFHkn9iWdD7SElBRNo5cp+hADwyb03IkeSfREkh6yfEE5H81NDUGnYIeSUW936r+khEssovPnUEAO9qEFtGtVUfhReDkoKI7GVUlbc854YdGsSWScnqI5UURCSbjKkuBWDtNq3ZnEnZMHW2koKI7KW2opghxQUs27gr7FDySmsWjGhWUhCRvZgZ+9UO4Z0NSgqZFNeIZhHJVlNGDuGdjTvDDiOvJBuaVVIQkWwzubaC+l3NbG3QgjuZonEKIpK19q311lZYUa8qpEzxc4LaFEQk+0waMQSA5Zs0ViFT2lZeCy8GJQURSWvcUK9b6u3PLAs5kvyh6iMRyVoF/rDa1VsaWaYG54xQ9ZGIZLXESmzvbdEgtkxQ9ZGIZLVnvnkyAKs2q10hEzTNhYhktVGV3hxINzy2KLl+sAQnruojEclmqWs2Xz9joZboDFhbQ3N4MSgpiEiXfnLBwQDcN3sVL2qJzkBpOU4RyXrnHTo2ub25oSnESAa/REHMlBREJFuVFkX55hlTAHh/m9ZXCFJ8sPc+MrOzzGyJmS0zs2u7uO4CM3NmNj3IeESkb6764GSGlxexZqu6pgZpUM+SamZR4A7gbGAqcImZTU1zXQXwdeCVoGIRkf4bXV3C+u1KCkEa7L2PjgaWOedWOOeagYeA89Nc9wPgJ4DKpSJZbFRlCeu2659pkBLVRyHmhECTwljgvZT9Nf6xJDM7AhjvnJvZ1ROZ2RVmNtfM5m7atGngIxWRbo2qKuHt9Zruois79rRw70srca5vXXcHdfVRd8wsAvwUuKa7a51zdznnpjvnptfU1AQfnIjsJdHIfNs/3wk5kux18+OLuH7GQv40b02fHj/Yq4/WAuNT9sf5xxIqgGnAc2a2EjgWmKHGZpHstGN3CwC3Pr2Uhe9vDzma7LRpp9dl9/898iYr63s/NUjMDe7qoznAZDObaGZFwMXAjMRJ59x259wI51ydc64OmA2c55ybG2BMItJHP/jotOT2wrU7Qowke52w34jk9hd+3/uPspsfXwQM0pKCc64VuAp4ElgMPOycW2hmN5nZeUHdV0SCccCoiuT2jj0tIUaSvSpLCpPb72zs3Yp1/7d0E02tcSDcEc0FQT65c+4J4IkOx77fybWnBBmLiPRP6ijbm2cu5vMfmBRiNNnJ0fe5oZ59e2NyW4vsiEhOWP6jc8IOIav1Z77ACcPKktslheF9NCspiEiPRSPGuYeMBtBU2mnE+9gVFbzpRBKKokoKIpIjFr3vNTL/ce573VyZf/qTJ1O7+mpCPBHJGVeduh8Q7qRtWasfJYWqsqIBDKTvlBREpFdOnOx1u2yJqfqoo65KCs2tcT5/7xzeWpt+jMeBKb27wqSkICK9UlXqdbt8/M33Q44kO/x61grqrvVm6umqTWH1lgb+sXgjX3nwtbTnEwPXbjzvoIEPsheUFESkV0oKvQbR2Su2hBxJdvjhE4sBWFnf0GVJocBfY3NdJ2tSxOKOSSPKufT4uoEOsVeUFESk106eUkNtRXHYYWSVnXtakxPhfeqYCXv1ILp+xkIAmmNxZr65bq/Hx50LdXxCgpKCiPTafrVD2NXUGnYYWeU3L76brD4qK4rSHIsT84sO8bjj/5a2zfD8lQdfo+7amcx/b1vy2Kyl9Szr5SjoICgpiEiv1VYU09gcY1tjc9ihBOb11Vv5xC9f4uRbnk1+uHflr6+vTTa+J6a7aGz2Euf3Z7yV9jGJuY6ArEmySgoi0muJxuZLfj14F0z82J0vMWflVlZtbuTTd6f/O3/ktycADCsv4pYnlwAwblgpAEs3eOtP3D97ddrHr92WfSvZKSmISK+dNMVb12TxuvyYLfXlFZvTHr9r1ork9paGtlLT2Gpvyoru1rROrGS3anMDw8qL+PgRY7u8PhOUFESk10ZWliS3B+OMqf2dwmPiiHKgbQ2KcUNLk+cKo8Y9l7YtG+Oc4+RbnmNLQ3OyZ1eYlBREpNeiEeMLH5gIdN7FMpdd+ttX2+1XlHQ9ofQ3Tp+S3H7um6ckq9e2NrYQj7t2JYZXvnMaHzpwZHJ/4fttpa2SAiUFEclRZx40CoD3t2dfvXh/dVwjubE5lnbd5eHlRXzqmAlMG1uZPFZZWkhRQYRh5UWsrG9g5oL23U/Litp/8Dc2x1K2w29sVlIQkT4ZXe1ViQzGkkJVaSEThpVx7+VHc8J+w4nFHXta4ntdt6clRklhlEPGVSePDSn2ShXlxVH+8vpavvqH15PnfnbxYckqoj9deRzQNn4B4KE54U8yqKQgIn0ysqKYiMH7WdiDpj+cc/zv/PfZsGMPJ0+p4axp3lThO5ta9rpud0uM0sIow8vbJrMrKvA+VhubYu2u/8bpUzj/sLaG5ERZJLWxviYLBgQGuvKaiAxeBVGviqR+V1PYoQyo3S3eh3liacwK/5v/rj2t1KbMWdcSc8Sdtw5CuqmuNze0H8Nx4OjKdvv71Q5pt//vH57KRdPH9Tv+/lJJQUT6rH5Xc1ZUeQykJr+a6LxDxwBQVeY1GnfslppIHsV+ySDxuzOHjKtqt19dVsS5B49O7n/uxIlUpKzxHBaVFESk3xJ164NBo/9hf9y+wwEYU+W1nXz3r2/x4UPGJHsWNfnXJVZMe+HbH+xy5HNqN96E2y45nMtPrKMomj2vnUoKItJnXzx5EtD9IK1csmmnVx02YohXvz9lZFs1z6E3PkVLzCtJJEoKpX4yrKkoZlTV3h/8qdd0FIkYR+4zjIM7lCLCpKQgIn12xlSvv/3qLQ0hRzJwvv+/3jxFo/xv9h3bC/7Ln8ri+XfqATotId3/uWOS2z+/5PABjzMoqj4SkT7bZ7g3cnfV5saQI+k/5xwn/uTZ5HxEIyvT9wT61awVHLnPUL73qJc8dnYyojtR/QRw+tSRaa/JRkoKItJnw8uLKC+KZm1SOPIHT3P5iRP5ir+udFd2t8TaTVCXqD5K59klG5PbE4aVp70mGjH+8uXj91pXIdvlVrQiklXMjAnDy3ln486wQ9lLSyzO5obm5MylCd/96wJe8Kt+Um3f3f4bf1cL3vzh1bYeV6klgo6OmDCUaWOzp72gJ5QURKRfDhpTybxVW2lqjXV/cQalll4a/LUKmlvjPPDKaj59z95TYe/Y3TbFxE8uODj4ALOUkoKI9MuZB41iT0uceSu3dnttLO54dsnGtPMIDbTUqawTo65ffbfzdaUTJYUDRlXwyaMmtDv3zDUnt5vZNOGpq08aiFCzipKCiPTL8fsOpyga4bmU5SY7c9/LK/nsb+fwt7fWBx5XalJYv8Obn2nh+9s7vX7uKi9h3HrRoXudm1QzhA8dOJIj9xna7viUkRV7XZvrlBREpF/Kiws4eFwVb6SsN9yZdf6H87v1wXdh3ZqyVOh6fzGbeavaSjPLN7VfDzkxsd9BYzpvA/jNpUfx9Q9NHsgws46Sgoj029jq0uS38a48/oY3jXTHxt8gtCspbN/DnpYYTy3akDx2y9/bx7CifhdTO8xP1FFVWSFXp6ydMBipS6qI9Nvo6hL+vnAPzrm0k8MlZHJN4m2NzZQWRtndEuPWp5fuNcis4wR1qzY37lU91JkXvn1q2qm0BwOVFESk30ZXltDcGt9rZtCOjp00DIDqsuAnfnv13S0UpHQrbfanpxhaVkhFSQFrtrYfW7G1obnLsQmpxg0t22uW08FCSUFE+m2UP2lcou4+Heccs1d4jbnbGlvY0xJsF9Y31mxnZ1Mrpx1YC7RVWb3yndPYt2YIq7a0JYU9LTEammMMS1kXIV8pKYhIv42p9uYJ6mrBndRxAACzO0xFPZBS13g4dlL7wWVFBRHGDytjQ0obyLZGrzvq0DIlBSUFEem3xOygV9w3j/tmr0p7zcad7UsRl/12Ds2t6evl731pJXXXzmRlH3sp/XDmYgCmja1MznqaqraimFWbG5OllURPpaEZqNbKdkoKItJvI8rb6uL/3Z8orqNE76SfpowDmHb9k2mvTaxbfH8nCaY744Z61Vm//tfpyUn7oK0t4zl/7qK7n18BeO0JAENVfaSkICL913GeoHQlgA07vG/sh45vW+S+ORbfa3RzY3NbNdPdL7zbp3i2NjZTXVbI6KpSLjl6PCOGeB/2N50/DYDPnjARgP96aikAW/ySgtoUAk4KZnaWmS0xs2Vmdm2a81ea2QIzm29mL5jZ1CDjEZHMWLZx117HEnX4Y6pKufNTRySP1+9q32Ppk7+a3W6/L1Ni/G3Besr8LqhmxslTvMbmPc1eddEp+9e0uz5RUshEr6hsF1hSMLMocAdwNjAVuCTNh/6DzrmDnXOHAf8J/DSoeEQkWL/89BHJRe5ffXczz7y9gY/d+SKtflfQDTv2UFlSQGlRlHMOHs3nT/S+rSdGN29vbOGoH/6DBWu9qSj2rfGqff7YyzWgl23cyeaGZt5P6Ql1+ASvdDJumFetVFna9uH/0rJ6ZvmzpqqhOdiSwtHAMufcCudcM/AQcH7qBc65HSm75UDws2SJSCDOmjaa+defAcANjy3i8t/N5fXV2/j9y167wJqtuxlTXZq8/spT9gXg+Xc28eaabXzrkTfaNQp/8STv/NbG9IvYdOarf5i/17FPHTOBx646keP8nkhDitrG7f7L3a/wtD/SuTDH1j4IQpCvwFggNcWv8Y+1Y2ZfMbPleCWFrwUYj4gELJpmDYJEz553Nu5sN+ArMVDstmeWcd7tL7abgmJYeRHnHDIaoMfjGbY0NPO7F9+lMOrF8L1zD0yeMzMOHleVHG0diRjXnX1Ab/60vBF6WnTO3eGc2xf4NvC9dNeY2RVmNtfM5m7a1P1MjCISnsuOr2u3/7e31nPf7FWs2bqbybU9m1X0sPHVyTaBn/3znbTXbN7VxII1bbOeXvPwfG54bBHVfhXQ5X5jcmeuOGkSJ02p6fKafBRkUlgLjE/ZH+cf68xDwEfTnXDO3eWcm+6cm15To/+JItnshvMOare/bOMu/v3Rt3CubZBbwrfPSv9t/fSpI7tc+Qzg4794iY/c/gI1exr6AAAL8ElEQVTXPPwG8bhLzqs0a+kmItb1ymnglR5+9LFp3f05eSfIpDAHmGxmE82sCLgYmJF6gZmlzkF7LpD+K4GIDAoVJe1793S2WltiXeNr/BlJU7upAsx/b1tyZbU/v7aGSd95gqUb2no8xXvYOjk2pY3j1e98qGcPGuQCSwrOuVbgKuBJYDHwsHNuoZndZGbn+ZddZWYLzWw+8A3g0qDiEZHMmey3HVxy9Ph2xytL2k/MPHFE+0XvP3641+yY+EyfPNJ7nuUb249s/ugdLw5InGbGrG+dyl+/fDy1lSXdPyAPBDp1tnPuCeCJDse+n7L99SDvLyLheOrqk2iOxZm3amu7Re47zix63qFjACguiLKtsZnF67wOidv8xunR/kR7Nz62kK98cD9O3b+W+T1YzKc3JgwvY8LwsgF9zlym9RREZMCZGcUFUY6bNJyfXXwYHzyglmjEKCsq2Ou68w9r65T4j0UbuPflVdT5U1MkBpPNXbWVz/52Dit/fC71aeYySvjVZ47khhkL+elFhwXwV+UHJQURCUzHD/3unDZ1JLO+dWrym3t1msFkhQVerXc0Ytx2yeF8+YHXAFj543MBOPOgUf0NO68pKYhIVkmtykmMkE71zGJvPMNjV53I1DGVLLrpzL2m5Za+C32cgohIZyIR46sf3K/dsXv9EdJlRVH/d0Fy6m7pPyUFEclq15yxPxceOQ6g3cI4qfMXycBRUhCRrHfQmEoArvvLguQxLYgTDCUFEcl6lx1fx9CyQp55e2PyWGIeIxlYSgoikvXMrN1sqSMri7u4WvpDSUFEcs4fvnBs2CEMWkoKIpITbrvkcGoriln+o3OYVDOk+wdIn2icgojkhI8cOoaP+NNiSHBUUhARkSQlBRERSVJSEBGRJCUFERFJUlIQEZEkJQUREUlSUhARkSQlBRERSTLnXPdXZREz2wRsA7anHK5K2U9sd/w9Aqjvwy1Tn7un5zoe72q/Y5ypx/oSc1fxdna+u2PdbSve7s/39D0x0PF2F/NgfQ/3JN50sQ/m9/A+zrmabq9yzuXcD3BXZ/uJ7TS/5w7EvXpyrqv4uou3vzF3FW9n57s71t224h2498RAx9tdzIP1PdyTeMN8T4T5Hu7uJ1erjx7rYv+xTn4P1L16cq6r+Drup4uzPzF399h057s71t224u3+fE/fEwMdb3ePH6zv4Z7Em7oddrwdjwUZb5dyrvqor8xsrnNuethx9Eauxax4g5Vr8ULuxax486uh+a6wA+iDXItZ8QYr1+KF3Is57+PNm5KCiIh0L59KCiIi0g0lBRERSVJSEBGRJCUFwMw+YGa/NLO7zeylsOPpjplFzOyHZnabmV0adjw9YWanmNnz/ut8Stjx9ISZlZvZXDP7cNixdMfMDvRf20fM7Ethx9MdM/uomf3azP5oZmeEHU9PmNkkM7vHzB4JO5bO+O/Ze/3X9lN9eY6cTwpm9hsz22hmb3U4fpaZLTGzZWZ2bVfP4Zx73jl3JfA4cG+2xwucD4wDWoA1QcWaEttAxOyAXUAJAcc8QPECfBt4OJgo28U1EO/hxf57+CLghByI91Hn3BeAK4FPBhmvH9tAxLzCOfe5YCPdWy9j/zjwiP/antenGw70aLhM/wAnAUcAb6UciwLLgUlAEfAGMBU4GO+DP/WnNuVxDwMV2R4vcC3wRf+xj+TCawxE/MeNBB7IgXhPBy4GLgM+nO3x+o85D/gb8C+5EK//uFuBI3LhPZzyuMD/zfUj9uuAw/xrHuzL/QrIcc65WWZW1+Hw0cAy59wKADN7CDjfOfcfQNqqADObAGx3zu0MMNwBidfM1gDN/m4suGg9A/Ua+7YCxUHEmTBAr/EpQDneP7TdZvaEcy6erfH6zzMDmGFmM4EHg4h1oOI1MwN+DPzNOfdaULEmDPB7OKN6EzteKXwcMJ8+1gTlfFLoxFjgvZT9NcAx3Tzmc8BvA4uoa72N9y/AbWb2AWBWkIF1oVcxm9nHgTOBauD2YENLq1fxOue+C2BmlwH1QSWELvT29T0Fr+qgGHgi0MjS6+17+KvAaUCVme3nnPtlkMF1orev8XDgh8DhZnadnzzC0lnsPwduN7Nz6eNUGIM1KfSac+76sGPoKedcI14SyxnOub/gJbOc4pz7Xdgx9IRz7jnguZDD6DHn3M/xPsByhnNuM14bSNZyzjUAn+3Pc+R8Q3Mn1gLjU/bH+ceyVa7FC7kXs+INVq7FC7kZc0JgsQ/WpDAHmGxmE82sCK/BcEbIMXUl1+KF3ItZ8QYr1+KF3Iw5IbjYM9mKHlDL/B+AdbR1z/ycf/wcYCleC/13w44zV+PNxZgVr+IdDDGHFbsmxBMRkaTBWn0kIiJ9oKQgIiJJSgoiIpKkpCAiIklKCiIikqSkICIiSUoKEjgz25WBe5zXw+mwB/Kep5jZ8X143OFmdo+/fZmZhTEX1F7MrK7j9Mxprqkxs79nKibJPCUFyRlmFu3snHNuhnPuxwHcs6v5wU4Bep0UgO+QY/P+JDjnNgHrzCzQNRskPEoKklFm9i0zm2Nmb5rZjSnHHzWzeWa20MyuSDm+y8xuNbM3gOPMbKWZ3Whmr5nZAjM7wL8u+Y3bzH5nZj83s5fMbIWZXegfj5jZnWb2tpk9bWZPJM51iPE5M/sfM5sLfN3MPmJmr5jZ62b2DzMb6U9lfCVwtZnNN2/1vhoz+7P/981J98FpZhXAIc65N9KcqzOzZ/zX5p/+dO6Y2b5mNtv/e29OV/Iyb8WtmWb2hpm9ZWaf9I8f5b8Ob5jZq2ZW4d/nef81fC1dacfMomZ2S8r/qy+mnH4U6NOqXpIDwh7CrZ/B/wPs8n+fAdwFGN4XkseBk/xzw/zfpcBbwHB/3wEXpTzXSuCr/vaXgbv97cuA2/3t3wF/8u8xFW/eeYAL8aaVjgCj8NZ2uDBNvM8Bd6bsD4Xk6P/PA7f62zcA30y57kHgRH97ArA4zXOfCvw5ZT817seAS/3ty4FH/e3HgUv87SsTr2eH570A+HXKfhXe4isrgKP8Y5V4MyOXASX+scnAXH+7Dn8hF+AK4Hv+djEwF5jo748FFoT9vtJPMD+aOlsy6Qz/53V/fwjeh9Is4Gtm9jH/+Hj/+Ga8RYT+3OF5ElNwz8NbQyCdR523BsIiMxvpHzsR+JN/fL2ZPdtFrH9M2R4H/NHMRuN90L7byWNOA6aaWWK/0syGOOdSv9mPBjZ18vjjUv6e+4D/TDn+UX/7QeC/0jx2AXCrmf0EeNw597yZHQysc87NAXDO7QCvVIE35/5heK/vlDTPdwZwSEpJqgrv/8m7wEZgTCd/g+Q4JQXJJAP+wzn3q3YHvQViTgOOc841mtlzeGs5A+xxznVcXa7J/x2j8/dwU8q2dXJNVxpStm8Dfuqcm+HHekMnj4kAxzrn9nTxvLtp+9sGjHNuqZkdgTdJ2s1m9k/gr51cfjWwATgUL+Z08RpeiezJNOdK8P4OGYTUpiCZ9CRwuZkNATCzsWZWi/ctdKufEA4Ajg3o/i8CF/htCyPxGop7ooq2ueovTTm+E6hI2X8Kb0UxAPxv4h0tBvbr5D4v4U2BDF6d/fP+9my86iFSzrdjZmOARufc/cAteGv6LgFGm9lR/jUVfsN5FV4JIg58Bm+9346eBL5kZoX+Y6f4JQzwShZd9lKS3KWkIBnjnHsKr/rjZTNbADyC96H6d6DAzBbjrds7O6AQ/ow39fAi4H7gNWB7Dx53A/AnM5sH1Kccfwz4WKKhGfgaMN1vmF1EmlW6nHNv4y1BWdHxHF5C+ayZvYn3Yf11//i/Ad/wj+/XScwHA6+a2XzgeuBm51wz8Em8pVvfAJ7G+5Z/J3Cpf+wA2peKEu7Ge51e87up/oq2UtmpwMw0j5FBQFNnS15J1PGbt97uq8AJzrn1GY7hamCnc+7uHl5fBux2zjkzuxiv0fn8QIPsOp5ZeAvcbw0rBgmO2hQk3zxuZtV4DcY/yHRC8P0C+EQvrj8Sr2HYgG14PZNCYWY1eO0rSgiDlEoKIiKSpDYFERFJUlIQEZEkJQUREUlSUhARkSQlBRERSVJSEBGRpP8Pv247cf9mtdAAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like the maximal learning rate associated with a still-falling loss (prior the loss diverging). Based on the plot, we will start with a learning rate of 0.005." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Interactive Training\n", "\n", "It is sometimes advantageous to train interactively. For instance, one can train a model for one or two epochs using one learning rate. Then, based on the results, a higher or lower learning rate can be used for subsequent epochs. *ktrain* makes such interactive training easy. Here, using the fit method of the Learner object, we train a single epoch at the learning rate found previously and a second epoch at a slightly lower learning rate. The first argument is the learning rate and the second argument is the number of epochs." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/1\n", "25000/25000 [==============================] - 5s 197us/step - loss: 0.4010 - acc: 0.8293 - val_loss: 0.2984 - val_acc: 0.8777\n", "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/1\n", "25000/25000 [==============================] - 5s 183us/step - loss: 0.2105 - acc: 0.9283 - val_loss: 0.2869 - val_acc: 0.8860\n" ] } ], "source": [ "# reinitialize the model to train from scratch \n", "learner.set_model(get_model())\n", "\n", "hist = learner.fit(0.005, 1)\n", "hist = learner.fit(0.0005, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Learning Rate Schedules\n", "\n", "In the example above, a static learning rate is used throughout each epoch. It is sometimes beneficial to employ the use of learning rate schedules to *automatically* adjust the learning rate during the course training to more effectively minimize loss. Such adjustments can help jump out of suboptimal areas in the loss landscape and get to \"sweet spots\" with minimal loss that generalize well. *ktrain* allows you to easily employ a variety of demonstrably effective learning rate policies during training. These include:\n", "\n", "* a [triangular learning rate policy](https://arxiv.org/abs/1506.01186) available via the ```autofit``` method\n", "* a [1cycle policy](https://arxiv.org/abs/1803.09820) available via the ```fit_onecycle``` method\n", "* an [SGDR](https://arxiv.org/abs/1608.03983) (Stochastic Gradient Descent with Restart) schedule available using the ```fit``` method by supplying a *cycle_len* argument.\n", "\n", "\n", "### SGDR\n", "We will begin by covering SGDR. *ktrain* allows you to easily employ an SGDR learning rate policy in a similar style to that of the *fastai* library. We will begin with covering the cycle_len parameter.\n", "\n", "**cycle_len:** When *cycle_len* is not None, the second argument fo ```fit``` is interpreted as the number of cycles instead of the number of epochs. For instance, the following call runs 2 cycles each of length 2 epochs - totaling 4 (or 2 * 2) epochs. The learning rate gradually decreases throughout the 2-epoch cycle and then restarts at 5e-3 at the start of a new 2-epoch cycle. Decreases follow a functional form (cosine annealing). More information can be found in the original [SGDR paper](https://arxiv.org/abs/1608.03983)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/4\n", "25000/25000 [==============================] - 6s 230us/step - loss: 0.4063 - acc: 0.8253 - val_loss: 0.3004 - val_acc: 0.8841\n", "Epoch 2/4\n", "25000/25000 [==============================] - 6s 223us/step - loss: 0.2265 - acc: 0.9209 - val_loss: 0.2874 - val_acc: 0.8872\n", "Epoch 3/4\n", "25000/25000 [==============================] - 6s 222us/step - loss: 0.2062 - acc: 0.9227 - val_loss: 0.2840 - val_acc: 0.8880\n", "Epoch 4/4\n", "25000/25000 [==============================] - 6s 227us/step - loss: 0.1397 - acc: 0.9555 - val_loss: 0.2812 - val_acc: 0.8894\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# reinitialize the model to train from scratch \n", "learner.set_model(get_model())\n", "\n", "# training using cycle_len \n", "learner.fit(5e-3, 2, cycle_len=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ```learner.plot``` method can be used to plot the training-validation loss (with 'loss' as argument) in addition to plotting the learning rate schedule with ('lr' as argument) and momentum schedule (with 'momentum' as argument) where applicable. Here, we plot the learning rate schedule employed by the previous call to ```learner.fit```." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('lr')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**cycle_mult:** The cycle_mult parameter allows you to increase the cycle length as training progresses. For instance, cycle_mult=2 will double the length of the cycle. In the example below, seven epochs are run:\n", "- first cycle has length of one epoch\n", "- second cycle has length two epochs\n", "- third cycle has length of four epochs\n", "Each cycle will begin at a learning rate of 5e-3 and gradually decrease until it resets at the beginning of the next cycle.\n", "\n", "Note that the example below overfits. It is shown to merely illustrate the *cycle_mult* parameter." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/7\n", "25000/25000 [==============================] - 6s 235us/step - loss: 0.4295 - acc: 0.8241 - val_loss: 0.3478 - val_acc: 0.8704\n", "Epoch 2/7\n", "25000/25000 [==============================] - 6s 226us/step - loss: 0.2622 - acc: 0.9010 - val_loss: 0.2829 - val_acc: 0.8879\n", "Epoch 3/7\n", "25000/25000 [==============================] - 6s 225us/step - loss: 0.1782 - acc: 0.9408 - val_loss: 0.2776 - val_acc: 0.8900\n", "Epoch 4/7\n", "25000/25000 [==============================] - 6s 228us/step - loss: 0.1701 - acc: 0.9390 - val_loss: 0.2962 - val_acc: 0.8862\n", "Epoch 5/7\n", "25000/25000 [==============================] - 6s 227us/step - loss: 0.1170 - acc: 0.9613 - val_loss: 0.3159 - val_acc: 0.8832\n", "Epoch 6/7\n", "25000/25000 [==============================] - 6s 223us/step - loss: 0.0848 - acc: 0.9745 - val_loss: 0.3328 - val_acc: 0.8800\n", "Epoch 7/7\n", "25000/25000 [==============================] - 6s 222us/step - loss: 0.0704 - acc: 0.9816 - val_loss: 0.3355 - val_acc: 0.8809\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# rebuild the model to train from scratch \n", "learner.set_model(get_model())\n", "\n", "# training using cycle_len \n", "learner.fit(5e-3, 3, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is what the learning rate schedule looks like when using the **cycle_mult** parameter." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('lr')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Triangular Learning Rate Policy via ```autofit```\n", "\n", "The ```autofit``` method in *ktrain* employs a default cyclical learning rate schedule that tends to work well in practice. The default learning rate schedule in ```autofit``` is currently the [triangular learning rate policy](https://arxiv.org/abs/1506.01186), which some slight modifications.\n", "\n", "The ```autofit``` method accepts two primary arguments. The first (required) is the learning rate (**lr**) to be used, which can be found using the learning rate finder above. The second is optional and indicates the number of epochs (**epochs**) to train. If **epochs** is not supplied as a second argument, then ```autofit``` will train until the validation loss no longer improves after a certain period. This period can be configured using the **early_stopping** argument. At the end of training, the weights producing the lowest validation loss are automatically loaded into the model, when **early_stopping** is enabled. The ```autofit``` method can also automatically reduce the maximum (and base) learning rates in the triangular policy when validation loss no longer improves. This can be configured using the **reduce_on_plateau** and **reduce_factor** arguments to ```autofit```. \n", "\n", "Example:\n", "```\n", "learner.autofit(0.001, 20, reduce_on_plateau=2, reduce_factor=10)\n", "```\n", "\n", "The above will reduce the maximum and base learning rates in the triangular policy by a factor of 10 after two consecutive epochs of no improvement in validation loss. Validation loss (i.e., val_loss) is the default criterion for both **early_stopping** and **reduce_on_plateau**. To use validation accuracy instead, use invoke ```autofit``` with ```monitor='val_acc'```.\n", "\n", "Here, we will use the ```autofit``` method and run the main training phase for two epochs. " ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.005...\n", "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/2\n", "25000/25000 [==============================] - 7s 264us/step - loss: 0.4752 - acc: 0.7769 - val_loss: 0.3279 - val_acc: 0.8764\n", "Epoch 2/2\n", "25000/25000 [==============================] - 6s 255us/step - loss: 0.2541 - acc: 0.9061 - val_loss: 0.2851 - val_acc: 0.8880\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# rebuild the model to train from scratch \n", "learner.set_model(get_model())\n", "\n", "# training using autofit\n", "learner.autofit(0.005, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ```autofit``` method runs a triangular learning rate schedule with two modifications. First, it annihilates the learning rate at the end of each cycle:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('lr')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Second, if using the Adam, Nadam, or Adamax optimizers, it cycles the momentum between 0.85 and 0.95 in such a way that higher learning rates have lower momentum and lower learning rates have higher momentum, as suggested in [this paper](https://arxiv.org/pdf/1803.09820.pdf)." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('momentum')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Additional Cooldowns\n", "Since we are not [overfitting](https://en.wikipedia.org/wiki/Overfitting#Machine_learning) yet (i.e., validation loss is not increasing while training loss decreases), let's do a few more \"cooldowns\" starting at a smaller learning rate to improve the accuracy score further using the regular ```fit``` method that employs SGDR. These \"cooldown\" epochs will start the learning rate at 0.005/10 and gradually decrease it to a very small value. We will use the **checkpoint_folder** argument covered earlier, so that we can restore the weights from any epoch in case we train too much and overfit. If you are not using Linux, you should set this to your folder path of choice." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/1\n", "25000/25000 [==============================] - 6s 220us/step - loss: 0.1927 - acc: 0.9304 - val_loss: 0.2804 - val_acc: 0.8901\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.005, 1, cycle_len=1, checkpoint_folder='/tmp')" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/1\n", "25000/25000 [==============================] - 6s 222us/step - loss: 0.1535 - acc: 0.9518 - val_loss: 0.2797 - val_acc: 0.8904\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.005/10, 1, cycle_len=1, checkpoint_folder='/tmp')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we are running multiple short cooldown phases here - cycles of only one epoch. This essentially amounts to SGDR. Although we are not doing it here, we can also run one longer cooldown by simply calling ```fit``` with a larger value for ```cycle_len``` and leaving the number of cycles at 1. \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The 1cycle Policy\n", "\n", "The [1cycle policy](https://arxiv.org/pdf/1803.09820.pdf) was proposed by Leslie Smith (as was the triangular learning rate policy). The 1cycle policy runs a single triangular cycle over the course of training and then annihilates the learning rate to a near-zero value towards the end." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.005...\n", "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/3\n", "25000/25000 [==============================] - 7s 266us/step - loss: 0.5566 - acc: 0.7379 - val_loss: 0.3626 - val_acc: 0.8598\n", "Epoch 2/3\n", "25000/25000 [==============================] - 6s 252us/step - loss: 0.2684 - acc: 0.8990 - val_loss: 0.2864 - val_acc: 0.8870\n", "Epoch 3/3\n", "25000/25000 [==============================] - 6s 251us/step - loss: 0.1737 - acc: 0.9410 - val_loss: 0.2774 - val_acc: 0.8902\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# rebuild the model to train from scratch \n", "learner.set_model(get_model())\n", "\n", "# training using the 1cycle policy\n", "learner.fit_onecycle(0.005, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The 1cycle policy runs a single triangular cycle over all the epochs and also cycles the momentum in the shape of a *V*." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('lr')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.plot('momentum')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The final accuracy here is **~89%** using only unigram feaures and a simple model. In the [text classification notebook](https://github.com/amaiya/ktrain/blob/master/tutorials/tutorial-04-text-classification.ipynb), we show that an accuracy of **~92.3%** can be acheived on this dataset in mere seconds using built-in convenience methods in *ktrain*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }