{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: utf-8\n", "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trn, val, preproc = text.texts_from_folder('data/aclImdb', \n", " maxlen=500, \n", " preprocess_mode='bert',\n", " train_test_names=['train', \n", " 'test'],\n", " classes=['pos', 'neg'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "maxlen is 500\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('bert', trn , preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, \n", " train_data=trn, \n", " val_data=val, \n", " batch_size=6)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", " 6492/25000 [======>.......................] - ETA: 19:19 - loss: 0.6908 - acc: 0.6155\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3xV9fnA8c+TTQaBDGaAsMKSPQQBhSIo2GLVuqrWTbXVH7VqtXXWarW1tdYO96hY90SGoFYEkRX23jNAEgLZO/f7++PcXJKQhAD33JN77/N+vfLizHufe0juc77jfL9ijEEppVTwCnE6AKWUUs7SRKCUUkFOE4FSSgU5TQRKKRXkNBEopVSQ00SglFJBLszpAE5VUlKSSU1NdToMpZTyKytXrjxijEmub5/fJYLU1FTS09OdDkMppfyKiOxtaJ9WDSmlVJDTRKCUUkFOE4FSSgU5TQRKKRXkNBEopVSQ00SglFJBThOBUko5ZG9OEQWlFU6HoYlAKaWcct7TC/jpy8ucDkMTgVJKOaGiygXA+ow8hyPRRKCUUo7IL3G+SqiaJgKllHJAXo1EcKSwzMFINBEopZQjcmskgvQ9Rx2MRBOBUko5Yu3+XM/y9sxCByPxw9FHlVIqEHy/M4fOCdGUVVax72ixo7FoIlBKKQfsyymmd7s4jhWXO54ItGpIKaV87FhROceKy4mNCqNTQjT7tUSglFLBocpluPfDtXy8KgOAqPBQ2sRF8snqDMorXUSEOXNvriUCpZTykZzCMk8SANh/tJhOraMxBr7enOlYXJoIlFLKRypcptZ6WaWLXu3iAHht8W4nQgI0ESillM9UVLpqrReXV3JWx3g6JbRAEIei0kSglFI+U+mqnQgeuqgvAMNTEzhwzLkGY20sVkopHymvtKqGLh+awp8uG0BIiFUKaBMXRXZhGcYYRHxfMtASgVJK+Uh1ieDCs9p5kgBAm7hIKqoMRwrL6z0vr7iC8X9ZwGdrMurdf6Y0ESillI9UDz0dHlr7q3dgp1YADH/iq3qfKcgrqWD3kSLK67QxeIsmAqWU8pGKKqtqKCy0dvXPYHciAFiyM+eE8wrKrAHq4qLCbYlLE4FSSvlIdYkgok6JoGY10a4jRSecV1BaCUDLKHuadTURKKWUj1R6SgQnfvUuvHc8cVFh7Mw+cSTS6kSgJQKllPJz5Z42ghN7BnVOjGZ09yS2ZxacsK96gvtYLREopZR/qy4R1G0srtY/JZ49OcXkFdeexrLM3UgcFW7PV7YmAqWU8pGGeg1VS02MAeBQfkmt7ZXu88JCNBEopZRfK6usAmhwlNHW0VYbwJNztlBcXunZXuEpSdjzsJkmAqWU8pHCMisRxEbUX9ffKjoCgG+3ZfPjfy0m3902UP0gWn2NzN6giUAppXxg5tqD/GHWJgBiIkPrPSYxNsKzvC2zkGlvpgM1nj8I0RKBUkr5rX9/s8Oz3NCdfZu4yFrrS3cd5bM1GSdtZD5TmgiUUsoHkut8yddHRFjy2x9wx/genm0PfrKBSpcLEQjVEoFSSvmvqPD6q4Pqah/fgrsnpbHu0UlcOqQjBWWVZOaXEm5TjyHQRKCUUj5RVGb1Anr+miEnPVZEaBkVzjVndwasKqK64xN5kyYCpZTygcKySs5LS2Zy//ZNPqd/x1bERISy72ixbQ3FoIlAKaV8orzSdcpPBkeEhTAgxRqZ1K6GYrAxEYjIayKSJSIbGtgvIvKciOwQkXUicvLyklJK+anyStdpfZmnJllPG9vVUAz2lgjeAC5sZP9koKf7ZxrwvI2xKKWUo8qrXCcMP90U7eOjbIimNtsSgTFmIXC0kUMuBt40lqVAKxFpeuWZUkr5kYoqV4NDSzSmlXvYieLyKm+H5OFkG0FHYH+N9QPubUopFXBOt2oovoWVCIpqjD3kbX7RWCwi00QkXUTSs7OznQ5HKaVOWUWVOaNEECL+2UZwMhlApxrrKe5tJzDGvGSMGWaMGZacnOyT4JRSypvKq1yEh536l3k7dxtBlct4OyQPJxPBTOBn7t5DI4E8Y8whB+NRSilbGGOoqHIReRolgm5Jse5/Y7wdloc9854BIvIOMA5IEpEDwCNAOIAx5gVgDjAF2AEUAzfaFYtSSjlp5tqDGHN6zwJEhIXw9i1n0y051obILLYlAmPM1SfZb4Bf2vX+SinVXEx/dw0Ap1u7c06PJC9GcyK/aCxWSqlAsD4jz+kQ6qWJQCmlbNY5IRqAn43q4nAk9dNEoJRSNhOBiwd14Ny05tnrUROBUkrZrKC0kthI25pkz5gmAqWUsllhaSVxUeFOh9EgTQRKKWWj0ooqyqtcxEVpiUAppYJSQak1RpAmAqWUClKFZZoIlFIq6FRUufj1e2vYdDCfgtIKAGIjtY1AKaWCxtr9uXy8OoO7P1jL4bxSAJLjIh2OqmGaCJRSysu2ZxUCsPlQPntzigHo4n6orDnSRKCUUl6WlV/mWX5izmZCQ8Qz01hzpIlAKaW8LKugtNZ6lcsgNk4sc6Y0ESillJdlFZTRKaGF02E0mSYCpZTysuyCMlITY+iebN9kMt4U1Ingng/W8u7yfZ71Rz7bwMerDjgYkVLK32XklrBmfy7JcZF8ffc4bhrdlfsn93Y6rEY13yccfODDlQf4cOUBLjyrHVHhofxnyV5YspdLh6R4jjlwrJgxf/oGgNUPTaR1TIRT4Sql/MCP/7UYgMiwUAAe/lFfJ8NpkqAuEVQb9NiX9H7oi3r3rdqX61ke/IcvKSmv8lVYSik/VFZhfUfcODrV2UBOQdAmAmumzJM7nFdSa/1X763GdbrzzSmlAl7bllFMPqsdaW3jnA6lyYIuEXy3/Qi5xeVUVDX8ZV49NgjA7iNFJMRE8Or1wwCYtzGTLzdnaslAKVWv8ioXEWH+9dXqX9GeoW2ZBVz76jKenLOF0sraX+Rf330ef718IGC1CwD84+vtvLN8P12TYpjQpy3LfjcBgJ/PWMnP31rp2+CVUn6hvNJFRKh/fbUGVWPx+gPWxNF5JRWU1rij79iqBd2TY8krsQaHuvDZRUzo3Yavt2QBkJpodQFrU2OskIXbsn0VtlLKj5RVuogM969E4F/RnqEPVu4HoGWLMM9YIADv/XwkACmtjj8AUp0EALq5+wKLCD8a2MGz/fudR2yNVynlf6wSQajTYZySoEkES3flsHTXUQAqqgzXvLIMgL9fNYiU1tZgUMlxkdw6tusJ57aPj/IsP3fVIFY+eD5hIcKi7ZoIlFK1lVdqG0GzlVtc7ln+ZHWGZ/kHvdt4lkWEBy46sc9vjzaxtY5JjI0kpXUL9h0ttilapZQ/Msb4ZWNx0LQRhIac+B/ztysH1juh9Df3jONQbgmxUWHERYXTNenEx8Q7JUQze90h8oqXkdK6BU9dNsCWuJVS/qO8ygVApCaC5qm4vPKEbXENzBjUNSmm3i//mnZlFwHw3Q6reuiWsV3p0cZ/+g0rpbyvvNJKBP7Wa8i/oj0DRWVWL6H7Ljw+5kfsGcwhOv38nrXWz39mYZMfUlNKBabSCneJQHsNNU+FZVbX0J+N6uLZdiaTSV8xrNMJ26a/u4bHZ23i682Zp/26Sin/VT0/cct6qpybs6CpGhrXqw2tWkTQIvx4t642cVGNnHFyv5vSmxe/3UVOkdUQPXPtQQBe+W430RGhLLh33Bm/h1LKf1Q/ixTfwr8SQdCUCNLaxnHF8E6EhByfJSjxDEcSnXZud1Y+NJH0B88/YV9xeRUjnvjaU2eolAp8+aVWW2TLFv51jx00iaCmC/q1BaiVFM5EUmwkz18zhPQHz2fHE5PZ9vhkz760B+eSXVDWyNlKqUCRX+KfVUNBmQiev2Yo25+YfPIDT8Hk/u1Jio0kLDSEiLAQfjfleKN09RPNSqnAplVDfiQkRAi3uXvXtHO7s/vJKbRtGcnOrCJb30sp1TzkVzcWayJQ1USErkkxfLTqQK0nm5VSgSm/pJKI0BC/e6DMv6L1Q+N6WUNYDHrsS258fbnD0Sil7FLlMrzw7U7Kq1yIeKf90Vc0Edhs2thunuVvtmZTWqET2igViNYeyD35Qc2UJgKbhYQIc6eP5f7JVuPxE7M3OxyRUsoOmXmlAHx0+zkOR3LqbE0EInKhiGwVkR0icn89+zuLyDcislpE1onIFDvjcUqf9i25ZHBHAGYs3cunNUY/VUoFhiPuB0s7tW5xkiObH9sSgYiEAv8CJgN9gatFpO4Yzw8C7xtjBgNXAf+2Kx6ntW0Zxd+utKbCfHnRLh2XSKkAsyu7kBCB1mf4oKoT7CwRjAB2GGN2GWPKgXeBi+scY4CW7uV44KCN8TjuksEpPHhRHzYezOfAsRIKyyqpqNInj5UKBEt25nBO9yTbu6bbwc6IOwI1n6Q64N5W06PAtSJyAJgD3FnfC4nINBFJF5H07Gz/nit4Qp+2hAjc9tZKznpkHhf/czG7j+hzBkr5s8N5pWw5XECXxGinQzktTqeuq4E3jDEpwBRghoicEJMx5iVjzDBjzLDk5GSfB+lNXZNiePLS/mzLLABg06F8xv9lARm5JQ5HppQ6VTmFZbzw7U4m/u1bwBpuxh/ZmQgygJpjNae4t9V0M/A+gDFmCRAFJNkYU7Nw5fDOrH/0Anb9cQp921s1Y4/O3EiVS9sNlLJTfmkFmw/le61K9oOVB3hq7hYK3IPNjeyW6JXX9TU7E8EKoKeIdBWRCKzG4Jl1jtkHTAAQkT5YicC/636aKCo8lJAQYc70sUw7txtfbsrkmleWcrRIn0BW6nRVVLk8cwLU9d32Iwx4dD6T/76IN5fs9cr7FZVZCSA8VJj9f2MY1V0TQS3GmErgDmAesBmrd9BGEXlMRKa6D7sbuFVE1gLvADeYIOxOc8+kXvzmwl6s2pfLJf9erENXK3Wa/jBrE/0fnc/+o8Un7Fu0/fg95pKdOV55v0N5pSTFRrLm4Un06xDvldd0gq1tBMaYOcaYNGNMd2PME+5tDxtjZrqXNxljRhtjBhpjBhlj5tsZT3MVERbCL8b14NaxXdmbU0zag3PJzC91Oiyl/M6i7dYc4jf/Z8UJ85RnF5bRIT6KG85J5avNmeQUWsPDu1yGuesPUVZZRWWVq8k3YlkFpXy48gARoUJMpH/NP1CX043FqoY7f3B8HuTznv5GnzVQ6hTllVQQFiJszyqk78PzeGL2JgAy80v5eFUGrWMiPNU302asBGDZ7qPc/t9V9HrwC3o8MJcpzy1qUjLYdrgQgEGdW9n0aXxHE0EzEhUeyvYnJtOnfUtKK1zc+c5qVu875nRYSvmFpbtyOFpUTtuWUdw5vgcALy/azdbDBcxwtwkMSGnF2J5Wf5SVe4/xxYbD7DpSWOt1dmQV8t9lJ29DqJ4H/Zfu9/JnmgiamfDQEGbdOYbz+7Rl1rpDXPb892w+lO90WEo1e9+76/0fuKgPd01M48PbRhEXGcZ1ry7jk9UZ9Gnfkicv7U90RBjvTRtJp4QW3PbWSh74ZAPREaHMunMMu5+cwugeifz+801c/dJSth4uaLBk7pmW0s9mI6uPJoJmKDREePG6ofz3lrMREZ5fsJPSiirW7s9l8t8XMX/jYadDVKrZyS+pIDYyjCn92yMiDEtN4IXrhpJVUEZGbgndkmM8x57dLZEv7zqPG0enAtAtOYazOsYjIvxuSh8AluzK4YJnFzLxbwu5bcZKNh2sfUNWEECJwL9bOAJYaIgwukcSUwd24JPVGcxce3z0jWkzVvLutJF+22dZKW/IKSxj7YFcftDbmoO8sKySuKjaX2mjeyTRq20cWzML+PGg2gMbRIWHcs+kXmzIyGP6hDTP9n4d4nnpuqFMm7ESEauqaEdWIYt3HqFnm1hiIsO4a2KaZ37i2Cj//xqVpjRIish04HWgAHgFGAzc70Qvn2HDhpn09HRfv61jXC7DH2Zv4vXFe2ptH9MjiRk3j/C7CTCU8pbff76R1xfvoWOrFtx7QS/+OGczraLDmX/XebWOyyupIEQg7jTu3I0x/HzGSorKK1m8o/4up3ueuui04vc1EVlpjBlW376mprKbjDF/F5ELgNbAdcAMICi7e/pSSIjw8A/7cvWIznRJjCYsJIQ3l+zh959v4vrXV3D3xDQGdvL/XgtKnarKKusmNiO3hF+9twag3ieGz2QieRHhpZ9Z351vLd3Lg59u4JLBHfkkwIaSb2oiqL7tnALMcD8YpreiPiIipLWN86z/9OzOzF53iIXbslm4LZuk2EjuvSCNCX3a+u1YJ0qdqqLySjq2asGVwzvxzJfbaNcyip+e3dm297vm7M60iYtkQp+2PHHJWfR9eB6B8i3Y1Kqh17FGDu0KDARCgQXGmKH2hneiYKsaasx/vt/DU3O3UOKe/rJTQgvmTj+XWD9/uEWpprjlPyvIyC1l1p1jOFpUTnKcb2+C9h8tJjw0hHbxUT5939Pljaqhm4FBwC5jTLGIJAA3eitAdXquPyeVq0d05q2le2kVHc49H6xl+jurObtbAv/bkgXArWO7MaFPW4cjVerM7cgqZN7Gw9x+XndCQoS8kgriW4QRGiI+TwIAnRL8c8jp+jQ1EYwC1hhjikTkWmAI8Hf7wlJNFREWwk1jugKw9XABLy7cxddbsggLESpdhqW7jjLn/8bSt0PLk7ySUs2XMYa/zNvKFxsPc7SonC2H81mx5xgXDWjvdGgBoamJ4HlgoIgMxBoo7hXgTeC8Rs9SPvXrSWmktG5Bq+gIBqTEs/9oCde+uoyb3ljBtHO7ce3ILkSE6aMjyv/8df42vnA/P/Pqd7s926v78qsz09REUGmMMSJyMfBPY8yrInKznYGpUxcZFsp1o1I9610SY3j+miH86r01PDZrE0/O3UxybCSdE6MZkZrAVSM606GV/020rYLP7PWHALh7YhrHiit4bbGVDCJCA6S11mFNbSz+FvgCuAkYC2QBa40x/e0N70TaWHzqcovLmbXuEC8t3MW+OsPzLvrN+ICq61SBZd2BXJbuyuHLTZkAfHDbOYBVVfTJ6gzG9EyiTZx/NNY6rbHG4qYmgnbAT4EVxphFItIZGGeMedO7oZ6cJoLTV1JeRVZBKbGRYUz952IycktoGRXGpUNSOK9XMuPSkvUBNdWsTP77Is9YWxN6t+HVG4Y7HJH/OuNeQ8aYwyLyX2C4iPwQWO5EElBnpkVEKF0SrfFWFt//A+ZvPMzjszfzxvd7eOP7PQCc1bElHVu14K6JaXRJiKFFRKiDEatgV1njAbHckvpnHlNnrkmJQESuAJ4GFmA9XPYPEbnXGPOhjbEpm03q145J/dpxtKicmWsy2HAwnzX7c/lmazbzNmZ6jnv/56MY0TXBwUhVIFt3IJeyShfDUxOochmyCkpZuz+PD9L3sz2rkAv6tWX9gTxuOCfV6VADVlOrhtYCE40xWe71ZOArY8xAm+M7gVYN2W9vThF/+3IbAJ+usQa7G98rmWHuP9TRPRIZkNKK8FDtgaTOjDGGrr+dA1hdoVtHh5OZX1brmDvG9+CeC3o5EV5A8cYDZSHVScAtBx3COmB1SYzh2asGA/CL8T342avL+WZrNt9steZ8feZLaB8fRf+O8Tz0w76IQEprbXBWJ1dR5aLKZYgKt6oct2UenxTG5TL079iKvu1dbD5UwGH3dK2FZdpF1G5NTQRfiMg8rAnmAa4E5tgTkmpO0trGsfR3E3C5DOsz8sgvrWBHViHvLN/H/E2ZzHf35ohvEU5yXCR3nZ/G5LPaERKijc7KUlRWyX0frSOvpIKdWYUczCv1POS4NbMAgBevG8qkvm09nRWMMXy7LZsbXl9BZ+3VZrsmVQ0BiMhlwGj36iJjzCe2RdUIrRpqPlbuPcq3246w+VA+y3cfJa9GY167llH8/LxuDOncWkdHDXLzNx72zA9cLTEmgrm/GsuCLdn85qN1fHff+HpLlct3H2VYl9Z6Y+EF3qgawhjzEfCR16JSfm9olwSGdrEakfOKK9h0KJ8lO4+w72gxn645yO8/tyYOn9K/HdMnpBEWKnRPjnUyZOVjlVUups1YSViI8OZNI/jNR+u4eUxXfv/5Jp6cs4VB7puEFuH1907TTgq+0WgiEJECoL4igwDGGKMD2CgA4qPDGdU9kVHdrVnTrhrRmVcW7earzZnMWX+YOeut4QEemNKHm8Z0JVTv8AKeMcbzAOOYnkmc0yOJ7+77AQCbDuYza90hEmIiADxtBsoZjSYCY0xcY/uVasjIbomM7JZIeaWLV7/bzf5jxby9bB9PzNnMpkP5/PXygVrcD3BPz9vKvxfsBODas7vU2nfViE58sPKAZ9wgTQTO0oHrla0iwkK4fVx3AG4/rzszlu7lpYW7mL/xMH+5fCCT+rXT0kEAenTmRs9DigBtW9YeBmJolwTG9kxi0fYjAPo74DDtAqp8plNCNL+d3JvrR3WhqLyK2/+7iu6/m8NLC3ey/2gxx4rKnQ5ReUFpRRUzlu5lZLcE0h88n8d/fBb96hkG/eEf9nUgOlUfTQTKp0SER6f2Y+lvJ3DbeVZJ4Y9ztjD2z98w/q8L2JVdeJJXUM3dp6szqHIZbjinK0mxkVw7sku91YA922rNc3PR5O6jzYV2Hw0sVS7DV5szeXXRbpbvOUpcVBhzp4/VB9T81NbDBVzw7ELAGs+q40mGOd+ZXcih3FLG9EzyRXhBrbHuo1oiUI4KDREu6NeO928bxZ8u609BaSVj/vQNc9zjzyv/sfFgnicJDO3S+qRJAKB7cqwmgWZAE4FqNq4Y1olLBncE4Bf/XcUv315FeaXrJGep5uK9Ffs9y4/8SOv//YkmAtVsiAjPXDGQ128cTliIMHvdIT5fe9DpsFQTbT6Uz1kdW/Lc1YPp3zHe6XDUKdBEoJoVEWF8rzYsum88AHd/sJbU+2fzzvJ9DkemTuZIYTldEmOYOrCDTnDkZzQRqGapfXwLbh7T1bP+24/X43L5V8eGYPLHOZvZfaSISB2a3C/p/5pqtn51fk8e+VFfkuMiAfhuxxGHI1L1ySks46WFuwAoKtcho/2RJgLVbMVFhXPj6K4s+s142sdH8ev317J2f67TYSm3KpfhszUZXP7iEgDG9EjisYvPcjgqdTo0EahmLyo8lN9c2IsjhWVc++oyCkp17trm4K/ztzL93TXkFJYzfUJPZtw84oShJJR/0ESg/MIlg1P47JejKSit5P6P1jsdjuJ4Vd2i+8Zz18Q0bSD2Y5oIlN8Y2KkVt4/rzuz1h3jMPdeBcsaxonI2Hszn0iEdaRkV7nQ46gzZmghE5EIR2SoiO0Tk/gaOuUJENonIRhF52854lP+7dWw3AF5bvJujOkidI1wuw/nPfEuVy9CqRYTT4SgvsC0RiEgo8C9gMtAXuFpE+tY5pifwW2C0MaYf8Cu74lGBISEmgvl3nUuIwLNfbXM6nKC0LiOPHHcSHthJHxwLBHaWCEYAO4wxu4wx5cC7wMV1jrkV+Jcx5hiAMSbLxnhUgEhrG8fUgR14c8lefdDMAfM2HiYsRFjxwPlcPKij0+EoL7AzEXQE9tdYP+DeVlMakCYii0VkqYhcWN8Licg0EUkXkfTs7GybwlX+5JIhKYD1oNmsdToMha8YY5i34TCjuid6nu9Q/s/pxuIwoCcwDrgaeFlEWtU9yBjzkjFmmDFmWHJyso9DVM3ReWnJnoHN7nh7tQ5O5yPf7TjCriNFTOrXzulQlBfZmQgygE411lPc22o6AMw0xlQYY3YD27ASg1IndePorlw0oD0Ay3cfdTiawFde6eLhzzaSHBfJRf3bOx2O8iI7E8EKoKeIdBWRCOAqYGadYz7FKg0gIklYVUW7bIxJBZjHpvYjJiKU+z5aR5WORWSrd5bvY/eRIv582QASYrS3UCCxLREYYyqBO4B5wGbgfWPMRhF5TESmug+bB+SIyCbgG+BeY0yOXTGpwJMYG8kjP+pHRm4JK/ZoqcAu5ZUuXvx2J8NTWzOul1bPBpowO1/cGDMHmFNn28M1lg3wa/ePUqdlQp82RISG8Np3uxnZLdHpcALSzLUHOZhXyhOX9NcniAOQ043FSp2xxNhILhuawnc7jpBbrA+ZeZvLZXjh2530bhenpYEApYlABYTLhnSkuLyKz9fpXMfe9tXmTHZkFXL7uO5aGghQmghUQOjboSUAD326weFIAs+7K/bTrmWU9hQKYJoIVECIjjje3PXZmrq9lNXpysov5dtt2Vw6pCNhOvtYwNL/WRUwPrhtFAAvfKs9kL3lszUHqXIZLhua4nQoykaaCFTAGJ6awPQJPdlyOJ9DeSVOhxMQ1uzPJTUxmu7JsU6HomykiUAFlIsHdcAYmL8x0+lQAsKWw/mktI52OgxlM00EKqB0S46lQ3wUy3brc4lnaumuHHZmF9EtOcbpUJTNNBGogDOiawJfbDjMvpxip0Pxa3/7chtJsZHcc0Evp0NRNtNEoALOryf2wmVgzgZ9puB05RaXs2z3Ua4f1UWnogwCmghUwOmcaNVpPzV3C9YoJupU7XGXpnq1i3M4EuULmghUQOrurtc+mFfqcCT+aUdWIYC2DwQJTQQqIP3tykEAfLLqgMOR+Kelu3KIjQyja5J2Gw0GmghUQBqQ0or4FuH8Zf42tmcWOB2OXzlaVM7MNQeZOqgDoSE6tlAw0ESgAtZZHa3xh/67TCe4PxXrM/Ior3IxdWAHp0NRPqKJQAWsp38yEECHpj5FS3Zaz2CktdWG4mChiUAFrA6tWjC2ZxI7sgudDsWvLNmVw/DU1jodZRDRRKACWlrbOLZnFlJaUeV0KH7jSEEZnRJ0WIlgoolABbTz0pIpq3Rx/WvL+WLDYafDafaMMWQXlJEcF+l0KMqHNBGogDaquzWH8bLdR7ntrZW4XIYql+GW/6Qz7ulvdML7OjJySyivctFJB5oLKpoIVEALDw3hxtGpnvV+j8zj5UW7+GpzJntyinl5oc5dUNM3W7MBGJAS73Akypc0EaiA98CUPnx0uzVpTUlFFU/N3QLANWd3Zv6mTC56bpH2LHKbs+4QaW1j6d9RE0Ew0USgAl5YaAhDuyRwzdmdPduSYiO4cXRXADYezOfrzVlOhddslFZUsXLfMc7tmdFi24QAABILSURBVKyT1AcZTQQqaDz8o76M75UMQEFpJT3axPLmTSMAuO+jdZRVntiz6NmvtjH4sflBMXjdbz9eT3mli/G92zgdivIxTQQqaESGhfL4Jf0BKKt0AXBuWjI/P7cblS7DpoP5J5zz7FfbOVZcwY//tZjnF+xkR1YBL3y7kyOFZT6N3W4l5VV8sjqDsBDhHHcDuwoemghUUOkQH8WAlHieuWKgZ9u1I7sAsOlQ7URQWFbpWV57II8/fbGF859ZyFNzt3Dfh+t8E/Bp2JCRx9A/fMm327JxuQwr9x6loLSi0XNyS6w2kken9tNqoSCkiUAFFRFh5h1juHRIimdbSusWxEWGMWf9ISqrXJ5qoIxjJQ2+ztdbsvj1+2tsj/dUFZdXcuc7q8kpKuf615bz2KxNXPb8Evo/Or/RBvH8EivptY7Wp4mDkSYCFfREhKmDOrB4Rw49HpjLHW+vBiCvxLqLToiJYEjnVgBcPaIz8+86F4CPV2WwI6v5jGx64FgxfR+ex+4jRVw+1Ep0b3y/x7N/5d5jDZ5b/VnjW+hsZMFIE4FSwK8npnmWZ6+3prisvoN+86YRfPyL0ax+aCKPXdyPtLZxTD6rHQDfbT/i+2AbsHDb8Vievnwg4aG1q3gW78hp8FxNBMFNE4FSQGJsJN/dN96znplfSm6dL8fWMRGEh1p/Ms9fO5QO8VGs2NPwXbavHcy1qrKW/W4CANPO7ebZN/msdryfvp8HP11P2gNzeWd57aG5qxNByxZhPopWNSf6v66UW0rraBb9ZjznPf0Nby/bR1yU9ecRH13/XfLwrgl8vzMHY0yzaGDNLiijTVwkbVtGAXD3xF78cEAHisurSIiJYO6GBby11EoAv/14PfEtwpnSvz2gJYJgpyUCpWrolBBN73YteeHbnfx53lZCQ4S4yPrvl4anJpBdUMZe90TvTssuLCMp9vhgcSEhQp/2LRnapTVdk2IYWGfYiF/8d5VnOd+dCOKiNBEEIy0RKFVH65hwyg65POsN3e2f3TUBgOW7j5Ka5Pwk7/klFbRqoPQC8M+fDmHexsPcPKYrP3lhCSv3HuPZr7aRW1zBG9/vITREdGrKIKUlAqXquHlM1yYd16NNLK2jw1neTEYwLSyrJDqi4Xu7TgnR3DK2GyLC/ZN7A9YDc9U9i24d263Bc1Vg0xKBUnX8oHdbHryoDy8v2sV5ackNHiciDEtNaDZDWReXVxEbGdqkYwemtKq1/sFtoxiemmBHWMoPaCJQqh63jO3GLU24Qx6RmsCXmzLJyi+ljbuR1ilFZZXENNCeUVdE2PHKgL9cPpBhXVrbFZbyA5oIlDoDw93tBN/vzOHHgzs6GkthWSWxTUwEAP/66RAqXS4uHuRs3Mp5mgiUOgP9OrSkXcsofvPhOqtnTqdWJz/JBqUVVZRVujxdXpviogHtbYxI+RNbG4tF5EIR2SoiO0Tk/kaOu0xEjIgMszMepbwtPDSET385mgqXiwXu2b2ccOc71rAYXRKd772k/I9tiUBEQoF/AZOBvsDVItK3nuPigOnAMrtiUcpO7eKjSE2MIX2vM43GRWWVfLkpE4A+7eMciUH5NztLBCOAHcaYXcaYcuBd4OJ6jvsD8Ceg1MZYlLLVxL5tWbIzp97Jbez2vy3HZ1frnhzr8/dX/s/ORNAR2F9j/YB7m4eIDAE6GWNmN/ZCIjJNRNJFJD0727nit1IN6d8xnkqXYeth341GejivlIzcEuZtPExMRCi7/jilWQx1ofyPY43FIhICPAPccLJjjTEvAS8BDBs2LPDnDFR+Z2S3RETg263ZDEjxTYPx2D//j4oq689hQEo8IfpUsDpNdpYIMoBONdZT3NuqxQFnAQtEZA8wEpipDcbKHyXHRdIlIZr/bc3C5bL/XsUY40kCADeNbtrT0ErVx85EsALoKSJdRSQCuAqYWb3TGJNnjEkyxqQaY1KBpcBUY0y6jTEpZZtKl2H1vlzeS7dqRNcdyGXprobnADgTh/KON6lFR4QyvpdOOK9On22JwBhTCdwBzAM2A+8bYzaKyGMiMtWu91XKKQ9M6QNAunuOgqn/XMxVLy1l3sbDXn0fl8vw8GcbAXj5Z8P45p5xDQ6VrVRT2NpGYIyZA8yps+3hBo4dZ2csStltcv/29O8YT1ZBKZVVx0cvfW/Ffi7o185r77Niz1G+2mx1Fx2e2ppWOs+wOkM6+qhSXtSrXRybDxXw0GcbPNtyCsvIKijFGO+0HWw4mA9Am7hITQLKKzQRKOVFvdvFcaSwjHeWW+0E7eOjWHsgjxFPfM3lLyw549fPKSzjH//bTruWUZ4pKZU6U5oIlPKifh1qzwJ296RenuX0vcfIKjiz5ya/2pxJbnEFV43opM8MKK/RRKCUFw1LPT6c8/QJPfnJ0BQ6J0R7ti3eceSMXv/rzVm0jArj/37Q84xeR6maNBEo5UXhocf/pMb3trp0zp0+ltUPTSQpNoL/fL/3tF87M7+U+ZsymTqogz48prxKE4FSXvbSdUOZOrADfdu3BCAmMozWMRFcObwT6zPyyC+t4KtNmZRWNH1coqKySi56bhEAXRJ0hFHlXZoIlPKySf3a8dzVg2vNAgYwqW87qlyGf3y9nVveTGf44181uSfRjKV7OVJYzjndE/np2Z3tCFsFMU0ESvnIgJR4WkWH89Eqa6SVgrJK7nDPI9CYgtIKnpq7haTYCN6+dWSTp6NUqqk0ESjlIyLC4E6tOFpU7tk2e92hk45NtCu7CIBBDs1+pgKfJgKlfOhHAzt4lv925UAAbn0znUXbs/n95xvJL6044ZxtmdbQ1ndoTyFlE00ESvnQJYM7IgId4qMY3T0JgK+3ZHHdq8t5ffEeBjw6/4RkcO+H6wBIjNGniJU9tLJRKR8SEdY9MgmXiwYHihvw6Hw+un0UQ7skUF55fMyidvFRvgpTBRktESjlY3FR4Z4kcNmQlBrbj9+XzdtoDSr3+OxNAEzp367WMwpKeZN4ayAsXxk2bJhJT9cpC1RgcLkMR4vLWbn3GJP6tmXEH78mu6CMs7sm0Do6gi/cQ1h/8otzGNy59UleTamGichKY0y9E39p1ZBSDgoJEZJiIz3DVH/2y9G8smg3ry3eXes4TQLKTlrWVKoZ6dCqBfde0KvWtgcv6uNQNCpYaCJQqplpERHK368aRL8OLbluZBedj1jZTquGlGqGLh7UkYsHdXQ6DBUktESglFJBThOBUkoFOU0ESikV5DQRKKVUkNNEoJRSQU4TgVJKBTlNBEopFeQ0ESilVJDzu0HnRCQb2AvEA3k1dtVcr7mcBBzxYgh13/dMj29sf337Gvvcddfr7vPmtfCn61B3Xa+Dxcm/Db0OTT+2oWOach1qbutijEmu9x2MMX75A7zU0Hqd5XQ73/dMj29sf337GvvcjV0Hb18Lf7oOdv5O6HU4vWuh16HpxzZ0TFOuQ1Pfw5+rhj5vZL3uPjvf90yPb2x/ffsa+9x11/U6NP39T5deh9N7bb0OTT+2oWOach2a9B5+VzV0qkQk3TQwBnew0Wth0etg0etg0esQHI3FLzkdQDOi18Ki18Gi18ES9Nch4EsESimlGhcMJQKllFKN0ESglFJBThOBUkoFuaBOBCIyVkReEJFXROR7p+NxioiEiMgTIvIPEbne6XicIiLjRGSR+3dinNPxOElEYkQkXUR+6HQsThGRPu7fhQ9F5Han47GT3yYCEXlNRLJEZEOd7ReKyFYR2SEi9zf2GsaYRcaY24BZwH/sjNcu3rgOwMVAClABHLArVjt56ToYoBCIIrivA8B9wPv2RGk/L30/bHZ/P1wBjLYzXqf5ba8hETkX64/2TWPMWe5tocA2YCLWH/IK4GogFHiyzkvcZIzJcp/3PnCzMabAR+F7jTeug/vnmDHmRRH50BjzE1/F7y1eug5HjDEuEWkLPGOMucZX8XuLl67DQCARKyEeMcbM8k303uOt7wcRmQrcDswwxrztq/h9zW8nrzfGLBSR1DqbRwA7jDG7AETkXeBiY8yTQL1FXBHpDOT5YxIA71wHETkAlLtXq+yL1j7e+n1wOwZE2hGn3bz0+zAOiAH6AiUiMscY47Izbm/z1u+DMWYmMFNEZgOaCPxER2B/jfUDwNknOedm4HXbInLGqV6Hj4F/iMhYYKGdgfnYKV0HEbkUuABoBfzT3tB86pSugzHmAQARuQF3KcnW6HznVH8fxgGXYt0UzLE1MocFWiI4ZcaYR5yOwWnGmGKshBjUjDEfYyVFBRhj3nA6BicZYxYACxwOwyf8trG4ARlApxrrKe5twUavg0Wvg0Wvg0WvQwMCLRGsAHqKSFcRiQCuAmY6HJMT9DpY9DpY9DpY9Do0wG8TgYi8AywBeonIARG52RhTCdwBzAM2A+8bYzY6Gafd9DpY9DpY9DpY9DqcGr/tPqqUUso7/LZEoJRSyjs0ESilVJDTRKCUUkFOE4FSSgU5TQRKKRXkNBEopVSQ00SgbCcihT54j6lNHF7Zm+85TkTOOY3zBovIq+7lG0SkWYxrJCKpdYdtrueYZBH5wlcxKd/QRKD8hnsY4XoZY2YaY56y4T0bG49rHHDKiQD4HfDcaQXkMGNMNnBIRAJ6fP5go4lA+ZSI3CsiK0RknYj8vsb2T0VkpYhsFJFpNbYXishfRWQtMEpE9ojI70VklYisF5He7uM8d9Yi8oaIPCci34vILhH5iXt7iIj8W0S2iMiXIjKnel+dGBeIyLMikg5MF5EficgyEVktIl+JSFv3EMe3AXeJyBqxZrtLFpGP3J9vRX1fliISBwwwxqytZ1+qiPzPfW2+dg+Rjoh0F5Gl7s/7eH0lLLFmFJstImtFZIOIXOnePtx9HdaKyHIRiXO/zyL3NVxVX6lGREJF5Oka/1c/r7H7U8Dv5mpQjTDG6I/+2PoDFLr/nQS8BAjWTcgs4Fz3vgT3vy2ADUCie90AV9R4rT3Ane7lXwCvuJdvAP7pXn4D+MD9Hn2xxqAH+AnWcMIhQDuseQd+Uk+8C4B/11hvzfGn8G8B/upefhS4p8ZxbwNj3Mudgc31vPZ44KMa6zXj/hy43r18E/Cpe3kWcLV7+bbq61nndS8DXq6xHg9EALuA4e5tLbFGHI4GotzbegLp7uVUYIN7eRrwoHs5EkgHurrXOwLrnf690h/v/QT9MNTKpya5f1a712OxvogWAv8nIpe4t3dyb8/BmijnozqvUz1U9Eqs8eLr86mxxtHfJNaMYwBjgA/c2w+LyDeNxPpejeUU4D0RaY/15bq7gXPOB/qKSPV6SxGJNcbUvINvD2Q3cP6oGp9nBvDnGtt/7F5+G/hLPeeuB/4qIn8CZhljFolIf+CQMWYFgDEmH6zSA/BPERmEdX3T6nm9ScCAGiWmeKz/k91AFtChgc+g/JAmAuVLAjxpjHmx1kZrApDzgVHGmGIRWYA1TSJAqTGm7qxpZe5/q2j4d7isxrI0cExjimos/wNr6sqZ7lgfbeCcEGCkMaa0kdct4fhn8xpjzDYRGQJMAR4Xka+BTxo4/C4gE2tKyhCgvngFq+Q1r559UVifQwUIbSNQvjQPuElEYgFEpKOItMG62zzmTgK9gZE2vf9i4DJ3W0FbrMbepojn+Lj119fYXgDE1VifD9xZveK+465rM9Cjgff5HmtoZLDq4Be5l5diVf1QY38tItIBKDbGvAU8DQwBtgLtRWS4+5g4d+N3PFZJwQVchzVnb13zgNtFJNx9bpq7JAFWCaLR3kXKv2giUD5jjJmPVbWxRETWAx9ifZF+AYSJyGbgKawvPjt8hDU94SbgLWAVkNeE8x4FPhCRlcCRGts/By6pbiwG/g8Y5m5c3YRVn1+LMWYLEO9uNK7rTuBGEVmH9QU93b39V8Cv3dt7NBBzf2C5iKwBHgEeN8aUA1diTUO6FvgS627+38D17m29qV36qfYK1nVa5e5S+iLHS1/jgdn1nKP8lA5DrYJKdZ29iCQCy4HRxpjDPo7hLqDAGPNKE4+PBkqMMUZErsJqOL7Y1iAbj2ch1qTvx5yKQXmXthGoYDNLRFphNfr+wddJwO154PJTOH4oVuOuALlYPYocISLJWO0lmgQCiJYIlFIqyGkbgVJKBTlNBEopFeQ0ESilVJDTRKCUUkFOE4FSSgU5TQRKKRXk/h+84GZK4AfTkwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 2e-05...\n", "Train on 25000 samples, validate on 25000 samples\n", "25000/25000 [==============================] - 2304s 92ms/sample - loss: 0.2442 - accuracy: 0.9008 - val_loss: 0.1596 - val_accuracy: 0.9394\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 2e-5 is one of the LRs recommended by Google and is consistent with the plot above.\n", "learner.fit_onecycle(2e-5, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **93.94%** accuracy in a single epoch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make some predictions on new data." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "data = [ 'This movie was horrible! The plot was boring. Acting was okay, though.',\n", " 'The film really sucked. I want my money back.',\n", " 'The plot had too many holes.',\n", " 'What a beautiful romantic comedy. 10/10 would see again!',\n", " ]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "['neg', 'neg', 'neg', 'pos']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To save and reload the the predictor for later use:\n", "```python\n", "predictor.save('/tmp/my_predictor')\n", "reloaded_predictor = ktrain.load_predictor('/tmp/my_predictor')\n", "```\n", "\n", "Please see the [text classification tutorial](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/tutorials/tutorial-04-text-classification.ipynb) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }