{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a Chinese-Language Sentiment Analyzer\n", "\n", "In this notebook, we will build a Chinese-language text classification model in 4 simple steps. More specifically, we will build a model that classifies Chinese hotel reviews as either positive or negative.\n", "\n", "The dataset can be downloaded from Chengwei Zhang's GitHub repository [here](https://github.com/Tony607/Chinese_sentiment_analysis/tree/master/data/ChnSentiCorp_htl_ba_6000).\n", "\n", "(**Disclaimer:** I don't speak Chinese. Please forgive mistakes.) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess the Data\n", "\n", "First, we use the `texts_from_folder` function to load and preprocess the data. We assume that the data is in the following form:\n", "```\n", " ├── datadir\n", " │ ├── train\n", " │ │ ├── class0 # folder containing documents of class 0\n", " │ │ ├── class1 # folder containing documents of class 1\n", " │ │ ├── class2 # folder containing documents of class 2\n", " │ │ └── classN # folder containing documents of class N\n", "```\n", "We set `val_pct` as 0.1, which will automatically sample 10% of the data for validation. We specifiy `preprocess_mode='standard'` to employ normal text preprocessing. If you are using the BERT model (i.e., 'bert'), you should use `preprocess_mode='bert'`.\n", "\n", "**Notice that there is nothing speical or extra we need to do here for non-English text.** *ktrain* automatically detects the language and character encoding and prepares the data and configures the model appropriately.\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: GB18030\n", "Decoding with GB18030 failed 1st attempt - using GB18030 with skips\n", "skipped 104 lines (0.3%) due to character decoding errors\n", "skipped 14 lines (0.4%) due to character decoding errors\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Building prefix dict from the default dictionary ...\n", "WARNING: Logging before flag parsing goes to stderr.\n", "I1001 15:11:00.816586 140470484846400 __init__.py:111] Building prefix dict from the default dictionary ...\n", "Loading model from cache /tmp/jieba.cache\n", "I1001 15:11:00.818966 140470484846400 __init__.py:131] Loading model from cache /tmp/jieba.cache\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "language: zh-cn\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading model cost 0.641 seconds.\n", "I1001 15:11:01.459813 140470484846400 __init__.py:163] Loading model cost 0.641 seconds.\n", "Prefix dict has been built succesfully.\n", "I1001 15:11:01.461843 140470484846400 __init__.py:164] Prefix dict has been built succesfully.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 22066\n", "Nrows: 5324\n", "5324 train sequences\n", "Average train sequence length: 81\n", "x_train shape: (5324,75)\n", "y_train shape: (5324,2)\n", "592 test sequences\n", "Average test sequence length: 85\n", "x_test shape: (592,75)\n", "y_test shape: (592,2)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_folder('data/ChnSentiCorp_htl_ba_6000', \n", " maxlen=75, \n", " max_features=30000,\n", " preprocess_mode='standard',\n", " train_test_names=['train'],\n", " val_pct=0.1,\n", " classes=['pos', 'neg'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Create a Model and Wrap in Learner Object" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "maxlen is 75\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('fasttext', (x_train, y_train) , preproc=preproc)\n", "learner = ktrain.get_learner(model, \n", " train_data=(x_train, y_train), \n", " val_data=(x_test, y_test), \n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Estimate the LR\n", "We'll use the *ktrain* learning rate finder to find a good learning rate to use with *fasttext*. We select a high learning rate that is associated with a still falling loss from the plot.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "5324/5324 [==============================] - 2s 466us/step - loss: 0.9928 - acc: 0.5173\n", "Epoch 2/1024\n", "5324/5324 [==============================] - 2s 308us/step - loss: 1.0088 - acc: 0.5011\n", "Epoch 3/1024\n", "5324/5324 [==============================] - 2s 324us/step - loss: 0.9870 - acc: 0.5066\n", "Epoch 4/1024\n", "5324/5324 [==============================] - 2s 314us/step - loss: 0.9727 - acc: 0.5116\n", "Epoch 5/1024\n", "5324/5324 [==============================] - 2s 319us/step - loss: 0.8829 - acc: 0.5406\n", "Epoch 6/1024\n", "5324/5324 [==============================] - 2s 309us/step - loss: 0.6585 - acc: 0.6597\n", "Epoch 7/1024\n", "5324/5324 [==============================] - 2s 314us/step - loss: 0.5113 - acc: 0.7607\n", "Epoch 8/1024\n", "5324/5324 [==============================] - 2s 309us/step - loss: 0.4962 - acc: 0.7746\n", "Epoch 9/1024\n", "5324/5324 [==============================] - 2s 318us/step - loss: 0.6645 - acc: 0.5920\n", "Epoch 10/1024\n", "5324/5324 [==============================] - 2s 325us/step - loss: 0.7151 - acc: 0.4985\n", "Epoch 11/1024\n", "5324/5324 [==============================] - 2s 317us/step - loss: 0.8465 - acc: 0.5015\n", "Epoch 12/1024\n", " 416/5324 [=>............................] - ETA: 1s - loss: 2.3385 - acc: 0.5048\n", "\n", "done.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3xV9fnA8c+TTRLCSgibMGUjEJagogIiWLF14h4tWuus2h+u4tYOrbXaIiqlVSsOrKJQUByAIkLYe48ECBmMLDLv9/fHObnchCQkkHNH7vN+vfLyrHvuk0O8z/1uMcaglFIqeIX4OgCllFK+pYlAKaWCnCYCpZQKcpoIlFIqyGkiUEqpIKeJQCmlglyYrwOoq/j4eJOUlOTrMJRSKqCsXLkyyxiTUNW5gEsESUlJpKSk+DoMpZQKKCKyt7pzWjWklFJBThOBUkoFOU0ESikV5DQRKKVUkNNEoJRSQU4TgVJKBTlNBEopFQAWbjrE9kO5jtxbE4FSSgWAu95bxexV+x25tyYCpZQKACUuF+Gh4si9HUsEIjJDRDJEZEM1568XkXUisl5ElopIf6diUUqpQFbmMhgDYSHOfGQ7WSKYCYyr4fxu4HxjTF/gGWC6g7EopVTAKilzARDmUInAsbmGjDGLRSSphvNLPXaXAe2cikUppQJZqctaWz7gqobq6Hbgf74OQiml/FFpeYnAoaohn88+KiIXYCWCkTVcMxmYDNChQwcvRaaUUv6hpKwBlwhEpB/wFjDRGJNd3XXGmOnGmGRjTHJCQpXTaSulVINV6ipvIwi8xuIaiUgH4BPgRmPMNl/FoZRS/q7ULhGEhQRYY7GIvA+MAuJFJA2YCoQDGGOmAb8HWgB/FxGAUmNMslPxKKVUoMorKgUgNtKZj2wnew1NOsX5XwK/dOr9lVKqoThaUAJAk+hwR+7vL72GlFJKVePY8WIAmjaKcOT+mgiUUsrPHTtulQiaaolAKaWCU3a+XSLQRKCUUsFpb1YB8bGRREc406yriUAppfzc8ZIyYiNDHbu/JgKllPJzZcYQ4tAYAtBEoJRSfq+szDg2mAw0ESillN8rM4YQ0USglFJBy+UyhGqJQCmlglepS6uGlFIqqLm0sVgppYJbmZYIlFIquJW6tLFYKaWCmjYWK6VUkCvVRKCUUsHNZTQRKKVUUCtzGUIDsY1ARGaISIaIbKjmfA8R+VFEikTkIafiUEqpQFcWwFVDM4FxNZw/DNwL/NnBGJRSKuAFbCIwxizG+rCv7nyGMWYFUOJUDEop1RCUaRsBiMhkEUkRkZTMzExfh6OUUl4VsCWC+mSMmW6MSTbGJCckJPg6HKWU8qqAbSxWSilVP7REoJRSQc7pRODMSsiAiLwPjALiRSQNmAqEAxhjpolIKyAFiANcInI/0MsYk+NUTEopFYicHlDmWCIwxkw6xfl0oJ1T76+UUg2FTjGhlFJBrkxnH1VKqeDm0vUIlFIquGnVkFJKBTldqlIppYKcLl6vlFJBzBiDMWhjsVJKBasylwHQEoFSSgWrUjsRaBuBUkoFKZexEoH2GlJKqSClVUOqWruz8nntm+247D8SpVTDVJ4InGwsdmyuIeWMnMIS+j35pXv/vO4J9GvX1L1/8Nhxjh0voUerOF+Ep5SqZ+WJICAnnVPOGPLcwgr72fnFABwrKOGjlak8O3czAHtenOD12JRS9c8biUCrhrzMGENu4ekt03ysoITCEleFY28t2QXAU59vdCcBgMKSstMPUinlN8q0sdg5e7PzueOdFK9/YM5etZ++T37Jzsy8Or92d3Y+AE9P7M3aqWMB+GFHNvuPHiczr6jCtV9vzjjzYJVSPrc7y/r//sDR4469R9Amgqc/38SCjYf4fntWheNHC4r5KCUVl8uQeriAxdsy6/V9v9qUDsBFLy0ir6i01q/Lzivi8td/AKBP2yY0aRTOned3AWDEi9+QmVsxEUyds7HW9569Mo2UPYdrfb1Synt2ZlhfGps0CnfsPRxLBCIyQ0QyRGRDNedFRF4VkR0isk5EBjoVS2XbD+VS3gBfXFaxqmX8X5fw8MfrWLY7mzF/WcRNM5af9PqbZywnacpcDuUU1ur9jDnRs6dReKh7e+XeI7WOeU3qUff22Xbj8L0XdXUf25Key6X9WvPqpAEA5FeTZHZm5jHulcVsP5QLQOrhAh78aC1XTvuRH3Zk8e3WDHZk5NY6LqWUsxIaRwIwomu8Y+/hZIlgJjCuhvOXAN3sn8nAPxyMxW3boVzG/GUxC+2qk5JKiSDD/mb9ysLt7vr4f/6wm/12sSwjt5BFdinhtx+uYeGmQ0z9bEO13TiNMYx+eRGvLNxGcamLT9cccJ/bcjDHHcOEV5fw9eZDJ71+6c4sjuQX88AHawB45vI+7hGG0RFhzP71Oe5rf3VuZy7r34Ybh3UkPLTq+sSLXlrElvRcXvt2B4D7vgDXv/UTt/5zBaNfXkxxqfW7p+w5zPwN6VXeSynlvFIvjCNwcqnKxSKSVMMlE4F/G+vr8jIRaSoirY0xB52KCaxvzp7um7WGohIXv5u9jtZNotwPffnuE1UlT32+iae/2ESLmEiaRZ8onv2wI5sfdmQDcMf5XWjTtNFJ75eVV8zOzHxeWbidIZ2aVzj3wv+20LtNE3Zn57PxQA63/yuFm4d35JrBHejVJo4j+cVc9+ZPFV5zw9AOFfYHdWzm3u7aMhaAji2iySks5bM1+xnftzXhoVa+X+RRzbXpQA7dHptHSVnVCezlr7YRESq8+o2VMLY+O47IsNAqr1VKOaeh9xpqC6R67KfZx04iIpNFJEVEUjIzz6zOvqiKxuHfzV4HwMFj1Vf1GANZeUVsz6i6kfc5jx47nh75ZL17u/xD/e/Xn6gFu+Htn3ji0xO1Z//6cS8T/rYEYwxHCoor3OvxCT2RKgaV/HZMd0Z2jScm0srr7ZtHA1aS6/nEfPd1N9vVXDERoWzPyKsyCZyV2BiAaYt2upMAwNKd2VX+fkopZ5WWlZcInPu4DojGYmPMdGNMsjEmOSEh4YzuVd03YE9dEmJOec2DY7pX2J+7/iAHjx1n9so0dwZfsecwC6uo7unQPJqIsOofvTHw5aZD5BZWrOf/xcB2VV5/70XdePeXQ6uMv9RlyCsq5ZstJ+KYOKBivu3aMpbpNw7if/edy//uO7fK97j1nysAqwtr0pS5PDd3U4XzpWUunp+3mY9XplX7eyml6s5dIqimurc++DIR7Afae+y3s485qri06u6id19gNbxOu2EQyR1PVOEs+d0FfHTncNo0iXIfW/HYaCaf35nrK1XTDH/hGx78aC1Dn7cGff37x71Vvle7Zo1Y/+RY5t9/4kN30pAObH12HD89ehFxUWH85r1VZOdX7AnUPCaiVr9j15aNubRfa/d+n6kLuG1mCgAf3zmc1nHW7zK6ZyL/vescPv3NCMb2bkXP1nEnzXDY1qO6q8xl+GK91cbx5pLd3DdrNZsO5FBa5uK6N39i+uJdPPTRWkortbvUVnGpizEvL2L+BkdrB5UKKAHdRlALc4C7RWQWMBQ45nT7AFRfIrhhWEceuvgsAC7q2ZJ9hwu4b3Q32jePpn3zaJY+chH7sgtIPVLgbsV/7ud9ee7nfflkVRq//XCt+15ZecUUlZaRdqQAgLPbN8UYw1XJ7blmcHt3nb3nNBBXJ7cjMiyUxLhQwkNDKHWV8ugnVpXRvHvPpVtibJ1+z9euG8j9o3MZ/fLiCseTk5rTtWUssVFh3DisI2GhJ38X+NW5nXhzyW4AereJIzEuklX7jrI3O5+jBScGw3225gCfrTnAo+N7sNyj+2nXx/53WiOb044UsD0jjzvfXaUjo5WylbmsL1YBOcWEiLwPjALiRSQNmAqEAxhjpgHzgPHADqAAuNWpWDyVdxc9p0sLCorL2Jqey7DOzWnl8Y0/PDSE9ycPO+m1HVpE06FF9EnHz+t+cnXVNW8sc3f5nHHL4Gq/zb923QAWbDzEgA4nGn0T46LIzi8m3e6e2iQ63J086qJry8ZseWYcS7ZnkbL3MA+MtqqzmkZHcOuITtW+7rEJvXhsQi/mrjvIiK4tSD18nJ+99j2X/HUJneJjiI+NICvvRPvF8/O21Dm2qqR7dMdN2XOYqPBQereJo8xluHXmCpJaxPDM5X3q5b2UChQBXSIwxkw6xXkD/Map969OebfId28fWm8LPcTHRtK1ZSw7MvIY3TORhZsPVej3X1OVzqX92nBpvzYVjt02shMPfXSihNE8unZVQlWJCg9lTK9ExvRKrPNrJ9jVS1H22IeiUhdb0nNJ7tiMz+4eycq9R7j3/dUVXlP+HOZvSGdcn1Z1er/3l5/oO3DltB9POr9kexZPXtbb0W9GSvmbht5ryCeOl5QRGRZS76v9jLQHe9w0vGOF4+ueHFvne105qB1NPbqpNorwbbfNqPCK798tMZa2TRtxad/WFY73aNWY+0d3A+Dhj9dSW8WlLlbvO8Lnaw+c8touj86r9X2VaghOlAiCvNdQfco5XuLIUO1Hxvfg5av7M7JrPB9MHsZl/duwdupY4qJO771m3DKY0BBh5q2D6znS01M+YhmsMQhgLZ1370XduLh3ItuevYQ5d4/k0n5tiI+NPKnHk6fKo55vePsnfv73pe79PS9O4LXrBlS4xrPRurztRalgoNNQOyCnsIQ4BxJBZFiou3vn0M4tGNq5xRndb2CHZux8fnx9hFYvLuvfhj5t4rjwpUUkJ53oVfXbSt1oAS7p04p3lu3lhXmbiQgL4YqB7UjZe4Q/zN/CP28ZzKV/+55/XD+QS/q2prjUVWHw3od3DAcqVpnN33CQnq3j+GzNAV7+ahtvLdnNQxefRWxk0P35qiB0YhyBJoJ6k3O8lLiooPu160XnhFi+uGck3e1BZ9W5fEBb3lm2lzcWW1Nk/81jYNqk6csAmLchnUv6tuajlakVXjuwQ1MqG9fHqoK6fmgHXv5qGzOX7uGDFalsfqamGUyUahjKXC5EdPH6epVT6EzVULDo07ZJjYPhwJr2YkTXqktEuXa1UPn032/ayWJ831Ysf+yiKruzlvNsdD9eUsa/lu45afZYpRqaUpdxtDQAQZgI1qUdIzpCSwROe/f2oTWez7C7iu7Jtur7X79uIC0bR9X0EkSEN29Kdu9PnbORG97+iQ37j51htEr5rzKXcbynXFAlgmW7rPly5q7XkatOExH+cf1AnpnYmxWPjQagcVQY/7x1MBP6tmZt2jF2ZOTSs3Uco3u2rHIOpaqM6ZXIWx7JAODSv33PPZW6sSrVUFglAmc/qoMmESzblc21dv208o5L+rbmxuFJJDSOZNoNA/n6t+dzwVktObu91Q4w+uXF5J5G473n4L8ke4Df52sPcLTSJH1KNQRaIqhHR/JPfEhMu8Fra+Ao27g+rWlpz3F0+8gTo5r3Hz1e5y627ZpZXUkjwkL47uELGGL3Yhr15+/qJ1il/Eipy6VtBPXFs4FzcFLzGq5UTgsJEW6zp7gwhjp3A20aHcH9o7vxnj3j6qSh1tyFnvMgKdVQaImgHnnO1aO9hnzvPnsEMlCntZvL3T+6uzuh/3xAO87rnkB87OlPxaGUv3p/eap75USnBE0i8CwR1NRFUXlHk0bhfHKXtcxmfazFOrRTc7Lyinnnxz1nfC+lgk3QfCLqRGX+Z2CHZqx+YsxpTYhX2bDOVungic82nvG9lPIXRfb6KZef3eYUV56ZoEkElee3Uf6hWS0X2zmVQR2bExoidLR7Ee3KzKPkNBfIUcpfjPrTdwB8u/XMlug9laBJBOXr+dbHt0/ln+4a1YW92QXcPnMFF760iG6P/Y9dmVWvMa1UIKhpHfX6FDRDbAcnNeeftwxmZLczr49W/qlXa2vFt6+3ZLiPXfjSIgCGd27Bv28fcloL/Cjla7N/fY6j93f0/woRGSciW0Vkh4hMqeJ8RxH5WkTWich3IlL16uz15IIeLfWDoAEb16cV11VaR7rcj7uy2XYo18sRKXVmWsREcP3QDnRtWbelauvKsU9FEQkFXgcuAXoBk0SkV6XL/gz82xjTD3gaeMGpeFTDJyI8d3kfEhpH0rZpI/5yTf8K58vXUVAqUOQXl3plunUnvx4PAXYYY3YZY4qBWcDEStf0Ar6xt7+t4rxSdSIirHhsND9MuZDxfVvTOT7GPbX1ws2HfBydUrVXWuaisMTllUkynUwEbQHPyebT7GOe1gK/sLd/DjQWkTNb0UUpW2RYKN88NIpP7hrBRT1asjdbVzZTgaPAnqo9JtL5pWp9XWH+EHC+iKwGzgf2A2WVLxKRySKSIiIpmZnOdqNSDVP75tFsSc/l2PESvlh3QLuWKr9X3uU9xgtVQ06+w36gvcd+O/uYmzHmAHaJQERigSuMMUcr38gYMx2YDpCcnGycClg1XOX/U/V/6kv3sT0vTvBVOEqdUn5ReYkgsKuGVgDdRKSTiEQA1wJzPC8QkXgRKY/hEWCGg/GoIPZAFWsr787K90EkStWOu0QQEcBVQ8aYUuBuYAGwGfjQGLNRRJ4Wkcvsy0YBW0VkG5AIPOdUPCq4tWna6KTeF2tTTyp8KuU38ou9VzXkaBuBMWaeMaa7MaaLMeY5+9jvjTFz7O2PjTHd7Gt+aYxxdoo9FdSeuLQnAP+xp69OO6KNx8p/uauGvNBrKGhGFit1zeAOXDPYGnAWHxvBvsNWItiRkUepy0WPVnG+DE+pCgrcJQLnq4Y0Eaig1K9dUz5MSePDlDT3MW08Vv4kz4u9hnzdfVQpn7j5nCRfh6BUjQoaSK8hpfzWed3i+dW5nTgrsbH72Pfbs3wYkVIVlTcWR4cHcK8hpfyZiPDYhF4seOA8lvzuAgCW7tREoPzHDzusv8cQLyyqpYlABb32zaPp1jKW7Rm6doHyHyv2HPHae9UqEYjIfSISJ5a3RWSViIx1OjilvCUpPoavNh1ixve7fR2KUl5X2xLBbcaYHGAs0Ay4EXjRsaiU8rKExpEAPP3FJlIP6/gC5VuHcryzMlm52iaC8kqq8cA7xpiNHseUCni/Pr+Le3ujrlugfGzD/mMA/HJkJ6+8X20TwUoR+RIrESwQkcaATt+oGoz2zaPZ8sw4QgRWp3qvblYpT39esJVXFm7jnvdXA3BWq8aneEX9qG0H1duBs4FdxpgCEWkO3OpcWEp5X1R4KH3bNeWNRbuY2L8t3RJjdWlT5VWvfbsDwD0v1vi+rb3yvrX9Kx8ObDXGHBWRG4DHgWPOhaWUb1w/xJqCYvyrS7hy2o8+jkY1dEWlZUyZvY6lO7MqrJGRV1RK54QYrwwmg9ongn8ABSLSH3gQ2An827GolPKRqwefWEJjbepRNh/U9gLljNe/3cFZj89n1opU7n1/DffNWl3hfKh4rxm2tomg1BhjsNYUfs0Y8zrgncorpbzsg8nD3EXz1ft0qmrljGmLdrq3s/KKmLc+HYDJ53UG8Oq4ltomglwReQSr2+hcezGZcOfCUsp3hnZuwdqp1jCZR/+7nuJS7Reh6l+XhFgSGkey/smxhHmMHn50fE+eubwPix4e5bVYapsIrgGKsMYTpGMtO/knx6JSysdCQ4TbRlhd99alaalA1b8dGXk0iw6ncVQ4pS5rBd6LeycCcOOwjnRsEeO1WGqVCOwP//eAJiJyKVBojNE2AtWg3X1hVwBufHs5Vs2oUvWnuNTFOV3iAfjinpEAPHt5X5/EUtspJq4GlgNXAVcDP4nIlbV43TgR2SoiO0RkShXnO4jItyKyWkTWicj4uv4CSjmleUwEAMdLyvjIY90Cpc5UUWkZxWUu94j2Pm2bsOfFCe59b6tt1dBjwGBjzM3GmJuAIcATNb1AREKB14FLgF7AJBHpVemyx7HWMh6Atbj93+sSvFJOm3P3CAB+N3sdZS4tFaj6cTi/GIAmjfyjqbW2iSDEGJPhsZ9di9cOAXYYY3YZY4qBWVi9jjwZoHx9wCbAgVrGo5RX9GvXlOuHWmMLdI1jVV92Z+UDkOTFdoCa1DYRzBeRBSJyi4jcAswF5p3iNW2BVI/9NPuYpyeBG0Qkzb7fPVXdSEQmi0iKiKRkZmbWMmSl6scEe3Tn/iPHfRyJaijKJzbs2CLax5FYattY/DAwHehn/0w3xvxfPbz/JGCmMaYd9oR2dtfUyu8/3RiTbIxJTkhIqIe3Var22je3/mddv18H06v6kWcvQxkX5R9VQ7Uev2yMmQ3MrsO99wPtPfbb2cc83Q6Ms+//o4hEAfFABkr5iXbNGtE8JoIX52/hDo9ZSpU6XYUlViKIDPePuaxqjEJEckUkp4qfXBE51dj7FUA3EekkIhFYjcFzKl2zD7jIfq+eQBSgdT/Kr4g91N8YbSdQ9aOopAwRiAwLgERgjGlsjImr4qexMSbuFK8tBe4GFgCbsXoHbRSRp0XkMvuyB4Fficha4H3gFqMdtpUfevEXVv/ud5ft83EkqiH4x6KdGHPiS4avOTq1nTFmHpUalY0xv/fY3gSMcDIGperDkE7NAWt+mCmX9PBxNCrQlZT51/dd/yiXKOXnmkZH8OCY7gBs0hXM1Bnq1jKWS/q08nUYbpoIlKqlG4Z1JETgqmlL2Zqe6+twVADLLSylcZR31hqoDU0EStVSs5gIWjaOIr+4jItfWeweFKTUrOX7+N3Ha9lTy7+J3MISGvtJ11HQRKBUnfztugHu7Qv+/B07MrRkEOxmLd/HlE/W82FKGqP+/B0/7Miq8fr8olLyi8vwp24xmgiUqoPBSc154tITU2Z9s0WHvAS7KZ+sB2BAh6YA3Pj2T+5xAlVZujMbgKPHi50PrpY0EShVR7eP7MQzl/cB4Pl5W3wcjfIHZ7dvyn/vGsEr15yNy8DUzzZWe22Zy1ro6NZzOnkrvFPSRKDUabhxWEdaN4kC0FlJg9iBo9b8U+d1s9YVuHxAW34xoC0fpKRWu6BRzvFSAJrFaBuBUgHvobFnATD8ha99HInylc/WWBMmT+jXxn1s6mW9SYyLZPK/V7J89+GTXpNTWAJAnJ9MQQ2aCJQ6bf3bNwEgI7fIx5EoX5m5dDctG0dyVqvG7mNNGoUz7YZBpOcUcvUbPzLzh918uzWDtalWCSFlzxEAYiP8p/uo/0SiVIDpkhDr3s7KKyI+1jerSynfWJd2lEM5RVw1qN1J5wZ0aMYX94zkste+58nPN7mPP3VZb37YafUqCgnxj+klQEsESp02EeGDycMAWJ+mU1QHk9IyF5e99gMAVyW3r/KaPm2b8PWDoyocmzpnI7mFpZzdvqnTIdaJJgKlzkDvtk0QgXWaCILK+8tPTD44OKlZtdd1io/hobHd6Rwfw9qpY93HD+UUOhpfXUmgTfaZnJxsUlJSfB2GUm5JU+YCsPXZcUSGhfo4GuWk0jIXYaEh7n/zaTcMYlwd5gzKKSzhv6v2M6xziwrtCt4gIiuNMclVndM2AqXqyXdbM7m4t/9MJKbqT05hCWNfXkxeUSlJ8SeWl7y4d2Kd7hMXFc7N5yTVc3RnTquGlDpDP0y5EIA73llJpvYgapC+2niI9JxC8opK2bDfmn129q+H+816AmdKE4FSZ6ht00bu7cHPLWRvtk5GF8hSDxeQkVNIcak1Anhrei4PfrQWgDl3j2BCv9Z89cB5DOrY3Jdh1itHq4ZEZBzwVyAUeMsY82Kl838BLrB3o4GWxhj/ak5XqhY+mDyMa6YvA+Cu91Yx995zfRyRqqtth3IZ+5fFFY41iw7nSIE1AKxNkyj6tWvK69cN9EV4jnIsEYhIKPA6MAZIA1aIyBx7VTIAjDEPeFx/DzDgpBspFQCGdm7hTgYbD+RgjGkw1QaBxBjDbTNX8O3WTJpFh3P14PZ0TYglJjKMTQdy6N6qMd0TY0nZc4Sfdh/m2y0Z5BWV0rdtE9bvt3p+JXdsxgU9WrJ4WyYlZS7KMvJ4/NJeXH52Wx//ds5xskQwBNhhjNkFICKzgInApmqunwRMdTAepRw1tHMLxvdtxbz16RzKKaKVPReR8o7th3K54h9LySm05vI5UlDCjO9312pZyO32dOKPT+jJzeckER4awm8u6OpovP7EyUTQFkj12E8DhlZ1oYh0BDoB31RzfjIwGaBDhw71G6VS9eiGYR2Ztz6dYS98za7nx/vV6NGG7FBOIWPsap3EuEi+fnAUsZFhHC8uY8Wew6zff4wuCTGEh4ZwKKeIuEZhtGnaiL3Z+Uzs3zbo/538pfvotcDHxpgqJ/E2xkwHpoM1jsCbgSlVFwM7nBhcdN1by5g1ebgPowkOhSVlDH3emvjvthGdeGR8D8JDrX4wjSJCOa97Aud1T6jytZ7/XsHMyV5D+wHPsdft7GNVuRZ438FYlPKKqPBQVj8xBoBlu06eeVLVj+JSFyv2HOadZXvp8cR8wBrF+/DFZ7mTgKo9J0sEK4BuItIJKwFcC1xX+SIR6QE0A350MBalvKZZTAQPjunOS19tI6ewhDg/Wps2kBljWLI9iw9SUvl+exbHjpe4zzWPieCrB84jTJPAaXEsERhjSkXkbmABVvfRGcaYjSLyNJBijJljX3otMMsE2lwXStWgV5s4ADYdyGFY5xY+jibw7czM41f/SmFXVj7xsRGM6ZVIv3ZNyM4rZkK/1nRP9O50DQ2NzjWklAOy8opIfnYhAFueGUdUuM5BdDpKyly8snAb0xfvolF4KE9c2otL+7WhUYQ+z7qqaa4hLUcp5YD42EiuG2r1cHv5q20+jiYwFZaUcdPby3n92538rH8b5t13Llclt9ck4ABNBEo55PEJPWkaHc70xbs4WlDs63ACSpnLcPOM5fy4K5s/XNGXl68+m3bNok/9QnVaNBEo5ZDoiDBevdYaLP/s3M0+jiaw7D9ynJ92H+aXIztxzWAdO+Q0TQRKOWh4F6uh+OOVaRSWVDlMRlUhr8gaHZxcw6Ivqv5oIlDKQeGhITw+oScAX6w76ONoAoMxhvGvLgEgNlK73nqDJgKlHFa+EMlD9lTGqmY7MvLc290SY30YSfDQRKCUwzxHurpcgdVd2xcWb88C4NVJA0iM04n7vEETgVJeMPVnvQD4w/wtPo7E/327JYNO8TH8rF9rX4cSNDQRKGWOXq0AABJ/SURBVOUFE+257N9YvItjBSWnuDq4rUk9yuCkZrqegxdpIlDKC5rHRHBNsjUH48p9OhlddcpchryiUlo3aXTqi1W90USglJdMuaQHALfNTCHQpnbxlvd+2gtA4yh/mSE/OGgiUMpLmsVEuLe/25rpw0j8U3Gpi99/thGAAbpOgFdpIlDKixY9PAqA77Zm+DYQP/T15kMAPHN5HwZ11ETgTZoIlPKiji1iGNihKVvSc30dit9Zk3qU8FDh6uR2vg4l6GgiUMrLzmoVx5b0XG0nqOTLTYcY2qkFkWE6u6i3aSJQysu6JMRw7HgJR7QbaQXpxwrp2VoXmPEFRxOBiIwTka0iskNEplRzzdUisklENorIf5yMRyl/0Ck+BoCtWj3kVuYyHC8pIzpCewv5gmOJQERCgdeBS4BewCQR6VXpmm7AI8AIY0xv4H6n4lHKX3RraX3rveu9lT6OxH8czrfWa4iN1ETgC06WCIYAO4wxu4wxxcAsYGKla34FvG6MOQJgjNGuFKrBa9/cGix1pKBE2wls5au49WvXxMeRBCcnE0FbINVjP80+5qk70F1EfhCRZSIyzsF4lPILIsIfrugLwIb9OT6Oxvd2Z+Xz/vJ99GjVmKGdW/g6nKDk68biMKAbMAqYBLwpIk0rXyQik0UkRURSMjN1II4KfBf1TCQmIpS3vt/l61B8bk92PgD3XtTNx5EELycTwX6gvcd+O/uYpzRgjjGmxBizG9iGlRgqMMZMN8YkG2OSExISHAtYKW+Jj40ksUkUn605wIo9wT33UM5xq/dUd117wGecTAQrgG4i0klEIoBrgTmVrvkUqzSAiMRjVRXpVyQVFB65xFq5bOmObIwxLN6WSbfH5vHp6srflxq2LzcdQgSaRUec+mLlCMea6I0xpSJyN7AACAVmGGM2isjTQIoxZo59bqyIbALKgIeNMdlOxaSUPxnTK5EWMRHsO1xA76kLKCi21jS+/4M1dGgRzcAgmW9n88EchiQ1p0VspK9DCVqO9tUyxswD5lU69nuPbQP81v5RKujENQpn9qq0k47/4u9L2fPiBB9E5F1lLkPa4eOM6ZXo61CCmq8bi5UKapf0aVVhf92TY2kUbk2xcPDYcV+E5FXpOYUUl7no2DzG16EENU0ESvnQ78b14E9X9qNv2yY8c3kf4qLCmXHLYAA+X3vAx9E5b6/dY6hji2gfRxLcdBifUj52VXJ7rko+0cFuWOfmADw/bwtb0/N46er+vgrNcZ+sshrGOzTXROBLWiJQys+ICI9PsHoUVdV+0FAcOHqcj1dav1+bpro0pS9pIlDKD/3y3M7u7ZzChjlLaflAsgl9WxMaogvV+5ImAqX81D9vtdoKNjawaSg2H8whacpcrnvzJwDuOL/zKV6hnKaJQCk/1betNQHblE/W+TiS+vWH+Vsq7EdH6EI0vqaJQCk/FW8PsNqbXcCG/cd8HE392Judz3dbK84XpmsQ+J4mAqX82Js3JQMw44fdLN2ZFfDTVn+92ZppfnTPRLY+O46P7hyuDcV+QBOBUn7swh4tAaub5XVv/sQ3WzIoKC71cVRVKywpI/VwQY3XfLkpnSaNwnnzpkFEhoUyOKm5l6JTNdFEoJQfCw0RHhzT3b3/9ve76fX7BSze5h/TsecUlvDyV9vYnZXPs3M3ce4fv6X37+czafoyikrLKly78cAxlu06TNumjRDRXkL+RBOBUn5u4tlt3dNOLN1pzcn46Rr/mKH0k5VpvPr1di7483cs22VNp51fXMaPu7L5bPWJkdFH8ouZ8Or3ANx1QRefxKqqp4lAKT/XoUU0m58Zx03DO7qPrUk9CsCU2euYvyGdwpIyth/K9XpsL9lLTALsyMgDoHxIwD8W7cQYg8tlmPTmMvd143pXnF9J+Z421ysVIMq7kwLsysxnxve7mbUilVkrUjm/ewKLtmWy4amLvbYA/N3/WUVuYSlNGoVzzF5cZtKQ9rzwi36c98dv2Z2VT6dHKkw+zB3ndyYsVL9/+hv9F1EqQFyV3J7pNw7is9+MAODpLza5zy2y2wwe/WS9V2I5VlDCF+sOAjDz1sH8dkx3rh3cnud/bq3FfO2Q9ie95u4LuroX41H+RUsESgWQsXa1SnkJoLI5aw+QeqSA1fuOMvvX5zCoozOL22xOt0Y7T7thEAM6NGNApUV0rhzYjj/O3wpYC/B8tekQ943WNYn9lSYCpQLQmzclk51fRLPoCHo8MR+A/xvXgz/M38LqfVb7wbvL9jKoYzPeWbaX/KJS7jy//hppV+49AsDADk2rPN8yLsq9sI7LZShxuQjXKiG/5WgiEJFxwF+xlqp8yxjzYqXztwB/4sSi9q8ZY95yMialGoKIsBBaN7EGYi1/9CJSjxxnYIem7MnK54edWTSLjuC/q/cjwCf2Gsh3nNe5XrptFpaU8acF1rf9lnFRp7w+JESIDNFpJPyZYylaREKB14FLgF7AJBHpVcWlHxhjzrZ/NAkoVUct46IY1LEZIsIfruzH9/93IclJVlVNeRIAmL8h/Yzfq8xlGPOXRWd8H+VfnCyrDQF2GGN2GWOKgVnARAffTyllq6oa6NfvraK0zHVG952+eBeph60lNN+9fegZ3Uv5DycTQVsg1WM/zT5W2RUisk5EPhaRk7saACIyWURSRCQlM9M/RlQq5c8S46L46oHzeOEXfSvM7nnwWOEZ3Xf1Pqtt4LPfjGBkt/gzupfyH75uvfkcSDLG9AO+Av5V1UXGmOnGmGRjTHJCQoJXA1QqUHVLbMykIR3Y+NTFTLthIAA7M/NO+36phwv4ctMhLu6dSP/2VTcSq8DkZCLYD3h+w2/HiUZhAIwx2caYInv3LWCQg/EoFZREhHO7JdAoPJSFmw+d1j0O5RRy7h+/BeC6oR1PcbUKNE4mghVANxHpJCIRwLXAHM8LRKS1x+5lwGYH41EqaMVEhnF+9wQWbso4rams31++z719fnctlTc0jiUCY0wpcDewAOsD/kNjzEYReVpELrMvu1dENorIWuBe4Ban4lEq2F3UsyXpOYVsPFD7pS9X7DlM0pS57lHEc+4e4VR4yoccHUdgjJkHzKt07Pce248AjzgZg1LKUr62waJtmfTxmLeoJldN+xGwJpQb3bMl/dpp20BD5OvGYqWUl7SIjaRji2g2Hjj1spd5RaUkTZlb4djibVlOhaZ8TBOBUkGkV+u4WlUNzfJoExhsD05740bty9FQ6VxDSgWR3m3i+N+GdHIKS4iLCq/2uq3p1toG53VP4OWr+xMfG+mtEJUPaCJQKoj0bmO1DWw5mMuQTievF3zOC19z4FghIWIljX/fNsTbISof0ESgVBDp3SYOsNYPLk8Ee7LyWbXvCIfzizlgjzx2GejYItpncSrv0kSgVBBJaBxJfGwEK/Yc5opB7Zjx/W5eWbi9wjURYSEUl7q4eXiSb4JUXqeJQKkgIiKEhYQwb30689afPBvpmF6J/OnKfizZnlVl1ZFqmLTXkFJBxnDyyOJXrjkbsMYaNI2O4Gf929TL2gUqMGiJQKkgM/fec0l+dqF7/5sHz6dzQiwju8XTIibCh5EpX9ESgVJBJj42kg1PXUznhBjuGtWFzgmx7uNaCghOWiJQKgjFRobxzYOjfB2G8hNaIlBKqSCniUAppYKcJgKllApymgiUUirIaSJQSqkgp4lAKaWCnCYCpZQKcpoIlFIqyIkxJ8874s9EJBPY6+s46lE8oGsA1o0+s7rTZ1Y3DfF5dTTGJFR1IuASQUMjIinGmGRfxxFI9JnVnT6zugm256VVQ0opFeQ0ESilVJDTROB7030dQADSZ1Z3+szqJqiel7YRKKVUkNMSgVJKBTlNBEopFeQ0ESilVJDTRODHRCRERJ4Tkb+JyM2+jicQiEiMiKSIyKW+jiUQiMjlIvKmiHwgImN9HY+/sv+u/mU/q+t9HU9900TgEBGZISIZIrKh0vFxIrJVRHaIyJRT3GYi0A4oAdKcitUf1NPzAvg/4ENnovQv9fHMjDGfGmN+BdwJXONkvP6mjs/vF8DH9rO6zOvBOkx7DTlERM4D8oB/G2P62MdCgW3AGKwP9hXAJCAUeKHSLW6zf44YY94QkY+NMVd6K35vq6fn1R9oAUQBWcaYL7wTvW/UxzMzxmTYr3sJeM8Ys8pL4ftcHZ/fROB/xpg1IvIfY8x1PgrbEbp4vUOMMYtFJKnS4SHADmPMLgARmQVMNMa8AJxUlSEiaUCxvVvmXLS+V0/PaxQQA/QCjovIPGOMy8m4famenpkAL2J9yAVNEoC6PT+spNAOWEMDrEnRROBdbYFUj/00YGgN138C/E1EzgUWOxmYn6rT8zLGPAYgIrdglQgabBKoQV3/xu4BRgNNRKSrMWaak8EFgOqe36vAayIyAfjcF4E5SROBHzPGFAC3+zqOQGOMmenrGAKFMeZVrA85VQNjTD5wq6/jcEqDK+L4uf1Ae4/9dvYxVTV9XnWnz+zMBOXz00TgXSuAbiLSSUQigGuBOT6OyZ/p86o7fWZnJiifnyYCh4jI+8CPwFkikiYitxtjSoG7gQXAZuBDY8xGX8bpL/R51Z0+szOjz+8E7T6qlFJBTksESikV5DQRKKVUkNNEoJRSQU4TgVJKBTlNBEopFeQ0ESilVJDTRKAcJyJ5XniPy2o5TXV9vucoETnnNF43QETetrdvEZHX6j+6uhORpMpTMldxTYKIzPdWTMo7NBGogGFPEVwlY8wcY8yLDrxnTfNxjQLqnAiARwnQ+X2MMZnAQREZ4etYVP3RRKC8SkQeFpEVIrJORJ7yOP6piKwUkY0iMtnjeJ6IvCQia4HhIrJHRJ4SkVUisl5EetjXub9Zi8hMEXlVRJaKyC4RudI+HiIifxeRLSLylYjMKz9XKcbvROQVEUkB7hORn4nITyKyWkQWikiiPX3xncADIrJGRM61vy3Ptn+/FVV9WIpIY6CfMWZtFeeSROQb+9l8LSId7ONdRGSZ/fs+W1UJS6wVtOaKyFoR2SAi19jHB9vPYa2ILBeRxvb7LLGf4aqqSjUiEioif/L4t7rD4/SnQINbpSuoGWP0R38c/QHy7P+OBaYDgvUl5AvgPPtcc/u/jYANQAt73wBXe9xrD3CPvX0X8Ja9fQvwmr09E/jIfo9eWPPLA1wJzLOPtwKOAFdWEe93wN899ptxYhT+L4GX7O0ngYc8rvsPMNLe7gBsruLeFwCzPfY94/4cuNnevg341N7+Aphkb99Z/jwr3fcK4E2P/SZABLALGGwfi8OacTgaiLKPdQNS7O0kYIO9PRl43N6OBFKATvZ+W2C9r/+u9Kf+fnQaauVNY+2f1fZ+LNYH0WLgXhH5uX28vX08G2tBntmV7vOJ/d+VWEsIVuVTY61HsElEEu1jI4GP7OPpIvJtDbF+4LHdDvhARFpjfbjuruY1o4Fe1lovAMSJSKwxxvMbfGsgs5rXD/f4fd4B/uhx/HJ7+z/An6t47XrgJRH5A/CFMWaJiPQFDhpjVgAYY3LAKj1gza1/Ntbz7V7F/cYC/TxKTE2w/k12AxlAm2p+BxWANBEobxLgBWPMGxUOWiuLjQaGG2MKROQ7rOUmAQqNMZVXZyuy/1tG9X/DRR7bUs01Ncn32P4b8LIxZo4d65PVvCYEGGaMKazhvsc58bvVG2PMNhEZCIwHnhWRr4H/VnP5A8AhrKU9Q4Cq4hWskteCKs5FYf0eqoHQNgLlTQuA20QkFkBE2opIS6xvm0fsJNADGObQ+/8AXGG3FSRiNfbWRhNOzEl/s8fxXKCxx/6XWCt+AWB/465sM9C1mvdZijXtMVh18Evs7WVYVT94nK9ARNoABcaYd4E/AQOBrUBrERlsX9PYbvxuglVScAE3Yq1nXNkC4NciEm6/trtdkgCrBFFj7yIVWDQRKK8xxnyJVbXxo4isBz7G+iCdD4SJyGas9XOXORTCbKylBzcB7wKrgGO1eN2TwEcishLI8jj+OfDz8sZi4F4g2W5c3YRVn1+BMWYL1rKQjSufw0oit4rIOqwP6Pvs4/cDv7WPd60m5r7AchFZA0wFnjXGFAPXYC13uhb4Cuvb/N+Bm+1jPahY+in3FtZzWmV3KX2DE6WvC4C5VbxGBSidhloFlfI6exFpASwHRhhj0r0cwwNArjHmrVpeHw0cN8YYEbkWq+F4oqNB1hzPYmCiMeaIr2JQ9UvbCFSw+UJEmmI1+j7j7SRg+wdwVR2uH4TVuCvAUaweRT4hIglY7SWaBBoQLREopVSQ0zYCpZQKcpoIlFIqyGkiUEqpIKeJQCmlgpwmAqWUCnKaCJRSKsj9P6CwEND8XR6bAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train the Model\n", "\n", "We will use the `fit_onecycle` method that employs a [1cycle learning rate policy](https://arxiv.org/pdf/1803.09820.pdf) for 10 epochs (i.e., roughly 20 seconds)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.005...\n", "Train on 5324 samples, validate on 592 samples\n", "Epoch 1/10\n", "5324/5324 [==============================] - 2s 356us/step - loss: 0.7315 - acc: 0.6409 - val_loss: 0.4885 - val_acc: 0.7669\n", "Epoch 2/10\n", "5324/5324 [==============================] - 2s 352us/step - loss: 0.4666 - acc: 0.7855 - val_loss: 0.3647 - val_acc: 0.8530\n", "Epoch 3/10\n", "5324/5324 [==============================] - 2s 353us/step - loss: 0.3553 - acc: 0.8492 - val_loss: 0.3181 - val_acc: 0.8750\n", "Epoch 4/10\n", "5324/5324 [==============================] - 2s 356us/step - loss: 0.2746 - acc: 0.8875 - val_loss: 0.3126 - val_acc: 0.8699\n", "Epoch 5/10\n", "5324/5324 [==============================] - 2s 349us/step - loss: 0.2424 - acc: 0.9031 - val_loss: 0.3129 - val_acc: 0.8801\n", "Epoch 6/10\n", "5324/5324 [==============================] - 2s 353us/step - loss: 0.2130 - acc: 0.9174 - val_loss: 0.2984 - val_acc: 0.8750\n", "Epoch 7/10\n", "5324/5324 [==============================] - 2s 352us/step - loss: 0.1643 - acc: 0.9378 - val_loss: 0.2843 - val_acc: 0.9020\n", "Epoch 8/10\n", "5324/5324 [==============================] - 2s 352us/step - loss: 0.1301 - acc: 0.9517 - val_loss: 0.2865 - val_acc: 0.9037\n", "Epoch 9/10\n", "5324/5324 [==============================] - 2s 362us/step - loss: 0.1019 - acc: 0.9592 - val_loss: 0.3035 - val_acc: 0.9037\n", "Epoch 10/10\n", "5324/5324 [==============================] - 2s 363us/step - loss: 0.0823 - acc: 0.9728 - val_loss: 0.3098 - val_acc: 0.9037\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(5e-3, 10)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " neg 0.91 0.91 0.91 315\n", " pos 0.90 0.89 0.90 277\n", "\n", " accuracy 0.90 592\n", " macro avg 0.90 0.90 0.90 592\n", "weighted avg 0.90 0.90 0.90 592\n", "\n" ] }, { "data": { "text/plain": [ "array([[288, 27],\n", " [ 30, 247]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspecting Misclassifications" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------\n", "id:345 | loss:10.15 | true:pos | pred:neg)\n", "\n", "所谓山景房,就是非海景房而已,没有什么山景可言,海景房确实,有条件尽量选。只是这种房的窗帘边上拉不严,早上光线进来如同亮着灯一般,可能引发。另外窗外隔音不佳,如果呼呼明显,这想必也必不了了。\n" ] } ], "source": [ "learner.view_top_losses(n=1, preproc=preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using Google Translate, the above roughly translates to:\n", "```\n", "\n", "The so-called mountain view room is just a non-sea view room, there is no mountain view at all, the sea view room is indeed, there are conditions to choose as much as possible. It’s just that the curtains in this room are not pulled up. The morning light comes in like a lit lamp, which may be triggered. In addition, the sound insulation outside the window is not good. If the whirring is obvious, it must be no longer necessary.\n", "```\n", "\n", "Mistranslations aside, this is clearly a negative review. It appears to have been incorrectly assigned a ground-truth label of positive." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making Predictions on New Data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "p = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for the text\n", "> \"*The view and service of this hotel were terrible and our room was dirty.*\"" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'neg'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(\"这家酒店的看法和服务都很糟糕,我们的房间很脏。\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for:\n", "> \"*I like the service of this hotel.*\"" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pos'" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict('我喜欢这家酒店的服务')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving Predictor for Later Deployment" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "p.save('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "p = ktrain.load_predictor('/tmp/mypred')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'neg'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# still works\n", "p.predict(\"这家酒店的风景和服务都非常糟糕\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }