{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a Chinese-Language Sentiment Analyzer\n", "\n", "In this notebook, we will build a Chinese-language text classification model in 4 simple steps. More specifically, we will build a model that classifies Chinese hotel reviews as either positive or negative.\n", "\n", "The dataset can be downloaded from Chengwei Zhang's GitHub repository [here](https://github.com/Tony607/Chinese_sentiment_analysis/tree/master/data/ChnSentiCorp_htl_ba_6000).\n", "\n", "(**Disclaimer:** I don't speak Chinese. Please forgive mistakes.) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess the Data\n", "\n", "First, we use the `texts_from_folder` function to load and preprocess the data. We assume that the data is in the following form:\n", "```\n", " ├── datadir\n", " │ ├── train\n", " │ │ ├── class0 # folder containing documents of class 0\n", " │ │ ├── class1 # folder containing documents of class 1\n", " │ │ ├── class2 # folder containing documents of class 2\n", " │ │ └── classN # folder containing documents of class N\n", "```\n", "We set `val_pct` as 0.1, which will automatically sample 10% of the data for validation. We specifiy `preprocess_mode='standard'` to employ normal text preprocessing. If you are using the BERT model (i.e., 'bert'), you should use `preprocess_mode='bert'`.\n", "\n", "**Notice that there is nothing speical or extra we need to do here for non-English text.** *ktrain* automatically detects the language and character encoding and prepares the data and configures the model appropriately.\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: GB18030\n", "Decoding with GB18030 failed 1st attempt - using GB18030 with skips\n", "skipped 107 lines (0.3%) due to character decoding errors\n", "skipped 11 lines (0.3%) due to character decoding errors\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Building prefix dict from the default dictionary ...\n", "WARNING: Logging before flag parsing goes to stderr.\n", "I1001 17:33:09.975814 140013155014464 __init__.py:111] Building prefix dict from the default dictionary ...\n", "Loading model from cache /tmp/jieba.cache\n", "I1001 17:33:09.978070 140013155014464 __init__.py:131] Loading model from cache /tmp/jieba.cache\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "language: zh-cn\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading model cost 0.652 seconds.\n", "I1001 17:33:10.629599 140013155014464 __init__.py:163] Loading model cost 0.652 seconds.\n", "Prefix dict has been built succesfully.\n", "I1001 17:33:10.631566 140013155014464 __init__.py:164] Prefix dict has been built succesfully.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 22388\n", "Nrows: 5324\n", "5324 train sequences\n", "Average train sequence length: 82\n", "Adding 3-gram features\n", "max_features changed to 457800 with addition of ngrams\n", "Average train sequence length with ngrams: 245\n", "x_train shape: (5324,100)\n", "y_train shape: (5324,2)\n", "592 test sequences\n", "Average test sequence length: 75\n", "Average test sequence length with ngrams: 183\n", "x_test shape: (592,100)\n", "y_test shape: (592,2)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_folder('data/ChnSentiCorp_htl_ba_6000', \n", " maxlen=100, \n", " max_features=30000,\n", " preprocess_mode='standard',\n", " train_test_names=['train'],\n", " val_pct=0.1,\n", " ngram_range=3,\n", " classes=['pos', 'neg'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Create a Model and Wrap in Learner Object" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "maxlen is 100\n", "building document-term matrix... this may take a few moments...\n", "rows: 1-5324\n", "computing log-count ratios...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('nbsvm', (x_train, y_train) , preproc=preproc)\n", "learner = ktrain.get_learner(model, \n", " train_data=(x_train, y_train), \n", " val_data=(x_test, y_test), \n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Estimate the LR\n", "We'll use the *ktrain* learning rate finder to find a good learning rate to use with *nbsvm*. We will, then, select the highest learning rate associated with a still falling loss.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "5324/5324 [==============================] - 1s 245us/step - loss: 0.6923 - acc: 0.5255\n", "Epoch 2/1024\n", "5324/5324 [==============================] - 1s 169us/step - loss: 0.6915 - acc: 0.5385\n", "Epoch 3/1024\n", "5324/5324 [==============================] - 1s 161us/step - loss: 0.6872 - acc: 0.6056\n", "Epoch 4/1024\n", "5324/5324 [==============================] - 1s 167us/step - loss: 0.6653 - acc: 0.7979\n", "Epoch 5/1024\n", "5324/5324 [==============================] - 1s 174us/step - loss: 0.5744 - acc: 0.9303\n", "Epoch 6/1024\n", "5324/5324 [==============================] - 1s 173us/step - loss: 0.3491 - acc: 0.9699\n", "Epoch 7/1024\n", "5324/5324 [==============================] - 1s 170us/step - loss: 0.1122 - acc: 0.9900\n", "Epoch 8/1024\n", "5324/5324 [==============================] - 1s 172us/step - loss: 0.0244 - acc: 0.9962\n", "Epoch 9/1024\n", "5324/5324 [==============================] - 1s 171us/step - loss: 0.0106 - acc: 0.9968\n", "Epoch 10/1024\n", "5324/5324 [==============================] - 1s 169us/step - loss: 0.0072 - acc: 0.9970\n", "Epoch 11/1024\n", " 352/5324 [>.............................] - ETA: 0s - loss: 0.0056 - acc: 0.9972 \n", "\n", "done.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxcdb3/8ddnJmvbpGvapum+UbpAS0MBWSwKWBZbkK2gWGQTFUG9v6tw9aJyr/eiXpcrolIQRQUK4hULVAooyFILTTdKWwpt6ZKu6b5mnc/vj5niECdt0ubkzGTez8djHj3bzLwzhHnnzJlzvubuiIhI9oqEHUBERMKlIhARyXIqAhGRLKciEBHJcioCEZEspyIQEclyOWEHaKkePXr4wIEDw44hIpJR5s+fv83dS1Kty7giGDhwIBUVFWHHEBHJKGa2tql1+mhIRCTLqQhERLJcoEVgZpPMbIWZrTSz21Os/5GZLUrc3jGzXUHmERGRfxbYMQIziwL3AucClcA8M5vp7ssObePuX07a/ovAuKDyiIhIakHuEUwAVrr7anevBWYAUw6z/VXAowHmERGRFIIsgjJgfdJ8ZWLZPzGzAcAg4K9NrL/JzCrMrKKqqqrVg4qIZLN0+froVOAJd29ItdLdpwPTAcrLy4/qutkbdx2kam8N+bkR9lXX06u4gLIuhZhBQ8ypjzkxd2Ien8+LRsjLiRCNGLGYUxeLETFL3MDMDvt87o47eGI65lAfixGNGFEzohE74mOIiLSFIItgA9Avab5vYlkqU4EvBJiFJxdt4HvPrmjx/Q69V6catiEaiZcCQMwhlnjzb46IQX5OlJyoUVMXIz83ggEFuVE65EXfX1fXkFRAETASz5kopIgZRuJfi+ctyI1SXJBLcWEOxQW5dMzPoWNelK4d86ipj73/mFEzCvKiFOXnkJ8TL76caISC3AhFBbmUdMonL0dfLBNp74IsgnnAMDMbRLwApgJXN97IzEYAXYG/B5iFT4zry+AenWiIOR3yo6zbfoBt+2owM3Ii8b/QD/21bga1DTHq6p36WAwD8nOj7/9l3xDzf0wn3vmT35RJ/GsWf+M2i6/PiUZoiDkNMae2PkZNfQN1DU5+boSauhjuTk19jIN1DdTUxd+wc6MRGg7tXXh8r8WJF8+hvY7YoeWJbDv217Jm2372VNez52Ad9bGjG3zIDHoW5dO7uIDOHfLIiRi5UaNbxzy6dcwjNxqhc2EuETM6F+bSrWMeZV0LGdCtAzlRFYhIpgisCNy93sxuAWYDUeBBd19qZncBFe4+M7HpVGCGBzxUWu/OBUzq3DvIp0hLh8plb3U9e6rryItGyM+Jl0vM4UBNPftrG6iua6C2PkZ9LMbB2hh7q+vYvKeayp0H2bq3ht0H62iIxaitjzF/7U527K+lqX6JGAzq0ZET+3VhbL8unNC3C8eXFpGfE23bH15EmsUybajK8vJy1yUm0kNDzNl1oBYHdh+sY/u+WtbtOMDa7ftZvmkvi9bvYtu+GgByo8boss6M79+V8QPit57FBeH+ACJZxMzmu3t5ynUqAgmKu7NxdzVvrt/FospdLFy7i8WVu6ipjwHQt2shJ/brwvCeRYzqU8zgko6UdS3UnoNIAA5XBOnyrSFph8yMsi6FlHUp5PwxpQDU1sdYtmkP89fuZMHanSyp3M0zb256/z6dC3O55tQBnD2iJyNLiynMUymIBE17BBK6/TX1LNu0h3XbDzB76WaeW7YFgMLcKB8b1YuLx5Vx5rASohF93VbkaOmjIcko63ccYMXmvfzl7a3MWrKJ3Qfr6NEpj9FlnfnYqN5MPK6E0s6FYccUySgqAslYNfUNvLBsK88t28xf397K3up6ohHjohNKOXlgN8b268KoPsU6OU/kCFQE0i7UNcRYu30/D7++jsfmredAbfxE9GE9O/Ev5w1n0ujSkBOKpC8VgbQ77k7lzoO88u42fvP3Nby9eS+TRvXm7kvH0KVDXtjxRNLO4YpAp39KRjIz+nXrwNWn9GfmLWfw1UnH8Ze3t/CZX8/jQG192PFEMoqKQDJeXk6Ez08cyj1XjWPx+l3c8FAF1XUpr18oIimoCKTdmDS6lB9ccSJzVm3nlkcWUtcQCzuSSEZQEUi7csm4vvzHlFG8sHwLX35sEQ1HecE9kWyiM4ul3bnmtIEcqG3gv//8NtV1MX48dSyd8vWrLtIU7RFIu/TZDw/hzotG8uKKrVz/63nsq9EBZJGmqAik3brujEH88IoTqVi7k8n3vMrSjbvDjiSSllQE0q5NGVvG764/hf219Vx531yWVKoMRBpTEUi7d9qQ7jz5hdPpXJjL1Q/MZc6qbWFHEkkrKgLJCqWdC3n85tPoWZTPtAff4JV3q8KOJJI2VASSNcq6FPLYZ09jQPeO3Pzb+by1QR8TiYCKQLJMj075PHzDKXQuzOXSn8/h639cwsFanYUs2U1FIFmnV3EBj332NM4Z2YuHX1/HV//wps5ClqwWaBGY2SQzW2FmK83s9ia2ucLMlpnZUjN7JMg8Iof069aBe68+ia9OOo6nFm/ki48sJNOuxCvSWgIrAjOLAvcC5wMjgavMbGSjbYYBdwCnu/so4EtB5RFJ5fMTh/JvF4zg2aWbmfarecR0SQrJQkHuEUwAVrr7anevBWYAUxptcyNwr7vvBHD3rQHmEUnpxjMHc+awHrz8ThUf/eHf2Lq3OuxIIm0qyCIoA9YnzVcmliUbDgw3s9fMbK6ZTUr1QGZ2k5lVmFlFVZW+9iety8z4zXUTuPUjQ3lv234+ef/ruiSFZJWwDxbnAMOAicBVwP1m1qXxRu4+3d3L3b28pKSkjSNKNjAzvnLecfzXJWNYVbWPL81YqCuXStYIsgg2AP2S5vsmliWrBGa6e527vwe8Q7wYREJx9Sn9+dbkUbywfCtX3z+XvdV1YUcSCVyQRTAPGGZmg8wsD5gKzGy0zZPE9wYwsx7EPypaHWAmkSP69GkD+dqkEcxbs4PPP7xAXy2Vdi+wInD3euAWYDawHHjc3Zea2V1mNjmx2Wxgu5ktA14E/tXdtweVSaS5PjdxCHd/4gReeXcbn/jZHH21VNo1y7Rf8PLycq+oqAg7hmSJbzy5hN/NXce/XTCCm84aEnYckaNmZvPdvTzVurAPFouktTsvGsWI3kV879kVvFm5K+w4IoFQEYgcRl5OhMduOo0enfKZOn0uLyzbEnYkkVanIhA5gs4dcnnkxlPokBflht9UMGelxjOQ9kVFINIMg0s68fANp1JckMMdf1xCTb2uWCrth4pApJmO613EvZ88ibXbD/DvT74VdhyRVqMiEGmBM4eVcN3pg3i8olKjnEm7oSIQaaEvnTuM7h3z+PJji1m7fX/YcUSOmYpApIWKC3J5+MZTqK1v4KJ7XmXuap0DKZlNRSByFEb0LuaRG0+lpFM+n37wDTbuOhh2JJGjpiIQOUqjyzrz0HUTqK2P8adFG8OOI3LUVAQix6Bftw4M7tGRijU7wo4ictRUBCLH6NQh3Zm7eju19bpKqWQmFYHIMfrw8BL21zYwf+3OsKOIHBUVgcgxOn1oD3Iixt/e0XkFkplUBCLHqFN+DuUDu/LSiq1hRxE5KioCkVYw8bievL15L1v2VIcdRaTFVAQireDDw0sAePFt7RVI5lERiLSCEb2LKOtSyHMar0AykIpApBWYGR8b1ZtX393Gvpr6sOOItIiKQKSVTBrdm9qGmD4ekowTaBGY2SQzW2FmK83s9hTrrzWzKjNblLjdEGQekSCNH9CVHp3ymL10c9hRRFokJ6gHNrMocC9wLlAJzDOzme6+rNGmj7n7LUHlEGkr0Yhx7shezFy0keq6Bgpyo2FHEmmWIPcIJgAr3X21u9cCM4ApAT6fSOg+Nqo3+2sbmLNK4xpL5giyCMqA9UnzlYlljV1qZm+a2RNm1i/VA5nZTWZWYWYVVVU6e1PS14eG9KAoP4dn39LHQ5I5wj5Y/BQw0N1PAJ4HHkq1kbtPd/dydy8vKSlp04AiLZGXE+Ejx/fk+WVbqG/QRegkMwRZBBuA5L/w+yaWvc/dt7t7TWL2AWB8gHlE2sSkUb3ZeaCOeWt0ETrJDEEWwTxgmJkNMrM8YCowM3kDMytNmp0MLA8wj0ib+PBxJRTkRnhmiQarkcwQWBG4ez1wCzCb+Bv84+6+1MzuMrPJic1uNbOlZrYYuBW4Nqg8Im2lQ14O5xzfi2fe3ESdPh6SDBDY10cB3H0WMKvRsjuTpu8A7ggyg0gYLh5bxtNvbuLld6r46PG9wo4jclhhHywWaZfOGl5Clw65PKmxjCUDqAhEApCXE+HCMaU8v2wz1XUNYccROSwVgUhAzh3Zi+q6GC9r5DJJcyoCkYCcPrQHZV0K+cXfVuHuYccRaZKKQCQgudEIN08cwoJ1u5izanvYcUSapCIQCdDl4/vSqzifn/zl3bCjiDRJRSASoILcKJ89awivv7eDN97bEXYckZRUBCIBu2pCf3p0yuOev2qvQNKTikAkYIV5UW44czCvvLuN5zRojaQhFYFIG7jm1AF065jHTb+dz57qurDjiHyAikCkDXTMz+HfLzoegBlvrAs5jcgHqQhE2sgl4/oytl8XHq+o1HkFklZUBCJt6Iryfqzcuo9HtFcgaURFINKGpp7cj+NLi/nZi6s0gpmkDRWBSBuKRIyvnDucDbsOMkvjGkuaUBGItLGPjujJ4JKOPPDK6rCjiAAqApE2F4kY004byJuVu1m6cXfYcURUBCJhuHhsGR3yojw0Z03YUURUBCJh6Nwhl0mje/PsW5uprddBYwlXoEVgZpPMbIWZrTSz2w+z3aVm5mZWHmQekXTy8RP6sKe6nt/PXx92FMlygRWBmUWBe4HzgZHAVWY2MsV2RcBtwOtBZRFJRx8eXsLA7h34xd/0VVIJV5B7BBOAle6+2t1rgRnAlBTb/QfwXaA6wCwiaScSMa4/czDrdxzkvpf1DSIJT5BFUAYk7/NWJpa9z8xOAvq5+zMB5hBJW5+c0J/hvTrx0Jw11NRrkHsJR2gHi80sAvwQ+JdmbHuTmVWYWUVVlQYCl/YjEjG+ceFItu6t4U+LNoYdR7JUkEWwAeiXNN83seyQImA08JKZrQFOBWamOmDs7tPdvdzdy0tKSgKMLNL2zhzWg+NLi/nx8++w+6AuUS1tL8gimAcMM7NBZpYHTAVmHlrp7rvdvYe7D3T3gcBcYLK7VwSYSSTtmBlf/dhxbNxdzf/MXhF2HMlCgRWBu9cDtwCzgeXA4+6+1MzuMrPJQT2vSCY6e0RPLhxTyqNvrOO9bfvDjiNZJtBjBO4+y92Hu/sQd/9OYtmd7j4zxbYTtTcg2ezOj48kEjEefPW9sKNIltGZxSJpoldxAWcfV8Jv567VNYikTakIRNLIFz8yDICbfzdfJ5lJm1ERiKSR0WWduffqk1i/Q+MVSNtpVhGY2W1mVmxxvzSzBWZ2XtDhRLLR+aN7M7RnJ259dCHb99WEHUeyQHP3CK5z9z3AeUBX4Brg7sBSiWSxSMS4/oxBAFzyszka6F4C19wisMS/FwC/dfelSctEpJVdUd6Pc47vxbodB5i5WGccS7CaWwTzzew54kUwO3HFUB3JEglINGJMv2b8+wPda69AgtTcIrgeuB042d0PALnAZwJLJSJEIsbVp/RnxZa9LNmgr5NKcJpbBKcBK9x9l5l9CvgGoN9MkYBNGduHovwcpusy1RKg5hbBz4EDZnYi8auFrgJ+E1gqEQGguCCXT546gFlLNvH25j1hx5F2qrlFUO/xDymnAD9193uJXz1URAJ2/RmDiEaMC3/yKgdrNWaBtL7mFsFeM7uD+NdGn0mMJZAbXCwROaSkKJ9vXDiShpjz8Otrw44j7VBzi+BKoIb4+QSbiY8t8P3AUonIB0z70EBOG9yd+19ZrZHMpNU1qwgSb/4PA53N7CKg2t11jECkDd3ykaFs2VPDg6+uCTuKtDPNvcTEFcAbwOXAFcDrZnZZkMFE5INOH9qDDw8vYfrLq9hfUx92HGlHmvvR0NeJn0Mwzd0/DUwA/j24WCKSym3nDGPngTr+5zmNZCatp7lFEHH3rUnz21twXxFpJSf178pl4/vy0Jw1bNh1MOw40k409838WTObbWbXmtm1wDPArOBiiUhTbjprMA7c/ee3w44i7URzDxb/KzAdOCFxm+7uXwsymIikNrxXEZ/50CCeWryRddsPhB1H2oFmf7zj7n9w968kbn8MMpSIHN7Vp/QHYNZbm0JOIu3BYYvAzPaa2Z4Ut71mdsTz3c1skpmtMLOVZnZ7ivU3m9kSM1tkZq+a2chj+WFEssXQnp3o27WQpRt12Qk5djmHW+nuR30ZCTOLAvcC5wKVwDwzm+nuy5I2e8Tdf5HYfjLwQ2DS0T6nSDYZ0buYtzepCOTYBfnNnwnASndf7e61wAzi1yp6X2LUs0M6ArroukgzHV9axOpt+6mu05nGcmyCLIIyYH3SfGVi2QeY2RfMbBXwPeDWAPOItCvHlxbTEHNWbN4bdhTJcKGfC+Du97r7EOBrxMc5+CdmdpOZVZhZRVVVVdsGFElT4/p3AWDBup0hJ5FMF2QRbAD6Jc33TSxrygzg4lQr3H26u5e7e3lJSUkrRhTJXKWdC+nTuYAF63aFHUUyXJBFMA8YZmaDzCwPmArMTN7AzIYlzV4IvBtgHpF2Z9yArixYqz0COTaBFYG71wO3ALOB5cDj7r7UzO5KfEMI4BYzW2pmi4CvANOCyiPSHo3v35UNuw6yeXd12FEkgx3266PHyt1n0ehSFO5+Z9L0bUE+v0h7d9KArkD8OMEFY0pDTiOZKvSDxSJy9EaWFpOfE2G+Ph6SY6AiEMlgeTkRTuzbRUUgx0RFIJLhxg3owtKNu3VimRw1FYFIhhvfvyt1Dc6blbvDjiIZSkUgkuEmDOqGGby+envYUSRDqQhEMlyXDnkc16uIue+pCOToqAhE2oFTB3dn/tqd1NbHwo4iGUhFINIOnDq4G9V1MZZs0OUmpOVUBCLtwIRB3QGYu3pHyEkkE6kIRNqBbh3zGFlazEsrtoYdRTKQikCknZg0ujcVa3eyZY+uOyQtoyIQaScuGNMbd5i9dHPYUSTDqAhE2omhPYsY3qsTz7y5KewokmFUBCLtyKRRvZm3Zgdb9+rjIWk+FYFIOzJ5bBkxhycXHm4wQJEPUhGItCNDe3ZiXP8u/L6iEncPO45kCBWBSDtzRXk/3t26j0XrdXKZNI+KQKSdueiEUgpyI/x+fmXYUSRDqAhE2pmiglzOH13KU4s3aowCaRYVgUg7dHl5X/ZW1+ucAmmWQIvAzCaZ2QozW2lmt6dY/xUzW2Zmb5rZX8xsQJB5RLLFqYO6U9alkD8s0LeH5MgCKwIziwL3AucDI4GrzGxko80WAuXufgLwBPC9oPKIZJNIxLh4XB9efbdK5xTIEQW5RzABWOnuq929FpgBTEnewN1fdPcDidm5QN8A84hklUvGxc8peGqxzjSWwwuyCMqA9UnzlYllTbke+HOqFWZ2k5lVmFlFVVVVK0YUab+G9ixiTFlnnVwmR5QWB4vN7FNAOfD9VOvdfbq7l7t7eUlJSduGE8lgU8b2YcmG3azcui/sKJLGgiyCDUC/pPm+iWUfYGbnAF8HJrt7TYB5RLLO5LF96JgX5TvPLAs7iqSxIItgHjDMzAaZWR4wFZiZvIGZjQPuI14CGlFDpJX1LCrg82cP5cUVVSzduDvsOJKmAisCd68HbgFmA8uBx919qZndZWaTE5t9H+gE/N7MFpnZzCYeTkSO0mXj+5ITMb74yEJdf0hSygnywd19FjCr0bI7k6bPCfL5RQR6FRdw/RmDuO/l1Sxav4tx/buGHUnSTFocLBaRYN3ykaEU5efwg+fe0V6B/BMVgUgWKCrI5f997DheXbmNv6/eHnYcSTMqApEsceXJ/SjKz+HRN9YfeWPJKioCkSxRkBvlipP7MWvJJrbt0ze15R9UBCJZ5KoJ/Ym5c+ujC8OOImlERSCSRYb27MSZw0qYs2o7K7fuDTuOpAkVgUiW+cHlJ1KQG2Hag/PYro+IBBWBSNYpKcrn558cz+Y91Xz7qWX6OqmoCESy0dkjevKljw5j5uKNPPuWRjHLdioCkSz1uYlDOK5XEf/15+Ua2zjLqQhEslRONMK/XzSS9TsO8uBr74UdR0KkIhDJYmcM68FHR/Tk5y+u0pCWWUxFIJLlvnLecGobYtz4m/nsqa4LO46EQEUgkuVG9enMj64cy5uVu7jt0YU0xPQtomyjIhARLhhTyl1TRvPiiir+a9ZyfaU0y6gIRASAT53Sn2mnDeCXr77HoDtm8b8vvKtCyBIqAhEBwMz41uRRnDG0BwA/euEdfvP3tSGnkrYQ6AhlIpJZzIz7rhnPSyuq+O3cNXxz5lIaYs4VJ/ejU77eLtor7RGIyAd0zM/hwhNKmf7pcsaUdeaup5dxwf++ws79tWFHk4CoCEQkpeKCXB66bgJTT+7Huh0H+NDdf+Xq++eyZY/ON2hvAi0CM5tkZivMbKWZ3Z5i/VlmtsDM6s3ssiCziEjLdeuYx92XnsBjN53KkJ4dmbNqO1N++hprt+8PO5q0osCKwMyiwL3A+cBI4CozG9los3XAtcAjQeUQkWN3yuDuPP3FM3nkxlPYsb+WH7/wbtiRpBUFuUcwAVjp7qvdvRaYAUxJ3sDd17j7m0AswBwi0ko+NKQH0z40gD8t2sCabdoraC+CLIIyIHmU7MrEshYzs5vMrMLMKqqqqlolnIgcnRvPHExOJML0V1aHHUVaSUYcLHb36e5e7u7lJSUlYccRyWo9iwv4xEllPDG/km0a4axdCLIINgD9kub7JpaJSIa78azB1DXE+MVLq8KOIq0gyCKYBwwzs0FmlgdMBWYG+Hwi0kaGlHTiknFl/O71tRr3uB0IrAjcvR64BZgNLAced/elZnaXmU0GMLOTzawSuBy4z8yWBpVHRFrX5ycOoaY+xv2vaFCbTBfoOePuPguY1WjZnUnT84h/ZCQiGWZozyI+fkIfHpqzhuvOGEjPooKwI8lRyoiDxSKSnr58bnxQmx89/07YUeQYqAhE5KgN6tGRy8f35dE31vPQnDVhx5GjpMsJisgx+faUUWzZU803Zy6lrEsh54zsFXYkaSHtEYjIMcnPifLzT41ndFkxt81YqDOOM5CKQESOWUFulOnXlJMTjXDbjIUcqK0PO5K0gIpARFpFny6FfPfSMSzZsJtrH5xHXYMuIZYpVAQi0momjS7lh1eM5Y01O7j6/rm6BEWGUBGISKu6eFwZ37jweBas28WnHnidtzfvCTuSHIGKQERa3Q1nDuZX157Mtn01TP7pa9w2Y6FGNktjKgIRCcRZw0uYdduZnDWshD8t2sjH73mVtzbsDjuWpKAiEJHA9Cwq4IFp5fzpC6eTEzGuuO/v/PXtLWHHkkZUBCISuBP7deGPXzidwSUdue7XFVzys9f406IN1OubRWnB3D3sDC1SXl7uFRUVYccQkaNwoLae381dy2Pz1rOqaj/9uhVyybi+TDttAN075Ycdr10zs/nuXp5ynYpARNpaLOY8v3wLv/jbKhau20Vu1BhT1plrTx/ERWNKiUQs5X1eW7WN8QO60iFPV8dpKRWBiKStlVv38ctX3+PRN9YBkBs1OubncP7o3owsLWb3wTrKuhbStUMe1/5qHh3yotzykaFcf8Yg8nOiIafPHCoCEUl7sZjzp8Ub+L8FG1i+aS/VdQ3sq/nnS1UUF+Swpzq+/OKxfbjhzMGM6lOM2T/vRcg/qAhEJOO4Oxt3V7Ntbw37auqZuWgjdbEY37/sRF5duY1fvvoec1dtpzZxwPny8X25vLwfJUX5DOrRMeT06UdFICLtUuXOAzz71mZ+9doaNuw6+P7y8gFdmXhcCceXFlPWtZC+XTvQKT+7jyuoCESk3du6p5o31uxg3Y4DPDx33QeK4ZDighx6FRdw1vASzhpeQsRg98E6BvfoRF1DjOG9iohGjNyovf9R076aejrmRTP+oycVgYhkFXdn855qllTuZvOearbuqWHl1n1s2VtNTsRYvH73+x8ppdIxL8qA7h2prm9gdVV8fIX+3Towtl8XRvUpZlivTnTrmM/a7fvp2iGPhet2sWTDLvJzovTr1oGyLgVEIsaoPp0Z1aeYqBkN7uRG46duVdc1cLA2fgxkx/5a9lTX0bdrBwZ065DyG1Ot4XBFEOi+kplNAv4XiAIPuPvdjdbnA78BxgPbgSvdfU2QmUSk/TMzSjsXUtq5MOX6PdV1zFm5ndyo0aVDHm9t2M22fTUU5EbZsb+Wnftr2b6/loO1DZx9XE/qG2Ks3rafl9+tYubijSkfs1vHPCIGzyzZ9IHlhblRahtiRM3okB9lb3U9DbHUf4BHI8bA7h3o3imfXQdq2by7mrKuHejaIZeDdQ3cdOZgzh9TemwvTgqBFYGZRYF7gXOBSmCemc1092VJm10P7HT3oWY2FfgucGVQmUREAIoLcpk0uvf78+MHdG32fbfsqWb5pj1s2VONO5QU5ROJGBOHl2Bm1NbHWLl1H7sP1rFjfy1vvLedDbsO0qdLIdv311JckENxQS77aurp06WQkqJ89lXXE40Y723bz7odB9i46yCGcfLAbhxI7DlEzCjMC+brskHuEUwAVrr7agAzmwFMAZKLYArwrcT0E8BPzcw80z6vEpGs0au4gF7FBU2uz8uJMLJP8fvzF57Q+n/Bt7YgrzVUBqxPmq9MLEu5jbvXA7uB7o0fyMxuMrMKM6uoqqoKKK6ISHbKiIvOuft0dy939/KSkpKw44iItCtBFsEGoF/SfN/EspTbmFkO0Jn4QWMREWkjQRbBPGCYmQ0yszxgKjCz0TYzgWmJ6cuAv+r4gIhI2wrsYLG715vZLcBs4l8ffdDdl5rZXUCFu88Efgn81sxWAjuIl4WIiLShQM8jcPdZwKxGy+5Mmq4GLg8yg4iIHF5GHCwWEZHgqAhERLJcxl1ryMyqgF3EzzmA+DeNGk83XpYLbGvhUyU/RnPXN16WKtuRMvdoYdZszZkqU2vkPFLWbMYIjDsAAAjLSURBVMiZKl+YOZvKFnbOVMvTOWcXd0/9/Xt3z7gbMP1w042XET84fdTP0dz1jZcdKWeqzC3Nmq05D/ff+1hyHilrNuRsIl9oOZvKFnbO5r6G6Ziz8S1TPxp66gjTTa0/2udo7vrGy46UM3laOY+8rKn51sx5pPtmQ87k6XTI2XhZuuRMtTxTcn5Axn00dDTMrMKbuPxqusmUrMrZupSzdSlny2TqHkFLTQ87QAtkSlblbF3K2bqUswWyYo9ARESali17BCIi0gQVgYhIllMRiIhkuawvAjM708x+YWYPmNmcsPM0xcwiZvYdM7vHzKYd+R7hMLOJZvZK4jWdGHaewzGzjokBjy4KO8vhmNnxidfzCTP7XNh5mmJmF5vZ/Wb2mJmdF3aeppjZYDP7pZk9EXaWxhK/kw8lXsdPttXzZnQRmNmDZrbVzN5qtHySma0ws5VmdvvhHsPdX3H3m4GngYfSNSfxYT37AnXER3tL15wO7AMK0jwnwNeAx4PImJSpNX5Hlyd+R68ATk/jnE+6+43AzQQ09ngr5Vzt7tcHkS+VFmb+BPBE4nWc3FYZW3RGW7rdgLOAk4C3kpZFgVXAYCAPWAyMBMYQf7NPvvVMut/jQFG65gRuBz6buO8TaZwzkrhfL+DhNM55LvHLnl8LXJTuv6PE3xT+DFydzjkT9/sBcFIG5Azk/6NjzHwHMDaxzSNtkc/dg70MddDc/WUzG9ho8QRgpbuvBjCzGcAUd/9vIOVHAGbWH9jt7nvTNaeZVQK1idmGdM2ZZCeQn645Ex9bdST+P99BM5vl7rF0zJp4nJnATDN7BngkHXOamQF3A3929wWtnbG1cra1lmQmvhfdF1hEG35ik9FF0IQyYH3SfCVwyhHucz3wq8ASpdbSnP8H3GNmZwIvBxmskRblNLNPAB8DugA/DTbaB7Qop7t/HcDMrgW2BVECh9HS13Qi8Y8M8mk0vkfAWvo7+kXgHKCzmQ11918EGS5JS1/P7sB3gHFmdkeiMNpaU5l/AvzUzC7k2C5D0SLtsQhazN2/GXaGI3H3A8QLK625+/8RL62M4O6/DjvDkbj7S8BLIcc4Inf/CfE3srTm7tuJH8dIO+6+H/hMWz9vRh8sbsIGoF/SfN/EsnSjnK0rU3JC5mRVzuCkVeb2WATzgGFmNsjM8ogfEJwZcqZUlLN1ZUpOyJysyhmc9MrcVkelAzoa/yiwiX98pfL6xPILgHeIH5X/unIqp7IqpzI3fdNF50REslx7/GhIRERaQEUgIpLlVAQiIllORSAikuVUBCIiWU5FICKS5VQEEjgz29cGzzG5mZeebs3nnGhmHzqK+40zs18mpq81s7a8JlOTzGxg40slp9imxMyebatM0jZUBJIxzCza1Dp3n+nudwfwnIe7HtdEoMVFAPwbGXBNnlTcvQrYZGaBjIsg4VARSJsys381s3lm9qaZfTtp+ZNmNt/MlprZTUnL95nZD8xsMXCama0xs2+b2QIzW2JmIxLbvf+XtZn92sx+YmZzzGy1mV2WWB4xs5+Z2dtm9ryZzTq0rlHGl8zsx2ZWAdxmZh83s9fNbKGZvWBmvRKXFb4Z+LKZLbL4SHclZvaHxM83L9WbpZkVASe4++IU6waa2V8Tr81fEpdHx8yGmNncxM/7n6n2sCw+stUzZrbYzN4ysysTy09OvA6LzewNMytKPM8riddwQaq9GjOLmtn3k/5bfTZp9ZNAm42eJW0g7NOvdWv/N2Bf4t/zgOmAEf8j5GngrMS6bol/C4G3gO6JeQeuSHqsNcAXE9OfBx5ITF8L/DQx/Wvg94nnGEn8uu8AlxG/jHME6E18zITLUuR9CfhZ0nxXeP8s/BuAHySmvwX8v6TtHgHOSEz3B5aneOyzgT8kzSfnfgqYlpi+DngyMf00cFVi+uZDr2ejx70UuD9pvjPxAU9WAycnlhUTv+JwB6AgsWwYUJGYHkhi8BTgJuAbiel8oAIYlJgvA5aE/XulW+vddBlqaUvnJW4LE/OdiL8RvQzcamaXJJb3SyzfTnwQnj80epxDl7meT/w6/ak86fExBpaZWa/EsjOA3yeWbzazFw+T9bGk6b7AY2ZWSvzN9b0m7nMOMNLMDs0Xm1knd0/+C74UqGri/qcl/Ty/Bb6XtPzixPQjwP+kuO8S4Adm9l3gaXd/xczGAJvcfR6Au++B+N4D8WvejyX++g5P8XjnASck7TF1Jv7f5D1gK9CniZ9BMpCKQNqSAf/t7vd9YGF84JVzgNPc/YCZvUR8zGOAandvPCJbTeLfBpr+Ha5JmrYmtjmc/UnT9wA/dPeZiazfauI+EeBUd68+zOMe5B8/W6tx93fM7CTiFzL7TzP7C/DHJjb/MrAFOJF45lR5jfie1+wU6wqI/xzSTugYgbSl2cB1ZtYJwMzKzKwn8b82dyZKYARwakDP/xpwaeJYQS/iB3ubozP/uFb8tKTle4GipPnniI/SBUDiL+7GlgNDm3ieOcQvRwzxz+BfSUzPJf7RD0nrP8DM+gAH3P13wPeJj5G7Aig1s5MT2xQlDn53Jr6nEAOuIT5+bmOzgc+ZWW7ivsMTexIQ34M47LeLJLOoCKTNuPtzxD/a+LuZLQGeIP5G+iyQY2bLiY95OzegCH8gfhngZcDvgAXA7mbc71vA781sPrAtaflTwCWHDhYDtwLliYOry0gxCpa7v018KMeixuuIl8hnzOxN4m/QtyWWfwn4SmL50CYyjwHeMLNFwDeB/3T3WuBK4kOcLgaeJ/7X/M+AaYllI/jg3s8hDxB/nRYkvlJ6H//Y+zobeCbFfSRD6TLUklUOfWZv8XFr3wBOd/fNbZzhy8Bed3+gmdt3AA66u5vZVOIHjqcEGvLweV4mPjj8zrAySOvSMQLJNk+bWRfiB33/o61LIOHnwOUt2H488YO7Buwi/o2iUJhZCfHjJSqBdkR7BCIiWU7HCEREspyKQEQky6kIRESynIpARCTLqQhERLKcikBEJMv9fwj/lpLxJt0EAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train the Model\n", "\n", "We will use the `autofit` method that employs a triangular learning rate policy with EarlyStopping and ReduceLROnPlateau automatically enabled, since the epochs argument is omitted. We monitor `val_acc`, so weights from the epoch with the highest validation accuracy will be automatically loaded into our model when training completes.\n", "\n", "As shown in the cell below, our final validation accuracy is **92%** with only 7 seconds of training!" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.007...\n", "Train on 5324 samples, validate on 592 samples\n", "Epoch 1/1024\n", "5324/5324 [==============================] - 1s 219us/step - loss: 0.3265 - acc: 0.8924 - val_loss: 0.2218 - val_acc: 0.9139\n", "Epoch 2/1024\n", "5324/5324 [==============================] - 1s 208us/step - loss: 0.0274 - acc: 0.9951 - val_loss: 0.2047 - val_acc: 0.9155\n", "Epoch 3/1024\n", "5324/5324 [==============================] - 1s 204us/step - loss: 0.0166 - acc: 0.9968 - val_loss: 0.2060 - val_acc: 0.9155\n", "Epoch 4/1024\n", "5324/5324 [==============================] - 1s 206us/step - loss: 0.0137 - acc: 0.9968 - val_loss: 0.2062 - val_acc: 0.9206\n", "Epoch 5/1024\n", "5324/5324 [==============================] - 1s 213us/step - loss: 0.0120 - acc: 0.9970 - val_loss: 0.2078 - val_acc: 0.9189\n", "Epoch 6/1024\n", "5324/5324 [==============================] - 1s 204us/step - loss: 0.0111 - acc: 0.9970 - val_loss: 0.2082 - val_acc: 0.9206\n", "\n", "Epoch 00006: Reducing Max LR on Plateau: new max lr will be 0.0035 (if not early_stopping).\n", "Epoch 7/1024\n", "5324/5324 [==============================] - 1s 211us/step - loss: 0.0103 - acc: 0.9970 - val_loss: 0.2090 - val_acc: 0.9206\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(7e-3, monitor='val_acc')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " neg 0.91 0.94 0.92 310\n", " pos 0.93 0.89 0.91 282\n", "\n", " accuracy 0.92 592\n", " macro avg 0.92 0.91 0.92 592\n", "weighted avg 0.92 0.92 0.92 592\n", "\n" ] }, { "data": { "text/plain": [ "array([[290, 20],\n", " [ 30, 252]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspecting Misclassifications" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------\n", "id:294 | loss:5.13 | true:neg | pred:pos)\n", "\n", "酒店 环境 还 不错 , 装修 也 很 好 。 早餐 不怎么样 , 价格 偏高 。\n" ] } ], "source": [ "learner.view_top_losses(n=1, preproc=preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using Google Translate, the above roughly translates to:\n", "```\n", "The hotel environment is not bad, the decoration is also very good. Breakfast is not good, the price is high.\n", "```\n", "\n", "This is a mixed review, but is labeled only as negative. Our classifier is undertandably confused and predicts positive for this reivew." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making Predictions on New Data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "p = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for the text\n", "> \"*The view and service of this hotel were terrible and our room was dirty.*\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'neg'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(\"这家酒店的看法和服务都很糟糕,我们的房间很脏。\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for:\n", "> \"*I like the service of this hotel.*\"" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pos'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict('我喜欢这家酒店的服务')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }