{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building an Arabic Sentiment Analyzer\n", "\n", "In this notebook, we will build a simple, fast, and accurate Arabic-language text classification model in 4 simple steps. More specifically, we will build a model that classifies Arabic hotel reviews as either positive or negative.\n", "\n", "The dataset can be downloaded from Ashraf Elnagar's GitHub repository (https://github.com/elnagara/HARD-Arabic-Dataset).\n", "\n", "Each entry in the dataset includes a review in Arabic and a rating between 1 and 5. We will convert this to a binary classification dataset by assigning reviews with a rating of above 3 a positive label of 1 and assigning reviews with a rating of less than 3 a negative label of 0.\n", "\n", "(**Disclaimer:** I don't speak Arabic. Please forgive mistakes.) \n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textnegpos
0“ممتاز”. النظافة والطاقم متعاون.10
1استثنائي. سهولة إنهاء المعاملة في الاستقبال. ل...01
2استثنائي. انصح بأختيار الاسويت و بالاخص غرفه ر...01
3“استغرب تقييم الفندق كخمس نجوم”. لا شي. يستحق ...10
4جيد. المكان جميل وهاديء. كل شي جيد ونظيف بس كا...01
\n", "
" ], "text/plain": [ " text neg pos\n", "0 “ممتاز”. النظافة والطاقم متعاون. 1 0\n", "1 استثنائي. سهولة إنهاء المعاملة في الاستقبال. ل... 0 1\n", "2 استثنائي. انصح بأختيار الاسويت و بالاخص غرفه ر... 0 1\n", "3 “استغرب تقييم الفندق كخمس نجوم”. لا شي. يستحق ... 1 0\n", "4 جيد. المكان جميل وهاديء. كل شي جيد ونظيف بس كا... 0 1" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# convert ratings to a binary format: 1=positive, 0=negative\n", "import pandas as pd\n", "df = pd.read_csv('data/arabic_hotel_reviews/balanced-reviews.txt', delimiter='\\t', encoding='utf-16')\n", "df = df[['rating', 'review']] \n", "df['rating'] = df['rating'].apply(lambda x: 'neg' if x < 3 else 'pos')\n", "df.columns = ['label', 'text']\n", "df = pd.concat([df, df.label.astype('str').str.get_dummies()], axis=1, sort=False)\n", "df = df[['text', 'neg', 'pos']]\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess the Data\n", "\n", "First, we use the `texts_from_df` function to load and preprocess the data in to arrays that can be directly fed into a neural network model. \n", "\n", "We set `val_pct` as 0.1, which will automatically sample 10% of the data for validation. We specifiy `preprocess_mode='standard'` to employ normal text preprocessing. If you are using the BERT model (i.e., 'bert'), you should use `preprocess_mode='bert'`.\n", "\n", "**Notice that there is nothing speical or extra we need to do here for non-English text.** *ktrain* automatically detects the language and character encoding and prepares the data and configures the model appropriately.\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "language: ar\n", "Word Counts: 151716\n", "Nrows: 95128\n", "95128 train sequences\n", "Average train sequence length: 23\n", "Adding 3-gram features\n", "max_features changed to 2689751 with addition of ngrams\n", "Average train sequence length with ngrams: 66\n", "x_train shape: (95128,75)\n", "y_train shape: (95128,2)\n", "10570 test sequences\n", "Average test sequence length: 22\n", "Average test sequence length with ngrams: 41\n", "x_test shape: (10570,75)\n", "y_test shape: (10570,2)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_df(df, \n", " 'text', # name of column containing review text\n", " label_columns=['neg', 'pos'],\n", " maxlen=75, \n", " max_features=100000,\n", " preprocess_mode='standard',\n", " val_pct=0.1,\n", " ngram_range=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Create a Model and Wrap in Learner Object" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will employ a neural implementation of the [NBSVM](https://www.aclweb.org/anthology/P12-2018/)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "maxlen is 75\n", "building document-term matrix... this may take a few moments...\n", "rows: 1-10000\n", "rows: 10001-20000\n", "rows: 20001-30000\n", "rows: 30001-40000\n", "rows: 40001-50000\n", "rows: 50001-60000\n", "rows: 60001-70000\n", "rows: 70001-80000\n", "rows: 80001-90000\n", "rows: 90001-95128\n", "computing log-count ratios...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('nbsvm', (x_train, y_train) , preproc=preproc)\n", "learner = ktrain.get_learner(model, \n", " train_data=(x_train, y_train), \n", " val_data=(x_test, y_test), \n", " batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Estimate the LR\n", "We'll use the *ktrain* learning rate finder to find a good learning rate to use with *nbsvm*. We will, then, select the highest learning rate associated with a still falling loss.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "48736/95128 [==============>...............] - ETA: 8s - loss: 0.5553 - acc: 0.7843\n", "\n", "done.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEMCAYAAADJQLEhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9b3/8ddnJhtZWBOQPaCkiCKigEutUuve1qWtivZatbZW77Vaa9ur1nqt1ra3rf56rXah1tpq676UChW1dasrKCCyR1wABQKyJEC2yef3x0zCEJIQIGfOTOb9fDzmkZkzZ855Zwjzme/5nvP9mrsjIiLZKxJ2ABERCZcKgYhIllMhEBHJcioEIiJZToVARCTLqRCIiGS5QAuBmZ1kZkvMrNLMrm7j+eFm9k8ze8vMnjOzIUHmERGRnVlQ1xGYWRRYChwPrARmAee4+8KkdR4CnnD3P5nZscCF7n5eIIFERKRNOQFuexJQ6e7LAczsfuA0YGHSOmOAbyfuPws8vquNlpaWenl5edcmFRHp5t5444117l7W1nNBFoLBwIqkxyuBw1qtMw/4AvB/wBlAiZn1c/f17W20vLyc2bNnd3VWEZFuzczeb++5sDuLvwMcY2ZzgGOAVUCs9UpmdrGZzTaz2VVVVanOKCLSrQVZCFYBQ5MeD0ksa+HuH7r7F9x9PPD9xLKNrTfk7lPdfYK7Tygra7NlIyIieyjIQjALGGVmI8wsD5gCTEtewcxKzaw5wzXAXQHmERGRNgRWCNy9EbgMmAksAh509wVmdqOZnZpYbTKwxMyWAgOAm4PKIyIibQvs9NGgTJgwwdVZLCKye8zsDXef0NZzYXcWi4hIyFQIRETSXF1jjCffXs2Kj7cGsn0VAhGRNLdpawOX3PsGzy8N5vR5FQIRkTRX19gEQH5OMB/ZKgQiImmupRDkRgPZvgqBiEiaq2uMD7iQF1WLQEQkK21vEagQiIhkpXr1EYiIZDd1FouIZLm6hngfQX6OOotFRLJSfSzeIshTi0BEJDvVNejQkIhIVlu8ejOgQ0MiIlnr9y++C+jQkIhI1tOhIRGRLKcWgYhIljKL/8wNaIiJnEC2KiIiXWZEaRH7D+wZ2PbVIhARSXN1DU2B9Q+ACoGISFpzdzZuradnQW5g+1AhEBFJY+u31LOlPkZ5v8LA9qFCICKSxpoHnCvMC65LN9BCYGYnmdkSM6s0s6vbeH6YmT1rZnPM7C0zOyXIPCIimSYWcwCiEQtsH4EVAjOLAncAJwNjgHPMbEyr1a4DHnT38cAU4NdB5RERyUSNTfEWQU40AwsBMAmodPfl7l4P3A+c1modB5rPieoFfBhgHhGRjBNrCr5FEOR1BIOBFUmPVwKHtVrnBuApM/smUAQcF2AeEZGM05A4NJSTiYeGOukc4G53HwKcAtxjZjtlMrOLzWy2mc2uqqpKeUgRkbBsbxFk5nUEq4ChSY+HJJYluwh4EMDdXwEKgNLWG3L3qe4+wd0nlJWVBRRXRCT9tPQRZGiLYBYwysxGmFke8c7gaa3W+QD4DICZ7U+8EOgrv4hIQir6CAIrBO7eCFwGzAQWET87aIGZ3WhmpyZWuwr4upnNA+4DLnB3DyqTiEimaWwKvo8g0EHn3H0GMKPVsuuT7i8EPhlkBhGRTJbRLQIREdl7LS2CDL2OQERE9lIs0VmcqWcNiYjIXqpriBeCvIAmpQEVAhGRtFZd2whASUGGDjonIiJ7Z3NtAwA9e2g+AhGRrFRVU0fEoDhfLQIRkay0ZHU1n9inp04fFRHJVrUNMUoCbA2ACoGISFqra2wiPzfYj2oVAhGRNFbX0ER+TjTQfagQiIiksYUfbQ50nCFQIRARSVsfbdoGwJMLVge6HxUCEZE0taUulpL9qBCIiKSp2oZ4IfjR6QcGuh8VAhGRNLUtUQiG9ysMdD8qBCIiaWpbfbwQFObprCERkazU3CIoyFUhEBHJSs0tgh4qBCIi2am5RVCYpyEmRESykloEIiJZrqWPIE9jDYmIZKVt9TEiFuw0lRBwITCzk8xsiZlVmtnVbTz//8xsbuK21Mw2BplHRCSTbGuIUZiXg1mwYw0F1gNhZlHgDuB4YCUwy8ymufvC5nXc/cqk9b8JjA8qj4hIptlaHwv81FEItkUwCah09+XuXg/cD5zWwfrnAPcFmEdEJKPUNsToEXD/AATYIgAGAyuSHq8EDmtrRTMbDowA/tXO8xcDFwMMGzasa1OKiKSpx+asSsl+gj05tfOmAA+7e5tD7bn7VGAqwIQJE3xPdhBrciJGh8fa3B13aHIn5k5TU3x580vMwLDEz/i2rHl5G9t1d+pjTTTEfKflTnwb0YjFbxb/GfSxQBGR1oIsBKuAoUmPhySWtWUK8F8BZuGPL73Lz55cQs8eOWyrj7ElcX5u8+eu71F52VlykWhKFJbdfX1OxIhYUoGIxn9GIrbDczmRnZftcDMjJxp/LmK0FJr4tkgsby5AJNaPUJAboSA3SkFOtOV+j9wohflRCvOiRCORlsLVvP28aIRIJL6PorwcCvOiFOXnkJ8TUXET2QsXHFke+D6CLASzgFFmNoJ4AZgCnNt6JTMbDfQBXgkwCwcN6c2FR5WzeVtj/EMqaRAnZ/u3+4glPhQTH67x57d/oDe3GhwSP7c/JvFNv3l5xIyC3Ci5USO+9eTfO75ezJ1Y0/ZbU6vHyc83P9fY5DQ1/9xpfYg1NdEYcxqbmqhtdJocmpLWbd5vk3tiOS3bb4g1UdfQRG1jbKeWzJ6IWPyqyObC0CM3SnF+DsUFOfTukUuvwlz6FObRvySf0uJ8SkvyKS3Oo7Q4PyWdZCLpKtYU///XpzAv8H0FVgjcvdHMLgNmAlHgLndfYGY3ArPdfVpi1SnA/e5d9Z28bZNG9GXSiL5B7qLbaYw1UdfYRG1DjNrGJrbUNbK1PtZSNBpj24tHQ2x7odlS18i2hhhb6mJsrW/c/rM+xta6RrbUN7K2upZla6vZuLWB6trGNvffsyCHfXoVMKBnAQN7FTC4dyHlpYUM71fEiH5F9OwR/Gl1ImGpbRlwLrM7i3H3GcCMVsuub/X4hiAzyJ7LiUbIiUYoyg+2K6m+sYm11bV8vKWedTV1VFVvv63eXMvqTbUsXl3Nupq6HQ61FefnMKxvIaMHljC0TyGjBhRzyLA+DOrdI9C8IqnQfFVxj4CHoIb06SyWLJaXE2FIn0KG9Ol48o1t9TFWbNjKu+u28MH6razauI3KtTW8XLmeNdWrWorEsL6FHF1RylH7lXHEvv3o1SM3Bb+FSNeqTdEQ1KBCIBmkR16UigElVAwo2em5usYYiz+q5s0PNvBS5Toee3MV9776AdGIMW5IL44aVcbRo0oZN7Q3uQFfri/SFVQIRHZTfk6UcUN7M25oby785AjqG5uYu2IjLy6r4sVl67j9X8u47Z/LKM7P4fCR/Ti6opTj9h+gw0iStrbVx89fD3rkUVAhkG4qLyfScoLAVSd8gk1bG3j5nXW8sGwdLy6r4plFa7j+bwsY1KuA0pJ8jti3H/95zH70KtRhJEkPtY3dpLNYJF30Kszl5LEDOXnsQNydd9dt4amFa1i6upqPNtUy9YXlPPrmKn5x5jiOqSgLO65IyuYiABUCyUJmxsiyYi45prhl2az3Pub7j83n/Lte5zOj+/MfRwxnckWZTk+V0KiPQCTFJpb3ZdplR/HLZ5bxl9fe55+L1zKytIijK8o4a8JQxgzqGXZEyTLNF3Tm5ejQkEjKFORGufrk0Vx1QgV/m/shf5u7intffZ+7X36PEaVFfH7cIL5x9MjAr6sQAWhMDHYWjQTfKtVftEgrudEIXzp0CF86dAjraup4fM4qnl9axW3/XMb0tz7keyeN5oQxA3TYSALVmGgR5KSgEOiEapEOlBbn87VPjeSeiw5j6nmH0uTwjXve4LO3/Zs//Ptd1tXUhR1RuqnmsYZyUnDdiwqBSCedcMA+PH3l0fz4jLFsa4hx0xMLmXjzM/x85mLqGtscQV1kjzU2qUUgkpZyohHOPWwY/7rqGB659EiOGNmPO559h8/c8jzvVNWEHU+6keY+AhUCkTRlZhw6vA9//frhTD3vULbVxzjt9pd4bM5KAh5IV7LE9j4CHRoSSXsnHLAPj1x6JPv1L+bKB+ZxwR9nsXFrfdixJMO1tAiiahGIZITy0iIeufRIvnviJ3ipch1fvvM11lbXhh1LMlhzH0EqTh9VIRDpItGI8V+f3o/bzhnPsjU1HPuL5/nan2axtb7tiXdEOhLT6aMimeuUsQN54vKjOKaijGcWreXU21+iqlqnmcruaVCLQCSzVQwo4Y4vH8I1J4+mcm0NVz00r2UQMZHOqGuIkZcTScmFiyoEIgH6xjH78uMzxvLisiom/+JZ7ni2MuxIkiHeqaphSIrmy1AhEAnYuYcN456vHsaGLQ38fOYSvvfwPDZtbQg7lqS5N97fwLihvVOyLxUCkRQ4alQpC288ka8dNYIHZ6/kyJ/+kyff/ijsWJLGGmJO36K8lOxLhUAkRXKiEa773BgeufRICnKj/Gj6opYx50Vaa4g1peSMIQi4EJjZSWa2xMwqzezqdtY5y8wWmtkCM/trkHlE0sGhw/vw8zMPYuWGbZzyfy9SXavDRLKzWJOn5IwhCLAQmFkUuAM4GRgDnGNmY1qtMwq4Bvikux8AfCuoPCLp5NjRA/jB58bw7votXHT3bPUZyA7cncYmT8nIoxBsi2ASUOnuy929HrgfOK3VOl8H7nD3DQDuvjbAPCJp5aKjRnDrWeOY/f7HnHLbi6zXkNaSEEvhyKMQbCEYDKxIerwysSxZBVBhZi+Z2atmdlJbGzKzi81stpnNrqqqCiiuSOqdMX4Iv5wynlUbt/HDvy8MO46kiZYhqFMwzhCE31mcA4wCJgPnAL83s53Ol3L3qe4+wd0nlJWVpTiiSLBOHTeIK4+rYNq8+PSYIqmciwCCLQSrgKFJj4ckliVbCUxz9wZ3fxdYSrwwiGSVSyfvy6j+xVxx/1yeXawjpNmueZyhaAqGoIZgC8EsYJSZjTCzPGAKMK3VOo8Tbw1gZqXEDxUtDzCTSFrKy4lw1wUTGVlWxIV3z2LBh5vCjiQhakgMQZ2b6YeG3L0RuAyYCSwCHnT3BWZ2o5mdmlhtJrDezBYCzwLfdff1QWUSSWdD+xbypwsnAfD5X/2bu196N+REEpZYCgecg/gx+sC4+wxgRqtl1yfdd+DbiZtI1hvat5D7vn44N89YyA1/X0gkYnzliPKwY0mKNcQSLYJucGhIRPbAEfv242//dRSTRvTl5umLeH/9lrAjSYqtq4nPcJdWQ0yY2RVm1tPi/mBmb5rZCUGHE8lW0Yhx25Tx5EYjXPvY/JZDBZIdllfVADCkb3qNPvpVd98MnAD0Ac4DfhpYKhFhn14FXHvK/rxUuZ5rH50fdhxJod8+/w7RiFHRvyQl++tsIWjusTgFuMfdFyQtE5GAnHvYMC44spwH31jBy++sCzuOpMCbH2xg6ZoaYk1OJM2uI3jDzJ4iXghmmlkJ0BRcLBFpduVxFYwsLeKCP87i5UoVg+6ucm1NyvfZ2UJwEXA1MNHdtwK5wIWBpRKRFr0Kc7n7wkn0L8nn3Dtf48+vvBd2JAlQOheCI4Al7r7RzP4DuA7QFS8iKTK0byGPXnokI0qLuP5vC7jtn8uIn30t3Ym7M/WF+DW115w8OmX77Wwh+A2w1czGAVcB7wB/DiyViOykf88CnvzWpzjt4EHc+vRS7n3tg7AjSRf74OOtABwyrDffOGbflO23s4WgMXHx12nA7e5+B5Ca7mwRaZGfE+WnXzgIgB88/jYvLNVovN3Jpm3xeSnOP7I8pfvtbCGoNrNriJ82Ot3MIsT7CUQkxXrkRXn6yqMBuOmJhTTpGoNu46Yn4kOR9y8pSOl+O1sIzgbqiF9PsJr4SKI/DyyViHRo1IASfnn2wSxbW8Pzy9Qq6C5mvbcBgMam1J6U2alCkPjw/wvQy8w+B9S6u/oIREJ0ytiBlBbn8eCsFbteWTJCSX58+LcxA3umdL+dHWLiLOB14EzgLOA1M/tSkMFEpGN5ORFOO3gwTy9cw4cbt4UdR7rAyLIijq4oo19xfkr329lDQ98nfg3B+e7+FeLzEf8guFgi0hlfOWI4jU3OE299GHYU6QLrauopLU7NQHPJOlsIIq0mll+/G68VkYAM71fEuCG9+Mk/FlNT1xh2HNkLsSanqrqOspLUtgag8x/mT5rZTDO7wMwuAKbTap4BEQnHuYcNwx1unr4o7CiyF9ZsrqU+1sSwvoUp33dnO4u/C0wFDkrcprr7fwcZTEQ65+yJw/j8uEH8fd6HvFOV+uEJpGs0X0PQtzB9Dw3h7o+4+7cTt8eCDCUiu+eq4yuIRoyfPbk47Ciyh7bWx4D4dSKp1mEhMLNqM9vcxq3azDanKqSIdKy8tIgzDx3CzAVrWPih/mtmoq318T6eovxAZxBuU4eFwN1L3L1nG7cSd0/tia4i0qGLjx5Jj9wo/++ZpWFHkT3wzfvmANAjN81aBCKSOfr3LGDKpKG8sLSq5dulZI6NW+N9BGnXIhCRzPLZsQOpa2zi2kc1z3Emqaqua7lfmG59BHvLzE4ysyVmVmlmV7fx/AVmVmVmcxO3rwWZR6S7m1Del9MPHsTjcz9Ux3EGmXjzMy33w+gsDqwNYmZR4A7geGAlMMvMprn7wlarPuDulwWVQyTb3HrWwWxriHHXS+/yxUOHUDFAI8ZnksJu1kcwCah09+XuXg/cT3w+AxEJUCRi/OBzY+iRG+XqR96ivlHTi6ezxtiO/z450dQfsQ9yj4OB5GERVyaWtfZFM3vLzB42s6FtbcjMLjaz2WY2u6pKQ+6K7MqQPoVcc8r+vPnBRm58YkHYcaQDzy3Z/pn2w1MPCCVD2J3FfwfK3f0g4GngT22t5O5T3X2Cu08oKytLaUCRTDVl4lA+e9BA7n31A1YkpkCU9HPT9PjR8s8eNDDlM5M1C7IQrAKSv+EPSSxr4e7r3b25u/xO4NAA84hkFTPjOyd8AoDH56zaxdoSljWbawGImoWWIchCMAsYZWYjzCwPmAJMS17BzAYmPTwV0KhZIl1oRGkR5f0KueXppcx+7+Ow40gbChKdw+OG9g4tQ2CFwN0bgcuAmcQ/4B909wVmdqOZnZpY7XIzW2Bm84DLgQuCyiOSrb530mgAvnHPGzt1TEr4BvfuwX79i/nqJ8tDyxBoH4G7z3D3Cnff191vTiy73t2nJe5f4+4HuPs4d/+0u+vEZ5EudsrYgfzsiwexfks9x/z8OfUXpJmaukYOHNQT66aHhkQkTXzhkMHcfMaBrKup45fPLAs7jiSpqW2kuCD1w0okC3fvIpISOdEIXz5sOEtWV/PnV96nMC/KTacfGHYsAarrGinOzw01g1oEIlnk4qNHAnDPq+/zcuW6kNNIXWOM+sYmSkJuEagQiGSRIX0KWXzTSRTkRjj3ztdYX1O36xdJYLbUxSejKQ5hxNFkKgQiWaYgN9pyBev103TVcZhqauPDhasQiEjKnT1xGKeM3Yfpb33Es0vWhh0na22ujc9BEHZnsQqBSJa68bQDGdKnB//7j8U0ae6CUNTUxVsEJWoRiEgYSovzueqEChavruaphWvCjpOVmg8NlRTorCERCcnnDxrEiNIirnt8vjqOQ7BsbQ2gQ0MiEqKcaIRfnHkQG7c28Kt/VYYdJ6u8t24L/5uYRW5Y38JQs6gQiGS5Q4f35dSDB/HArBV8vKU+7DhZY33ivT5sRF+ikfCGlwAVAhEBLj1mXxpiTdzy1JKwo3R7m7Y18HLlOr74m5cB+NZxFSEnUiEQEWDUgBJOGTuQx+as4tXl68OO061dft8czr3ztZbHYUxW35oKgYgA8P3P7k+fwjzO+f2rzPlgQ9hxuq3WhbZHCJPVt6ZCICIADOhZwCOXHkluJML5d73OtvpY2JG6pfycHT92i/JVCEQkjezTq4DvnfQJNtc2ctu/NFx1EHr22H7NwLWnjGZw7x4hponTMNQisoOvfWokry5fz2+ee4cR/Yo4a+LQXb9IOu3jLfVcdNQIvnviJ1qmqQybWgQispNbzz6YAT3zuf3ZStw1/ERX2VrfyNb6GP2K89KmCIAKgYi0oWdBLt86roIPPt7Kfa+vCDtOt7G+Jn7tQGlxfshJdqRCICJtOmP8YI4Y2Y8bpi1g2ZrqsON0Cys2xOeL3qdnQchJdqRCICJtKsiNcts542ly5zYNP9El3lsXLwT79S8OOcmOVAhEpF1lJfmcMX4wf5/3oa4t6AL1jfFTctPh2oFkgRYCMzvJzJaYWaWZXd3Bel80MzezCUHmEZHdd8OpB1BSkMOd/3437CgZryEW73jPzUmv7+CBpTGzKHAHcDIwBjjHzMa0sV4JcAXwWuvnRCR8Rfk5nD1hKNPf+ogn314ddpyMVh9rAiA3Gu4gc60FWZYmAZXuvtzd64H7gdPaWO8m4H+B2gCziMhe+OaxoxjYq4BfPrNUp5PuhYbmQhDJkhYBMBhIPu9sZWJZCzM7BBjq7tM72pCZXWxms81sdlVVVdcnFZEO9SrM5crj4rOZvaJB6fZYQ6yJnIgRCXnY6dZCK0tmFgFuBa7a1bruPtXdJ7j7hLKysuDDichOTj14EL165HLu719jg+Yt2CMNMSc3ml6tAQi2EKwCkq9NH5JY1qwEOBB4zszeAw4HpqnDWCQ9FeRGOWXsQADG3/Q0tQ0alG531Tc2pV3/AARbCGYBo8xshJnlAVOAac1Puvsmdy9193J3LwdeBU5199kBZhKRvXD958YwtG98kLR7Xnk/5DSZ58m3V5OOPSyBFQJ3bwQuA2YCi4AH3X2Bmd1oZqcGtV8RCU6PvCgvfu9YJgzvw19ee5+mpnT8WEtPM+Z/xOrNteRl2aEh3H2Gu1e4+77ufnNi2fXuPq2NdSerNSCSGc47Yjjvrd/KvyvXhR0lI1RV1/Gff3kT2D5XcTpJv9IkImnvpAP3oV9RHve+qsNDnfGLmdvngh69T0mISdqmQiAiuy0/J8rZE4fyzKI1LK+qCTtOWqtrjPHA7O1n0v/qnPEhpmmbCoGI7JELjiwH4H+mLdBFZh3YvK2x5f6fvzqJUQPUIhCRbqJ/zwKu+EwFLy5bx4z5GnqiPZtrGwD4wvjBHF2RntdBqRCIyB677Nj9GFlaxI9nLCKmM4h28t66LXzmlucB+Py4QSGnaZ8KgYjssWjEOHnsPqzauI1XNfTETib/4rmW+yUF6TtFvAqBiOyVi4/el8K8KL97Ybn6ClrZf2DPlvs9e+SGmKRjKgQisld69cjlq58cwQtLq7jpiUUqBkn6l2yfm7gwL70mo0mWvm0VEckY3zpuFPNWbuSul95l07YGbjlrXNiR0sK2+hjD+xVy5qFDGNy7R9hx2qUWgYjstZxohN+ddyilxfk88uZKyq+ezlMLdCbRtoYYI0uLuOzYUZil32BzzVQIRKRLFObl8NSVRzOsbyEA335wHh+s3xpyqvC4O1XVdfRK476BZioEItJl+hbl8cL3Ps30y4+iIdbEVQ/Nzdo+g/fXb2X15lrGD+sTdpRdUiEQkS53wKBenD1xKLPe28Br734cdpyUW72pls/f/m8AjknTi8iSqRCISCAunbwvANc+Nj/rhqt+dM5KqmvjQ0sM71cYcppdUyEQkUAM7NWDi48eyfKqLZz8fy9m1YxmzVN5HlNRltadxM1UCEQkMN898RNMLO/DkjXVjP7Bk/xt7qpdv6gb+GhTLYN79+DuCyeGHaVTVAhEJDC50QgPXXIk3zhmJABX3D835ESpMX/VJg4a0isjWgOgQiAiKXDNyftzxvjBAN3+lNI3P9jA++u30rsw/U8bbaZCICIp8e3jKwD4x9sfhZwkWLc+tRSAeSs2hZyk81QIRCQlhvYtZPyw3vzkH4uZMb/7FoPmqSi/c2JFyEk6T4VARFLm8mNHAXDz9EU0xppCThOcwrwox44eEHaMTlMhEJGU+fTo/vz+KxNYtXEbNz2xMOw4gdjWEKNHbvqONNqWQAuBmZ1kZkvMrNLMrm7j+UvMbL6ZzTWzf5vZmCDziEj4jh8zgN6Fufzplfe599X3w47T5f4298OMu2YisEJgZlHgDuBkYAxwThsf9H9197HufjDwM+DWoPKISPq4bcp4AK57/G3qGjPrQ7M9U6a+wrcfnEtNXSNb6jPrdwqyRTAJqHT35e5eD9wPnJa8grtvTnpYBGTXdegiWeroijIuOLIcgEfeyOyLzOat2Mi0eR/y6vKPefTN+O9SngHDSiQLcmKawcCKpMcrgcNar2Rm/wV8G8gDjm1rQ2Z2MXAxwLBhw7o8qIik3g8+N4bZ73/MtY/N5+bpC3nrhhOJRjLjAqxmtQ0xTrvjpZ2WX3vK/iGk2XOhdxa7+x3uvi/w38B17awz1d0nuPuEsrL0H8lPRHYtGjH+dOEkALbUx7ju8fkhJ9o95VdPZ/QPnmzzuYoBJSlOs3eCLASrgKFJj4cklrXnfuD0APOISJrpV5xP5c0nM7K0iPteX5Exncdb6xt3Wnb6wYOA+BzO5aVFqY60V4IsBLOAUWY2wszygCnAtOQVzGxU0sPPAssCzCMiaSgnGuGRS48kLxrhusff5it3vU5N3c4ftOnknbVbdlr2n5/ej7svnMj9Fx8eQqK9E1ghcPdG4DJgJrAIeNDdF5jZjWZ2amK1y8xsgZnNJd5PcH5QeUQkffUpyuOhS44A4IWlVRxy09PE0ngOg1UbtwFw5qFDmFTel3d/cgoVA0qY/In+7D+wZ8jpdl+QncW4+wxgRqtl1yfdvyLI/YtI5hg3tDeVN5/Ml+98jdfe/Zh/LV7L8WPS8+rczbUNAFz+mVEM7ZtZZwi1JfTOYhGRZjnRCH/52mEM6lXA1/88m8fmrKRybU3YsXayeVu8EPTMgInpO0OFQETSSk40wpWJkUqvfGAep97+b7akSZ/BktXVnPeH17h/VvzM+OL8QA+qpIwKgYiknTMnDOXK4+LFYGt9jO8+PC/kRHD9397mxF++wIvL1rW0UjLtuof2qBCISFq64vktj34AAAvBSURBVLhRvPfTz3LV8RXMmL+aA65/kt+/sDyULNW1Dfz5lR1Pbc2kiWd2pXu0a0Sk27pk8r7c8vRSttTHuHnGIkoKcpgyKXUjDDQ1OWNveKrl8W++fAjHpWkn9p5Si0BE0lpuNMJ9Xz+cn35hLABXPzqfH89YBBDYgHW1DTEWfriZbfUxfvP8Oy3Ln/3OZE4eO5DcaITcaPf5+DT39D1Xty0TJkzw2bNnhx1DRELw0aZtHPGTfwFQWpzPupo69h/YkxmXH7VXE8VX1zawZnMd+/Uv3mEfye65aBKfGpW5Q9yY2RvuPqGt57pPSRORbm9grx789j8OBWBdTR0Aiz7azHWPv836xOPd8ezitZRfPZ2xNzzFcbc+zxNvfdhmEQDYp2fBngdPcyoEIpJRTjpwH/701fhgdc1DWf/ltQ+4+J43WtZZuWErL1Wu2+W2fv1c5Q6PL/vrnJb7Pz4jfijq5MT+RmXYQHK7Q4eGRCQj1TbEKMiN8uyStVz4x1ntrrf4ppNYuqaakoJcRiQNBhdrcg66YSZb6mN8/VMjeGDWCjbXxq9XOPGAAfzuvAmsra6lZ0EuBRk29WRbOjo0pEIgIhlvc20DByWd2dOe/5tyMKcdPJi6xhjXPDKfR+es4qbTD+S8w4ezvqaONZvrGNirgD5FeSlInVodFQKdPioiGa9nQS5zfnA8a6preXHpOg4f2Y99+xfx3YffYvpbH7Wsd8X9c/njS+/xH4cP59E58VHxhyXGCupXnE+/4vxQ8odNLQIR6daqquvYXNvAQ7NX8tukU0EBDhrSi4cuOYL8nMw/9LMrahGISNYqK8mnrCSfq06oYEtdI1XVdTy5YDUAvzpnfFYUgV1RIRCRrJAbjXDT6QcC8JN/LOKxN1cxsFePkFOlBx0aEhHJArqgTERE2qVCICKS5VQIRESynAqBiEiWUyEQEclyKgQiIllOhUBEJMupEIiIZLmMu6DMzKqA5lmkewGbOrjf/LMU2PXg5DtL3uburNOZZR3lTV4WVPb2nu8oZ+vHu/odMjV72Lk7ytpW3uRlYWdPt/e8vXW60//R9rK2vj/c3dueYs3dM/YGTO3oftLP2Xu7/d1ZpzPLOsqbiuztPd9Rzt15zzM5e9i5O8raXf9e9H80mL+Xzrw/7p7xh4b+vov7ycv2dvu7s05nlu0qb9DZ23u+o5ytHwfxnndmG+mavatyt17W3u+hvxf9H21r2W5nzLhDQ3vCzGZ7O2NspDtlT71MzQ2Zmz1Tc0NmZ2+W6S2CzpoadoC9oOypl6m5IXOzZ2puyOzsQJa0CEREpH3Z0iIQEZF2qBCIiGQ5FQIRkSyX9YXAzD5lZr81szvN7OWw8+wOM4uY2c1m9iszOz/sPJ1lZpPN7MXE+z457Dy7y8yKzGy2mX0u7CydZWb7J97vh83s0rDz7A4zO93Mfm9mD5jZCWHn2R1mNtLM/mBmD4edpSMZXQjM7C4zW2tmb7dafpKZLTGzSjO7uqNtuPuL7n4J8ATwpyDzJuuK7MBpwBCgAVgZVNZkXZTbgRqggBTlhi7LDvDfwIPBpNxZF/2dL0r8nZ8FfDLIvMm6KPvj7v514BLg7CDzJuui7Mvd/aJgk3aBPbkiLl1uwNHAIcDbScuiwDvASCAPmAeMAcYS/7BPvvVPet2DQEkmZQeuBr6ReO3DGZQ7knjdAOAvGfaeHw9MAS4APpcpuROvORX4B3BuJr3nSa+7BTgkQ7On5P/nnt5yyGDu/oKZlbdaPAmodPflAGZ2P3Cau/8EaLMpb2bDgE3uXh1g3B10RXYzWwnUJx7Ggku7XVe95wkbgPwgcrali97zyUAR8f/828xshrs3pXvuxHamAdPMbDrw1+AS77DPrnjPDfgp8A93fzPYxNt18d96WsvoQtCOwcCKpMcrgcN28ZqLgD8Glqjzdjf7o8CvzOxTwAtBBtuF3cptZl8ATgR6A7cHG22Xdiu7u38fwMwuANYFXQQ6sLvv+WTgC8QL74xAk+3a7v6dfxM4DuhlZvu5+2+DDLcLu/u+9wNuBsab2TWJgpF2umMh2G3u/j9hZ9gT7r6VeBHLKO7+KPEilrHc/e6wM+wOd38OeC7kGHvE3W8Dbgs7x55w9/XE+zbSWkZ3FrdjFTA06fGQxLJMkKnZMzU3ZG72TM0Nyp52umMhmAWMMrMRZpZHvGNvWsiZOitTs2dqbsjc7JmaG5Q9/YTdW72Xvfr3AR+x/fTJixLLTwGWEu/d/37YObtT9kzNncnZMzW3smfOTYPOiYhkue54aEhERHaDCoGISJZTIRARyXIqBCIiWU6FQEQky6kQiIhkORUCCZyZ1aRgH6d2cgjprtznZDM7cg9eN97M/pC4f4GZhT3eEgBmVt56yOU21ikzsydTlUlSQ4VAMoaZRdt7zt2nuftPA9hnR+NxTQZ2uxAA15K5Y+dUAR+ZWcrmNJDgqRBISpnZd81slpm9ZWY/TFr+uJm9YWYLzOzipOU1ZnaLmc0DjjCz98zsh2b2ppnNN7PRifVavlmb2d1mdpuZvWxmy83sS4nlETP7tZktNrOnzWxG83OtMj5nZr80s9nAFWb2eTN7zczmmNkzZjYgMTzxJcCVZjbX4jPdlZnZI4nfb1ZbH5ZmVgIc5O7z2niu3Mz+lXhv/pkYHh0z29fMXk38vj9qq4Vl8VnTppvZPDN728zOTiyfmHgf5pnZ62ZWktjPi4n38M22WjVmFjWznyf9W30j6enHgS+3+Q8smSnsS5t16/43oCbx8wRgKmDEv4Q8ARydeK5v4mcP4G2gX+KxA2clbes94JuJ+/8J3Jm4fwFwe+L+3cBDiX2MIT5+PMCXiA/BHAH2IT4fwpfayPsc8Oukx32g5Sr8rwG3JO7fAHwnab2/Akcl7g8DFrWx7U8DjyQ9Ts79d+D8xP2vAo8n7j8BnJO4f0nz+9lqu18Efp/0uBfxiVOWAxMTy3oSH3G4EChILBsFzE7cLycxCQtwMXBd4n4+MBsYkXg8GJgf9t+Vbl130zDUkkonJG5zEo+LiX8QvQBcbmZnJJYPTSxfT3zCnUdabad5COs3iI+x35bHPT5XwEIzG5BYdhTwUGL5ajN7toOsDyTdHwI8YGYDiX+4vtvOa44DxsTnUQGgp5kVu3vyN/iBQFU7rz8i6fe5B/hZ0vLTE/f/CvyijdfOB24xs/8FnnD3F81sLPCRu88CcPfNEG89ALeb2cHE39+KNrZ3AnBQUoupF/F/k3eBtcCgdn4HyUAqBJJKBvzE3X+3w8L4pCnHAUe4+1Yze474fMYAte7eeva1usTPGO3/Ddcl3bd21unIlqT7vwJudfdpiaw3tPOaCHC4u9d2sN1tbP/duoy7LzWzQ4gPiPYjM/sn8Fg7q18JrAHGEc/cVl4j3vKa2cZzBcR/D+km1EcgqTQT+KqZFQOY2WAz60/82+aGRBEYDRwe0P5fAr6Y6CsYQLyztzN6sX3M+fOTllcDJUmPnyI+mxYAiW/crS0C9mtnPy8TH9YY4sfgX0zcf5X4oR+Snt+BmQ0Ctrr7vcDPic+1uwQYaGYTE+uUJDq/exFvKTQB5xGfh7e1mcClZpabeG1FoiUB8RZEh2cXSWZRIZCUcfeniB/aeMXM5gMPE/8gfRLIMbNFxOemfTWgCI8QH054IXAv8CawqROvuwF4yMzeANYlLf87cEZzZzFwOTAh0bm6kDZmpnL3xcSnXCxp/RzxInKhmb1F/AP6isTybwHfTizfr53MY4HXzWwu8D/Aj9y9Hjib+HSm84CniX+b/zVwfmLZaHZs/TS7k/j79GbilNLfsb319WlgehuvkQylYaglqzQfs7f4XLKvA59099UpznAlUO3ud3Zy/UJgm7u7mU0h3nF8WqAhO87zAvEJ2zeElUG6lvoIJNs8YWa9iXf63pTqIpDwG+DM3Vj/UOKduwZsJH5GUSjMrIx4f4mKQDeiFoGISJZTH4GISJZTIRARyXIqBCIiWU6FQEQky6kQiIhkORUCEZEs9/8BACB1u6EnyHwAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train the Model\n", "\n", "We will use the `autofit` method that employs a triangular learning rate policy and train 1 epoch.\n", "\n", "As shown in the cell below, our final validation accuracy is **94%** with only ~20 seconds of training over a single epoch!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.005...\n", "Train on 95128 samples, validate on 10570 samples\n", "Epoch 1/1\n", "95128/95128 [==============================] - 21s 219us/step - loss: 0.2227 - acc: 0.9343 - val_loss: 0.1754 - val_acc: 0.9427\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(5e-3, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspecting Misclassifications" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------\n", "id:5338 | loss:13.94 | true:neg | pred:pos)\n", "\n", "“ممتاز جدا وأشكر كل العاملين في الفندق ع حسن ممتاز جدا وانصح به من جميع النواحي نظافه وتعامل راقي ومواصلات الي الحرم ه مدار الساعه وإفطار ممتاز وأشكر موظفو الاستقبال ع حسن اسقبالهم مكان مميز وسعر مميز فقطالماء للتعريف\n" ] } ], "source": [ "learner.view_top_losses(n=1, preproc=preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using Google Translate, the above roughly translates to:\n", "> *The hotel is very clean and the staff are very friendly and helpful.*\n", "\n", "This is clearly a positive review and was predicted as positive by our model. The \"ground truth\" rating for this review is incorrect and explains the high loss. Perhaps the customer accidently entered a low rating." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making Predictions on New Data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "p = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for the text\n", "> \"*The room was clean, the food excellent, and I loved the view from my room.*\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pos'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(\"الغرفة كانت نظيفة ، الطعام ممتاز ، وأنا أحب المنظر من غرفتي.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicting label for:\n", "> \"*This hotel was too expensive and the staff is rude.*\"" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'neg'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict('كان هذا الفندق باهظ الثمن والموظفين غير مهذبين.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save our Predictor for Later Deployment" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# save model for later use\n", "p.save('/tmp/arabic_predictor')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# reload from disk\n", "p = ktrain.load_predictor('/tmp/arabic_predictor')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pos'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# still works as expected after reloading from disk\n", "p.predict(\"الغرفة كانت نظيفة ، الطعام ممتاز ، وأنا أحب المنظر من غرفتي.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }