{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; \n", "import sys" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import vision\n", "import tensorflow.keras.backend as K" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# load cifar10 and manually standaridize\n", "from tensorflow.keras.datasets import cifar10\n", "from tensorflow.keras.utils import to_categorical\n", "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", "x_train = x_train.astype('float32')\n", "x_train = (x_train - x_train.mean(axis=0)) / (x_train.std(axis=0))\n", "x_test = x_test.astype('float32')\n", "x_test = (x_test - x_test.mean(axis=0)) / (x_test.std(axis=0))\n", "y_train = to_categorical(y_train)\n", "y_test = to_categorical(y_test)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(32, 32, 3)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train[0].shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#input_shape = (3, 32, 32) if K.image_dim_ordering() == 'th' else (32, 32, 3)\n", "input_shape = (32, 32, 3)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# define data augmentation turn featurewise* off, since we've manually standardized above\n", "data_aug = vision.get_data_aug(featurewise_center=False, \n", " featurewise_std_normalization=False,\n", " horizontal_flip=True,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1,\n", " zoom_range=0.0,\n", " rotation_range=10)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# load training and validation data as generators with data augmentation\n", "(train_data, val_data, preproc) = vision.images_from_array(x_train, y_train, \n", " validation_data=(x_test, y_test),\n", " data_aug=data_aug)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pretrained_resnet50: 50-layer Residual Network (pretrained on ImageNet)\n", "resnet50: 50-layer Resididual Network (randomly initialized)\n", "pretrained_mobilenet: MobileNet Neural Network (pretrained on ImageNet)\n", "mobilenet: MobileNet Neural Network (randomly initialized)\n", "pretrained_inception: Inception Version 3 (pretrained on ImageNet)\n", "inception: Inception Version 3 (randomly initialized)\n", "wrn22: 22-layer Wide Residual Network (randomly initialized)\n", "default_cnn: a default Convolutional Neural Network\n" ] } ], "source": [ "# let's examine the available image classifiers\n", "vision.print_image_classifiers()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "Wide Residual Network-22-6 created.\n" ] } ], "source": [ "# load a 22-layer Wide ResNet\n", "model = vision.image_classifier('wrn22', train_data, val_data)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# get a Learner object to be used in training\n", "learner = ktrain.get_learner(model, train_data=train_data, val_data=val_data, \n", " workers=8, use_multiprocessing=True, batch_size=64)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/5\n", "781/781 [==============================] - 54s 69ms/step - loss: 6.9901 - acc: 0.1210\n", "Epoch 2/5\n", "781/781 [==============================] - 48s 62ms/step - loss: 5.7066 - acc: 0.3097\n", "Epoch 3/5\n", "781/781 [==============================] - 49s 63ms/step - loss: 2.2498 - acc: 0.5087\n", "Epoch 4/5\n", "781/781 [==============================] - 50s 64ms/step - loss: 2.2908 - acc: 0.3730\n", "Epoch 5/5\n", "163/781 [=====>........................] - ETA: 41s - loss: 5.8823 - acc: 0.1258\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "# find a good learning rate\n", "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEOCAYAAABhOhcDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xd4HNW9//H3d9W7bFmyLclyBRtwR4AN2KG3EEMSSgoJIYCT3FRubnIhuSW5F5L8QshzQxo4EEISQhJMr6aDATcZVzAG3CVXWb1rtef3x66NMLIt2RrNls/refbR7GyZj8byd2fPnDnHnHOIiEhiCfgdQEREBp6Kv4hIAlLxFxFJQCr+IiIJSMVfRCQBqfiLiCQgFX8RkQSk4i8ikoBU/EVEEpCKv4hIAkr2O0B3Q4YMcaNGjfI7hohIzFi+fHm1c66wr6+LquI/atQoKioq/I4hIhIzzGzLkbxOzT4iIglIxV9EJAGp+IuIJCAVfxGRBKTiLyKSgDwr/mY23sxWdrs1mNl3vNqeiIj0nmddPZ1z64GpAGaWBFQBD3uxrVff3UNHMEQgAGZGwIyAQcAMi/zct84s/JwkM1KSAqQkGclJAZIDRmpy+GdyZH1KZL2ZeRFbRMQ3A9XP/2xgg3PuiPqjHs7cv1TQ1hny4q0BSA5EPgiSjNSkAGnJAdJSkkhLDpCekkR6SoC05PDP9JQkMlKSyM1IIS9yy8/stpyRSl5GCjnpyQQC+lARSXRvbq1lV30bF04aPqDbHaji/xngfq/e/B9zZ2IGzkHIOUIOXORn+L4jFAJHt3UhR2eXo7MrRDAU+mB5/zpHZzBEZ8gRjNzvCIbo7ArRHgzf2jq7aOvsoj0Yoq6lY/+6lo4uGto6D/mBZAa56R/+YOj+YTEoM5WS/AxKB2VSMiiDQZkp+gYiEoceqNjGC+t2x1/xN7NUYA5w00EenwvMBSgrKzuibUwZkX+k8TzV1tlFQ2sn9a2d1LV2Ut8S+dnaSX1LxwfrWzupa+mkqrZ1/7qukPvQe2WmJkU+DMIfCKWDMhhXlM2E4bkU56Xrg0EkRjW2BclOG/jBFgZiixcCbzrndvX0oHNuHjAPoLy83PX0nFgVbhJKoig3vU+vc87R0Bqksq6FytpWqmpbqaxtpbK2haq6Vt7cWkd9a+f+5+ekJ3PcsFzGD8vhhOJcppUNYlxRNklqVhKJes3tQbLT47P4fxYPm3zikZmRl5lCXmYeJxTn9fic+tZO3tvVyLqdjazf2cA7Oxp5eEUVf1kcPq2SnZbM1BH5zBxbwGnjhjCpJE8fBiJRqKk9Do/8zSwLOBf4ipfbSUR5GSmUjxpM+ajB+9c559i8t4UVW2tZsbWOZZtruHXBem5dsJ7c9GTOnFDEnCnFzDqmkNRkXeIhEg0a24KMGJw54Nv1tPg755qBAi+3IR8wM0YPyWL0kCw+Nb0UgOqmdt7YsJeF7+7h2bd38ejK7eRnpnDhxOHMmVLMKaMHq9eRiI+a2oPkxNuRv/hvSHYac6YUM2dKMbcEQyx8bw+PrtzOIyuquH/pVoblpnPx5OFcOq2EiSU9NzGJiHfiuc1fokRqcoCzjxvK2ccNpaUjyPPrdvPYyu3cu2gzd722iZljCvjGWeM4dWyBeg+JDADnXHy2+Uv0ykxN3v+NoK6lg/nLK5n36kY+f9cSpo7I59/OG8/pxwzxO6ZIXGsPhq8xyvKh+Ousn5Cfmcp1s8bw6vfP5OZLJ7KnsZ2r7l7CF+5ewtqqer/jicStpvYgEO6uPdBU/GW/9JQkrpoxkhe++zH+4+PHsaaqnot//Ro/eHgN9S2dh38DEemT5kjx96PZR8VfPiI9JWn/N4FrTx/N35du5exfvsJjq7bjXFxdhyfiq8a2cPFXs49Eldz0FP7z4uN57BunU5yfzrfuX8F3H1hFS0fQ72gicWF/s4+Kv0SjiSV5PPwvp/Hts4/h4RVVXPrb19mwp8nvWCIxryly5O9HV08Vf+mVpIBxw7nHcu81J1Pd1MGcX7/G46u2+x1LJKY1d6jNX2LE7GMLefJbpzNheC7fvH8F//XoWtqDXX7HEolJ+9r8VfwlJgzPy+Dvc2dw3emj+fOiLVxxxyKqm9r9jiUSc/a1+avZR2JGSlKA/7j4eO646kTW72rkM/MWs6uhze9YIjGlqS1IwCAjJWnAt63iL0flgonD+NM1J7OjrpUr7lxEZW2L35FEYkZTe5CstGRfhlNR8ZejNmNMAX+57hRqmju48s7FbNnb7HckkZjg14ieoOIv/WR62SDuv34GLR1BLr9jEe/vbvQ7kkjUa2rzZ0RPUPGXfjSxJI+/z51JyMHn71pCVV2r35FEoppfI3qCir/0s/HDcrjvulNo6ejiS39cqjGBRA5hX5u/H1T8pd+NH5bDnV84kc17m7n+LxW6DkDkIJrag76M6Akq/uKRU8cO4ReXT2Hpphq++89VhEIaEE7kQE1t/jX7aDIX8cwlU0vYUd/Gz55+h7GF2dxw7rF+RxKJKuE2/xRftq0jf/HUV2aP4VPTS7j9xfdY+N4ev+OIRI1QyNHcESQ7beAv8AIVf/GYmXHzpRM5piib7/x9JTvrdRWwCEBLZxfO+TO0A6j4ywDITE3md5+fTmtnF9+8/02CXSG/I4n4bv9wzmr2kXg2riiHn35qEss21/KLZ9/1O46I75raw92gs9TsI/HukqklfO6UMu54ZQMvr9/tdxwRXzW1h7tAq6unJIT/uvh4jinK5gcPrdk/nK1IIlKzjySU9JQkfvbpyexoaOMXC9b7HUfEN/uafTS8gySME0cO4uqZo7h30WaWb6n1O46IL/ycxQtU/MUn/3b+eIrzMvj3B1dr+AdJSM0+zuIFHhd/M8s3s/lm9o6ZrTOzmV5uT2JHdloyN39yIu/vbuJ3L23wO47IgNt3zitee/v8CnjGOTcBmAKs83h7EkPOHF/EpVOL+d3L7/PuLo3/L4mlsT1IalKAtOQ4K/5mlgfMBu4GcM51OOfqvNqexKb/+sQJ5KSn8P35q+nS4G+SQPycyAW8PfIfDewB7jGzFWZ2l5llebg9iUGDs1L5r4uPZ+W2Oh5cXul3HJEBEx7L35+jfvC2+CcD04HfO+emAc3AjQc+yczmmlmFmVXs2aOBvxLRJVOLmVaWzy+fe5fWDp38lcTQ2BYkN92fPv7gbfGvBCqdc0si9+cT/jD4EOfcPOdcuXOuvLCw0MM4Eq3MjJsuPI6dDW388fVNfscRGRANrZ3xWfydczuBbWY2PrLqbOBtr7Ynse3k0YM59/ih3PHyBmqaO/yOI+K5hrZOcjPis80f4JvAfWa2GpgK/MTj7UkM+/cLxtPcEeTXL77ndxQRz8Vzsw/OuZWRJp3JzrlLnXO6nFMOalxRDleeVMZfF29h694Wv+OIeKqhtZPcjDgt/iJ9dcM5x5AcCHDrsxr3R+JXsCtEc0eXbyN6goq/RJmi3HSunzWax1dtZ9U2XRYi8WnfuD5x2+wjciTmfmwsBVmp/HzBO35HEfHE/uKvZh+RD2SnJfPVj43l9ff36uhf4lJDW3g451w1+4h82GdPKSM3PZk7XtGgbxJ/GlojxV9H/iIflp2WzBdmjuSZt3aycU+T33FE+tW+I3+d8BXpwZdOHU1qUoB5r270O4pIv2rQCV+RgyvMSePy8lIeerOK3Q1tfscR6Tdq9hE5jLmzxhIMhbhbY/5IHGloC2IGOT5N4Qgq/hLlygoyuWjScO5bvHV/O6lIrGto7SQ7LZlAwHzLoOIvUe+rHxtLU3uQvy/d6ncUkX7h97g+oOIvMWBiSR4zxgzmntc309kV8juOyFFraOv0tacPqPhLjLh+1hh21Lfx1JodfkcROWp+D+oGKv4SI84cX8SYwiz+sHAjzmmuX4ltDWr2EemdQMC47vQxrK1qYMmmGr/jiByVupYO8nTkL9I7n5pewuCsVO5Rt0+JYc459jZ1MCQn1dccKv4SM9JTkri8vJTn1+3WRV8SsxragnR0hSjMTvM1h4q/xJQry0fQFXI8tKLK7ygiR6S6qR2AISr+Ir03pjCb8pGDeKBim078SkyqbgwX/4JsNfuI9Mnl5aVs2NPMCo31LzGouqkD0JG/SJ99fHIxGSlJPFCxze8oIn2mZh+RI5SdlsxFk4bz+KodtHZ0+R1HpE+qm9oJGAzOUrOPSJ9dXl5KU3uQZ97SFb8SW6qbOhiclUqSj4O6gYq/xKhTRg+mbHAmD1RU+h1FpE+qm9p9b/IBFX+JUWbGZSeW8saGvWyrafE7jkivqfiLHKVPn1iKGcxfrqN/iR3VTe2+d/MEFX+JYSX5GZw+bgjzl1cSCqnPv8SG6sYOHfmLHK3LTiylqq6VRRv3+h1F5LCa24O0dnap+IscrfNPGEZOerL6/EtM+KCPf5w3+5jZZjNbY2YrzazCy21JYkpPSeKSqcU8vXan5viVqLe/+OckxpH/mc65qc658gHYliSgy08cQXswxJOr1edfotv2uvBotMV5GT4nUbOPxIHJpXmMGZLFoys10qdEt8raVgBKBsV/8XfAs2a23Mzm9vQEM5trZhVmVrFnzx6P40g8MjM+MaWYJZtq2Fmvcf4lelXWtjAoM4XsNH8nbwfvi//pzrnpwIXA181s9oFPcM7Nc86VO+fKCwsLPY4j8erSaSU4h47+Japtq22ldFCm3zEAj4u/c64q8nM38DBwspfbk8Q1ekgW08vyefDNSo3zL1GrsraF0iho8gEPi7+ZZZlZzr5l4DxgrVfbE/n0iaW8u6uJt7Y3+B1F5COcc1TVtsZ/8QeGAq+Z2SpgKfCkc+4ZD7cnCe6iicNJDhiPr97udxSRj9jT1E57MBT/zT7OuY3OuSmR2wnOuVu82pYIwKCsVGYdM4QnVu3QcA8Sdar29fTJj/8jf5EBN2dqMVV1razYVut3FJEP2RHpiTY8P93nJGEq/hJXzj1+GGnJAR5bqaYfiS7b68JH/tFwgReo+EucyU5L5uzjinhyzQ6CXSG/44jst6O+jfSUAPmZKX5HAVT8JQ7NmVJMdVMHizfW+B1FZL8d9a0U52Vg5u/0jfuo+EvcOWN8EdlpyTy2Shd8SfTYXtcWNe39oOIvcSg9JYnzThjK02t30h7s8juOCBA+8h8eJe39oOIvcWrOlGIa24K8+m6131FEcM6xtyk6ZvDaR8Vf4tJp44YwKDOFx1ap14/4r6k9SDDkGBQlJ3tBxV/iVEpSgIsmDef5t3fR0hH0O44kuLqW8ERDg7L8n8Frn14VfzP7tpnlWtjdZvammZ3ndTiRozFnSjGtnV08v26331EkwdW2dAAwKDPGij/wZedcA+HB2QYBXwB+5lkqkX5w0qjBDMtN1wVf4rua5n3FP/aaffZ1TL0I+Itz7q1u60SiUiBgXDx5OK+8u5v6Fs3vK/6J2WYfYLmZPUu4+C+IDNWsyycl6s2ZWkxnl2PBWzv9jiIJLJabfa4FbgROcs61ACnANZ6lEuknk0ryGFmQqV4/4qvalk7MIC8j9pp9ZgLrnXN1ZnYV8B9AvXexRPqHmTFnSjFvbKhmd6Pm9xV/1LV0kJeRQlIgelrLe1v8fw+0mNkU4LvABuDPnqUS6UdzphQTcvD0GjX9iD9qmjuiqskHel/8gy48MeolwG+cc78FcryLJdJ/jhmaw4RhOWr6Ed/UtXRGzWie+/S2+Dea2U2Eu3g+aWYBwu3+IjHhE1OKWb6llsraFr+jSALaUd/K0JzoGdQNel/8rwTaCff33wmUArd6lkqkn31icjEAz6xV048MrFDIsa22lbKC6Ji7d59eFf9Iwb8PyDOzi4E255za/CVmlBVkMn5oDs+v2+V3FEkwuxvb6QiGGDE4Bou/mV0BLAUuB64AlpjZZV4GE+lv5xxfxLLNtdRF+lyLDIRtkabGslgs/sAPCffxv9o590XgZOA/vYsl0v/OOW4oXSHHCxrrRwbQ1r3h4j9iUPSM5Q+9L/4B51z3/zF7+/BakagwpTSf4XnpPL12h99RJIFsq23BDEqirPgn9/J5z5jZAuD+yP0rgae8iSTijUDAuHDicP66eAsNbZ3kpqvDmnivsjbc0yctOcnvKB/S2xO+3wPmAZMjt3nOuX/3MpiIFy6aNIyOrhAvvaOmHxkYlbUtUXfUD70/8sc59yDwoIdZRDw3rWwQBVmpvLBuN5dMLfE7jiSAqrpWpo0Y5HeMjzjkkb+ZNZpZQw+3RjNrGKiQIv0lKWCcOaGIl9fvprNLA9OKt7pCjh11bZRG4ZH/IYu/cy7HOZfbwy3HOZc7UCFF+tM5xw2loS3I0k01fkeROLeroY1gyEVls4/nPXbMLMnMVpjZE15vS6Q3zhhfSHZaMo+urPI7isS5qrpWAEryE7D4A98G1g3AdkR6JT0lifNPGMbTa3fS1tnldxyJY/vGkiodFF0XeIHHxd/MSoGPA3d5uR2RvrpkajGNbUFeXq9eP+KdqtrEPfL/P+D7aMpHiTKnji1gSHYqj2pyd/FQVV0rBVmpZKRGVx9/8LD4RwaA2+2cW36Y5801swozq9izZ49XcUQ+JDkpwMWTi3nhnd00tGlyd/FGZW1rVPb0AW+P/E8D5pjZZuDvwFlm9tcDn+Scm+ecK3fOlRcWFnoYR+TDLplaTEcwpGGexTNVda0UR2GTD3hY/J1zNznnSp1zo4DPAC86567yansifTV1RD4jCzJ5XDN8iQecC/fxH56XYMVfJNqZGRdNGs4bG/ZS26xhnqV/VTd10NrZRdngBC7+zrmXnXMXD8S2RPrioonD6Qo5nntbk7xI/9qytxmAkUOyfE7SMx35S0KbWJJL6aAMntIwz9LPttZE5yQu+6j4S0IzMz4+aTivv19NfYt6/Uj/qYk0JQ7JSvM5Sc9U/CXhXTy5mM4ux5NrdPQv/ae+tZOAQU56rwdPHlAq/pLwJpbkMq4om4dXVPodReJIXUsneRkpBALmd5QeqfhLwjMzPjmthGWba/fPtypytGpbOsjPTPU7xkGp+IsAl04LT+zy8AqN9Cn9o761k/zM6J0qVMVfhPDAWzPGDObhFZU45/yOI3GgtqWD/AwVf5GoN2dKCZv3trB+V6PfUSQO1LV0qtlHJBacd8JQkgPGg8t14leOXrj468hfJOoNyU7jrAlFPLxiu+b3laPSEQzR1B4kP0NH/iIx4fLyEVQ3tfPKeg0vLkduT1M7AENzo/MCL1DxF/mQM8YXMiQ7lflq+pGjsLO+DYChuek+Jzk4FX+RblKSAlw6tYQX3tm1//J8kb7a3RAu/kU68heJHZeVl9LZ5XhEff7lCO2KFP9hOvIXiR0ThuUyqSRPTT9yxHY2tJOSZAxSV0+R2HJ5eSlv72jgre31fkeRGLS7oY2inPSoHdcHVPxFejRnSjGpSQEd/csR2dXYFtU9fUDFX6RH+ZmpnHv8UB5duZ2OoPr8S9/srG+L6p4+oOIvclCXlZdS09zBC+s0xaP0ze6GdhV/kVg1+5hChuel89clW/yOIjGkuT1IY3tQxV8kViUFjKtmjOT19/fy9vYGv+NIjKiqawWgOF/FXyRmff6UMnLTk/np0+v8jiIxoqo2XPxLB2X4nOTQVPxFDiE/M5VvnnUMC9+r5o0N1X7HkRiwqboZgJL8TJ+THJqKv8hhfGHmSIZkp3L3wk1+R5EY8Nr71YwqyGRYnpp9RGJaekoSnzu5jBfX72Zz5KhOpCdtnV28saGaM8YX+R3lsFT8RXrhqhkjSQkE+OPrOvqXg1u2uYa2zhAfG1/od5TDUvEX6YWi3HQumVrMPyu2UavRPuUglm2uJWBw0qjBfkc5LBV/kV66btYY2jpD3Kd+/9KNc473dzdy2e/f4PYX3mNiSR7Zacl+xzqs6E8oEiXGD8th9rGF/OmNLVw/ewxpyUl+RxKf1TR38OU/LWPltjryM1P43CllfGHGSL9j9YpnR/5mlm5mS81slZm9ZWY/9mpbIgNl7qwxVDe18+iK7X5HkSjw+5ffZ+W2OsYVZXPPl07iJ5+cxHHDc/2O1SteNvu0A2c556YAU4ELzGyGh9sT8dxp4wqYMCyHPy/e7HcU8ZlzjqfW7OTsCUU8/68fY1rZIL8j9Ylnxd+FNUXupkRuzqvtiQwEM+Nzp5SxtqqB1ZV1fscRH7387h6q6lq5YOIwv6McEU9P+JpZkpmtBHYDzznnlni5PZGBcOm0EjJSkvjbkq1+RxEf/fbF9ykdlMHHJw/3O8oR8bT4O+e6nHNTgVLgZDObeOBzzGyumVWYWcWePXu8jCPSL3LTU5gzpZhHVlaxt6nd7zjig9aOLlZV1nHx5GIyU2Oz38yAdPV0ztUBLwEX9PDYPOdcuXOuvLAw+i+MEAG4fvZo2oMh7nl9s99RxAcrttXS2eU4ZXT09+c/GC97+xSaWX5kOQM4F3jHq+2JDKRxRTlccMIw7n1jM3Utuugr0SzdVIMZTB8ZWyd5u/PyyH848JKZrQaWEW7zf8LD7YkMqK+fOY7G9iAPvVnldxQZYEs31TBhWC55GSl+RzliXvb2We2cm+acm+ycm+ic+x+vtiXih4kleUwvy+fu1zbR2tHldxwZIC0dQSo213L6uAK/oxwVDe8gchS+d/4EqupaNeBbAlm8cS8dXSFmHxvb5yhV/EWOwsyxBZxz3FB+//IG9fxJEE+u3klmahInx/DJXlDxFzlqN144npaOIL9+8X2/o4jHXnl3Dw++WckV5SNifmwnFX+RozSuKIcrTyrjviVb2LJXk73Es7sWbqR0UAY3XjjB7yhHTcVfpB/ccM4xJAcC/HzBer+jiEdaOoIs2VjDhROHkZ4S20f9oOIv0i+KctO5ftZonly9g+fe3uV3HPHAss21dHSFmHVMbJ/o3UfFX6Sf/MuZ4xhVkMn1f67QXL9xaG1VPQBTy/J9TtI/VPxF+kl6ShI/v2wKAJ/9w2JaOoI+J5L+tLaqnpEFmeSmx+6FXd2p+Iv0o5NHD+aOq05kR30b35+/2u840k+cc6zcVsfEkjy/o/QbFX+RfnbBxGGMH5rDE6t3ULG5xu840g82Vjezo76NU8fG9lW93an4i3jgz9eeTHLA+NHjb9EV0hxGse6196oBmDUuPk72goq/iCeG5qZz2xVTWFvVwF8WbfY7jhylhe9VUzY4k7KCTL+j9BsVfxGPzJlSzIwxg/nVC++xp1FDP8Sqzq4Qizfu5fRjhvgdpV+p+It4xMz43vkTaGwLcuWdi9T7J0at2lZHU3uQWeNU/EWkl04cOYjffG46G6ub+elTmssoFi18r5qAwaljVfxFpA8umDiMq2eO5C+Lt/CTp9b5HUd60BVyrKms54cPr2HZ5hqCXaH9j7363h4mleaTlxkf/fv3ic2Zh0VizI0XHser71Uz79WN5GWk8PUzx/kdSSKa24NcdfcSVmytA+C+JVv3PzZzTAErttbxvfPH+xXPMyr+IgMgIzWJZ74zi9N+9iK3LlhPVmoSX5w5ikDA/I6WELbVtDA4K5WstGTe3dXIr154jydX7+DM8YWs2FZHXUsnZ00o4oryUrbXtfHnRZvZvLeFRRv3MrIgk8+cNMLvX6HfmXPR0we5vLzcVVRU+B1DxDPVTe1cc88y1lTV85XZY7jpouP8jhT3/lmxbf/V1iMLMtmyt+Ujz/n5pydzxQEFvrMrxLNv7WL2sUPIieIhHcxsuXOuvK+v05G/yAAakp3GY984jR88vJY7X91IUW46154+2u9YcaW6qZ3FG/cyqiCLpIDxw4fX7H9sy94WrjltFF87YyyF2WkAtAdDPQ7RnJIU4OOThw9Y7oGm4i8ywMyMmy+dyPqdDdz27Ho+MWU4RTnpfseKC41tnVx8+2vsbGjbvy4vI4UXv/sxCiLF/kDxMDb/kVBvHxEfJAWM266YSnswxC80AUy/CIUcNz20hp0NbXzv/PH8eM4JlORncMdVJx608CcyHfmL+GT0kCyumzWaO1/ZyLiibK6fNQYznQA+Uv/56FqeWL2DG845dn9vqqtPHeVvqCimI38RH33//AmcPaGInzz1Dmf84mXNAXyEXn13D/ct2crnTynjW2erG21vqPiL+CgpYMz7YjlzphSzZW8Ln79rCTXNHX7Hiil3LdzIF/+4lJL8DH5w0XH69tRLKv4iPksKGLd/dhq/vGIKlbWtTP/f53jozUq/Y0W9ts4ufvvS+9z8ZPiq6XuuOYmsNLVk95aKv0iU+NT0Ur4UaaP+13+u4jPzFvkbKIqFQo5r7lnGrQvWc/aEIt6/5UKOHZrjd6yYouIvEkV+NOcE3v6f8ynKSWPxxhp++9L7fkeKSvMWbmTRxr3MmVLM7Z+dRnKSSllfaY+JRJnM1GRe/t4ZFOelc+uC9dzy5Nt+R4oq63c2ctuz67lw4jB+9Zmpauo5Qp4VfzMbYWYvmdnbZvaWmX3bq22JxJvM1GQe/+bpAPxh4SZuf+E9nxNFh+ff3sUnfv0aOekp3HzpRJ3cPQpeHvkHge86544HZgBfN7PjPdyeSFwpyE5j3f9cwJDsVH753Ls8vWYH1U3trKms51fPv0dbZ5ffEQdMc3uQu1/bxHV/riAnPZm7ri7XhVtHacAGdjOzR4HfOOeeO9hzNLCbyEftbmjjrNteoan9wzOBDclOZdFNZ5MS5+3d22pauPLORWyvb6MgK5UHv3Yqo4Zk+R0rahzpwG4D8ldjZqOAacCSgdieSDwpyk3n2Rtmc9aEIgBmjBnM+KE5VDd1cPItz7NxT5PPCfufc455r27gFwvWM+vnL1HT0sGNF05gyQ/OVuHvJ54f+ZtZNvAKcItz7qEeHp8LzAUoKys7ccuWLZ7mEYkXty54h9++tIGkgPHsDbMZW5jtd6Q+qWvp4KaH1jCuKJtrThtNTXMH35+/ijcjk6p099drT4m7CdT7y5Ee+Xta/M0sBXgCWOCc++Xhnq9mH5G+eemd3Xz53mXkZ6Rw/9wZTBiW63ekj1i1rY7r/1zBtaePZu7s8PhFy7fU8OnfH/w6hpQkY2RBFj/55CSmjsgnNTm+m7aORtQVfwufhr8XqHHOfac3r1HxF+m7F9/ZxTf/toLmji4mDMvhD18sZ8TgTL+3UE1eAAALeklEQVRjAXDP65v48eMfdFX9+KThpCQZj6zcDsB5xw+lIDuN+5duZcTgDH76ycmcOHIQGamJOczykYjG4n86sBBYA+ybDfkHzrmnDvYaFX+RI7OtpoVv/33F/iaTsYVZ/Ou547lo0jC6Qo7OLjcgBbW5Pcj85ZW8vb2Bt3bUs7aqgaSAMf+rM1nw1i7uWriRYMgxOCuVO646kZNHD/Y8U7yLuuJ/JFT8RY7O2qp6/veJt1myqeYjj80YM5hJJXl0djluOOdY0lICPL9uFzPGFDCkl90mnXM99q3v7Aox5zevs25Hw4fWXzRpGP935bT9zTbN7UF21LcyZki25i/uJyr+IrLfM2t38sq7e1iycS8bqz86THReRgqtnV10BMNfyr9+5lg6uxyNbUHOHF/I+p2NLNtSy6IN1cw+ppCvnzWOkvwMvnj3UtbvaiQ1OUB6coCvfGwsn55eyoyfvrD/vW+8cAJfmDGSVdvqmDm2QBdieUzFX0QOqivkWFtVT2dXiK01LfxtyVbqWzs5dmgOT67Z0eNrslKTaO746IVkackB2oOhj6z/1PQSfnHZFB3RDzAVfxE5Ijvr29iwp4mm9iDfn7+alCTj3i+fzMiCLELO8czanTxQsY1lm2v53vnj+fqZ43DOsbqynlueWsfSTTV848xx/Nv54/3+VRKSir+ISAKK6it8RUQkuqj4i4gkIBV/EZEEpOIvIpKAVPxFRBKQir+ISAJS8RcRSUAq/iIiCSiqLvIys3qg+0zVeUD9Acs9rRsCVB/BJru/V28fP9y6wy37kbmn9Ye639/7eiAzd1+O9r+PnrJ2f1z7+tCP6/9iOPdI51xhX8PjnIuaGzDvYPf3LR9kXUV/bK83jx9u3eGW/ch8uH3r9b4eyMx+7+u+/H0cJGv352pf99O+jpbMR7qv+7Pu7btFW7PP44e4//gh1vXX9nrz+OHWHW7Zj8w9rR/IfT2QmbsvR/vfR/f7sfj30X052vd1tGTuaf1A1z0gypp9jpSZVbgjGNvCT7GYGWIzdyxmhtjMrcwD52hzR9uR/5Ga53eAIxCLmSE2c8diZojN3Mo8cI4qd1wc+YuISN/Ey5G/iIj0gYq/iEgCUvEXEUlAcV/8zWyWmd1hZneZ2Rt+5+kNMwuY2S1m9mszu9rvPL1hZmeY2cLIvj7D7zx9YWZZZlZhZhf7naU3zOy4yH6eb2Zf8ztPb5nZpWb2BzP7h5md53ee3jCzMWZ2t5nN9zvLoUT+hu+N7N/P9+Y1UV38zeyPZrbbzNYesP4CM1tvZu+b2Y2Heg/n3ELn3FeBJ4B7vcwbyXbUmYFLgFKgE6j0Kmu3bP2R2QFNQDoDkBn6LTfAvwP/9Cblh/XT3/S6yN/0FcBpXubtlq8/cj/inLse+CpwpZd5I9n6I/NG59y13ibtWR/zfwqYH9m/c3q1gaO5QszrGzAbmA6s7bYuCdgAjAFSgVXA8cAkwgW++62o2+v+CeTEQmbgRuArkdfOj5HMgcjrhgL3xcrfB3Au8BngS8DFsZA58po5wNPA52JlX3d73W3A9BjL7Pn/w6PMfxMwNfKcv/Xm/ZOJYs65V81s1AGrTwbed85tBDCzvwOXOOd+CvT4td3MyoB651yjh3GB/slsZpVAR+Rul3dpw/prP0fUAmle5DxQP+3rM4Aswv+BWs3sKedcKJozR97nMeAxM3sS+JtXebttrz/2tQE/A552zr3pbeJ+/7secH3JT/jbdimwkl626ER18T+IEmBbt/uVwCmHec21wD2eJTq8vmZ+CPi1mc0CXvUy2CH0KbOZfQo4H8gHfuNttEPqU27n3A8BzOxLQLWXhf8Q+rqvzyD8NT8NeMrTZIfW17/rbwLnAHlmNs45d4eX4Q6ir/u6ALgFmGZmN0U+JPx0sPy3A78xs4/Ty+EfYrH495lz7r/9ztAXzrkWwh9YMcM59xDhD62Y5Jz7k98Zess59zLwss8x+sw5dzvhIhUznHN7CZ+jiGrOuWbgmr68JqpP+B5EFTCi2/3SyLpopswDJxZzx2JmiM3csZi5u37LH4vFfxlwjJmNNrNUwifrHvM50+Eo88CJxdyxmBliM3csZu6u//IP9BnsPp7tvh/YwQddHq+NrL8IeJfwWe8f+p1TmZU7njPHau5YzDyQ+TWwm4hIAorFZh8RETlKKv4iIglIxV9EJAGp+IuIJCAVfxGRBKTiLyKSgFT85YiZWdMAbGNOL4dl7s9tnmFmpx7B66aZ2d2R5S+ZmZ9jHO1nZqMOHBa4h+cUmtkzA5VJ/KfiL74zs6SDPeace8w59zMPtnmoca3OAPpc/IEfEGNj1+zjnNsD7DCzAZkfQPyn4i/9wsy+Z2bLzGy1mf242/pHzGy5mb1lZnO7rW8ys9vMbBUw08w2m9mPzexNM1tjZhMiz9t/BG1mfzKz283sDTPbaGaXRdYHzOx3ZvaOmT1nZk/te+yAjC+b2f+ZWQXwbTP7hJktMbMVZva8mQ2NDKH7VeAGM1tp4ZngCs3swcjvt6ynAmlmOcBk59yqHh4bZWYvRvbNC5EhxjGzsWa2OPL73tzTNykLz9D0pJmtMrO1ZnZlZP1Jkf2wysyWmllOZDsLI/vwzZ6+vZhZkpnd2u3f6ivdHn4E6NUsUBIH/L6EWbfYvQFNkZ/nAfMAI3xA8QQwO/LY4MjPDGAtUBC574Arur3XZuCbkeV/Ae6KLH8J+E1k+U/AA5FtHE94XHOAywgPbRwAhhGeU+CyHvK+DPyu2/1BsP8q9+uA2yLLPwL+rdvz/gacHlkuA9b18N5nAg92u9899+PA1ZHlLwOPRJafAD4bWf7qvv15wPt+GvhDt/t5hCfx2AicFFmXS3iE3kwgPbLuGKAisjyKyIQgwFzgPyLLaUAFMDpyvwRY4/fflW4Dc0uIIZ3Fc+dFbisi97MJF59XgW+Z2Scj60dE1u8lPEnNgwe8z74hoZcTHq++J4+48Jj7b5vZ0Mi604EHIut3mtlLh8j6j27LpcA/zGw44YK66SCvOQc43sz23c81s2znXPcj9eHAnoO8fma33+cvwM+7rb80svw34Bc9vHYNcJuZ/T/gCefcQjObBOxwzi0DcM41QPhbAuEx3acS3r/H9vB+5wGTu30zyiP8b7IJ2A0UH+R3kDij4i/9wYCfOufu/NDK8KQj5wAznXMtZvYy4Tl+AdqccwfOUtYe+dnFwf8227st20GecyjN3ZZ/DfzSOfdYJOuPDvKaADDDOdd2iPdt5YPfrd845941s+mEB/O62cxeAB4+yNNvAHYBUwhn7imvEf6GtaCHx9IJ/x6SANTmL/1hAfBlM8sGMLMSMysifFRZGyn8E4AZHm3/deDTkbb/oYRP2PZGHh+MhX51t/WNQE63+88SnoUKgMiR9YHWAeMOsp03CA+9C+E29YWR5cWEm3Xo9viHmFkx0OKc+ytwK+E5XdcDw83spMhzciInsPMIfyMIAV8gPN/rgRYAXzOzlMhrj418Y4DwN4VD9gqS+KHiL0fNOfcs4WaLRWa2BphPuHg+AySb2TrCc7cu9ijCg4SHvH0b+CvwJlDfi9f9CHjAzJYD1d3WPw58ct8JX+BbQHnkBOnb9DCzk3PuHcLTE+Yc+BjhD45rzGw14aL87cj67wD/Glk/7iCZJwFLzWwl8N/Azc65DuBKwlN9rgKeI3zU/jvg6si6CXz4W84+dxHeT29Gun/eyQffss4EnuzhNRKHNKSzxIV9bfAWnnN1KXCac27nAGe4AWh0zt3Vy+dnAq3OOWdmnyF88vcST0MeOs+rhCczr/UrgwwctflLvHjCzPIJn7j934Eu/BG/By7vw/NPJHyC1oA6wj2BfGFmhYTPf6jwJwgd+YuIJCC1+YuIJCAVfxGRBKTiLyKSgFT8RUQSkIq/iEgCUvEXEUlA/x+SfnejIwZ9tAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.001...\n", "Epoch 1/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 3.6463 - acc: 0.4245 - val_loss: 2.2787 - val_acc: 0.5780\n", "Epoch 2/30\n", "\n", "781/781 [==============================] - 52s 67ms/step - loss: 1.9168 - acc: 0.6518 - val_loss: 1.7150 - val_acc: 0.6868\n", "Epoch 3/30\n", "781/781 [==============================] - 54s 69ms/step - loss: 1.4967 - acc: 0.7325 - val_loss: 1.4250 - val_acc: 0.7432\n", "Epoch 4/30\n", "781/781 [==============================] - 54s 70ms/step - loss: 1.2700 - acc: 0.7699 - val_loss: 1.2060 - val_acc: 0.7805\n", "Epoch 5/30\n", "781/781 [==============================] - 54s 69ms/step - loss: 1.1261 - acc: 0.7919 - val_loss: 1.1343 - val_acc: 0.7761\n", "Epoch 6/30\n", "\n", "781/781 [==============================] - 54s 70ms/step - loss: 1.0362 - acc: 0.8054 - val_loss: 0.9961 - val_acc: 0.8114\n", "Epoch 7/30\n", "781/781 [==============================] - 54s 69ms/step - loss: 0.9640 - acc: 0.8139 - val_loss: 1.0051 - val_acc: 0.8036\n", "Epoch 8/30\n", "781/781 [==============================] - 53s 68ms/step - loss: 0.9155 - acc: 0.8199 - val_loss: 0.8791 - val_acc: 0.8296\n", "Epoch 9/30\n", "781/781 [==============================] - 53s 68ms/step - loss: 0.8849 - acc: 0.8272 - val_loss: 0.8763 - val_acc: 0.8323\n", "Epoch 10/30\n", "781/781 [==============================] - 53s 67ms/step - loss: 0.8642 - acc: 0.8294 - val_loss: 0.9615 - val_acc: 0.7958\n", "Epoch 11/30\n", "781/781 [==============================] - 53s 68ms/step - loss: 0.8497 - acc: 0.8287 - val_loss: 0.9356 - val_acc: 0.7986\n", "Epoch 12/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.8397 - acc: 0.8324 - val_loss: 0.9610 - val_acc: 0.8042\n", "Epoch 13/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.8340 - acc: 0.8328 - val_loss: 0.9022 - val_acc: 0.8174\n", "Epoch 14/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.8346 - acc: 0.8330 - val_loss: 0.9342 - val_acc: 0.8092\n", "Epoch 15/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.8286 - acc: 0.8332 - val_loss: 1.0035 - val_acc: 0.7759\n", "Epoch 16/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.8149 - acc: 0.8375 - val_loss: 0.9024 - val_acc: 0.8075\n", "Epoch 17/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.7796 - acc: 0.8446 - val_loss: 0.8858 - val_acc: 0.8180\n", "Epoch 18/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.7447 - acc: 0.8528 - val_loss: 0.7538 - val_acc: 0.8466\n", "Epoch 19/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.7136 - acc: 0.8579 - val_loss: 0.7735 - val_acc: 0.8397\n", "Epoch 20/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.6818 - acc: 0.8664 - val_loss: 0.7233 - val_acc: 0.8514\n", "Epoch 21/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.6419 - acc: 0.8742 - val_loss: 0.6842 - val_acc: 0.8594\n", "Epoch 22/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.6179 - acc: 0.8774 - val_loss: 0.6773 - val_acc: 0.8605\n", "Epoch 23/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.5773 - acc: 0.8888 - val_loss: 0.6550 - val_acc: 0.8654\n", "Epoch 24/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.5462 - acc: 0.8942 - val_loss: 0.5973 - val_acc: 0.8767\n", "Epoch 25/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.5112 - acc: 0.9036 - val_loss: 0.5859 - val_acc: 0.8794\n", "Epoch 26/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.4733 - acc: 0.9140 - val_loss: 0.5714 - val_acc: 0.8830\n", "Epoch 27/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.4416 - acc: 0.9231 - val_loss: 0.5389 - val_acc: 0.8903\n", "Epoch 28/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.4115 - acc: 0.9309 - val_loss: 0.5103 - val_acc: 0.8956\n", "Epoch 29/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.3742 - acc: 0.9414 - val_loss: 0.5055 - val_acc: 0.9008\n", "Epoch 30/30\n", "781/781 [==============================] - 52s 67ms/step - loss: 0.3476 - acc: 0.9487 - val_loss: 0.4728 - val_acc: 0.9133\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit using onecycle policy\n", "learner.fit_onecycle(1e-3, 30)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }