{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect MNIST Dataset as Arrays" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.datasets import mnist\n", "import numpy as np\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "x_train = x_train.astype('float32')\n", "x_test = x_test.astype('float32')\n", "x_train = np.expand_dims(x_train, axis=3)\n", "x_test = np.expand_dims(x_test, axis=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Preprocess Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import vision as vis" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data_aug = vis.get_data_aug( rotation_range=15,\n", " zoom_range=0.1,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1)\n", "classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "(trn, val, preproc) = vis.images_from_array(x_train, y_train, \n", " validation_data=None,\n", " val_pct=0.1,\n", " random_state=42,\n", " data_aug=data_aug,\n", " class_names=classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Load Model and Wrap in `Learner`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "default_cnn model created.\n" ] } ], "source": [ "# Using a LeNet-style classifier\n", "model = vis.image_classifier('default_cnn', trn, val)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Find Learning Rate" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 421 steps\n", "Epoch 1/3\n", "421/421 [==============================] - 15s 35ms/step - loss: 2.9301 - accuracy: 0.1410\n", "Epoch 2/3\n", "421/421 [==============================] - 14s 33ms/step - loss: 0.8085 - accuracy: 0.7409\n", "Epoch 3/3\n", "197/421 [=============>................] - ETA: 7s - loss: 0.4978 - accuracy: 0.8852\n", "\n", "done.\n", "Visually inspect loss plot and select learning rate associated with falling loss\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhU5f3+8fdnJhsJSSAQJAlLAFkEZJGwqGhxrStotYq2dSlKqbZaa/11tVVra/tt7WK11rUutXWXuuBecQOUsO/ITljDlgSyJ8/vjxkxYBISyMmZydyv65qLmXPOzLkzwNw5yzzHnHOIiEjsCvgdQERE/KUiEBGJcSoCEZEYpyIQEYlxKgIRkRinIhARiXFxfgdors6dO7vc3Fy/Y4iIRJU5c+bscM5l1jcv6oogNzeX/Px8v2OIiEQVM1vf0DztGhIRiXEqAhGRGKciEBGJcSoCEZEYpyIQEYlxKgIRkRgXdaePHql1O/axrbic7A7t6J6R7HccERHfxVwRXPLATLaXVADwx68PJT5oLNlcTKeUBC4b3YOAGe0T43DOsXxrCUs3F3NS384EA0ZcIEB6cvwh1zF3w2527a1kUE4aWentANheXE6H5AQS4rQRJiKRJaaKYHtJ+f4SCBj86LkFB8y/6/XlAEwa24vXF21hc1H5AfPbJ8Zxx4RB9MhI5q7Xl1NSXsVFx3Xjo1U7WLq5mPTkeHbvq2R3adX+5wzomkpiXIAFBUUkBANcf8rRXDQih24dk9lbUc3stbuorKklr2dHOrVP9PgdEBH5Mou2K5Tl5eW5w/1m8ZOz1nPr1MW8+v2xdE1P4tIHZrK6cB+jcjOocY4563fvX7Zbx3acOqAL6e3i2bynHOccL87b1OBrpybFUVJeTef2iXztuBxSEuJ44IPVdE1LYs2OfWSnJx1QLKNyM1hVuJdd+yq/+Nl6dqRf11TyenbkgmE5BAK2f151TS1xQW1NiMjhMbM5zrm8eufFShHMWb+Li+6fSY+MZN6/ZRxmoQ/ZvRXVtE8MbRjtq6hmxbYS3l9RyHWn9CExLnjAa9TUOqbO28QTs9bz2wsH4xxsLSrHDE7p34X/Ld/OkG7pdElL2v8c5xyFJRV0SUuiptYxf+Nubnt5KasL9zK6Vwb9u6bRITmegt2lvDx/M8Xl1fufe/GIbpRV1fDawi0AjOmdwQXDcjh3SBapSYfeRSUi8jkVAbCwYA/j7/2YK47vyR0TBnuQ7MiVVdawt6KaN5Zs5XfTlrGvsqbe5QIGY3p3oqisirMHd+Ubo3uS1i6eZVuK6Z6RTGJcgNWFe+mekUyaCkNEUBHst6igiN6ZKaQkRsehkaWbi0lrF0d5VQ05HZLZsbeC+99fzey1u/hs+94Dlo0LGNW1B/5dJgQDHJOdRnllDYnxAaprHO0Sgnx9RDcGZqfRMTmBnA7tDtgFJSJtk4qgjaqoruGDlTuYtmgLn20voWdGCnvKKumRkQLAJ2t3Ulvr2LmvkszURHpkJDNj9U4qq2v3v8YJfTpxw2l9GdO7k18/hoi0gsaKwLNfjc0sCfgASAyv53nn3K8OWiYReAIYAewELnXOrfMqU1uTGBfkjIFHccbAo5r8nKKyKlZuK2F7cQWvLdrMtEVbmbF6J5eP7sE5g7MY3qND1GwxiUjL8GyLwEJHY1Occ3vNLB74CLjROTerzjLXAUOcc1PMbCJwoXPu0sZeV1sELauwpII7Xl3Kqws34xwkxAV48bsnUFVTS3q7eJIT4shMTSSo3UciUc33XUNmlkyoCL7rnPukzvQ3gducczPNLA7YCmS6RkKpCLyxY28Fry/eyq1TF9c7/9sn9uLak3vt/4KciEQXX3YNhVccBOYARwP31S2BsBxgI4BzrtrMioBOwA4vc8mXdW6fyLfG9GRA11Smr9hObqcUyqtqeGvpNqprHI9+vJYnZ63jnGOzGJSdRv+uaWSlJ9Ens722FkSiXGttEXQAXgK+75xbXGf6YuAs51xB+PFqYLRzbsdBz58MTAbo0aPHiPXrG7zimnhk1fYS/vDmCt5csu2A6alJcVxxfE/yemZwfJ9OJMUHG3gFEfGT77uGwiF+CZQ65/5YZ5p2DUWZ0spqpi3aSnVNLcGA8Wz+Rmav++Ib2WlJcYzMzeD2CYPo1lGD+olECl+KwMwygSrn3B4zawe8BfzeOfdqnWWuB46tc7D4a865Sxp7XRVBZHHOMXfDHtYU7uWNxVuZu2E3JeXVVNc6Jo7szvWnHK1RXkUigF9FMAR4HAgSuu7Bs865O8zsDiDfOfdy+BTTJ4HhwC5gonNuTWOvqyKIfIsKijj/3o/2P/7G6B7cecHg/cN6iEjri4hdQy1FRRA9Fm8q4u63VvDeikImjuzOL88fSHKCvqMg4ofGikDDWYpnBuek8+hVI7k0rztPz97IOX/98IDRVkUkMqgIxFNmxu8vHsIT3x7F+l2l/L/nFx4wxIWI+E9FIK3i5H6Z3D5+EO8s28aVj37Kks1FfkcSkTAVgbSaK47P5c4LBjNr7U4mPjCL5/I3UlFd/1DbItJ6VATSqr45pievfG8scUHjlucX8sNnF1BbG10nLIi0NSoCaXWDc9KZ84szuPG0vry2cAu36LiBiK9UBOKLQMD4wel9+dpxObwwt4Az/vw+iwp03EDEDyoC8Y2ZcffXh3Lf5cdRWFLB1x+YwYtzC/yOJRJzVATiKzPj3CFZTL9lHEO7deCHzy7gtpeX+B1LJKaoCCQidElN4slJo7l4RDcem7GOa5/Ip6S8yu9YIjFBRSARIyEuwK8nDKZHRjJvL93GxffPZNX2vX7HEmnzVAQSUdolBHnvR+P416TRbC4qY/y9HzF/4x6/Y4m0aSoCiTjBgDG2b2deuu4EOiYnMOXJOWwtKvc7lkibpSKQiHV0l1Qe+NYIdu6r4IpHP2FvRbXfkUTaJBWBRLTBOencfckwPtu+l1unLj70E0Sk2TQ4vES88UOzWVu4jz+/s5Ku6UncfEY/4oL6HUakpagIJCpcf0ofVmwr5v7pq9mws5R7Lx+uK56JtBD9WiVRIS4Y4L7Lj+Om0/vx2qItPDN7o9+RRNoMFYFEDTPj+6cezQl9OnH7K0tZu2Of35FE2gQVgUSVQMC4+5KhJMQF+MHT86iq0ailIkdKRSBRJyu9Hb+98FgWFBRxz7uf+R1HJOqpCCQqnTski4tHdOO+91Yxe90uv+OIRDUVgUSt28YPolvHZH7w9Hx276v0O45I1FIRSNRqnxjH3y4bzvaScm5/ZQnO6ZKXIodDRSBRbWj3Dlx7Um+mzt/Mfe+t8juOSFTyrAjMrLuZvWdmS81siZndWM8y48ysyMzmh2+/9CqPtF23fLU/5x6bxR/fWsl/Pt3gdxyRqOPlN4urgZudc3PNLBWYY2ZvO+eWHrTch8658zzMIW2cmfHXicMoqajm1qmL6d05hdG9O/kdSyRqeLZF4Jzb4pybG75fAiwDcrxan8S2uGCAv102nB6dkvnOv+awanuJ35FEokarHCMws1xgOPBJPbOPN7MFZva6mQ1qjTzSNqW3i+exq0YRMOOmZxboy2YiTeR5EZhZe+AF4AfOueKDZs8FejrnhgJ/A6Y28BqTzSzfzPILCwu9DSxRrUenZH5zwWAWbSrigfdX+x1HJCp4WgRmFk+oBJ5yzr148HznXLFzbm/4/jQg3sw617Pcg865POdcXmZmppeRpQ04+9gszh+azV/f/YxP1+rLZiKH4uVZQwY8Aixzzv2pgWW6hpfDzEaF8+z0KpPEjjsnDKZ7x2S+82Q+BbtL/Y4jEtG83CI4EfgWcGqd00PPMbMpZjYlvMzFwGIzWwDcA0x0+laQtID05HgeuWokZVU1fOfJOTpeINIIi7bP3by8PJefn+93DIkS/52/iRufns9lo3pw+/hBJMTpO5QSm8xsjnMur755ukKZtGnjh2azqKCIhz9aS1Z6Ejec1tfvSCIRR78eSZtmZvzivIGc1LczT85aT2W1dhGJHExFIDHhmpN6U1hSwd+nazwikYOpCCQmnNy3M+cOyeL+6avZVlzudxyRiKIikJhgZtxyZn+qamr58QsLqamNrpMkRLykIpCYkds5hZ+fO5DpKwp5+MM1fscRiRgqAokpk8b24qS+nXn047VU67sFIoCKQGLQt8b0ZFtxBe8u3+53FJGIoCKQmHPqgC50z2jHXdOWUVpZ7XccEd+pCCTmxAUD/OHioazfVcpvXlvmdxwR36kIJCaN6d2Ja8b24qlPNjBj9Q6/44j4SkUgMevmM/vTs1MyP31xEbv2VfodR8Q3KgKJWUnxQe762rEU7C7jmsdnU1RW5XckEV+oCCSmndCnM3++dBjzNu7hnnc/8zuOiC9UBBLzxg/N5vwh2Twze6N2EUlMUhGIAJNP7k1pZTU/fXGh31FEWp2KQAQYnJPOTaf3480l25i1RldLldiiIhAJu/bk3mSlJ/Gb15bp0pYSU1QEImFJ8UF+fu4xLNpUxC3PLfA7jkirURGI1HHekGxuOK0vU+dvZu6G3X7HEWkVKgKRg3zn5N6kJAR5atYGv6OItAoVgchBUhLjuGhEN16aV8CyLcV+xxHxnIpApB43n9GflMQ4/vT2Sr+jiHhORSBSj/TkeK49qTdvL93GBysL/Y4j4ikVgUgDJp/cmz6ZKfzspUWUVdb4HUfEM54VgZl1N7P3zGypmS0xsxvrWcbM7B4zW2VmC83sOK/yiDRXUnyQ314YGpTutpeX+B1HxDNebhFUAzc75wYCY4DrzWzgQcucDfQN3yYD93uYR6TZRvfuxJSv9OGZ/I189JmuWyBtk2dF4Jzb4pybG75fAiwDcg5abALwhAuZBXQwsyyvMokcjpvO6EtmaiKPfrzW7yginmiVYwRmlgsMBz45aFYOsLHO4wK+XBYivkqMC3JpXnf+t3w7L84t8DuOSIvzvAjMrD3wAvAD59xhnZRtZpPNLN/M8gsLdQaHtL7vnXo0I3p25LfTluuC99LmeFoEZhZPqASecs69WM8im4DudR53C087gHPuQedcnnMuLzMz05uwIo1Iig/ys3OOYee+Cn7/+nK/44i0KC/PGjLgEWCZc+5PDSz2MnBF+OyhMUCRc26LV5lEjsSInh25YkxPnpy1nnkah0jaEC+3CE4EvgWcambzw7dzzGyKmU0JLzMNWAOsAh4CrvMwj8gRu/mr/TkqLYnbXlmKc87vOCItIs6rF3bOfQTYIZZxwPVeZRBpaWlJ8Vw3rg+3/ncJ+et3MzI3w+9IIkdM3ywWaaaLRnSjQ3I8D7y/xu8oIi1CRSDSTMkJcVx5fC7vLNvGiq0lfscROWIqApHDcNUJuaQkBPnLOxqdVKKfikDkMHRMSWDyyX14ffFW5qzXGUQS3VQEIofpmpN6kZmayO2vLKG8SqOTSvRSEYgcppTEOG49byALC4p45CONQyTRS0UgcgTGD83m9GOO4h/vr6aorMrvOCKHRUUgcoRuOqMvJeXV3PzsAn3JTKKSikDkCA3KTufmM/rxzrJt3PPuKr/jiDRbk4rAzG40s7TwmECPmNlcMzvT63Ai0eJ7px7N+KHZ3PfeKlZu03cLJLo0dYvg2+EhpM8EOhIaQ+h3nqUSiTJmxi/OO4bUpDguvn8GW4vK/Y4k0mRNLYLPxww6B3jSObeEQ4wjJBJruqQm8e9rx7C3opp/6mpmEkWaWgRzzOwtQkXwppmlArXexRKJTv27pnL2sVk89ckGtpdoq0CiQ1OLYBLwE2Ckc64UiAeu9iyVSBT74Rn9qKyp5dapi3UWkUSFphbB8cAK59weM/sm8AugyLtYItGrT2Z7fnhGP95cso03l2zzO47IITW1CO4HSs1sKHAzsBp4wrNUIlHumrG96H9UKne+tpSKag0/IZGtqUVQHb6IzATgXufcfUCqd7FEoltcMMCPz+5Pwe4y/jtvs99xRBrV1CIoMbOfEjpt9DUzCxA6TiAiDfhKvy4M7ZbOX95ZSVWNzq2QyNXUIrgUqCD0fYKtQDfgD56lEmkDggHjxtP7srmonFcXaqtAIleTiiD84f8UkG5m5wHlzjkdIxA5hHH9unBMVhq/e305RaUalE4iU1OHmLgE+BT4OnAJ8ImZXexlMJG2IBAwfn/RsRSWVPD36RqHSCJTU3cN/ZzQdwiudM5dAYwCbvUulkjbMaRbB84a3JWnZ2+krFJnEEnkaWoRBJxz2+s83tmM54rEvCuPz6WorIr/zt/kdxSRL2nqh/kbZvammV1lZlcBrwHTvIsl0raM6pXBgK6pPDZjHbW1+raxRJamHiy+BXgQGBK+Peic+7GXwUTaEjPju+P6sHxrCf+ZvcHvOCIHaPLuHefcC865H4ZvLx1qeTN71My2m9niBuaPM7MiM5sfvv2yOcFFos34odmc0KcTv3t9OXtKK/2OI7Jfo0VgZiVmVlzPrcTMig/x2o8BZx1imQ+dc8PCtzuaE1wk2pgZPzl7ACXl1dz//mq/44js12gROOdSnXNp9dxSnXNph3juB8CuFk0rEuWOzUnnzIFH8ehHa9m4q9TvOCKA/2f+HG9mC8zsdTMb5HMWEc+ZGXdMGEzAjL+++5nfcUQAf4tgLtDTOTcU+BswtaEFzWyymeWbWX5hYWGrBRTxQtf0JC4b1YOp8zaxaU+Z33FE/CsC51yxc25v+P40IN7MOjew7IPOuTznXF5mZmar5hTxwrUn9wbgoQ/W+JxExMciMLOuZmbh+6PCWXb6lUekNeV0aMcFw3N4evYGdu/TGUTiL8+KwMz+A8wE+ptZgZlNMrMpZjYlvMjFwGIzWwDcA0x0uq6fxJBrT+pNeVUtT32y3u8oEuPivHph59xlh5h/L3CvV+sXiXT9u6Zycr9MHp+5nmtP7k1iXNDvSBKj/D5rSCSmTT6pN4UlFdyjM4jERyoCER+N7duZc4dk8c+P1+nbxuIbFYGIz75/6tGUVtbwxEwdKxB/qAhEfDagaxqnDejCPz9eS2lltd9xJAapCEQiwHfH9WF3aRXPzN7odxSJQSoCkQiQl5vBiJ4deXLmenQWtbQ2FYFIhLh0ZHfW7NjHp2s1VqO0LhWBSIQ4b0gWGSkJ/H26hqiW1qUiEIkQyQlxXH1CLu+vLOSzbSV+x5EYoiIQiSCXj+5BQlyAx2as8zuKxBAVgUgE6dQ+kQlDs3lx7iaKSqv8jiMxQkUgEmGuPrEXZVU1PK2L3EsrURGIRJiB2WmM6Z3B4zPWUV1T63cciQEqApEINGlsbzYXlfPeCl2RT7ynIhCJQOP6Z5KRksA/P17rdxSJASoCkQgUHwxw7Um9mbF6J+t27PM7jrRxKgKRCHXB8GzM4KV5m/yOIm2cikAkQmWlt+P43p2YOn8TtbUaf0i8oyIQiWCXjuzO+p2lvLV0q99RpA1TEYhEsHOPzaJ35xTufW+VRiUVz6gIRCJYXDDA1SfmsnhTMUu3FPsdR9ooFYFIhDtvSDbxQeO5/AK/o0gbpSIQiXAdUxI4f2g2z8zeqPGHxBMqApEoMGlsaPyhl+Zpq0BanopAJAoMyk5nSLd0np69UQeNpcV5VgRm9qiZbTezxQ3MNzO7x8xWmdlCMzvOqywibcHEkT1YvrWE+Rv3+B1F2hgvtwgeA85qZP7ZQN/wbTJwv4dZRKLe+GHZJCcE+c+nGp5aWpZnReCc+wBo7CrcE4AnXMgsoIOZZXmVRyTatU+M4/wh2byyYAsl5TpoLC3Hz2MEOcDGOo8LwtNEpAETR3WnrKqGlxds9juKtCFRcbDYzCabWb6Z5RcWanx2iV3DundgQNdUHvlwLRXVNX7HkTbCzyLYBHSv87hbeNqXOOcedM7lOefyMjMzWyWcSCQyM275an/W7NjHf+dpq0Bahp9F8DJwRfjsoTFAkXNui495RKLCqQO6MDArjX+8v5ryKm0VyJHz8vTR/wAzgf5mVmBmk8xsiplNCS8yDVgDrAIeAq7zKotIW2Jm3HJWaKvgkY90BTM5cnFevbBz7rJDzHfA9V6tX6QtO6V/F04b0IUH3l/NN0f3JD053u9IEsWi4mCxiHzZzWf2p7i8mgc+WO13FIlyKgKRKDUwO43xQ7P558fr2F5S7ncciWIqApEodtMZ/aisqeW+/63yO4pEMRWBSBTr1TmFS/K68e9PN7BxV6nfcSRKqQhEotwNp/XFzPjT2yv9jiJRSkUgEuWy0tsxaWwvXpq3iYUFGplUmk9FINIGXDeuD51SEvj1q0t1vQJpNhWBSBuQmhTPj77an9nrdjN1fr0jtYg0SEUg0kZckted4T068Kv/LqGoTMNUS9OpCETaiGDA+PWEwRSXV/PYx+v8jiNRREUg0oYMzknnzIFH8chHayjWxWukiVQEIm3MDaf1pbi8mj++uUIHjqVJVAQibczgnHSuPjGXJ2au57n8Ar/jSBRQEYi0Qb84dyAjczvy29eXsXtfpd9xJMKpCETaoGDAuPOCYykpr+Z3ry/3O460gB88PY//enRqsIpApI3q3zWVSWN78Uz+RvLX7fI7jhyB8qoaps7fzIad3ownpSIQacNuPK0v2elJ/GLqYl3sPopt2lMGQLeMdp68vopApA1LSYzjV+MHsXxrCT9+fqHfceQwFewOF0HHZE9eX0Ug0sZ9dVBXvn/q0Uydv5k3l2z1O44choLdoV1C3Tpqi0BEDtMNp/VlYFYaP39psc4iikIFu8uIDxpdUpM8eX0VgUgMiA8G+MPXh1BcVsWkx2dTWlntdyRphsWbiji6SyrBgHny+ioCkRgxKDudey4bxvyNe7juqbnU1Opbx9GgttYxf8MejuvRwbN1qAhEYshZg7O4Y8Jgpq8o5J53P/M7jjTBp+t2UVJRTV5uR8/WEefZK4tIRPrG6B7M3bCbv777GTkd2nHJyO5+R5JGTJ23idTEOM4alOXZOlQEIjHGzPjthcdSWFLBj19cSHycceHwbn7HkgbM37iHvNyOtEsIerYO7RoSiUFJ8UEeuiKPkT0z+NFzC3ln6Ta/I0k9amsda3fs4+gu7T1dj6dFYGZnmdkKM1tlZj+pZ/5VZlZoZvPDt2u8zCMiX0iKD/Lo1SPp26U9k5/M58EPVmvY6gizcnsJFdW19O+a5ul6PCsCMwsC9wFnAwOBy8xsYD2LPuOcGxa+PexVHhH5svaJcTw75XjOGHgUv522nF+/ukxlEEE+WRMaI2p0rwxP1+PlFsEoYJVzbo1zrhJ4Gpjg4fpE5DCkJcXzj2+O4OoTc3n047Xc9vISlUGE+PCzHeR0aOfZN4o/52UR5AAb6zwuCE872EVmttDMnjezek9fMLPJZpZvZvmFhYVeZBWJaWbGL88byLUn9eLxmev5+dTF+p6Bz/aUVvL+yu2cPbgrZt58kexzfh8sfgXIdc4NAd4GHq9vIefcg865POdcXmZmZqsGFIkVZsbPzjmGKV/pw78/2cB1T82hvEojlvrltUVbqKpxXDC8vt+fW5aXRbAJqPsbfrfwtP2cczudcxXhhw8DIzzMIyKHYGb85OwB/PK8gby1dBtXPPqpysAHeyuquf2VpfTJTGFQtrcHisHbIpgN9DWzXmaWAEwEXq67gJnV/YbEeGCZh3lEpIm+PbYXf7pkKJ+u3cWFf5/BvgqNTdSaPl61g8rqWiaO7OH5biHwsAicc9XA94A3CX3AP+ucW2Jmd5jZ+PBiN5jZEjNbANwAXOVVHhFpnguHd+P/LhrCiq3FXHT/DFZtL/E7Usx44P3VpCbFcdnoHq2yPou2swPy8vJcfn6+3zFEYsZ7K7Zz87MLKK2s5vpxR/Odr/QhIc7vw4tt1/QV27nqn7P51fkDufrEXi32umY2xzmXV988/W2KSKNO6d+FN248iRP7dObut1dyyh+nM3P1Tr9jtUnOOX796lI6JsfzjdE9W229KgIROaQuaUk8ctVIHrkyj4S4AJc/PIvfv7Fc10FuYa8s3MLqwn3cdEa/Vt3qUhGISJOddsxRvHbDWCaO7M7901dz2t3v8+naXX7HahPmb9zDXdOW0e+o9kwc2TrHBj6nIhCRZklOiOOurw3hyUmjiA8GuPyhWfzlnZU6s+gI/OHN5Vxw38fs2lfJHRMGt/oxGB0sFpHDVlxexU1Pz+fd5dvp3D6RiSO7c9WJuXRun+h3tIhVXlXDntIqlm8tZtOeMmas3slrC7eQmhjHc989ngEeDTDX2MFiFYGIHJHaWscrCzfz57dXsm5nKWlJcVw2qgfHZKXRqX0C2R3a0SfT22GUI1VNrWP+xj1s3lPGjr0V/G/5dj5atYODP3YnDMvm9xcNISneu2sONFYEujCNiByRQMCYMCyH84ZkM3/jHh76YA0PfrjmgA+7od3S+fbYXhyTlUbvzinEBdv+XulXF27mJy8sYm+dXWYpCUEuGdGdfl1TSYgLEDA4c2BXMlP93YLSFoGItLi9FdVs2VPGxt2lLN9awlOzNrBpTxkAqUlxdEiOp0dGMqcNOIoJw7Lp1IZ2JX22rYTbXlnCx6t2kpGSwM1n9qNnRgrtEgIc16Njq3xTuD7aNSQivqqqqWVhQRGrC/cyZ91uyqpqWLmthOVbSwgGjL5d2jOgayodUxJITYpnZG5HhnXvQGpSfKtn3V5czhtLtmJmdEyOZ9PuMsqqaqh1EDRjZG5Humck0yE5nqT4IMu2FPPhZzvYVlzO/I17WFhQRHJCkAnDcrj5zH4Rc7xERSAiEWn51mKmLdzC3A17mLdhN6VVNft3KSUEA1w2qjunDOjCmN6dSAyfSfP5b9Ql5VWUVtYQHwxQWhna/dIuPsiesiqKy6rYsKuUjJQEOrdPpGNyAmVVNSTGBVixtYTE+ABdUpPYVlxOda1jb3k1e8oqmbZoCx+vatqX5eKDRlXNF5+fyQlBBuekM7xHB64Z29v33T0H0zECEYlIA7qmHXCWTGllNcVl1Xy2vYSX5m3iyVnreXzmegIGDkiOD9KjUwoZKfHMXrebyuraFs2TkhDkuB4duOG0vgzMTmPXvkoS44JkpSdRWVNLdY1j7vrdbNpTxvqdpdQ6R2ZqIhOGZZOV3o5gwJ/dPkdKWwQiErHKKmv4ZO1OZq4J/ZZeWlHDluNPpn8AAAlYSURBVKIyCvdW4pzjq4O6khQfxDlHu4QgZZU1xAWMYMAYmJ1OZXUthXsrKCqrIj5gVNc6cjq0o7Kmlu3F5ft3PXVun0j3jHZ075hMIEo/zA9FWwQiEpXaJQQZ178L4/p38TtKm9b2z+ESEZFGqQhERGKcikBEJMapCEREYpyKQEQkxqkIRERinIpARCTGqQhERGJc1H2z2MwKgfVAOlAUnlzf/c//7AzsOMzV1X3d5sw/eHpjj73IfqjcjS2j7K2fvSnTlL15uZoy/3Cz1zctGrJ3cM5l1rsW51xU3oAHG7tf58/8llhHc+YfPL2xx15kP1RuZY+s7E2ZpuyHn72p/16amr2BaVGRvaFbNO8aeuUQ9+tOa4l1NGf+wdMbe+xF9qY8X9m/fN+v7E2ZpuwNa6n/p/VNO1TeaMter6jbNdRcZpbvGhhoKdIpuz+U3R/K7p9o3iJoqgf9DnAElN0fyu4PZfdJm98iEBGRxsXCFoGIiDRCRSAiEuNUBCIiMS6mi8DMTjKzf5jZw2Y2w+88zWFmATP7jZn9zcyu9DtPc5jZODP7MPzej/M7T3OZWYqZ5ZvZeX5naQ4zOyb8nj9vZt/1O09zmNkFZvaQmT1jZmf6nac5zKy3mT1iZs/7naUhUVsEZvaomW03s8UHTT/LzFaY2Soz+0ljr+Gc+9A5NwV4FXjcy7x1tUR2YALQDagCCrzKerAWyu6AvUAS0Zcd4MfAs96krF8L/XtfFv73fglwopd562qh7FOdc9cCU4BLvcxbVwtlX+Ocm+Rt0iN0ON+Gi4QbcDJwHLC4zrQgsBroDSQAC4CBwLGEPuzr3rrUed6zQGo0ZQd+Anwn/Nznoyx7IPy8o4Cnoiz7GcBE4CrgvGjKHn7OeOB14PJoyx5+3t3AcVGavdX+nzb3FrUXr3fOfWBmuQdNHgWscs6tATCzp4EJzrm7gHo3482sB1DknCvxMO4BWiK7mRUAleGHNd6lPVBLve9hu4FEL3LWp4Xe93FACqH/+GVmNs05V+tlbmi599059zLwspm9Bvzbu8QHrLMl3ncDfge87pyb623iL7Twv/eIFbVF0IAcYGOdxwXA6EM8ZxLwT88SNV1zs78I/M3MTgI+8DJYEzQru5l9Dfgq0AG419toh9Ss7M65nwOY2VXAjtYogUY0930fB3yNUPlO8zTZoTX33/v3gdOBdDM72jn3Dy/DHUJz3/dOwG+A4Wb203BhRJS2VgTN5pz7ld8ZDodzrpRQiUUd59yLhIosajnnHvM7Q3M556YD032OcVicc/cA9/id43A453YSOrYRsaL2YHEDNgHd6zzuFp4WDZTdH8ruD2WPIG2tCGYDfc2sl5klEDqo97LPmZpK2f2h7P5Q9kji99HqIzia/x9gC1+cPjkpPP0cYCWho/o/9zunskfOTdmVPZayN+emQedERGJcW9s1JCIizaQiEBGJcSoCEZEYpyIQEYlxKgIRkRinIhARiXEqAvGcme1thXWMb+IQ0i25znFmdsJhPG+4mT0Svn+Vmfk93hIAZpZ78HDL9SyTaWZvtFYmaR0qAokaZhZsaJ5z7mXn3O88WGdj43GNA5pdBMDPiN5xcwqBLWbWatczEO+pCKRVmdktZjbbzBaa2e11pk81szlmtsTMJteZvtfM7jazBcDxZrbOzG43s7lmtsjMBoSX2/+btZk9Zmb3mNkMM1tjZheHpwfM7O9mttzM3jazaZ/POyjjdDP7i5nlAzea2flm9omZzTOzd8zsqPDQxFOAm8xsvoWudpdpZi+Ef77Z9X1YmlkqMMQ5t6Ceeblm9r/we/NueIh0zKyPmc0K/7x31reFZaGrpr1mZgvMbLGZXRqePjL8Piwws0/NLDW8ng/D7+Hc+rZqzCxoZn+o83f1nTqzpwLfqPcvWKKT319t1q3t34C94T/PBB4EjNAvIa8CJ4fnZYT/bAcsBjqFHzvgkjqvtQ74fvj+dcDD4ftXAfeG7z8GPBdex0BCY8cDXExo+OUA0JXQ9RAurifvdODvdR53hP3fwr8GuDt8/zbgR3WW+zcwNny/B7Csntc+BXihzuO6uV8Brgzf/zYwNXz/VeCy8P0pn7+fB73uRcBDdR6nE7poyhpgZHhaGqERh5OBpPC0vkB++H4u4QuwAJOBX4TvJwL5QK/w4xxgkd//rnRruVvMD0MtrerM8G1e+HF7Qh9EHwA3mNmF4endw9N3ErrozgsHvc7nQ1jPITS+fn2mutC1Apaa2VHhaWOB58LTt5rZe41kfabO/W7AM2aWRejDdW0DzzkdGBi6hgoAaWbW3jlX9zf4LKCwgecfX+fneRL4vzrTLwjf/zfwx3qeuwi428x+D7zqnPvQzI4FtjjnZgM454ohtPUA3Gtmwwi9v/3qeb0zgSF1tpjSCf2drAW2A9kN/AwShVQE0poMuMs598ABE0MXTDkdON45V2pm0wldzxig3Dl38BXYKsJ/1tDwv+GKOvetgWUas6/O/b8Bf3LOvRzOelsDzwkAY5xz5Y28bhlf/Gwtxjm30syOIzQY2p1m9i7wUgOL3wRsA4YSylxfXiO05fVmPfOSCP0c0kboGIG0pjeBb5tZewAzyzGzLoR+29wdLoEBwBiP1v8xcFH4WMFRhA72NkU6X4w3f2Wd6SVAap3HbxG6khYA4d+4D7YMOLqB9cwgNKQxhPbBfxi+P4vQrh/qzD+AmWUDpc65fwF/IHSd3RVAlpmNDC+TGj74nU5oS6EW+Baha/Ae7E3gu2YWH35uv/CWBIS2IBo9u0iii4pAWo1z7i1CuzZmmtki4HlCH6RvAHFmtozQdWlneRThBUJDCS8F/gXMBYqa8LzbgOfMbA6wo870V4ALPz9YDNwA5IUPri6lnqtSOeeWE7rcYurB8wiVyNVmtpDQB/SN4ek/AH4Ynn50A5mPBT41s/nAr4A7nXOVwKWELmm6AHib0G/zfweuDE8bwIFbP597mND7NDd8SukDfLH1dQrwWj3PkSilYaglpny+z95C15H9FDjRObe1lTPcBJQ45x5u4vLJQJlzzpnZREIHjid4GrLxPB8Qulj7br8ySMvSMQKJNa+aWQdCB31/3dolEHY/8PVmLD+C0MFdA/YQOqPIF2aWSeh4iUqgDdEWgYhIjNMxAhGRGKciEBGJcSoCEZEYpyIQEYlxKgIRkRinIhARiXH/HyYX6rylA91GAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True, max_epochs=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train Model\n", "\n", "We only train for three epochs for demonstration purposes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.001...\n", "Train for 422 steps, validate for 188 steps\n", "Epoch 1/3\n", "422/422 [==============================] - 15s 35ms/step - loss: 0.8279 - accuracy: 0.7384 - val_loss: 0.0681 - val_accuracy: 0.9802\n", "Epoch 2/3\n", "422/422 [==============================] - 15s 35ms/step - loss: 0.1559 - accuracy: 0.9532 - val_loss: 0.0654 - val_accuracy: 0.9798\n", "Epoch 3/3\n", "422/422 [==============================] - 14s 34ms/step - loss: 0.0990 - accuracy: 0.9702 - val_loss: 0.0316 - val_accuracy: 0.9905\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(1e-3, 3)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " zero 1.00 0.99 0.99 624\n", " one 1.00 1.00 1.00 654\n", " two 0.99 0.99 0.99 572\n", " three 0.99 0.99 0.99 589\n", " four 0.99 0.99 0.99 580\n", " five 1.00 0.99 0.99 551\n", " six 0.99 0.99 0.99 580\n", " seven 0.99 0.99 0.99 633\n", " eight 0.98 0.98 0.98 585\n", " nine 0.99 0.99 0.99 632\n", "\n", " accuracy 0.99 6000\n", " macro avg 0.99 0.99 0.99 6000\n", "weighted avg 0.99 0.99 0.99 6000\n", "\n" ] }, { "data": { "text/plain": [ "array([[618, 0, 0, 0, 0, 0, 2, 0, 4, 0],\n", " [ 0, 652, 0, 1, 0, 0, 0, 0, 0, 1],\n", " [ 0, 1, 566, 0, 0, 0, 0, 1, 4, 0],\n", " [ 0, 0, 2, 586, 0, 0, 0, 1, 0, 0],\n", " [ 0, 0, 0, 0, 574, 0, 0, 1, 0, 5],\n", " [ 1, 0, 0, 2, 0, 545, 1, 0, 1, 1],\n", " [ 0, 0, 1, 1, 1, 0, 576, 0, 1, 0],\n", " [ 0, 1, 3, 0, 1, 0, 0, 626, 2, 0],\n", " [ 1, 0, 2, 0, 3, 1, 0, 1, 576, 1],\n", " [ 0, 0, 0, 0, 1, 1, 0, 5, 1, 624]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXlUlEQVR4nO3de7wcZX3H8c838RAgF0gMxBAu4RIKATGhEehLqqEqKIWCF2gAJdRLrBfAlpYivYBWWrAi2laEUGiC3CtEaIsWilhBhCamCIRwjYkk5CIEQhIk11//mDlhczj77Dm7e84ueb7v12tfZ8/8Znae2d3vzjMzOzuKCMxs2zeg1Q0ws/7hsJtlwmE3y4TDbpYJh90sEw67WSZaEnZJ8yRNrlKbLGlxPzep1yRdKOnCRD0k7dePTWoJSQsljW11O+rVyvZLGlu+T95Sx7SDJD0haZeeTtOSsEfEQRHx40YeQ9I4Sa9Juq5i2GRJmyWtqbhNraiPkDRL0lpJiySd2uUxd5F0g6RVkl6SdH0jbewvkn7QZZnXS3q0yrjjJc0pl+8lSf8taXxF/UJJG7o83j51tusMSffXu1x1zlOSLpH0Ynm7RJL6sw3NUr5fby6X4wVJ10saBhAR64BrgPN6+nhv5m78t4HZ3Qx/PiKGVNxmdplmPTAKOA34jqSDKuq3AcuAPYFdga/3TdObKyI+WLnMwAPAv1UZ/Xngo8AIYCRwB3BTl3Fu7vIcLuirtksa2OSHnAacCLwDOAQ4HvhMk+fxBuWHTLPz9FVgOLA3sC/F+/bCivoNwFRJg3ryYK3qxi+U9L7y/g6SZpRrmceBd/Zg+inAy8A9vZjnYOAjwF9HxJqIuJ/ijf7xsn40sAfw5xGxKiI2RMT/9Xrhup/3TpKulfTrskfxV51vDEn7SfqfsjfxgqSby+GSdJmkFZJekfSopIN7MK+xwO8C13ZXj4iXI2JhFF+dFLAJaPrmhqQDgSuA3yl7By+Xw2dI+o6kOyWtBY6S9GNJn6qYdqsegaQDJN0taaWkJyWdnJj1VODSiFgcEUuAS4Ez6lyGMyT9VNI/l6/PE5LeW1H/saSLJP0UeBXYp3ytr5a0VNISSV/t/ECTNFDS18vXeQHw+zWasDfw/Yh4JSJWAbOALSuniFgMvAQc0ZPlaYc1+wUUn1r7AsdQvFhbSLpc0uUV/w8DvgL8aZXH21XSckm/LMMyuBy+P7AxIp6qGPcXvP7kHQE8Ccwsu02zJb2n0YUr/ROwE7AP8B7gdOCPytrfAndRfILvXo4LcDTw7rLdOwEnAy8CSDpV0iNV5nU6cF9ELEw1qAzfa+X8/q5L+fgyWPMkfbaHy7iViJgP/DHws7J3sHNF+VTgImAokOzml6/f3RRrsV2BKcDlnZse3TwXB1G8rp0qX+N6HA48S9ELugC4TdKIivrHKXoTQ4FFwAxgI8UH6ESK17Hzg+zTwHHl8EkUPazKZT1P0n9UDPo2cJyk4ZKGU6ysftClffMpejE1tUPYTwYuioiVEfEc8I+VxYj4XER8rmLQ3wJXl59qXT0BTABGA78H/DbwjbI2BHily/irKF4kKIJ2NHAv8DaKNcLtkkbWu2CwpZs6BfhSRKwuQ3gpZY8C2ADsBewWEa+VPY7O4UOBAwBFxPyIWAoQETdExCFVZnk6xRsuqQzfTsAXgMoezC3AgcAuFG/Ov5F0Sg8Xt6duj4ifRsTmiHitxrjHAQsj4l8jYmPZ27oVOKlcjq7PxRCK17XTKmBIA9vtK4Bvlj29mylWCJVr5BkRMS8iNlJsGh0LfDEi1kbECuAyitcfivf6NyPiuYhYCfx95Ywi4uKIOK5i0FxgO4oP+RcpemGXs7XVwM70QDuEfTfguYr/F1UbUdIE4H0UT+AbRMSyiHi8fBP9EjiX4tMQYA0wrMskwyieLIDfULypri5f2JvKdr2rtwvUxUigg62XaxEwprx/LkV3+n/LNeknymX5EfDPFJ/uKyRN79w5U42kIyk+qL7Xk4ZFxFqKrva1knYthz0eEc9HxKaIeAD4Fl3WQE3wXO1RttgLOFzSy503iv0tb6syftfXeRiwJuo/42tJl2kXUbxnO1Uuy14Ur/XSirZeSdEjgV6810u3AE9RfOgPo+hhXNdlnKEUm7Q1tUPYl1JsK3faMzHuZGAs8CtJy4A/Az4iaW6V8YPXl/Ep4C2SxlXU3wHMK+8/Uo7fdfpGvcDra+9OewJLYMsH1KcjYjeKHUmXqzxkFxH/GBG/DYyn6M7/eY15TQVui4g1vWjfAGBHXv/w6apz274e1Z6/rsPXlm3oVBnk54D/iYidK25DIqLa5sU8tu7WVr7G9RjTpVewJ8VOzk6Vy/IcsA4YWdHWYRHRuRnRm/c6FL3UK8tewhqKD+Zju4xzIFtvtlTVDmG/BfhSuV2yO3BmYtzpFNv2E8rbFcB/UmzrI+koSXuVO7f2AC4Gbocta7HbgK9IGizpXcAJwHfLx54FDJc0tdyR8lGKrv1PG1m4iNhULuNFkoZK2otif8N1ZZtPKpcbip0tAWyW9E5Jh0vqoAjDa8DmavORtANFN3FGqj2S3i9pYrmMwyg2c16i2PZD0gnlayFJhwFnUT6HdVgO7C5puxrjPQx8WNKO5QfdJytq/wHsL+njkjrK2zvLHYDduRb4U0ljJO0GnEMPNmsSdgXOKud7EkW47uxuxHIz6y7gUknDJA2QtG/Fvp9bysfavdwGr3XYbDbwKRU7sXeg2DewZf+EpDEUmw4P9mRB2iHsX6bozvyS4on6bmVR0hWSrgCIiFfLNeGyiFhG0WV7LSJ+XY4+keKw09ry76MUb9ZOnwN2oNgOuxH4bETMKx97JfAHFL2FVRQvxAkR8UITlvHMsk0LKHZI3UBxjBSKow8PSVpDcXTg7PJQ1zDgKoogLqLYZvuH8jk5TVLXtdWJFN25e7vOvNw8OK38d2eKZV9F0S3cF/hAxbbzFOAZis2ba4FLuhy+7I0fUaxVl0lKPY+XURwSXQ7MBLZ8vyEiVlPsS5lCsUZdBlwCDCqXretzcSXw7xSv/WMUK4Mr62w/wEPAOIoe2kXARyPixcT4p1NsZz9O8dp9j2IfEhSv539RrInnUqx8tpB0vqTKHXCfoOjJLqboCe7D1juwTwVmlsfca5J/vKI+Kr89FxEXtrYlrSVpITC51t7/dpVqv6QzgE9FxJH93KyaVBxb/wXw7nJHYE29/pqembVeuTY/oDfTOOz1+3GrG9AmvkkP9wa3qTd7+3vM3XizTLTDDjoz6wf92o3fToNiewbXHtHM6vIaa1kf67r9XkRDYZf0AYpvWA0E/iUiLk6Nvz2DOfz18wjMrMkeiurnhtXdjS+/8/1t4IMU3/A6RRXnRZtZe2lkm/0w4JmIWBAR6ynOiT6hOc0ys2ZrJOxj2PpL/Yvp5vvVkqap+GWUORvo0Rd9zKwP9Pne+IiYHhGTImJSBz36QQ0z6wONhH0JW5/Bs3s5zMzaUCNhnw2Mk7R3eVbTFIoTOcysDdV96C0iNkr6AsVZPAOBazrPIDOz9tPQcfaIuJMq5/aaWXvx12XNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwT/XrJZnvz0cSDkvWnzt4+WT/u4Eeq1tZtTr/9Hrx+YrL+tm89kKzb1rxmN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4ePsmdt01KHJ+levvipZP2L7gc1szlbWnfuTZP09L56ZrO903YPNbM6bXkNhl7QQWA1sAjZGxKRmNMrMmq8Za/ajIuKFJjyOmfUhb7ObZaLRsAdwl6SfS5rW3QiSpkmaI2nOBtY1ODszq1ej3fgjI2KJpF2BuyU9ERFb7VWJiOnAdIBhGhENzs/M6tTQmj0ilpR/VwCzgMOa0Sgza766wy5psKShnfeBo4HHmtUwM2uuRrrxo4BZkjof54aI+GFTWmVNM3D48GT9k1femqwfuN36ZP3t3zgnWd/j9mVVa8+ePio57eOf+Hayvmq/9Lpqp2Q1P3WHPSIWAO9oYlvMrA/50JtZJhx2s0w47GaZcNjNMuGwm2XCp7huAwZMGF+1NuWmu5LTHrNj9UNjAL9/9heT9d1uTf+c86ZEbexfL0hO+9IZv0nWrXe8ZjfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuHj7NuAjm+trFo7fVj6t0Dffln6FNVax9Eb8copRyTrwwfMTdYH+FfOesVrdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEz7O/iYw4JADkvVb97u+au39809MTrv75b9I1jcPSF+SecDgHZP1l48/qGqtY+ry9LxJX0Bo7HW/StY3Jqv58ZrdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEj7O3gQE7po9VHzLziWS9Q4lj4Re8NTntxonp+urz1yTrD074XrIO99WoV/e/6zYn6/GaT2jvjZprdknXSFoh6bGKYSMk3S3p6fJv+iLgZtZyPenGzwA+0GXYecA9ETEOuKf838zaWM2wR8RPgK6/e3QCMLO8PxNIfyfTzFqu3m32URGxtLy/DBhVbURJ04BpANuT3jY1s77T8N74iAiofsZCREyPiEkRMamDQY3OzszqVG/Yl0saDVD+XdG8JplZX6g37HcAU8v7U4Hbm9McM+srNbfZJd0ITAZGSloMXABcDNwi6ZPAIuDkvmzktm71B9+erF8y6spkffwDH6ta2/M3G5LTTp+Vfuw93zIkWe9Lhw3qSNbnf2XvZH3/z/66mc1506sZ9og4pUrpvU1ui5n1IX9d1iwTDrtZJhx2s0w47GaZcNjNMuFTXPvBwJHp00jP+vubGnr8sV96tWrt3B9+PzltrUNrd72aPvz1F5d9Olk/8NT5VWsPzf6t5LTHvOvh9GMfsDhZ35Ss5sdrdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEz7O3g80ZHCyfvKQVcn6xNlTkvXt37lz1drkHdI/x3zi08ck6+s/V/2xAXad90Cyvuo/96xa+62Xqx+DB7j3qnHJ+rH7zEvWH0/8RPfmV6t/N2Fb5TW7WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJH2fvB8uOGZOsb4r0sfAdbkwf637xIPW6TZ3W/tVuyfqAef9X92MDbFz4q4amTzl+5/T57k8M73o90tf5OLuZbbMcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJH2fvB+uHpo+Dv7T5N8n60JseTNZXzxrf6zZ16nhxbbLezr+9fvajf5isv21J+nz53NRcs0u6RtIKSY9VDLtQ0hJJD5e3Y/u2mWbWqJ5042cA3X0V6bKImFDe7mxus8ys2WqGPSJ+Aqzsh7aYWR9qZAfdFyQ9Unbzh1cbSdI0SXMkzdnAugZmZ2aNqDfs3wH2BSYAS4FLq40YEdMjYlJETOpgUJ2zM7NG1RX2iFgeEZsiYjNwFXBYc5tlZs1WV9glja7490PAY9XGNbP2UPM4u6QbgcnASEmLgQuAyZImAAEsBD7Th21sewP32ztZ//6ZX0vWz3j2pBpzWJqs6r7E+e6Hpx/5ic+MSNb3Pyf9FomNG9MzaMCX33FHsn7BDaf12by3RTXDHhGndDP46j5oi5n1IX9d1iwTDrtZJhx2s0w47GaZcNjNMuFTXJsghuyQrO/bMSRZn//wXsn6fjUOve0xa0nV2skffm9y2gUnXZGsH7p3+jTSXf7gyWQ95cVP/U6y/qHBP0/Wr7zPX7/uDa/ZzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNM+Dh7Gxiwrv5LLgNs/OWiqrVV7x6YnPaooz+drI9+7Pn0vJNVGDhun6q1L517fXLaI+Z2d8Ll60b+aG6NuVslr9nNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0z4OHsb2P3Q9LHshmxOX3R5ux/OTtZrHUdXx3bJ+pOf37Vq7YDtlienHXHJjumZR6TrthWv2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTPTkks17ANcCoygu0Tw9Ir4laQRwMzCW4rLNJ0fES33X1PY1YPnKZP2uVzuS9dsOuDFZP23cx5L1TU8vSNYboUGDkvVnrhmfrD97VPXfpd/3lj9JTrvf/Q8m69Y7PVmzbwTOiYjxwBHA5yWNB84D7omIccA95f9m1qZqhj0ilkbE3PL+amA+MAY4AZhZjjYTOLGvGmlmjevVNrukscBE4CFgVER0XpdoGUU338zaVI/DLmkIcCvwxYh4pbIWEUGxPd/ddNMkzZE0ZwO+NpdZq/Qo7JI6KIJ+fUTcVg5eLml0WR8NrOhu2oiYHhGTImJSB+mdPWbWd2qGXZKAq4H5EfGNitIdwNTy/lTg9uY3z8yaRVHjNEFJRwL3AY8Cm8vB51Nst98C7Aksojj0ljwGNUwj4nClLyG8LXr2hgnJ+jOTZyTr89e/mqwfd//nq9besmj75LQ7Hpw+WvqJ/X6WrJ85vPrPWAOcv/yQqrXZZx6anHbA/Q8n6/ZGD8U9vBIru/1t8prH2SPifqDaD5vnl1yzNyl/g84sEw67WSYcdrNMOOxmmXDYzTLhsJtlwj8l3Q/2P+u5ZP37PxuSrJ84OP34z/7ev/a2ST22KTYn62c9f3iy/uzxb61aG7DUx9H7k9fsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmap7P3ky5ns9ey4CDD0jWl/9d/a/RvYfOSNZPnD8lWV99027J+luvTp/vbv0rdT671+xmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSZ8nN1sG+Lj7GbmsJvlwmE3y4TDbpYJh90sEw67WSYcdrNM1Ay7pD0k3SvpcUnzJJ1dDr9Q0hJJD5e3Y/u+uWZWr55cJGIjcE5EzJU0FPi5pLvL2mUR8fW+a56ZNUvNsEfEUmBpeX+1pPnAmL5umJk1V6+22SWNBSYCD5WDviDpEUnXSBpeZZppkuZImrOBdQ011szq1+OwSxoC3Ap8MSJeAb4D7AtMoFjzX9rddBExPSImRcSkDgY1oclmVo8ehV1SB0XQr4+I2wAiYnlEbIqIzcBVwGF910wza1RP9sYLuBqYHxHfqBg+umK0DwGPNb95ZtYsPdkb/y7g48CjkjqvsXs+cIqkCUAAC4HP9EkLzawperI3/n6gu/Nj72x+c8ysr/gbdGaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwT/XrJZkm/BhZVDBoJvNBvDeiddm1bu7YL3LZ6NbNte0XELt0V+jXsb5i5NCciJrWsAQnt2rZ2bRe4bfXqr7a5G2+WCYfdLBOtDvv0Fs8/pV3b1q7tAretXv3StpZus5tZ/2n1mt3M+onDbpaJloRd0gckPSnpGUnntaIN1UhaKOnR8jLUc1rclmskrZD0WMWwEZLulvR0+bfba+y1qG1tcRnvxGXGW/rctfry5/2+zS5pIPAU8H5gMTAbOCUiHu/XhlQhaSEwKSJa/gUMSe8G1gDXRsTB5bCvASsj4uLyg3J4RPxFm7TtQmBNqy/jXV6taHTlZcaBE4EzaOFzl2jXyfTD89aKNfthwDMRsSAi1gM3ASe0oB1tLyJ+AqzsMvgEYGZ5fybFm6XfVWlbW4iIpRExt7y/Gui8zHhLn7tEu/pFK8I+Bniu4v/FtNf13gO4S9LPJU1rdWO6MSoilpb3lwGjWtmYbtS8jHd/6nKZ8bZ57uq5/HmjvIPujY6MiEOBDwKfL7urbSmKbbB2Onbao8t495duLjO+RSufu3ovf96oVoR9CbBHxf+7l8PaQkQsKf+uAGbRfpeiXt55Bd3y74oWt2eLdrqMd3eXGacNnrtWXv68FWGfDYyTtLek7YApwB0taMcbSBpc7jhB0mDgaNrvUtR3AFPL+1OB21vYlq20y2W8q11mnBY/dy2//HlE9PsNOJZij/yzwF+2og1V2rUP8IvyNq/VbQNupOjWbaDYt/FJ4K3APcDTwH8DI9qobd8FHgUeoQjW6Ba17UiKLvojwMPl7dhWP3eJdvXL8+avy5plwjvozDLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNM/D/wvaHJMqUFBQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.view_top_losses(n=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seven'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(x_test[0:1])[0]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(predictor.predict(x_test[0:1], return_proba=True)[0])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_mnist')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "p = ktrain.load_predictor('/tmp/my_mnist')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seven'" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(x_test[0:1])[0]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "predictions = p.predict(x_test)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PredictedActual
0seven7
1two2
2one1
3zero0
4four4
\n", "
" ], "text/plain": [ " Predicted Actual\n", "0 seven 7\n", "1 two 2\n", "2 one 1\n", "3 zero 0\n", "4 four 4" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.DataFrame(zip(predictions, y_test), columns=['Predicted', 'Actual'])\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }