{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 88582\n", "Nrows: 25000\n", "25000 train sequences\n", "Average train sequence length: 231\n", "x_train shape: (25000,400)\n", "y_train shape: (25000,2)\n", "25000 test sequences\n", "Average test sequence length: 224\n", "x_test shape: (25000,400)\n", "y_test shape: (25000,2)\n" ] } ], "source": [ "# load training and validation data from a folder\n", "# download and unzip IMDb dataset: https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n", "DATADIR = 'data/aclImdb'\n", "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_folder(DATADIR, \n", " max_features=20000, maxlen=400, \n", " ngram_range=1, \n", " classes=['pos', 'neg'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D\n", "def get_model():\n", " model = Sequential()\n", " model.add(Embedding(20000+1, 50, input_length=400)) # add 1 for padding token\n", " model.add(GlobalAveragePooling1D())\n", " model.add(Dense(2, activation='softmax'))\n", " model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", " return model\n", "model = get_model()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=(x_train, y_train), val_data=(x_test, y_test))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "25000/25000 [==============================] - 5s 187us/step - loss: 0.6942 - acc: 0.4932\n", "Epoch 2/1024\n", "25000/25000 [==============================] - 4s 172us/step - loss: 0.5448 - acc: 0.7606\n", "Epoch 3/1024\n", " 32/25000 [..............................] - ETA: 4s - loss: 1.3311 - acc: 0.8750\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEKCAYAAAAW8vJGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZhcZZn+8e9TW2/pdLqTzp6QBBKSIIQlbLIYVgEVVFCB0QFHhnFmHB1nXEBHZFRGZpSZcYFRQHEZFTHwY6JEwS2AsiUsSQghELKQlXTW7nSnu2t5fn+c053qTifpJn2qujr357rq4pxTp6ruLir11Hvec97X3B0REZEOsWIHEBGRgUWFQUREulBhEBGRLlQYRESkCxUGERHpQoVBRES6SBQ7QF+NGDHCJ02aVOwYIiIl5dlnn93q7vW92bfkCsOkSZNYtGhRsWOIiJQUM1vb2311KElERLpQYRARkS5UGEREpIvICoOZfd/MtpjZi/u538zsm2a20syWmNmJUWUREZHei7LF8APgogPcfzEwNbxdD/xPhFlERKSXIisM7v4YsP0Au1wG/MgDTwHDzGxMVHlERKR3itnHMA5Yl7e+PtwmIiJ5WtNZHl62mXXbWwryeiXR+Wxm15vZIjNb1NDQUOw4IiIF1bgnzd/8+Fkee7Uw33/FLAwbgAl56+PDbftw9zvdfba7z66v79WFeyIig0Y6F0yolohZQV6vmIVhHvCX4dlJpwG73H1TEfOIiAxI2WxQGOKxwnxlRzYkhpn9DJgDjDCz9cAXgSSAu38HmA9cAqwEWoAPR5VFRKSUZXI5AJLxwrQYIisM7n7VQe534O+jen0RkcEim+toMQz+Q0kiItIL6ezh08cgIiK9kO3sfC7MV7YKg4jIANfRxxAvUB+DCoOIyACXOYxOVxURkV7IZNX5LCIieTr6GJJx9TGIiAiQ7uhjUItBRERg75XP6mMQEREgv/NZh5JERIS9p6smdLqqiIiAhsQQEZFuOk5XTepQkoiIQF6LQYeSREQE9p6uqrOSREQEyB9ET4VBRETY28eg01VFRATQ6KoiItKNRlcVEZEu0hkVBhERybNrT5qqVJyERlcVERGAnS3tDKtMFez1VBhERAa4prYM1eWJgr1epIXBzC4ysxVmttLMbujh/iPM7PdmtsTMFpjZ+CjziIiUokw2V7BJeiDCwmBmceB24GJgJnCVmc3sttvXgR+5+3HAl4CvRpVHRKRUZXJesAH0INoWwynASndf5e7twL3AZd32mQn8IVz+Yw/3i4gc9rI5J1mgaxgg2sIwDliXt74+3JZvMfDecPk9QLWZDe/+RGZ2vZktMrNFDQ0NkYQVERmoMtnB02LojU8BbzOz54G3ARuAbPed3P1Od5/t7rPr6+sLnVFEpKgyucL2MUTZzb0BmJC3Pj7c1sndNxK2GMxsCHC5u++MMJOISMkZTH0MC4GpZjbZzFLAlcC8/B3MbISZdWS4Efh+hHlEREpSJusFu+oZIiwM7p4BPgY8DCwH7nP3ZWb2JTO7NNxtDrDCzF4BRgG3RJVHRKRUZXNesJFVIdpDSbj7fGB+t2035S3PBeZGmUFEpNSlc7mCjawKxe98FhGRg8jmnORgOJQkIiL9IzhddRBc+SwiIv0jk8sNjs5nERHpH9mck1Afg4iIdEgPltNVRUSkfwQtBvUxiIhIKJ1VH4OIiOTJDqIhMURE5BC5OxkdShIRkQ7ZnAPoUJKIiAQyHYVBp6uKiAjkFQa1GEREBCCbDQqDhsQQEREAmtrSAAwpixfsNVUYREQGsJ0tQWEYVpkq2GuqMIiIDGC79gSFoaYiWbDXVGEQERnA0tkcAEldxyAiIgAe9D1TwJOSVBhERAayjgvcNCSGiIgAkA2bDDFTYRARESA32FoMZnaRma0ws5VmdkMP9080sz+a2fNmtsTMLokyj4hIqcl19jEMgsJgZnHgduBiYCZwlZnN7LbbvwD3ufsJwJXAHVHlEREpRR2Hkgp4UlKkLYZTgJXuvsrd24F7gcu67ePA0HC5BtgYYR4RkZLTcSipkC2GRITPPQ5Yl7e+Hji12z43A4+Y2T8AVcD5EeYRESk5h+NZSVcBP3D38cAlwI/NbJ9MZna9mS0ys0UNDQ0FDykiUiy5QXZW0gZgQt76+HBbvo8A9wG4+5NAOTCi+xO5+53uPtvdZ9fX10cUV0Rk4OksDIOkxbAQmGpmk80sRdC5PK/bPq8D5wGY2QyCwqAmgYhIKBwRg/hgaDG4ewb4GPAwsJzg7KNlZvYlM7s03O2fgb82s8XAz4Br3TsuABcRkc4L3Ap44D/KzmfcfT4wv9u2m/KWXwLOiDKDiEgp80HWxyAiIoeo86wkFQYREYG9hWGwdD6LiMghyvnhdx2DiIgcQE7zMYiISL5sEYbEUGEQERnABt2w2yIicmg6R1dVi0FERGBvH0MB64IKg4jIQJbLOTEDU4tBREQgOJRUyP4FUGEQERnQcu4FPSMJVBhERAa04FCSCoOIiISyucKeqgoqDCIiA1pwKKmwr6nCICIygOXU+SwiIvmy6mMQEZF8OfeCDrkNKgwiIgNaNucFHQ4DVBhERAa0nBf+rKRI53wWEZFDM/fZ9QV/TbUYRESkCxUGERHpItLCYGYXmdkKM1tpZjf0cP9/mdkL4e0VM9sZZR4RkVJUnizsb/jI+hjMLA7cDlwArAcWmtk8d3+pYx93/2Te/v8AnBBVHhGRUuHufPLnL3DO9JFUpeJcdcrEgr5+lGXoFGClu69y93bgXuCyA+x/FfCzCPOweVcrmWwuypcQETlk6azz4Asb+cS9L9CezZFKDJIWAzAOWJe3vh44tacdzewIYDLwh/3cfz1wPcDEiW+ucv74qbXc9H8vckRdJW9/y2g+evaR1Fal9tkv28P8qrmcYwZtmRyNrWl2tqRpS+eoKouTyTk1FUka96RZtHYHzW0Zpo2qJhEzWtqzVKbiNLVlqEolKE/GyDmMq61gzNByYjGjLZNl6+524ma0prMMrUgCkMnlaG7L0tKeIRWP0ZbJUVORZExNOZmcs7stw/CqVEEn7yi29kyOZNw6/+bWdJbWdJam1gzusLstw5amVmorU0yoqySViFGZjBf84iCRQ9WayXYup7M+qApDX1wJzHX3bE93uvudwJ0As2fP9jfzAidMGMa4YRWs2dbCdx9dxXcfXcXRo6qpry5jS1MrwypSrN3ezBuNbVSl4kypH8K6HS1kss6edJaKZJzdbZk3/xd2U56MUZGMs6MlfUjPMaG2kgl1lVSm4qSzORqa2khnnZb2DLv2ZGhpzzB2WAW1lUnaMzma2jLsbs0wpqacRDxGImadf1d1eYJELMaedJYhZQlGDS2jpiLJxp2trN3eTNyMnAdTDNZUJBk1tJydLWk27NxD4540o4aW0ZbJkYjHiBm0pXNkc05ze4aYGYm4Mb62ksnDKxlSnqCmIsnmXW2s2dZMU2ua2soU2ZyztbmdhsZWtjS1UVeVYmhYeLc0tXUZTCzXy0/CiCFljKkpJ5tzdrS0k4zHaGnPUFWWYGR1Wfj+pIJbVZKyRFDAh5QlKEvEqCoL/pm4QyoRIxk3UokYZYkY7lBZlmBUdRnxmOHh+3M4FWzpf23prkc2kvEBWBjM7BPAPUATcDdBX8AN7v7IAR62AZiQtz4+3NaTK4G/702WN+st42r402fPJZtzHl62mTsWrOTFDY2kszlWbW1mWGWSKSOqOG78MOqry3h9WwsTh48gbsbI6jKaWjNMqKtgaEWSylSCeCz4BVuWCApGKhFjxuihjB1WzvJNTaSzOYZWJGjPOO5OOufkct7ZQljV0ExbJkttZYq6qhSJuJGKB1/KMTNiMSMRM4ZVJEnnnGTMaGxNs2prM6l4jOryBG80trFuewvLNjaSyeUoT8YZMaSMEUOSVKQqqKlIEjNj2+52du1JU1mZYHxtJUPKEry6pQl3Z086aImUJ+M07knTlsmQiBkbd+7hpU2N7GxpZ2R1OeNqKyhPxolb8IW8o6WdV95oorYyxeiaco6sr2LzrlbqwlaMu1NXGSMeMypScYzgl89rDbt5/vUdtLRnyeaCX0JTRlQxrDLJ6q3NJOMxhg9JceSI4VSWxWnPBMWudnwNo4eWEzPrnPvWCL6o66vLiJlRXZ5gWGWKHc3tbG5sJZ3Nsbs1w6ZdrWxrbgdg6qghGFCRStC4J83mxlaee30HO5vTNB1C4e9oYXo4cXt1eRIzSMVjlCVjVJclKU/GKEvEw+ISoyIVZ3hY+IZVJKmrSlFblaK2MkltZYr66jLKk/FD+NRLqWrLdP2NXDZAWwx/5e7fMLO3A7XAh4AfAwcqDAuBqWY2maAgXAlc3X0nM5sePueTfQn+ZsVjxiXHjuHCmaM6v4DdvV9/4Z05tazfnmuwyuWcptYMQysSA+bXdTqbY2dLmtZ0lljM2N2aIZ3N0RwWDDOjPZMjnc3RlsnRns1hBIewNuzYg+MYRs6d5rYMTvDjoTWdZXdbcNirLZOlpSVDWybHnnSWbbvbD9gSHV6VYlxtBeNrKxhTU8HooeVMH1PNqZOHF/zwghROW6Zri2GgHkrq+Jd7CfBjd19mB/nX7O4ZM/sY8DAQB74fPu5LwCJ3nxfueiVwr3f81CqQRF7TbKB8MR1OYjGjpjJZ7BhdJONB66PQcjmnsTXN9uZ2drS0s6M5WN7S1MqGna1s2LmHlzc38ceXG9iTDn5JJmJGbVWKOdPqOW/GKE6bUsewyn37zKQ0taa7thhSA/FQEvCsmT1C0EF8o5lVAwc9vcfd5wPzu227qdv6zb3MIDIoxWLGsMrUQb/Y3Z3GPRmeWbOdRWu3s2lnK/OXbuIXz64nZjB7Uh3XnD6JS44drR87Ja5UWgwfAY4HVrl7i5nVAR+OLpaIdGcWtLIumDmKC2aOAoJDVYvX7+R3L73Br5Zs4u9/+hynTanjhIm1XHHSeI6sH1Lk1PJmdO98TgzQFsPpwAvu3mxmHwROBL4RXSwR6Y1UIsbJk+o4eVIdn71oOj95ei23//E1nl69ne//aTV/ceoRXHLsaE46olatiBLSvfO5uR/PiOyN3pah/wFazGwW8M/Aa8CPIkslIn0WixkfOn0ST33uPH73T2/j/JmjuOeJ1VzxnSf54rxl+xy3loGrNWwxnDq5Dhi4hSETdg5fBnzb3W8HqqOLJSKH4sj6Idx+9Yk8feN5fOTMyfzoybVc9u0/s7OlvdjRpBc6WgzHjqsBKPiJBb0tDE1mdiPBaaoPmVkMGFinlIjIPkYOLecL75zJPdeezOqtzVx919Os3dZc7FhyEB2dz9e8dRLfu2Y27z1hXEFfv7eF4QNAG8H1DJsJLlb7WmSpRKRfnTN9JHddM5vXt7dw7m2P8h+/ebnYkeQAOgpDeTLOeTNGDcw5n8Ni8BOgxszeCbS6u/oYRErI26bV8/t/fhunTanjjgWv8eOn1hY7kuxHW9gfVFbg4bY79OpVzez9wDPA+4D3A0+b2RVRBhOR/jdqaDnfu+Zkzps+ki88+CLLNzUWO9Jha9OuPbz/O0/ywyfW7HNfZ4shUZwhUXpbjj4PnOzu17j7XxIMqf2F6GKJSFTKk3G+/r5ZDC1PcPE3Hu/XwSHl4Dbu3MNtj6zg9K/+gWfWbOeL85bts09bOosZJOPFOcW4t4Uh5u5b8ta39eGxIjLA1Fal+PRF0wF6/MUq0fnqr1/mW39Y2WXbrm6jLLdlcpQlYkW79qS3X+6/MbOHzexaM7sWeIhuQ12ISGn50GlHcP6Mkdzxx5U6U6mAhpbve13xgy90HXi6NZ0t6si6ve18/jTBfAjHhbc73f2zUQYTkeh98V3H4MBn719CgcexPGyNHlreuXzq5DqqyxKsatjdZZ+OFkOx9PqV3f1+d/+n8Pb/ogwlIoUxoa6SGy+ezlOrtvOjJ3WWUjFMHF7J2u0tXba1hXO9FMsBC4OZNZlZYw+3JjPT6Qwig8AHTzuCUyfX8Z1HX+uc2laik86bd/7p1duZWFfJ6/sUhuzAbTG4e7W7D+3hVu3uQwsVUkSiY2Zc89ZJbNrVyu+Xv1HsOINeW15h+MaVxzNxeCWrGpq5ed6yzsN5rencwO9jEJHB7dzpIxleleKW+cs12F7E0pngy//f3nMsl84a2zke0g+eWNM5Be2AbjGIyOGhPBnnpnfNZO22FuYv3VTsOINaezZLbWWSq0+diJkxY8zegy8vbtjFpBse4s8rtxXtqmdQYRCR0LuOG8uUEVV87eEVNDS1FTvOoNXUmqG6fO8YpGNrKjqXf/L0653LC9fsKGiufCoMIgIE8zl848oT2NLUxp2PvVbsOIPW9uZ26qr2DqNdkdrbl/Dbl/b28bRnDjp7cmRUGESk07Hja7hw5igefGGjrmuIyKZdrYysLuuybUJdxX72Lg4VBhHp4sypI2hoamPx+l3FjjLorN7azMotu5k2qus8Z49/5lyuOGl8l22XHT+2kNG6iLQwmNlFZrbCzFaa2Q372ef9ZvaSmS0zs59GmUdEDu7SWWNJJWI8+PyGg+8sffKnVxsAeNvR9fvcN3lEVefyCzddwNffN6tgubqLrDCYWRy4HbgYmAlcZWYzu+0zFbgROMPdjwH+Mao8ItI71eVJzps+knmLN9LSrpFX+9P25mCwvOMnDNvnvvyhMoZVpkjGB+dZSacAK919lbu3A/cSzBmd76+B2919B0C3EVxFpEiuO2sK25vb+V9N5tOvdrS0U12e6PFLf1jlwJktOcrCMA5Yl7e+PtyWbxowzcz+bGZPmdlFEeYRkV466YhaTp8ynB8+sZachsnoN9ua2xmed0ZSvrOm1vPB0ybyp8+eU+BU+yp253MCmArMAa4C7jKzfdpYZna9mS0ys0UNDQ0FjihyePrAyRPYsHMPz6/bWewog8Ktv36ZXy7eyLDKngtDKhHjK+8+lvG1lQVOtq8oC8MGYELe+vhwW771wDx3T7v7auAVgkLRhbvf6e6z3X12ff2+nTYi0v/OnTGSimScB55bX+wog0LHhEhNrekD7zgARFkYFgJTzWyymaWAK4F53fZ5kKC1gJmNIDi0tCrCTCLSS0PLk5xx1HCefG1bsaMMCjPHBkNfrNuxp8hJDi6ywuDuGeBjwMPAcuA+d19mZl8ys0vD3R4GtpnZS8AfgU+7uz6FIgPEaVOGs2pr8z4TyUjfHTE8OET0zmPHFDnJwUXax+Du8919mrsf6e63hNtucvd54bKHE//MdPdj3f3eKPOISN9cOmss8Zjxi2d1OOlQuUN1WYJbLz+u2FEOqtidzyIygI0cWs45R9cz99n1RR27ZzBIZ3PUDy0jVcThtHtr4CcUkaK66pSJNDS17TNhvfRNOpsjVcSL1vqiNFKKSNGcO30kI4aU8adXtxY7SklLZ51E3Iodo1dUGETkgMyM86aP5JGXNrMjnGFM+i6dzRV1mIu+KI2UIlJU7z1xHK3pHL9asrHYUUqSu7Nx5x7KE8Wbx7kvVBhE5KCOGz+MqlScb/9xpeaEfhMeWrqJ1xqa2bq7NGbGU2EQkYOqSMX58rvfwhuNbSzRPA199sobwXUgU+qrDrLnwKDCICK9csZRIwBYtlGFoa+OqAsubrvh4hlFTtI7Kgwi0isjq8sYMSTFso2NxY5ScjK54BqQUriGAVQYRKSXzIyZY2t4cYNaDH2VzgZDlydjOl1VRAaZ06cM5+XNTby8Wa2GvsiGc1okdLqqiAw2Hzh5Asm48YtFGjupL9LZ4FBSXC0GERls6qpSXDBzFHOfXc/OFl3s1luZsMWQ1JXPIjIYffy8qTS2pvnhE5oPurcyYYshESuNr9zSSCkiA8b00UM5YcIwfrf8jWJHKRkbd7UCkNChJBEZrM6fOYqlG3axpam12FFKwk+ffh2AmAqDiAxWs8YPA2DlG5rZbTBSYRCRPjtq5BAAluqahl45fcpwTp5UW+wYvabCICJ9NrK6jOmjq/n98i3FjlIScu7ErDQOI4EKg4i8CWbG26bV88ya7exqSRc7zoCnwiAih4VjxtUA8LkHlxY5ycCX89K5uA1UGETkTXr7MaMAWLe9pchJBr5szimhBkO0hcHMLjKzFWa20sxu6OH+a82swcxeCG/XRZlHRPpPWSLO3805kmUbG3UV9EG4u1oMAGYWB24HLgZmAleZ2cwedv25ux8f3u6OKo+I9L8LjxlNNuc88pIudjuQrPoYOp0CrHT3Ve7eDtwLXBbh64lIgR03roYRQ1J8Zu4SHn2lodhxBqxcDhWG0DhgXd76+nBbd5eb2RIzm2tmE3p6IjO73swWmdmihgZ9+EQGiljM+NoVswC467FVRU4zcAVnJRU7Re8Vu/P5l8Akdz8O+C3ww552cvc73X22u8+ur68vaEARObBzpo/komNGs26HOqH3J6c+hk4bgPwWwPhwWyd33+bubeHq3cBJEeYRkYjMnlTL2m0trNyiITJ6ks2pj6HDQmCqmU02sxRwJTAvfwczG5O3eimwPMI8IhKRUycPB+C/f/dKkZMMTO6lM4AeRFgY3D0DfAx4mOAL/z53X2ZmXzKzS8PdPm5my8xsMfBx4Nqo8ohIdKaOCsZO+tWSTTz52rYipxl4supj2Mvd57v7NHc/0t1vCbfd5O7zwuUb3f0Yd5/l7ue4+8tR5hGRaJQn4zz492cA8KMn1xQ1y0CUcyeuQ0kicrg5fsIwLp01ViOu9iCXC8aXKhUqDCLSb46sH8L6HXvYurvt4DsfRoKzkoqdovdKKKqIDHTHjQ8G1vv10k1FTjKw6KwkETlszTm6nrqqFM+v21nsKANKzkvrUFKi2AFEZPAwM+ZMq+eB54JLlm5736yS+kKMSiaXI1FCpyWpxSAi/epj5x4FwAPPbeAVzQlNOptjZ0uauqpUsaP0mgqDiPSrKfVDuOKk8YDmagBYv2MPACOHlhU5Se+pMIhIv/v8JTMAWL21uchJiu+cry8AYGR1eXGD9IEKg4j0u9rwsMkt85fTms4WOc3AMG5YRbEj9JoKg4hE6p4/ryl2hKLJZHOdyzPGVBcxSd+oMIhIJBZ8ag4Av3nx8L2mYdOu1s7lUjo7S4VBRCIxaUQVHz9vKovX72JLY+vBHzAItbQHh9G+/r5ZRU7SNyoMIhKZ06bUAfDfv3+1yEmK45nVwUizw0voVFVQYRCRCJ0+ZThHDK9k0ZrtxY5SFF/4v2UAtGVyB9lzYFFhEJHImBnvOm4srzU005Y5/M5OuvzE4HqOc6ePLHKSvlFhEJFIHT26mmzOeW3L4XdNQ86d8bUVpBKl9VVbWmlFpOQcPTo4TXPFG41FTlJ4jXvSDC1PFjtGn6kwiEikJo+oAuCTP19MLudFTlMYL29u5Kb/e5Gde9JUl5feWKWll1hESkoyHmNsTTkbd7Xy1KptvPWoEcWOFJnF63ZywwNLWb11N63p0upwzqcWg4hE7uFPng3A4yu3FjlJV+u2tzDphof4y+8/c0jPs6WxlbsfX8Vlt/+Z5ZsaS7oogFoMIlIA1eVJpo+uZvmmgdXP0DFcx2OvNPDihl28ZVzNm3qe6360iCXre57r+o6/OPHNxiuaSFsMZnaRma0ws5VmdsMB9rvczNzMZkeZR0SKZ+aYoSxY0TAgBtXb057lgv98lO//eXXnts/ev6TPz/Pihl28vq2FVQ17z7j68BmTADjzqBEkYsYxY4cect5Ci6zFYGZx4HbgAmA9sNDM5rn7S932qwY+ATwdVRYRKb5jx9fwwPMbWLRmB2dOLW4/w6fmLubVLV0nEVq2sZHnX9/BCRNre/Ucreks7/zWn7ps+5d3zOC6s6bwuUtmEDdj557SmqCnQ5QthlOAle6+yt3bgXuBy3rY78vAvwOH52AqIoeJd80aC8CLG3s+5FJIDy3pOrDfO48b07l9T3uWrbvbDvocHdOXdjh7Wj3XnTUFCDrcYzEryaIA0fYxjAPW5a2vB07N38HMTgQmuPtDZvbpCLOISJGNGFLGmJpybv31y7yyuYl/e++xlCfjBc/RffKg737oJN5+zGi2Nz/F3X9azd1/Cg4vff6SGVx31uT9jor62CsNAHz8vKm0Z3JcfcrEaIMXUNHOSjKzGPCfwD/3Yt/rzWyRmS1qaGiIPpyIROIjZ04G4IHnNzB/aXGG43781eA7ZN7HzuBv3jaFc44OhqvoGL6iwy3zl/OZuUu6zKmQ75U3mjhlUh3/dME0brh4OhOHV0YbvICiLAwbgAl56+PDbR2qgbcAC8xsDXAaMK+nDmh3v9PdZ7v77Pr6+ggji0iUOgoDwPJNjbz6RhPH3PQbrr3nGdwLc/Hbk69tY2R1GceOq+HGi2d0Dlfx3hPHceXJE7rs+4tn13PU53/N86/v6LJ9/tJNrNrazDODdHDAKAvDQmCqmU02sxRwJTCv40533+XuI9x9krtPAp4CLnX3RRFmEpEiMjOe+dx51FYmuevx1Vx2+59pbs+yYEUD196zsMfHNDS1saYf545+evV2zp5Wv88hIjPj1suPY82t72DNre9g9hF7O6G/8tDyzuXfvvQGf/eT5wD2KSSDRWSFwd0zwMeAh4HlwH3uvszMvmRml0b1uiIysI0cWs7pRw4HgolsOr6fH32lgRsfWMITr+29CK6xNc2F//Uoc76+gAUrthzya89bvJHtze0cUXfwwz5z//atnBeOipqMByHbMzk+PXdx5z43X3rMIWcaiCLtY3D3+e4+zd2PdPdbwm03ufu8Hvado9aCyOHhq+85rnP5Uxce3XkR2M+eWcfVdz3NSxsbWbp+F8fd/Ag7WtIAXP+jZw/5dT87N7hWYVgvzxb63rUnc/3ZU1i4Zget6Sx/95Pn2NmS5sNnTOK5L1xQlM7zQtCVzyJScDWVSeYcXc+CFQ2cfuRwjh1XQ11Viu3N7QBc8s3H93lMXVUKdz+kuZNnjh3Ks2t38L6Txh9859DxE4aRzTmvvrGbJ8PWzGfePp2K1OAsCqCxkkSkSO659mRWfOUiTpxYSzIe44kbzuXXnzhrn/0unTWWr7z7LWxubGXttpaDPu99C9dx9V1PsSMsMvkyOefsafV9+qV/bDhMxqK122luz5XFOCMAAAtxSURBVDJrwrBBXRRAhUFEisTMKEvs/YItT8aZMWYod/9l1xMT/+sDxzNr/DAArr3nwIPd7WpJ85n7l/DEa9s44cu/5f9e2MDmXa1sCy9Ya3wTw2BPCPsj/vWXwaANJ0wY1qfHlyIVBhEZUM6fOYqlN1/IzDFDee+J44jHjEkjgi/nNdtaeKOx50ESHn2lgVlfeqTLtk/c+wKnffX3nPSV39HclmH11mYmD6/qe6YZe6fmPGGiCoOISMFVlyeZ/4mz+M/3H9+5/p0PBh3US/NGMf3Dy2/w0R8/y3m3LeCavKGzf3Ldqcw5uus1T8d88WEApo4a0uc83/ngSZw8qZaaiiSnTK7r8+NLjTqfRaQknD2tnpjB0g27OH/mKAC+8btXWZxXKP7lHTP4yJnBMBZnHDWCJ17bytV3dR2fc87RI+mrRDzGLz761kP7A0qIWgwiUhIqUwmmjqxm6Ya9haAts3e4iqtPnch1Z03pctbSW48cwZpb39HZ2rj8xPHUVJTeHMyFphaDiJSMt4yr4dFXtvAfv3mZEyfWsm57cJbShTNH8Y/nTd3v495+zGhue98szp6mIXV6Q4VBRErGjDHV3P/ceu5Y8Frnts9fMoO/PnvKAR9nZlzeh2sXDnc6lCQiJWP66H1nQ5tZgjOkDXRqMYhIyTjjqOF87pLp3LtwHadMqiMWM06eNPjPEio0FQYRKRlmxvVnH8n1Zx9Z7CiDmg4liYhIFyoMIiLShQqDiIh0ocIgIiJdqDCIiEgXKgwiItKFCoOIiHShwiAiIl2Yuxc7Q5+YWQOwNlytAXYdYLnjvyOArW/i5fKfsy/3d99+oPXuWfO3vZnch3Pm/eXc33IxPh+lmLmn7aX4+ehN5vzlwfb5OMLdezeKoLuX7A2480DLef9ddKjP35f7u28/0Hr3rIea+3DO3JvPRLE/H6WYebB8PnqTudjvdbE/Hx23Uj+U9MuDLOdvO9Tn78v93bcfaL2nrIeS+3DO3H3bQPx8lGLmnraX4uejN5nzlw/HzwdQgoeS3gwzW+Tusw++58BSirmVuTBKMTOUZu7DMXOptxh6685iB3iTSjG3MhdGKWaG0sx92GU+LFoMIiLSe4dLi0FERHpJhUFERLpQYRARkS4O+8JgZmeZ2XfM7G4ze6LYeXrDzGJmdouZfcvMril2nt4yszlm9nj4fs8pdp7eMrMqM1tkZu8sdpbeMLMZ4Xs818z+tth5esPM3m1md5nZz83swmLn6S0zm2Jm3zOzucXOciDhZ/iH4Xv8Fwfbv6QLg5l938y2mNmL3bZfZGYrzGylmd1woOdw98fd/aPAr4AfRpk3zHbImYHLgPFAGlgfVdZ8/ZTbgd1AOQXI3U+ZAT4L3BdNyq766TO9PPxMvx84I8q8Ybb+yPygu/818FHgA1HmzcvXH7lXuftHok3asz7mfy8wN3yPLz3okx/K1XHFvgFnAycCL+ZtiwOvAVOAFLAYmAkcS/Dln38bmfe4+4DqUsgM3AD8TfjYuaXyXgOx8HGjgJ+USOYLgCuBa4F3lkLm8DGXAr8Gri6VzOHjbgNOLJXPdN7jCvLv8BDy3wgcH+7z04M9d4IS5u6PmdmkbptPAVa6+yoAM7sXuMzdvwr0eCjAzCYCu9y9KcK4QP9kNrP1QHu4mo0u7V799V6HdgBlUeTM10/v9RygiuAf1x4zm+/uuYGcOXyeecA8M3sI+GlUecPX6o/32YBbgV+7+3NR5u3Qz5/pgutLfoIW+njgBXpxpKikC8N+jAPW5a2vB049yGM+AtwTWaKD62vmB4BvmdlZwGNRBjuIPuU2s/cCbweGAd+ONtp+9Smzu38ewMyuBbZGWRQOoK/v8xyCQwdlwPxIk+1fXz/T/wCcD9SY2VHu/p0owx1AX9/r4cAtwAlmdmNYQIppf/m/CXzbzN5BL4bNGIyFoc/c/YvFztAX7t5CUMxKirs/QFDUSo67/6DYGXrL3RcAC4oco0/c/ZsEX14lxd23EfSLDGju3gx8uLf7l3Tn835sACbkrY8Ptw1kpZgZSjO3MhdGKWaG0s3doV/yD8bCsBCYamaTzSxF0HE4r8iZDqYUM0Np5lbmwijFzFC6uTv0T/5C96T3c6/8z4BN7D1t8yPh9kuAVwh65z9f7JylnrlUcyuzMg/G3IXIr0H0RESki8F4KElERA6BCoOIiHShwiAiIl2oMIiISBcqDCIi0oUKg4iIdKHCIJEzs90FeI1LezmEdn++5hwze+ubeNwJZva9cPlaMyvWuFFdmNmk7kM497BPvZn9plCZpDhUGKRkmFl8f/e5+zx3vzWC1zzQeGJzgD4XBuBzlOC4QADu3gBsMrPI53mQ4lFhkIIys0+b2UIzW2Jm/5q3/UEze9bMlpnZ9Xnbd5vZbWa2GDjdzNaY2b+a2XNmttTMpof7df7yNrMfmNk3zewJM1tlZleE22NmdoeZvWxmvzWz+R33dcu4wMz+28wWAZ8ws3eZ2dNm9ryZ/c7MRoXDHX8U+KSZvWDBTID1ZnZ/+Pct7OnL08yqgePcfXEP900ysz+E783vw+HgMbMjzeyp8O/9Sk8tMAtm6HrIzBab2Ytm9oFw+8nh+7DYzJ4xs+rwdR4P38Pnemr1mFnczL6W9//qb/LufhA46CxgUsKKfVm3boP/BuwO/3shcCdgBD9KfgWcHd5XF/63AngRGB6uO/D+vOdaA/xDuPx3wN3h8rXAt8PlHwC/CF9jJsH49ABXEAxDHQNGE8wLcUUPeRcAd+St10LnKAHXAbeFyzcDn8rb76fAmeHyRGB5D899DnB/3np+7l8C14TLfwU8GC7/CrgqXP5ox/vZ7XkvB+7KW68hmKhlFXByuG0owYjKlUB5uG0qsChcnkQ46QtwPfAv4XIZsAiYHK6PA5YW+3OlW3Q3DbsthXRheHs+XB9C8MX0GPBxM3tPuH1CuH0bwURE93d7no6hu58lmHegJw96MHfCS2Y2Ktx2JvCLcPtmM/vjAbL+PG95PPBzMxtD8GW7ej+POR+YaWYd60PNbIi75//CHwM07Ofxp+f9PT8G/iNv+7vD5Z8CX+/hsUuB28zs34FfufvjZnYssMndFwK4eyMErQuCsfmPJ3h/p/XwfBcCx+W1qGoI/p+sBrYAY/fzN8ggoMIghWTAV939u102BhPLnA+c7u4tZraAYF5ogFZ37z5LXVv43yz7/wy35S3bfvY5kOa85W8B/+nu88KsN+/nMTHgNHdvPcDz7mHv39Zv3P0VMzuRYAC1r5jZ74H/t5/dPwm8AcwiyNxTXiNomT3cw33lBH+HDFLqY5BCehj4KzMbAmBm48xsJMGv0R1hUZgOnBbR6/8ZuDzsaxhF0HncGzXsHdP+mrztTUB13vojBDORARD+Iu9uOXDUfl7nCYJhkiE4hv94uPwUwaEi8u7vwszGAi3u/r/A1wjmAl4BjDGzk8N9qsPO9BqClkQO+BDBPMHdPQz8rZklw8dOC1saELQwDnj2kpQ2FQYpGHd/hOBQyJNmthSYS/DF+hsgYWbLCeb9fSqiCPcTDE/8EvC/wHPArl487mbgF2b2LLA1b/svgfd0dD4DHwdmh521L9HDzF7u/jLB9JXV3e8jKCofNrMlBF/Ynwi3/yPwT+H2o/aT+VjgGTN7Afgi8BV3bwc+QDAN7GLgtwS/9u8Argm3Tadr66jD3QTv03PhKazfZW/r7BzgoR4eI4OEht2Ww0rHMX8L5up9BjjD3TcXOMMngSZ3v7uX+1cCe9zdzexKgo7oyyINeeA8jwGXufuOYmWQaKmPQQ43vzKzYQSdyF8udFEI/Q/wvj7sfxJBZ7EBOwnOWCoKM6sn6G9RURjE1GIQEZEu1McgIiJdqDCIiEgXKgwiItKFCoOIiHShwiAiIl2oMIiISBf/HyqJdEDecgYQAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.005...\n", "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/2\n", "25000/25000 [==============================] - 6s 241us/step - loss: 0.5540 - acc: 0.7531 - val_loss: 0.4021 - val_acc: 0.8615\n", "Epoch 2/2\n", "25000/25000 [==============================] - 6s 240us/step - loss: 0.3062 - acc: 0.8922 - val_loss: 0.2988 - val_acc: 0.8853\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(0.005, 2)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "data = [ 'This movie was horrible! The plot was boring. Acting was okay, though.',\n", " 'The film really sucked. I want my money back.',\n", " 'What a beautiful romantic comedy. 10/10 would see again!']" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['neg', 'neg', 'pos']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.load_predictor('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['neg']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(['The plot had lots of holes and did not make sense'])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }