{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## List of callbacks" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.text import *\n", "from fastai.callbacks import * \n", "from fastai.basic_train import * \n", "from fastai.train import * \n", "from fastai import callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai's training loop is highly extensible, with a rich *callback* system. See the [`callback`](/callback.html#callback) docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.\n", "\n", "Every callback that is passed to [`Learner`](/basic_train.html#Learner) with the `callback_fns` parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance [`ActivationStats`](/callbacks.hooks.html#ActivationStats) will appear as `learn.activation_stats` (assuming your object is named `learn`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Callback`](/callback.html#Callback)\n", "\n", "This sub-package contains more sophisticated callbacks that each are in their own module. They are (click the link for more details):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`LRFinder`](/callbacks.lr_finder.html#LRFinder)\n", "\n", "Use Leslie Smith's [learning rate finder](https://www.jeremyjordan.me/nn-learning-rate/) to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fastai librairy already has a Learner method called [`lr_find`](/train.html#lr_find) that uses [`LRFinder`](/callbacks.lr_finder.html#LRFinder) to plot the loss as a function of the learning rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8nGW99/HPb7bszdKka9KmK7SF0iVUSqWUrVRUFhEsuICoeFTk0XP0eenR53gEj8cjetRzxAU5IC6IiKAgetgsUKDQpkApbelC17TQJk2bNs06M9fzx9wt0zTJpO1MJpn5vl+veXXmmuue+5fpZL657uW6zTmHiIhIb3zpLkBERAY+hYWIiCSksBARkYQUFiIikpDCQkREElJYiIhIQgoLERFJSGEhIiIJKSxERCShQLoLSJby8nJXXV2d7jJERAaVlStXNjjnKhL1S2lYmNki4EeAH7jTOfedLs//ADjPe5gPDHPOlXjPXQd83XvuW865e3pbV3V1NbW1tcksX0Qk45nZtr70S1lYmJkfuB24CKgDVpjZw865tYf7OOe+GNf/88BM734Z8A2gBnDASm/ZfamqV0REepbKfRZzgE3Ouc3OuQ7gPuCyXvpfA/zOu38x8IRzrtELiCeARSmsVUREepHKsBgN7Ih7XOe1HcPMxgLjgL8f77IiIpJ6qQwL66atp/nQFwMPOOcix7Osmd1oZrVmVltfX3+CZYqISCKpDIs6oCrucSWwq4e+i3lnE1Sfl3XO3eGcq3HO1VRUJNyZLyIiJyiVYbECmGRm48wsRCwQHu7aycxOAUqBZXHNjwELzazUzEqBhV6biIikQcqOhnLOhc3sJmJf8n7gLufcGjO7Bah1zh0OjmuA+1zcJfucc41mdiuxwAG4xTnXmKpaRUSkd5Ypl1WtqalxmXSeRVtnhDfrm9m0p5mtDS1EncPvM3wGoYCPwpwghbkBinIC+HyGcw7nIOocnRFHZyRKZyQKQG7QT17QT07Qh88MA8wMvw+Cft+RW07AR37IT17IT27Aj8/X3a6j4xONOjoiUdo6I7SHYzUF/T4CPiPg9xGJOto6I7R2Rmg7covS2hEh4hw5AR85AT+hgI+i3ABDcoMMyQuQF/RjdvL1iWQ7M1vpnKtJ1C9jzuA+Uc45DrSGGZIXSPjl45xjX0sn9QfbqT/YTkNzO42HOmjpCNPcHuFQe5h9LR3sbe5g76F2mtvClBflMGJILiOKcwn5fTS1dtLU2smBtk5aO6O0e1+iHeEokagj6mK3xkMdRNOc4yG/j5ygj9ygn5DfR8BvsS95n49wNEpHJEp7Z6xuv89iIeA3whFHa2eElo4wbZ3RlNU2bEgOI4tzGT4kl5L8ICF/LFRCXugVhPwU5MSC5XBtQb+P3KCPvGDgSDDG/7fnBv0U5ST+LIhkm6wPi6bWTmbc8gQhv4+hhSEqinIozQ+R631J5gR87G/pZHtjCzsaWzjUEen2dUJ+H/k5fkrzQwwtCDGuvICCnAANzR1s3XuIZZv30hmJUpIXojgvSHFekJK8IDlFOeR6X2Z+H/h9hplRUZjDpOGFTBpWRHV5PiHvr/Cog/ZwhEPtEZrbOznYFibqHOaNGHwW+0IMBWL/Ogdt4QitHbG/2A+PJB0Qib4zAmkPx25tnRFaOmL9Dz9uD8fuR6KOcNQRjkQJ+GJfyodDJPZajkg0is+MvJDf+zIOxN7LgJ/coJ+AzwjHrTfgs9jIJ+QnJxD7Ny/oJ9cbBR0OpLZwLIwPtIY50NbJvpYOdje18faBNl7f2cSBtjAd4ViAdYRPLqBCAR/lBSHKi3IozgseGc0MyQtSXpDD0MIQQwtzKMkLUpQboMh7PifgP6n1igxkWR8Wfp/x9fdOoaG548hoYV9Lx5EvqLbOCEW5QcaW5XPW+KFUleUzfEgOFYU5lBflUJYfoiAnQCiQ+jkZA/7YX7uxTTJBIDfl6xyMolHnhUssYFo6IoSjUTojsaBrC0dp7Yi1t3ZGiN8S29oRoeFQOw0HO2hobqeptZNd+1s50BamqaWTjkjPQVRRlEP10HzGDi2gsjSPoQUhSvJDlBWEGF2SR2VpHgG/5u6UwSnrw6IoN8gnzxmf7jIkiXw+Iz8UID8UoKIoJ2mv65yjuT3M3uZYkOxv6eSgN7rb39JJ3b4Wtu5tYenGenYfaD9m+ZDfR3V5PhOHFTJtVDFnVJZwemUxxXnBpNUokipZHxYifWVmFOUGKcoNUl1e0GvfzkiU/S2dR/Zh7djXwpv1zby55xCv7zzAX1e/faTv9MpiPj6vmvdNH0VQIw8ZoHQ0lEga7G/p4LW6Jlbt2M+fV+1i055mRgzJ5bqzq/nwWWMYkqvRhvSPvh4NpbAQSbNo1PHMhnrufG4zz2/ay5DcAJ86ZzzXz6v29k2JpI7CQmQQen1nEz98cgNPrttDSX6QG+eP57q51RTkaIuxpIbCQmQQW7VjPz98cgNL1tdTVhDiH84dz0fPqiYvpMNzJbkUFiIZ4OXt+/jBExtYurGB8sIcvvqeU7lydmW6y5IM0tew0KEXIgPYrDGl/PoT7+L+T89l7NB8/ukPq7jtsTfIlD/yZPBQWIgMAnPGlXHfjWdxzZwqbl/yJl/4/au0h7ufTUAkFbTXTGSQCPp9fPuK06kqy+e7/7uet5ra+MXHanRSn/QLjSxEBhEz47MLJvKjxTN4Zfs+rr97Oc3t4XSXJVlAYSEyCF02YzQ/vnYWr9U18YlfrqC1hwkuRZJFYSEySF08bQT/efUZLN/ayI2/rtU+DEkphYXIIHbZjNH8xwems3RjAzf/7hWi6b4IimQshYXIIHf1mVV87ZIpPLZmN795aVu6y5EMpbAQyQCfPGcc8ydX8O2/rmNzfXO6y5EMpLAQyQBmxnevnE7I7+Of/rCKcC8XaRI5EQoLkQwxojiXWy8/jVe27+fnz25OdzmSYRQWIhnk0jNG8d7TR/LDJzewZldTusuRDKKwEMkgZsatl5/GkNwg3/nbG+kuRzKIwkIkw5QVhLjh3eNYurGBtbsOpLscyRAKC5EM9JF3jSU/5OfOpdp3IcmhsBDJQMX5QT50ZhUPr9rFW02t6S5HMoDCQiRD3TBvHA64+/mt6S5FMoDCQiRDVZXlc8npI7n3pe0caOtMdzkyyCksRDLYjeeMp7k9zH3Lt6e7FBnkFBYiGez0ymLmjh/KXc9tpSOss7rlxCksRDLcp+aP4+0DbTy25u10lyKDmMJCJMMtmDyMMWX5/OZFzUgrJ05hIZLhfD7j2neN4aUtjWzYfTDd5cggpbAQyQJXza4k5PfxW40u5AQpLESywNDCHN47fSQPvryTQ+3hdJcjg1BKw8LMFpnZejPbZGZf6aHP1Wa21szWmNm9ce0RM3vVuz2cyjpFssFHzhrDwfYwD6/ale5SZBAKpOqFzcwP3A5cBNQBK8zsYefc2rg+k4CvAvOcc/vMbFjcS7Q652akqj6RbDNrTCmnjiji18u2sfjMKsws3SXJIJLKkcUcYJNzbrNzrgO4D7isS59PAbc75/YBOOf2pLAekaxmZnzkrLGsfesAr+zYn+5yZJBJZViMBnbEPa7z2uJNBiab2fNm9qKZLYp7LtfMar32y1NYp0jWuHzmaApzAjqMVo5bKsOiuzGu6/I4AEwCFgDXAHeaWYn33BjnXA1wLfBDM5twzArMbvQCpba+vj55lYtkqMKcAJfOGMXfVr9NS4d2dEvfpTIs6oCquMeVQNc9a3XAn51znc65LcB6YuGBc26X9+9m4GlgZtcVOOfucM7VOOdqKioqkv8TiGSg908fRWtnhL+/oa2+0nepDIsVwCQzG2dmIWAx0PWopj8B5wGYWTmxzVKbzazUzHLi2ucBaxGRkzZnXBkVRTk8+tpb6S5FBpGUhYVzLgzcBDwGrAPud86tMbNbzOxSr9tjwF4zWwssAb7snNsLTAFqzWyV1/6d+KOoROTE+X3GJaeN4O9v7KFZ51xIH5lzXXcjDE41NTWutrY23WWIDAortjZy1c+W8aPFM7hsRtfjTiSbmNlKb/9wr3QGt0gWmj2mlBFDcnlklTZFSd8oLESykM9nXHL6SJ7dUK+r6EmfKCxEstT7zhhJRyTKE2t2p7sUGQQUFiJZamZVCaNL8vjLa5orShJTWIhkKTPjvdNHsnRjA/tbOtJdjgxwCguRLPbe00cSjjoe16YoSUBhIZLFplcWM7okT9fnloQUFiJZzMxYOG04Szc16KJI0iuFhUiWWzh1BB3hKM9u0GSc0jOFhUiWO7O6lJL8II+v1X4L6ZnCQiTLBfw+Ljh1OE+t201nJJrucmSAUliICAunDedAW5jlWxrTXYoMUAoLEWH+pApygz4e11FR0gOFhYiQF/Izf1IFj6/dTabMRC3JpbAQEQAWThvBW01tvL7zQLpLkQFIYSEiAFxw6jB8Bo+v1aYoOZbCQkQAKC0IMWdcmc7mlm4pLETkiIVTR7BhdzPb9h5KdykywCgsROSIC6cMB+DJdXvSXIkMNAoLETlizNB8Jg0r5Kl1OptbjqawEJGjXDBlOMu3NOpyq3IUhYWIHOXCKcMIRx3PrNfEgvIOhYWIHGXmmFLKCkLaFCVHUViIyFH8PmPBKRUsWV9PWBMLikdhISLHuHDKcJpaO1m5bV+6S5EBQmEhIsc4Z1I5Qb/x1Bs6hFZiFBYicoyi3CBnjR/Kk9pvIR6FhYh064JTh7G5/hBbGnQ2tygsRKQHF3hnc+uoKAGFhYj0oKosn1NHFPHQKzt1jQtRWIhIzz4+r5o1uw6wdGNDukuRNFNYiEiPrphZyYghudy+ZFO6S5E0U1iISI9CAR+fmj+el7Y0snJbY7rLkTRSWIhIr66ZU0VpfpCfLHkz3aVIGiksRKRX+aEAN8wbx1Nv7GHtLl2fO1ulNCzMbJGZrTezTWb2lR76XG1ma81sjZndG9d+nZlt9G7XpbJOEendx+ZWU5gT4KfPaHSRrVIWFmbmB24H3gNMBa4xs6ld+kwCvgrMc85NA77gtZcB3wDeBcwBvmFmpamqVUR6V5wf5MNnjeHR13bpJL0slcqRxRxgk3Nus3OuA7gPuKxLn08Btzvn9gE45w5PRHMx8IRzrtF77glgUQprFZEEPvHucQT8Pu54dnO6S5E0SGVYjAZ2xD2u89riTQYmm9nzZvaimS06jmVFpB8NK8rlqtmV/HFlHXsOtKW7HOlnqQwL66at62mgAWASsAC4BrjTzEr6uCxmdqOZ1ZpZbX29ruolkmqfnj+BcDTK/zy3Jd2lSD9LZVjUAVVxjyuBXd30+bNzrtM5twVYTyw8+rIszrk7nHM1zrmaioqKpBYvIscaMzSf900fxW9e3EZTi67RnU1SGRYrgElmNs7MQsBi4OEuff4EnAdgZuXENkttBh4DFppZqbdje6HXJiJp9g/nTuBQR4TfvLQt3aVIP0pZWDjnwsBNxL7k1wH3O+fWmNktZnap1+0xYK+ZrQWWAF92zu11zjUCtxILnBXALV6biKTZ1FFDWHBKBXc9t4W2zki6y5F+Ypkym2RNTY2rra1NdxkiWWH5lkau/vkybr1sGh+dW53ucuQkmNlK51xNon46g1tEjtuZ1aXMHlvKHUs3a/ryLKGwEJHjZmYsPrOKHY2trNEUIFmhT2FhZhPMLMe7v8DMbvYOcRWRLHXeqcMwQ9fpzhJ9HVn8EYiY2UTgf4BxwL29LyIimay8MIdZY0p5at2exJ1l0OtrWES9o5uuAH7onPsiMDJ1ZYnIYHDBlGGs3tnE2006ozvT9TUsOs3sGuA64C9eWzA1JYnIYHHhlOEAPPWGNkVlur6GxceBucC/Oee2mNk44DepK0tEBoNJwwqpKsvTpqgsEOhLJ+fcWuBmAO+M6iLn3HdSWZiIDHxmxgWnDufe5dtp6QiTH+rTV4oMQn09GuppMxviXWdiFXC3mf1naksTkcHgoqnD6QhHeW5jQ7pLkRTq62aoYufcAeADwN3OudnAhakrS0QGizOryyjKCWhTVIbra1gEzGwkcDXv7OAWESEU8DH/lAqeemMP0ajO5s5UfQ2LW4hN+vemc26FmY0HNqauLBEZTC6cMoyG5nZW1e1PdymSIn0KC+fcH5xz051zn/Eeb3bOXZna0kRksFgweRi+Pp7N3doR4fzvPc2jr73VD5VJsvR1B3elmT1kZnvMbLeZ/dHMKlNdnIgMDqUFIc6sLuPJtYn3W6zY2sjmhkPc/byutjeY9HUz1N3ELlw0iti1sB/x2kREgNhRUet3H2Tb3kO99lu2eS8Atdv2JewrA0dfw6LCOXe3cy7s3X4J6DqmInLEwqkjAHhibe+bopa9uZfqofmYwUOv7OyP0iQJ+hoWDWb2ETPze7ePAHtTWZiIDC5jhuZz6ogiHl/Tc1g0t4dZvbOJ900fxdzxQ3nolZ26HsYg0dewuIHYYbNvA28BHyQ2BYiIyBELpw6ndlsje5vbu31+xZZGIlHH3AlD+cCsSrbtbeHl7fv6uUo5EX09Gmq7c+5S51yFc26Yc+5yYifoiYgcsXDaCKIOnnqj+x3dL7zZQMjvY/bYUhadNoLcoI8HX9amqMHgZK6U949Jq0JEMsK0UUMYVZzb436LZZv3MmNMCblBP4U5AS6eNoK/vPYW7eFIP1cqx+tkwsKSVoWIZAQz46Kpw1m6sZ7WjqMDoKmlkzW7DnD2hKFH2j4wq5Km1k6W9DASkYHjZMJCe6VE5BgXTR1BW2eUpRvrj2p/actenIO5498Ji3kThlJRlNPjpijnHH9b/RZtnRp5pFuvYWFmB83sQDe3g8TOuRAROcq7xpdRlBvg8S6bopZt3ktOwMeMMSVH2gJ+H5fPGMWS9Xtoau085rVeeHMvn/nty/z33zW7ULr1GhbOuSLn3JBubkXOOU1cLyLHCPp9nH/qMJ5at5twJHqkfdmbe6mpLiUn4D+q/6LTRtAZcd1Ocf70+tjmqTuXbmHX/tbUFi69OpnNUCIi3brk9JHsa+nkk7+qZc/BNvY2t/PG2weP2gR12IyqUkrygyxZf+x+i2c21HPqiCIAbntsfcrrlp4pLEQk6RZOHc43L53Gsjf3cvEPnuX7T2wAYO6E8mP6+n3G/EkVPL2+/qgpznftb2XD7maunFXJJ949jode2cnquqZ++xnkaAoLEUk6M+O6s6t59OZzqCzN596XtpMf8jO9srjb/gtOqaChuZ01uw4caXt2Q2wH+fzJFXxmwQSGFoT41qNrdcZ3migsRCRlJg4r5MHPns2XLz6FLy08haC/+6+c+ZMrMOOoTVHPbKhnxJBcJg8vpCg3yBcumsxLWxoTzj0lqaGwEJGUCvp9fO68idzw7nE99ikvzGF6ZcmRsAhHojy3qYFzJ1dgFjul65ozq5hQUcAtf1mrnd1poLAQkQHhvFMqeHXHfhoPdfDqjv0cbAtz7invTG4d8Pu47aozaGrp5MqfvsDG3QePWn51XRP//dRGnZORIgoLERkQzjtlGM7B0o31PLOhHr/PmDfx6B3is8aU8vtPzyUcdXzwZ8tY6V0T4/O/e4X3//g5vv/EBv6wsi5NP0FmU1iIyIBw+uhihhaEWPLGHp7ZUM/MqhKK84LH9Js6aggPfuZsSvODXPuLF7nwP5/hybW7+fz5E5leWcxdz2056qgqSQ6FhYgMCD6fce4pFTy1bg+rdzZx7uSer69WVZbPA585m7PGD+Xqmiqe+fIC/mnhKdw4fzxbGg716Vrgcnx0FraIDBgLThl2ZJ6o+b2EBcR2it9zw5yj2hZNG8HokjzuXLqFhdNGpKzObKSRhYgMGPMnleMzKCsIcfro7s/J6E3A7+OGd49j+dZGVu3Yn4IKs1dKw8LMFpnZejPbZGZf6eb5682s3sxe9W6fjHsuEtf+cCrrFJGBoSQ/xPvPGMVVNZX4fCd2FYQPnVlFUW6AXyzdnOTqslvKNkOZmR+4HbgIqANWmNnDzrm1Xbr+3jl3Uzcv0eqcm5Gq+kRkYPrR4pkntXxhToBr54zhzue2sKOxhaqy/CRVlt1SObKYA2xyzm12znUA9wGXpXB9IiIAXD+vGgPufn5rukvJGKkMi9HAjrjHdV5bV1ea2Wtm9oCZVcW155pZrZm9aGaXd7cCM7vR61NbX1/fXRcRyUIji/N4/xmjuG/FdvYd6kh3ORkhlWHR3QbHrgc/PwJUO+emA08C98Q9N8Y5VwNcC/zQzCYc82LO3eGcq3HO1VRU9H7khIhkl88smEBLR4S7n9+S7lIyQirDog6IHylUArviOzjn9jrn2r2HvwBmxz23y/t3M/A0cHIbMkUkq0weXsTF04bzyxe2crDt2KvwyfFJZVisACaZ2TgzCwGLgaOOajKzkXEPLwXWee2lZpbj3S8H5gFdd4yLiPTqpvMmcaAtzK9f3JbuUga9lIWFcy4M3AQ8RiwE7nfOrTGzW8zsUq/bzWa2xsxWATcD13vtU4Bar30J8J1ujqISEenV6ZXFzJ9cwf8s3UJrhyYYPBmWKRcSqampcbW1tekuQ0QGmOVbGrn658v4xvun8vF5PU+Tnq3MbKW3f7hXOoNbRDLanHFlzKku445nN9Me1ujiRCksRCTjfe78ibzV1MYfV+5MdymDlsJCRDLe/EnlnFFVwu1LNtERjqa7nEFJYSEiGc/M+MKFk9i5v5UHdHGkE6KwEJGssGByBTM0ujhhCgsRyQpmxhcvmszO/a3cX7sj8QJyFIWFiGSN+ZPKmTmmhJ8s2aQjo46TwkJEsoaZ8cULJ7OrqY37a7Xv4ngoLEQkq5wzqZzZY0u5/e+b2NvcnngBARQWIpJlzIyvv3cK+1s7+PCdL9GoKcz7RGEhIlln5phS7vzYmWxpOMS1v3hR17zoA4WFiGSld08q587ratjccIgP3/mSNkkloIkERSSrPbOhnk/9qpZI1DFlZBGzxpQyZ1wZi6aNIODP/L+nNZGgiEgfnDu5goc+ezafWzCB4rwgf1xZx033vsK3Hl2X7tIGlEC6CxARSbdpo4qZNqoYgEjU8a1H13L381s5fXQxV86uTHN1A4NGFiIicfw+42uXTGHu+KH880OrWV3XlO6SBgSFhYhIFwG/jx9fO5Pywhw+/evao3Z+Z8p+3uOlzVAiIt0YWpjDzz86myt/+gKLfrSUkN/HgbZOWjoifORdY/jXS6dhZukus98oLEREenDa6GJ++pFZ/G75DopyAgzJC7L3UAf3LNtGbtDPV95zatoD4+fPvElLR4QvXjQ5petRWIiI9OL8U4dz/qnDjzx2zlGSF+Tnz25mSF6Qz503MW21dUai/GLpFmaNKUn5uhQWIiLHwcz45qXTONDWyW2Prac4L8hHzhqbllqeXl9PQ3M7V9dUpXxdCgsRkePk8xnfu+oMDraF+X9/fp3x5QWcPbG83+u4v3YH5YU5LDilIuXr0tFQIiInIOgdMVU9tIAvP/AaB9s6+3X99QfbWfLGHq6cNbpfzjRXWIiInKD8UIDvXXUGbzW1cutf1vbruv/0yk7CUcdVNf1z0qDCQkTkJMweW8qnz53A/bV1PLVud7+s0znH/bU7mDmmhInDivplnQoLEZGT9IULJ3HqiCK+8uDqfpnufFVdExv3NPfLju3DFBYiIicpJ+Dn+1efwf6Wjn7ZHHV/7Q5ygz7eN31kytd1mMJCRCQJpo0q5uPzxvGnV3eyc39rytbT2hHhkVd3cclpIynKDaZsPV0pLEREkuRjc2PnW/zmxW0pW8ezG+s52B7mg/08G67CQkQkSSpL87lwynDuW76dts5IStaxpeEQAKdXFqfk9XuisBARSaLrz65mX0snj6zalZLXr9vXQkl+sF83QYHCQkQkqeZOGMrk4YXcs2xrSqYz39HYSmVpXtJfNxGFhYhIEpkZH5tbzes7D/Dy9n1Jf/26fS1UleYn/XUTUViIiCTZFTNHU5Qb4JcvJHdHt3OOun0aWYiIZISCnABX11Txt9VvsftAW9Jet6G5g/ZwlMpMG1mY2SIzW29mm8zsK908f72Z1ZvZq97tk3HPXWdmG73bdamsU0Qk2T42dyzhqOO+5TuS9pp1+1oAMmtkYWZ+4HbgPcBU4Bozm9pN198752Z4tzu9ZcuAbwDvAuYA3zCz0lTVKiKSbGOHFvDuieXcX7uDSDQ5O7p37Iud7FdVllkjiznAJufcZudcB3AfcFkfl70YeMI51+ic2wc8ASxKUZ0iIimxeE4VO/e38tymhqS83uGRxeiSDBpZAKOB+PFXndfW1ZVm9pqZPWBmh2fF6tOyZnajmdWaWW19fX2y6hYRSYqLpg6nrCDEfcu3J+X16va1UlYQoiCn/69bl8qw6O4q5l3HYo8A1c656cCTwD3HsSzOuTucczXOuZqKitRfKUpE5HjkBPx8YOZonli7m/qD7Sf9euk6EgpSGxZ1QPz8uZXAUac0Ouf2OucOv4O/AGb3dVkRkcFg8ZwqwlHHgy/XnfRr1e1ryciwWAFMMrNxZhYCFgMPx3cws/j5dS8F1nn3HwMWmlmpt2N7odcmIjKoTBxWRM3YUn6/YsdJndEdjcbOsUjHCXmQwrBwzoWBm4h9ya8D7nfOrTGzW8zsUq/bzWa2xsxWATcD13vLNgK3EgucFcAtXpuIyKCzeM4YNjccYvmWE/8aa2hupyMcTdvIIqV7SZxzfwX+2qXtX+LufxX4ag/L3gXclcr6RET6w3tPH8k3H17DfSt28K7xQ0/oNQ4fNpuOE/JAZ3CLiKRcXsjPZTNH8dfVb53wju50npAHCgsRkX5xw7xxdEai/PyZN09o+TpvZDFaYSEikrnGVxRyxcxKfv3iNvacwHxRdftaKS8MkR/q/3MsQGEhItJvbr5gIuGo4ydPH//oom5fC6PTtL8CFBYiIv1m7NACPjirknuXb+ftpuMbXaTzhDxQWIiI9Kubzp9INOr4ydOb+rxMNOrYqbAQEckeVWX5XFVTxX3Ld7Bzf2uflqlvbqcjEk3bCXmgsBAR6Xc3nT8Rh+OnfRxdpPuwWVBYiIj0u9EleVwxczQPrKxjf0tHwv47GtN7Qh4oLERE0uLj88bR1hnl9ysSX0lPIwsRkSw1ZeQQzhqPlz05AAAJuklEQVRfxq+WbSMcifbaN3aORQ65QX8/VXcshYWISJpcf/Y4du5v5cl1u3vtV7evlaqy9I0qQGEhIpI2F00dzuiSPO5+fmuv/WLXsUjf/gpQWIiIpI3fZ1x39lhe2tLI2l0Huu2z52AbO/a1Mm6owkJEJGt9qGYMeUE/v3xhS7fP3/vSdiJRx+UzR/dzZUdTWIiIpFFxfpArZo3mT6/uYm/z0dOXd4Sj/Pal7Zw7uYLxFYVpqjBGYSEikmY3zBtHJOq47bH1R7X/7fXY9S+un1ednsLiKCxERNJs4rBCPvnucdy3Ygcvbt57pP2eF7YyrryAcydVpLG6GIWFiMgA8IULJ1NVlsc/P7iats4Iq+uaeHn7fj561lh8Pkt3eQoLEZGBIC/k59tXnM7mhkP8+O+b+OULW8kP+flgTWW6SwMgPZdcEhGRY5wzqYIPzBzNz555E58ZHzqziiG5wXSXBWhkISIyoHz9fVMpyg3QEYly3dlj013OERpZiIgMIGUFIf7rmpms3XWAicOK0l3OEQoLEZEB5pxJFZwzAI6AiqfNUCIikpDCQkREElJYiIhIQgoLERFJSGEhIiIJKSxERCQhhYWIiCSksBARkYTMOZfuGpLCzOqB/UBTN08Xd2nv7fHh+921lQMNx1la13X19fkTqTn+/snU3FtdvT2fqG0g1txduz4fiWXL52Mw1txde2+PJznnihNW4pzLmBtwR1/ae3t8+H4PbbXJqikVNXdX/4nUfKJ1J2obiDXr86HPR6bVfDKfj95umbYZ6pE+tvf2+JFe2pJZU6LnT6Tm+PsnU3Nflu/u+URtA7Hm7tr1+UgsWz4fg7Hm7tr7+vnoUcZshuoPZlbrnKtJdx3HQzX3n8FYt2ruH4Ox5q4ybWSRaneku4AToJr7z2CsWzX3j8FY81E0shARkYQ0shARkYSyMizM7C4z22Nmr5/AsrPNbLWZbTKz/zIzi3vu82a23szWmNl3k1t1auo2s381s51m9qp3u2Sg1xz3/JfMzJlZefIqTtn7fKuZvea9x4+b2ahk1pzCum8zsze82h8ys5JBUPNV3u9g1MyStp/gZGrt4fWuM7ON3u26uPZeP/dpcyKHcw32GzAfmAW8fgLLLgfmAgb8DXiP134e8CSQ4z0eNkjq/lfgS4PpvfaeqwIeA7YB5QO9ZmBIXJ+bgZ8NhvcaWAgEvPv/AfzHIKh5CnAK8DRQk+5avTqqu7SVAZu9f0u9+6W9/VzpvmXlyMI59yzQGN9mZhPM7H/NbKWZLTWzU7suZ2Yjif3SL3Ox/9VfAZd7T38G+I5zrt1bx55BUndKpbDmHwD/F0j6TrdU1OycOxDXtWAQ1f24cy7sdX0RqBwENa9zzq1PZp0nU2sPLgaecM41Ouf2AU8Ai9L5u5pIVoZFD+4APu+cmw18CfhJN31GA3Vxj+u8NoDJwDlm9pKZPWNmZ6a02necbN0AN3mbGe4ys9LUlXrESdVsZpcCO51zq1JdaJyTfp/N7N/MbAfwYeBfUlhrvGR8Pg67gdhfuqmWzJpTrS+1dmc0sCPu8eH6B8rPdQxdgxsws0LgbOAPcZsHc7rr2k3b4b8QA8SGk2cBZwL3m9l476+DlEhS3T8FbvUe3wp8n9iXQkqcbM1mlg98jdjmkX6RpPcZ59zXgK+Z2VeBm4BvJLnUo4tJUt3ea30NCAO/TWaNxxSSxJpTrbdazezjwP/x2iYCfzWzDmCLc+4Keq4/7T9XTxQWMT5gv3NuRnyjmfmBld7Dh4l9scYPwyuBXd79OuBBLxyWm1mU2Hww9QO5bufc7rjlfgH8JYX1wsnXPAEYB6zyfkErgZfNbI5z7u0BWnNX9wKPkuKwIEl1eztf3wdckMo/fjzJfq9TqdtaAZxzdwN3A5jZ08D1zrmtcV3qgAVxjyuJ7duoI/0/V/fSvdMkXTegmrgdVcALwFXefQPO6GG5FcRGD4d3Pl3itf8DcIt3fzKxIaYNgrpHxvX5InDfQK+5S5+tJHkHd4re50lxfT4PPDBIPteLgLVARSrqTeXngyTv4D7RWul5B/cWYlsjSr37ZX393KfjlvYC0vJDw++At4BOYkn+CWJ/rf4vsMr75fiXHpatAV4H3gR+zDsnNoaA33jPvQycP0jq/jWwGniN2F9sIwd6zV36bCX5R0Ol4n3+o9f+GrG5eEYPks/HJmJ/+Lzq3ZJ6FFeKar7Ce612YDfwWDprpZuw8Npv8N7fTcDHj+dzn46bzuAWEZGEdDSUiIgkpLAQEZGEFBYiIpKQwkJERBJSWIiISEIKC8loZtbcz+u708ymJum1IhabpfZ1M3sk0YyvZlZiZp9NxrpFutKhs5LRzKzZOVeYxNcLuHcm1kup+NrN7B5gg3Pu33rpXw38xTl3Wn/UJ9lFIwvJOmZWYWZ/NLMV3m2e1z7HzF4ws1e8f0/x2q83sz+Y2SPA42a2wMyeNrMHLHath98evuaA117j3W/2Jg9cZWYvmtlwr32C93iFmd3Sx9HPMt6ZSLHQzJ4ys5ctdt2Dy7w+3wEmeKOR27y+X/bW85qZfTOJb6NkGYWFZKMfAT9wzp0JXAnc6bW/Acx3zs0kNivst+OWmQtc55w733s8E/gCMBUYD8zrZj0FwIvOuTOAZ4FPxa3/R976E877482LdAGxM+wB2oArnHOziF1H5fteWH0FeNM5N8M592UzWwhMAuYAM4DZZjY/0fpEuqOJBCUbXQhMjZspdIiZFQHFwD1mNonYTJ/BuGWecM7FX8tguXOuDsDMXiU2Z9BzXdbTwTsTM64ELvLuz+WdaxTcC3yvhzrz4l57JbFrHkBszqBve1/8UWIjjuHdLL/Qu73iPS4kFh7P9rA+kR4pLCQb+YC5zrnW+EYz+29giXPuCm/7/9NxTx/q8hrtcfcjdP+71One2SnYU5/etDrnZphZMbHQ+RzwX8Suh1EBzHbOdZrZViC3m+UN+Hfn3M+Pc70ix9BmKMlGjxO7ngQAZnZ4iuliYKd3//oUrv9FYpu/ABYn6uycayJ2KdYvmVmQWJ17vKA4DxjrdT0IFMUt+hhwg3fdBcxstJkNS9LPIFlGYSGZLt/M6uJu/0jsi7fG2+m7ltj08gDfBf7dzJ4H/Cms6QvAP5rZcmAk0JRoAefcK8RmNl1M7AJENWZWS2yU8YbXZy/wvHeo7W3OuceJbeZaZmargQc4OkxE+kyHzor0M+9qf63OOWdmi4FrnHOXJVpOJJ20z0Kk/80GfuwdwbSfFF7GViRZNLIQEZGEtM9CREQSUliIiEhCCgsREUlIYSEiIgkpLEREJCGFhYiIJPT/ASHFM/jclx2tAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, a learning rate around 2e-2 seems like the right fit." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "Train with Leslie Smith's [1cycle annealing](https://sgugger.github.io/the-1cycle-policy.html) method. Let's train our simple learner using the one cycle policy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1094390.0593490.98086400:02
10.0395820.0231520.99214900:02
20.0190090.0212390.99165900:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The learning rate and the momentum were changed during the epochs as follows (more info on the [dedicated documentation page](https://docs.fast.ai/callbacks.one_cycle.html))." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)\n", "\n", "Data augmentation using the method from [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412). It is very simple to add mixup in fastai :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`CSVLogger`](/callbacks.csv_logger.html#CSVLogger)\n", "\n", "Log the results of training in a csv file. Simply pass the CSVLogger callback to the Learner." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_ratetime
00.1272590.0980690.9695780.03042200:02
10.0846010.0680240.9749750.02502500:02
20.0550740.0472660.9833170.01668300:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can then read the csv." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_rate
000.1272590.0980690.9695780.030422
110.0846010.0680240.9749750.025025
220.0550740.0472660.9833170.016683
\n", "
" ], "text/plain": [ " epoch train_loss valid_loss accuracy error_rate\n", "0 0 0.127259 0.098069 0.969578 0.030422\n", "1 1 0.084601 0.068024 0.974975 0.025025\n", "2 2 0.055074 0.047266 0.983317 0.016683" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.csv_logger.read_logged_file()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GeneralScheduler`](/callbacks.general_sched.html#GeneralScheduler)\n", "\n", "Create your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_odd_shedule(learn, lr):\n", " n = len(learn.data.train_dl)\n", " phases = [TrainingPhase(n).schedule_hp('lr', lr, anneal=annealing_cos), \n", " TrainingPhase(n*2).schedule_hp('lr', lr, anneal=annealing_poly(2))]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " total_epochs = 3\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1766070.1572290.94602500:02
10.1409030.1336900.95436700:02
20.1309100.1311560.95682000:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "fit_odd_shedule(learn, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision)\n", "\n", "Use fp16 to [take advantage of tensor cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) on recent NVIDIA GPUs for a 200% or more speedup." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`HookCallback`](/callbacks.hooks.html#HookCallback)\n", "\n", "Convenient wrapper for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Also contains pre-defined hook callback: [`ActivationStats`](/callbacks.hooks.html#ActivationStats)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`RNNTrainer`](/callbacks.rnn.html#RNNTrainer)\n", "\n", "Callback taking care of all the tweaks to train an RNN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`TerminateOnNaNCallback`](/callbacks.tracker.html#TerminateOnNaNCallback)\n", "\n", "Stop training if the loss reaches NaN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`EarlyStoppingCallback`](/callbacks.tracker.html#EarlyStoppingCallback)\n", "\n", "Stop training if a given metric/validation loss doesn't improve." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`SaveModelCallback`](/callbacks.tracker.html#SaveModelCallback)\n", "\n", "Save the model at every epoch, or the best model for a given metric/validation loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.6791890.6465990.80422000:02
10.5274750.4972900.90824300:02
20.4647560.4624710.91707600:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy')])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "best.pth\t bestmodel_2.pth model_1.pth model_4.pth stage-1.pth\r\n", "bestmodel_0.pth bestmodel_3.pth model_2.pth model_5.pth tmp.pth\r\n", "bestmodel_1.pth model_0.pth\t model_3.pth one_epoch.pth trained_model.pth\r\n" ] } ], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ReduceLROnPlateauCallback`](/callbacks.tracker.html#ReduceLROnPlateauCallback)\n", "\n", "Reduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`PeakMemMetric`](/callbacks.mem.html#PeakMemMetric)\n", "\n", "GPU and general RAM profiling callback" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`StopAfterNBatches`](/callbacks.misc.html#StopAfterNBatches)\n", "\n", "Stop training after n batches of the first epoch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`LearnerTensorboardWriter`](/callbacks.tensorboard.html#LearnerTensorboardWriter)\n", "\n", "Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, embedding projector and gradient stats. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`train`](/train.html#train) and [`basic_train`](/basic_train.html#basic_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`Recorder`](/basic_train.html#Recorder)\n", "\n", "Track per-batch and per-epoch smoothed losses and metrics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ShowGraph`](/train.html#ShowGraph)\n", "\n", "Dynamically display a learning chart during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`BnFreeze`](/train.html#BnFreeze)\n", "\n", "Freeze batchnorm layer moving average statistics for non-trainable layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GradientClipping`](/train.html#GradientClipping)\n", "\n", "Clips gradient during training." ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Callbacks implemented in the fastai library", "title": "callbacks" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }