{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Text Classification Example: Sentiment Analysis with IMDb Movie Reviews\n", "\n", "We will begin by importing some required modules for performing text classification in *ktrain*." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will load and preprocess the text data for training and validation. *ktrain* can load texts and associated labels from a variety of source:\n", "\n", "- `texts_from_folder`: labels are represented as subfolders containing text files [ [example notebook] ](https://github.com/amaiya/ktrain/blob/master/examples/text/IMDb-BERT.ipynb)\n", "- `texts_from_csv`: texts and associated labels are stored in columns in a CSV file [ [example notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/toxic_comments-fasttext.ipynb) ]\n", "- `texts_from_df`: texts and associated labels are stored in columns in a *pandas* DataFrame [ [example notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/ArabicHotelReviews-nbsvm.ipynb) ]\n", "- `texts_from_array`: texts and labels are loaded and preprocessed from an array [ [example notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/20newsgroup-distilbert.ipynb) ]\n", "\n", "For `texts_from_csv` and `texts_from_df`, labels can either be multi or one-hot-encoded with one column per class or can be a single column storing integers or strings like this:\n", "```python\n", "# my_training_data.csv\n", "TEXT,LABEL\n", "I like this movie,positive\n", "I hate this movie,negative\n", "```\n", "\n", "For `texts_from_array`, the labels are arrays in one of the following forms:\n", "```python\n", "# string labels\n", "y_train = ['negative', 'positive']\n", "# integer labels\n", "y_train = [0, 1] # indices must start from 0\n", "# multi or one-hot encoded labels (used for multi-label problems)\n", "y_train = [[1,0], [0,1]]\n", "```\n", "\n", "In the latter two cases, you must supply a `class_names` argument to the `texts_from_array`, which tells *ktrain* how indices map to class names. In this case, `class_names=['negative', 'positive']` because 0=negative and 1=positive.\n", "\n", "Sample arrays for `texts_from_array` might look like this:\n", "```python\n", "x_train = ['I hate this movie.', 'I like this movie.']\n", "y_train = ['negative', 'positive']\n", "x_test = ['I despise this movie.', 'I love this movie.']\n", "y_test = ['negative', 'positive']\n", "```\n", "\n", "All of the above methods transform the texts into a sequence of word IDs in one way or another, as expected by neural network models.\n", "\n", "\n", "In this first example problem, we use the ```texts_from_folder``` function to load documents as fixed-length sequences of word IDs from a folder of raw documents. This function assumes a directory structure like the following:\n", "\n", "```\n", " ├── datadir\n", " │ ├── train\n", " │ │ ├── class0 # folder containing documents of class 0\n", " │ │ ├── class1 # folder containing documents of class 1\n", " │ │ ├── class2 # folder containing documents of class 2\n", " │ │ └── classN # folder containing documents of class N\n", " │ └── test \n", " │ ├── class0 # folder containing documents of class 0\n", " │ ├── class1 # folder containing documents of class 1\n", " │ ├── class2 # folder containing documents of class 2\n", " │ └── classN # folder containing documents of class N\n", "```\n", "\n", "Each subfolder will contain documents in plain text format (e.g., `.txt` files) pertaining to the class represented by the subfolder.\n", "\n", "For our text classification example, we will again classifiy IMDb movie reviews as either positive or negative. However, instead of using the pre-processed version of the dataset pre-packaged with Keras, we will use the original (or raw) *aclImdb* dataset. The dataset can be downloaded from [here](http://ai.stanford.edu/~amaas/data/sentiment/). Set the ```DATADIR``` variable to the location of the extracted *aclImdb* folder.\n", "\n", "In the cell below, note that we supplied `preprocess_mode='standard'` to the data-loading function (which is the default). For pretrained models like BERT and DistilBERT, the dataset must be preprocessed in a specific way. If you are planning to use BERT for text classification, you should replace this argument with `preprocess_mode='bert'`. Since we will not be using BERT in this example, we leave it as `preprocess_mode='standard'`. See [this notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/IMDb-BERT.ipynb) for an example of how to use BERT for text classification in *ktrain*. There is also a [DistilBERT example notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/20newsgroup-distilbert.ipynb). \n", "**NOTE:** If using `preprocess_mode='bert'` or `preprocess_mode='distilbert'`, an English pretrained model is used for English, a Chinese pretrained model is used for Chinese, and a multilingual pretrained model is used for all other languages. For more flexibility in choosing the model used, you can use the alternative [Transformer API for text classification](https://github.com/amaiya/ktrain/blob/master/tutorials/tutorial-A3-hugging_face_transformers.ipynb) in *ktrain*. \n", "\n", "Please also note that, when specifying `preprocess_mode='distilbert'`, the first two return values are `TransformerDataset` objects, not Numpy arrays. So, it is best to always use `trn, val, preproc` on the left-hand side of the expression (instead of `(x_train, y_train), (x_test, y_test_, preproc`) to avoid confusion, as shown below." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: utf-8\n", "language: en\n", "Word Counts: 88582\n", "Nrows: 25000\n", "25000 train sequences\n", "train sequence lengths:\n", "\tmean : 237\n", "\t95percentile : 608\n", "\t99percentile : 923\n", "Adding 3-gram features\n", "max_features changed to 5151281 with addition of ngrams\n", "Average train sequence length with ngrams: 709\n", "train (w/ngrams) sequence lengths:\n", "\tmean : 709\n", "\t95percentile : 1821\n", "\t99percentile : 2766\n", "x_train shape: (25000,2000)\n", "y_train shape: (25000, 2)\n", "Is Multi-Label? False\n", "25000 test sequences\n", "test sequence lengths:\n", "\tmean : 230\n", "\t95percentile : 584\n", "\t99percentile : 900\n", "Average test sequence length with ngrams: 523\n", "test (w/ngrams) sequence lengths:\n", "\tmean : 524\n", "\t95percentile : 1295\n", "\t99percentile : 1971\n", "x_test shape: (25000,2000)\n", "y_test shape: (25000, 2)\n" ] } ], "source": [ "# load training and validation data from a folder\n", "DATADIR = 'data/aclImdb'\n", "trn, val, preproc = text.texts_from_folder(DATADIR, \n", " max_features=80000, maxlen=2000, \n", " ngram_range=3, \n", " preprocess_mode='standard',\n", " classes=['pos', 'neg'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Having loaded the data, we will now create a text classification model. The `print_text_classifier` function prints some available models. The model selected should be consistent with the `preprocess_mode` selected above. \n", "\n", "(As mentioned above, one can also use the alternative `Transformer` API for text classification in *ktrain* to access an even larger library of Hugging Face Transformer models like RoBERTa and XLNet. See [this tutorial](https://github.com/amaiya/ktrain/blob/master/tutorials/tutorial-A3-hugging_face_transformers.ipynb) for more information on this.) \n", "\n", "In this example, the `text_classifier` function will return a [neural implementation of NBSVM](https://medium.com/@asmaiya/a-neural-implementation-of-nbsvm-in-keras-d4ef8c96cb7c), which is a strong baseline that can outperform more complex neural architectures. It may take a few moments to return as it builds a document-term matrix from the input data we provide it. The ```text_classifier``` function expects `trn` to be a preprocessed training set returned from the `texts_from*` function above. In this case where we have used `preprocess_mode='standard'`, `trn` is a numpy array with each document represented as fixed-size sequence of word IDs." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fasttext: a fastText-like model [http://arxiv.org/pdf/1607.01759.pdf]\n", "logreg: logistic regression using a trainable Embedding layer\n", "nbsvm: NBSVM model [http://www.aclweb.org/anthology/P12-2018]\n", "bigru: Bidirectional GRU with pretrained fasttext word vectors [https://fasttext.cc/docs/en/crawl-vectors.html]\n", "standard_gru: simple 2-layer GRU with randomly initialized embeddings\n", "bert: Bidirectional Encoder Representations from Transformers (BERT) [https://arxiv.org/abs/1810.04805]\n", "distilbert: distilled, smaller, and faster BERT from Hugging Face [https://arxiv.org/abs/1910.01108]\n" ] } ], "source": [ "text.print_text_classifiers()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "building document-term matrix... this may take a few moments...\n", "rows: 1-10000\n", "rows: 10001-20000\n", "rows: 20001-25000\n", "computing log-count ratios...\n", "done.\n" ] } ], "source": [ "# load an NBSVM model\n", "model = text.text_classifier('nbsvm', trn, preproc=preproc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we instantiate a Learner object and call the ```lr_find``` and ```lr_plot``` methods to help identify a good learning rate." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/5\n", "25000/25000 [==============================] - 6s 226us/step - loss: 0.6906 - acc: 0.5797\n", "Epoch 2/5\n", "25000/25000 [==============================] - 5s 206us/step - loss: 0.6071 - acc: 0.9114\n", "Epoch 3/5\n", "25000/25000 [==============================] - 5s 205us/step - loss: 0.2151 - acc: 0.9711\n", "Epoch 4/5\n", "16032/25000 [==================>...........] - ETA: 1s - loss: 0.0252 - acc: 0.9943\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEOCAYAAACTqoDjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xl8VOXZxvHfPZONJYQtrGEVkOKCSsBdcV+rrSv0bdVqi7a1ttr2rVrbqq3V1mpbt7bUWquv1lq3oqJYca2IEJAd0RhAQJCwr1nnfv+YIYY4YAI5OTOT6/txPpzznDMz98NIrpw55zyPuTsiIiINRcIuQEREUpMCQkREklJAiIhIUgoIERFJSgEhIiJJKSBERCQpBYSIiCSlgBARkaQUECIikpQCQkREksoKu4Cm6tq1q/fv3z/sMkRE0sqMGTPWuHthU56TdgHRv39/SkpKwi5DRCStmNnSpj5HXzGJiEhSCggREUkq0IAws1PNbJGZlZrZtUm2/87MZiUe75vZhiDrERGRxgvsHISZRYF7gZOA5cB0M5vg7gt27OPuV9fb/7vAwUHVIyIiTRPkEcQooNTdy9y9CngMOHs3+48F/hFgPSIi0gRBBkRvYFm99eWJts8ws37AAOCVAOsREZEmSJXLXMcAT7h7bbKNZjYOGAfQt2/fPX6TyppaNm2vYenarWyprMEdHCcnGsVxDMMMDMDYad2s/nJ8hx3rWZEI1bEYOdEI+XlZ5GVHycuKYhHIiUbIiUaIRGyP6xYRCUOQAbEC6FNvvSjRlswY4Du7eiF3Hw+MByguLt6jSbTvf7OMW194j9pYeHNwRwyiESNiRjRixNzrlrMiRk5WhOxEoGRHI4l1q1vesf+O7dmJ7TnRCFk79otGaJuTRducKHk5UXKjEdrmRmmXm0VBm2zysqK0yYnSPjerLsxERJIJMiCmA4PNbADxYBgDfKXhTmY2FOgEvB1gLRzctyPfPHogPQvy6Nu5LQVts4nEDwWoqolhRvyIwh2HuqOLxH916163Ht8Ph+raGNlZEapqYmypqKGippbK6hgxdyprYlTXxog5xGJOrXv8z5gTiRixmFOTWK+ujVFVG6Mq8Zzq2kRbTYwtlTUYEEu8X/3tdcs1MSoT+zdWu5woXdrn0qV9Dl3a5dK1fQ5d2ufQecdyu8S29jl0bptDVlRXRou0FoEFhLvXmNmVwCQgCjzg7vPN7GagxN0nJHYdAzzm7oH+aj+iX2dG9Osc5FukjNqYs62qhorqGJU1tWytrGVLZTUbt1dTVRNjW1UtWypr2FxRw9otVazdWsm6rVWs2LCdOcs3sG5rFTW7ONLq1Dabzu1y2LdHPicN686oAV3oVZCHmb5CE8k0FvDP5WZXXFzsGmojWLGYs6mimjVbqli3tYq1WypZk/hzR6BMW7yeNVsqAcjPy+KA3gUM79OR4UUFHFjUkZ4KDZGUYmYz3L24Kc9JlZPUkkIiEaNj2xw6ts3Z5T6xmDPv443MWb6RhSs3MXv5Bv7yRlndkUfX9rkc1CceFgcWFTC8qCOd2u369UQk9SggZI9EIpb44d+xrq2iupaFKzcxZ/lGZi/bwOzlG5j83mp2HKT27tiGI/bpwoh+nRjaswN9O7els0JDJGUpIKTZ5GVHObhvJw7u26mubXNFNXNXbGT2so3MW7GRF+et4l8zltdtH9ytPSfv151jBhcyrFcH8vOywyhdRJLQOQhpUbGYs2z9Nhau3MzStVt55b3VlCxdT23MycmKcMp+PTh+aCGHDexCz4I2YZcrkjH25ByEAkJCt3F7NSVL1vHaonKen7uSdVurABjaI599e+Tz7dGD2LdHfshViqQ3BYSkvVjMWbByE298UM47ZeuYtWwDmyuqOaB3ASd8oTtjRvahW4e8sMsUSTsKCMk45ZsreXjqUqaUrqFk6XoA9ilsx6gBnTmoT0eO27ebAkOkERQQktE+LN/CpPmrmL54HSVL17O5ooa87AjfGT2IsYf2pWv73LBLFElZCghpNWpjzoflW7jthfd45b3VdGmXwzUnD2HsyL4aGFEkiT0JCA2sI2kpGjGGdM/ngUtG8vS3j2BgYTt+8vQ8TrjzdV6ct4p0+8VHJBXpCEIygrvz3JyV3P3KB7z/yRYALj9mIJcfu49uxhNBXzGJUF0b4/GSZYx/o4xl67ZR0Cabrx85gHHHDNTQ5tKqKSBE6pm3YiM3TphPydL1DO2Rzz1fOYRB3dqHXZZIKHQOQqSe/XsX8MS3juD28w7k4w3bOel3r/ObF99r0nwZIq2ZAkIy3vnFfZj8g9Gce0gR9732IZc/XEJNrUJC5PMoIKRVKMzP5bfnD+fa04by6qJyzr73LZat2xZ2WSIpTQEhrcrlxwzkF1/an0WrNnPina/z1/8uDrskkZSlgJBWxcz42mH9eGzcYYzo14lfPLeAMePfZs7yDWGXJpJyFBDSKhX378xDl47if0/dl9nLNnLWPW8xZvzbLPh4U9iliaQMBYS0WlnRCN8ePYip153ARYf3Y2rZOk6/601unDA/7NJEUkKgAWFmp5rZIjMrNbNrd7HPBWa2wMzmm9mjQdYjkkxB22xuPnt/nvzW4QA8OGUJT9Sb9U6ktQosIMwsCtwLnAYMA8aa2bAG+wwGrgOOdPf9gO8HVY/I5xnRrzMLbz6VvOwIN02YT1n5lrBLEglVkEcQo4BSdy9z9yrgMeDsBvt8E7jX3dcDuPvqAOsR+VxtcqI8Nu5wqmMxzv3jFJ2TkFYtyIDoDSyrt7480VbfEGCImb1lZlPN7NQA6xFplIP6dGTiVUfTJjvK6Xe9yV2TPwi7JJFQhH2SOgsYDIwGxgJ/MbOODXcys3FmVmJmJeXl5S1corRGAwvb849xhzGqf2fu/M/7/OTpuRpCXFqdIANiBdCn3npRoq2+5cAEd69298XA+8QDYyfuPt7di929uLCwMLCCRerr16Udj3zzUC45oj+PvPMRd00uDbskkRYVZEBMBwab2QAzywHGABMa7PMM8aMHzKwr8a+cygKsSaRJsqMRfnbmMM48sCe/e/l9Ji/8JOySRFpMYAHh7jXAlcAkYCHwuLvPN7ObzeysxG6TgLVmtgB4FfiRu68NqiaRPRGJGHdcMJx+Xdpy47PzKd9cGXZJIi1C80GINNKUD9dw8QPT2K9XAU9ccThZ0bBP4Yk0nuaDEAnQEft05ZYvH8CsZRt4Yd6qsMsRCZwCQqQJzj2kiIFd23Htk3NYsmZr2OWIBEoBIdIE0Yhx/8XFVMeca5+aQyyWXl/RijSFAkKkiQYWtud/T9mXqWXruOHf86hVSEiGUkCI7IFLjxzAsJ4dePSdjxj3UIluopOMpIAQ2QORiPH4FYdz5XGDmPzeaq7+56ywSxJpdgoIkT3UPjeLa04awuh9C3lm1sc8827DgQJE0psCQmQvRCLG3WMPpku7HK5+fBYvzlsZdkkizUYBIbKX8vOyeeWHo+naPpcr/m8mFdW1YZck0iwUECLNoKBNNj86ZV8Arn96bsjViDQPBYRIMzl/RBHd8nN5auYK3UQnGUEBIdJMzIx/jDsMgOPueI3KGn3VJOlNASHSjPYpbM9VJwzGHQ74+Uu6P0LSmgJCpJldc9IQzjmkN1W1Ma5/ep5CQtKWAkIkALefN5z9e3fgH9M+4vZJi8IuR2SPKCBEAhCNGH+5KD70/kNvL6W6NhZyRSJNp4AQCUjPgjbcf1ExWyprmLxwddjliDSZAkIkQMcN7UbPgjwenLI47FJEmkwBIRKgaMT4xtEDmVq2jikfrgm7HJEmUUCIBOx/Du1L9w65fOeRmVTV6FyEpI9AA8LMTjWzRWZWambXJtl+iZmVm9msxOMbQdYjEoa87CjfP3EI67dV86uJC8MuR6TRAgsIM4sC9wKnAcOAsWY2LMmu/3T3gxKP+4OqRyRMY0b2Yd/u+Tw4ZQlf++s7YZcj0ihBHkGMAkrdvczdq4DHgLMDfD+RlGVmPHzZKADe/GANB9/8UsgViXy+IAOiN7Cs3vryRFtD55rZHDN7wsz6BFiPSKi6dchj0S9PBWD9tmoOvvklDQ0uKS3sk9TPAv3d/UDgP8Dfk+1kZuPMrMTMSsrLy1u0QJHmlJsVZcYNJ9IhL4v126q599XSsEsS2aUgA2IFUP+IoCjRVsfd17p7ZWL1fmBEshdy9/HuXuzuxYWFhYEUK9JSurTPZdbPTiY/N4u7XynVqK+SsoIMiOnAYDMbYGY5wBhgQv0dzKxnvdWzAF3iIa1CJGJcf8YXALj+qXkhVyOSXGAB4e41wJXAJOI/+B939/lmdrOZnZXY7Sozm29ms4GrgEuCqkck1Ywd1Zd9u+fz7OyPWbWxIuxyRD7D0m0o4uLiYi8pKQm7DJFm8dHabRxz+6t874TBXH3SkLDLkQxmZjPcvbgpzwn7JLVIq9a3S1tGDejMs3M+DrsUkc9QQIiE7Lh9u1FWvpWN26rDLkVkJwoIkZAN6tYegA/XbAm5EpGdKSBEQjawsB0AZeVbQ65EZGcKCJGQ9e3clqyIUVauIwhJLQoIkZBlRyP07dxWRxCSchQQIilgYGE7ynQOQlKMAkIkBQwsbM+StduojaXXfUmS2RQQIilgYNd2VNXEWLF+e9iliNRRQIikgH0Sl7qWlm8OuRKRTykgRFLA0B75mMHc5ZvCLkWkjgJCJAXk52UzsGs75izfEHYpInUUECIpYnhRR2Yv30i6DaApmUsBIZIiDiwqYM2WSlZq6G9JEQoIkRRxYJ+OAMxepq+ZJDUoIERSxH69OpCXHeGdxevCLkUEUECIpIzcrCjF/Trz39I1YZciAiggRFLKkYO6Urp6C6WrdT+EhE8BIZJCvji8JwAvzlsVciUiCgiRlFLUqS1De+QztUznISR8gQaEmZ1qZovMrNTMrt3NfueamZtZkybUFslEhw3sQsnSdVTVxMIuRVq5wALCzKLAvcBpwDBgrJkNS7JfPvA94J2gahFJJyP6daKiOsaDUxaHXYq0ckEeQYwCSt29zN2rgMeAs5Ps9wvg14DuDhIBTvhCNwBeeW91yJVIaxdkQPQGltVbX55oq2NmhwB93P35AOsQSSttc7K4oLiIBR9voqZWXzNJeEI7SW1mEeBO4AeN2HecmZWYWUl5eXnwxYmE7Ngh3dhUUcNsDd4nIQoyIFYAfeqtFyXadsgH9gdeM7MlwGHAhGQnqt19vLsXu3txYWFhgCWLpIajBnUlYvDaIv1CJOEJMiCmA4PNbICZ5QBjgAk7Nrr7Rnfv6u793b0/MBU4y91LAqxJJC0UtM3mkL6deP19BYSEJ7CAcPca4EpgErAQeNzd55vZzWZ2VlDvK5Ipjh1SyJzlG1mzpTLsUqSVCvQchLtPdPch7r6Pu9+SaPuZu09Isu9oHT2IfOrYfeNfpz741pJwC5FWS3dSi6So/XsVAPD0uys0iZCEolEBYWbfM7MOFvdXM5tpZicHXZxIaxaJGDedtR8rNmxnluaIkBA09gjiUnffBJwMdAK+BtwWWFUiAsCXDu5Nl3Y5/On1D8MuRVqhxgaEJf48HXjY3efXaxORgBS0yebMA3syaf4nzF2+MexypJVpbEDMMLOXiAfEpMT4SbrFU6QFjDt2HwC+eM9/2VJZE3I10po0NiAuA64FRrr7NiAb+HpgVYlInd4d2zC0Rz4AlzwwTSespcU0NiAOBxa5+wYz+ypwA6DjXZEW8uDXR3HskEJKlq7n6n/OCrscaSUaGxB/BLaZ2XDiYyd9CDwUWFUispMeBXmMv2gEAM/M+ph1W6tCrkhag8YGRI3Hj2vPBu5x93uJj6UkIi0kNyvKf64+BoCRt7yskV4lcI0NiM1mdh3xy1ufT4zEmh1cWSKSzODu+ZyyX3dqY84Nz8wjFtP5CAlOYwPiQqCS+P0Qq4iPzHp7YFWJyC79+WvFnHNIbx6bvoyB10/UkYQEplEBkQiFR4ACMzsTqHB3nYMQCcktXzqgbvkXzy0IsRLJZI0dauMCYBpwPnAB8I6ZnRdkYSKya21yoky7/gSyIsbDU5fywSebwy5JMlBjv2L6CfF7IC5294uIzzf90+DKEpHP061DHlOuPZ687Ci3T1oUdjmSgRobEBF3rz+D+tomPFdEAtKtQx7fOGoALy34hAUfbwq7HMkwjf0h/6KZTTKzS8zsEuB5YGJwZYlIY110RH/yc7O48dn5YZciGaaxJ6l/BIwHDkw8xrv7j4MsTEQap2v7XK4YvQ/TFq/j9knvhV2OZJBGf03k7k+6+zWJx9NBFiUiTXPJEf0BuPfVDzVFqTSb3QaEmW02s01JHpvNTF94iqSIdrlZ3H9RMQA/fmJOyNVIpthtQLh7vrt3SPLId/cOLVWkiHy+E4d156zhvZj83momzl2pUV9lrwV6JZKZnWpmi8ys1MyuTbL9CjOba2azzOy/ZjYsyHpEMt0PTh6CGXz7kZn8/uUPwi5H0lxgAWFmUeBe4DRgGDA2SQA86u4HuPtBwG+AO4OqR6Q16NelHZO+Hx/Q7w+TP2BTRXXIFUk6C/IIYhRQ6u5l7l4FPEZ8NNg6iXmud2gH6JhYZC8N6Z7PI984FIDLHpwecjWSzoIMiN7AsnrryxNtOzGz75jZh8SPIK4KsB6RVuPIQV0BmL5kPcvXbwu5GklXod8N7e73uvs+wI+Jz1T3GWY2zsxKzKykvLy8ZQsUSVP/+OZhABz161dZvaki5GokHQUZECuAPvXWixJtu/IY8KVkG9x9vLsXu3txYWFhM5YokrkO36cLp+7XA4BRv5pMRXVtyBVJugkyIKYDg81sgJnlAGOACfV3MLPB9VbPAHTZhUgz+tPXRtC7YxsAhv70RVZt1JGENF5gAeHuNcCVwCRgIfC4u883s5vN7KzEblea2XwzmwVcA1wcVD0irdVb1x5Pz4I8AE75/Ru6P0IazdLtf5bi4mIvKSkJuwyRtHPOfW8x86MN3HrOAVxY3IdIxMIuSVqQmc1w9+KmPCf0k9Qi0jIev/xwcrIiXPfUXI38Ko2igBBpJbKiEe4aczAAD729lGXrdPmr7J4CQqQVOXX/Hrx17fFEDL5y/1TueGkRsVh6fc0sLUcBIdLK9O7YhiMHdWXZuu3c/UopA6+fyMKVGpxZPksBIdIK3XH+8J3WL31wOtuqakKqRlKVAkKkFerWIY8lt53BB7ecxr1fOYSVGyu4fdKisMuSFKOAEGnFsqMRzjiwJ70K8vjbW0u4+p+zwi5JUogCQkT43YUHAfD0uyu45fkFIVcjqUIBISIcOrALz191FAB/eXMxKzduD7kiSQUKCBEBYL9eBfz0zPicXk/N3N24mtJaKCBEpM5lRw1gVP/O3D5pEf+epZBo7RQQIrKTS4/qD8D3HpvF3OUbKV29RZMOtVIarE9EPuPvU5bw8wk7j9c0/6ZTaJebFVJFsrc0WJ+INIuLj+jPuYcU7dT21Lv6yqm10RGEiOyWu3PWPW8xd8VGFv3yVHKzomGXJHtARxAi0uzMjLGj+gJw/G9fZ0ulhuRoLRQQIvK5xozsw4lf6M6KDdv5rYbkaDUUECLyuSIR4/6LixleVMCTM5dTUV0bdknSAhQQItJoPz5tKJsranSPRCuhgBCRRjt8YBf2792Bn/57vuaQaAUCDQgzO9XMFplZqZldm2T7NWa2wMzmmNlkM+sXZD0isnfMjLvHHoIB59w3hVfe+4R0uxJSGi+wgDCzKHAvcBowDBhrZsMa7PYuUOzuBwJPAL8Jqh4RaR4DurbjkW8cyvbqWi59sIQB103UOYkMFeQRxCig1N3L3L0KeAw4u/4O7v6qu++4h38qUISIpLzi/p158luH162f/oc3Q6xGghJkQPQGltVbX55o25XLgBcCrEdEmtGIfp1ZctsZAJSt2cqk+atCrkiaW0oMrGJmXwWKgWN3sX0cMA6gb9++LViZiHyeyT84lhPueJ3LH55R1/b+L08jJ0vXwKS7ID/BFUCfeutFibadmNmJwE+As9y9MtkLuft4dy929+LCwsJAihWRPbNPYXue/vYRO7UNueEF5q3YGFJF0lyCDIjpwGAzG2BmOcAYYEL9HczsYODPxMNhdYC1iEiADu7biQ9/dToXFn/6O+GZd/9Xw3KkucACwt1rgCuBScBC4HF3n29mN5vZWYndbgfaA/8ys1lmNmEXLyciKS4aMX593oHM/vnJdW2H3PyfECuSvaXRXEWk2VXW1LLvDS8C8JVD+/KrLx8QckWi0VxFJCXkZkV5/5enAfDoOx/xt7cWU765UjfVpRkFhIgEIicrwrs/PYmIwU3PLmDkLS8zZvzUsMuSJlBAiEhgOrXL4cXvH1O3/s7idbzxfnmIFUlTKCBEJFBDuuez+NbTefbKowC46IFpTCldE3JV0hgKCBEJnJlxQFEB3z1+EABfuf8d+l/7PO9+tD7kymR3FBAi0mJ+cPK+nDfi0yHXvnzfFGpqYyFWJLujgBCRFnXrOQfw8GWjOHpwVwAG/eQFNlVUh1yVJKOAEJEWlR2NcPTgQh64ZGRd24E3vqSQSEEKCBEJRXY0Quktp9Wtn3jH68Riuk8ilSggRCQ0WdEIZb86HYDVmysZeP1EHp++7HOeJS1FASEioYpEjLevO75u/X+fnEOtjiRSggJCRELXs6ANr/5wdN36O2VrwytG6iggRCQlDOjajgU3nwLE75O4a/IHGrspZAoIEUkZbXOyGDMyPqfEnf95nwHXTWT91ip95RQSDfctIimnfHMlI295eae2P331EHp1bEPX9rn06tgmpMrS154M962AEJGUFIs5o341mTVbks5EzMvXHMOgbvktXFX60nwQIpIxIhGj5IYTKfvV6fxhzEF0aZez0/YT73yDiurakKprHXQEISJp47VFq/l4QwX3vVbK8vXbAfjbJSM5bmi3kCtLffqKSURajWNvf5Wla7cBUJify1UnDOZ/RvUlErGQK0tN+opJRFqNf3/nSHZkQfnmSn76zDwu/ts0/jHtIw3Z0UwCDQgzO9XMFplZqZldm2T7MWY208xqzOy8IGsRkczSsW0OZbeeweJbT69re/ODNVz31FxufHZ+iJVljsACwsyiwL3AacAwYKyZDWuw20fAJcCjQdUhIpnNzFhy2xm8fM2nU5s+9PZSlq3bFmJVmSHII4hRQKm7l7l7FfAYcHb9Hdx9ibvPATRjiIjslUHd8im95TRu/GL899D5H28MuaL0F2RA9AbqD8u4PNEmIhKIrGiEMaP6AvCDx2eHXE36S4uT1GY2zsxKzKykvLw87HJEJIXlZUcB2FqleyT2VpABsQLoU2+9KNHWZO4+3t2L3b24sLCwWYoTkcz1szPjXzOt2lgRciXpLciAmA4MNrMBZpYDjAEmBPh+IiIAHNy3IwDvfrQ+5ErSW2AB4e41wJXAJGAh8Li7zzezm83sLAAzG2lmy4HzgT+bma5NE5G9tl+vAnKyIsxYqoDYG1lBvri7TwQmNmj7Wb3l6cS/ehIRaTY5WREO7tORqYs18dDeSIuT1CIiTXXkoK7M/3gT67ZWhV1K2lJAiEhGOmZIIe4waf6qsEtJWwoIEclIw4sKGFjYjmfe3aOLJwUFhIhkKDPjrOG9mLZkHSs3bg+7nLSkgBCRjHX2Qb1xh+dmrwy7lLSkgBCRjDWgazuGFxVoCPA9pIAQkYx28RH9KVuzlSkf6pLXplJAiEhGO/2AnnRsm80j7ywNu5S0o4AQkYyWlx3lguI+vDBvFaWrN4ddTlpRQIhIxrv8mIEAnHjnG1TVaPqZxlJAiEjG69I+ly8d1AuAU37/Bu7pdcLa3Xnzg3LWbqls0fdVQIhIq3DnBQexX68OLF6zlefnptdlr6s3V/K1v05jYgvXrYAQkVYhEjH+/Z0jGdojnysffZd/z0qfO6x/O2kRAH06t23R9w10NFcRkVSSFY3wx6+O4Ljfvsb3HpvFzKXruens/cMuK6kT7niND8u30rldTt2Ag4f069SiNegIQkRalQFd2zHvplPo1Dabv7+9lP99Yja1KXYT3baqGj4s3wpQFw77ds+nQ152i9ahgBCRVqd9bhb//fHxHDmoC4+XLOewWyeHNl7TJ5squPiBafx71gqqamLUxpwz7vovANedNpSjB3cF4P6Li1u8Nku3s/nFxcVeUlISdhkikgFqamOMe3gGr7y3moI22dx6zgGctn8PzGyPX/Px6cuYu2IjPzx5Xwra7v43/h/9azb/mrF8l9tn/+zkz32NxjKzGe7epJRRQIhIq7dkzVa+8+hM5n+8id4d21DcvxOXH7MPX+iZ3+iwqKiu5U+vf8jvX/4AgML8XI4ZXMh+vTpwxKAuDOmWT8ydJ2cuZ93WamYsXc/LCz8B4nNXVNXUMrVsHW1zoowa0Jm7xx5MfjN+paSAEBHZQxXVtfxh8gc8OWM5qzd/er9BUac2dG2fS040wsgBneiWn8eZB/akS/vcun1+O2kRD7y1mG1VtQBcUFzEig3bmVq2ru78Ro8OeazaVPGZ953985MpaBMPgs0V1eRkRcjNijZ7/xQQIiJ7qTbmfLxhO5Pmr+Kl+Z+Qmx2hfHMl7636dJiOaMTo1DabNVuqyM2KUJm4O/v8EUWcfkBPjhvaDYBNFdUsW7eN5+es5I0PyildvYXhRR05Zb8eFHVqwzFDCsnLbv4wSEYBISISoGXrtlGydB0ffLKFeR9v4o33y9m/dwdOGNqdi4/oT+d2OWGXuEt7EhCB3gdhZqcCfwCiwP3ufluD7bnAQ8AIYC1wobsvCbImEZE91adz2xa/WS1MgV3mamZR4F7gNGAYMNbMhjXY7TJgvbsPAn4H/DqoekREpGmCvA9iFFDq7mXuXgU8BpzdYJ+zgb8nlp8ATrC9ub5MRESaTZAB0RtYVm99eaIt6T7uXgNsBLo0fCEzG2dmJWZWUl5eHlC5IiJSX1rcSe3u49292N2LCwsLwy5HRKRVCDIgVgB96q0XJdqS7mNmWUAB8ZPVIiISsiADYjow2MwGmFkOMAaY0GCfCcDFieXzgFc83a67FRHJUIFd5uruNWZ2JTCJ+GWuD7j7fDO7GShx9wnAX4GHzawUWEc8REREJAUEeh/cA6CYAAAJFElEQVSEu08EJjZo+1m95Qrg/CBrEBGRPZN2d1KbWTmwNLFaQPzKp8YsdwXW7MVb13/NPdkn2baGbbtb37Fcvy3MPjWmPw3b9Bk1XXN/RsnaW+NnVH853fq0p/3p5+5Nu8rH3dP2AYxv7DLxr7Wa5b32ZJ9k2xq27W69Xj/qt4XWp8b0R59R6n1Gje1Dpn9G6dynoPtT/5EWl7nuxrNNXG6u99qTfZJta9i2u/Vnd7HP3tibPjWmPw3b9Bk1XXN/RsnaW+Nn1NhaGqOl+xR0f+qk3VdMe8rMSryJA1WlukzrU6b1BzKvT5nWH8i8PjVnf9L9CKIpxoddQAAyrU+Z1h/IvD5lWn8g8/rUbP1pNUcQIiLSNK3pCEJERJpAASEiIkkpIEREJCkFBGBmR5vZn8zsfjObEnY9e8vMImZ2i5ndbWYXf/4zUp+ZjTazNxOf0+iw62kuZtYuMZT9mWHXsrfM7AuJz+cJM/tW2PU0BzP7kpn9xcz+aWYnh13P3jKzgWb2VzN7ojH7p31AmNkDZrbazOY1aD/VzBaZWamZXbu713D3N939CuA5Pp3AKBTN0R/iEzEVAdXE5+EIVTP1yYEtQB6Z0yeAHwOPB1Nl4zXTv6OFiX9HFwBHBllvYzRTn55x928CVwAXBlnv52mm/pS5+2WNfs90v4rJzI4h/oPjIXffP9EWBd4HTiL+w2Q6MJb4oIG3NniJS919deJ5jwOXufvmFir/M5qjP4nHenf/s5k94e7ntVT9yTRTn9a4e8zMugN3uvv/tFT9yTRTn4YTnyArj3j/nmuZ6j+ruf4dmdlZwLeAh9390ZaqP5lm/tlwB/CIu89sofI/o5n706ifC4EO1tcS3P0NM+vfoLluulMAM3sMONvdbwWSHsqbWV9gY5jhAM3THzNbDlQlVmuDq7ZxmuszSlgP5AZRZ1M00+c0GmhHfM727WY20d1jQda9K831GXl8lOYJZvY8EGpANNNnZMBtwAthhgM0+7+jRkn7gNiFZNOdHvo5z7kM+FtgFe2dpvbnKeBuMzsaeCPIwvZCk/pkZucApwAdgXuCLW2PNalP7v4TADO7hMQRUqDVNV1TP6PRwDnEA3zirvYLWVP/LX0XOBEoMLNB7v6nIIvbA039jLoAtwAHm9l1iSDZpUwNiCZz95+HXUNzcfdtxAMvY7j7U8SDL+O4+4Nh19Ac3P014LWQy2hW7n4XcFfYdTQXd19L/HxKo6T9SepdaMx0p+kk0/oD6lM6yLT+QOb1KdD+ZGpANGa603SSaf0B9SkdZFp/IPP6FGx/mmvc8LAewD+AlXx6SedlifbTiZ/d/xD4Sdh1ttb+qE/h19oa+5OJfQqjP2l/mauIiAQjU79iEhGRvaSAEBGRpBQQIiKSlAJCRESSUkCIiEhSCggREUlKASGBM7MtLfAeZzVyeO3mfM/RZnbEHjzvYDP7a2L5EjNLibGlzKx/w6Gkk+xTaGYvtlRNEi4FhKSNxNDGSbn7BHe/LYD33N14ZaOBJgcEcD1pOr6Pu5cDK80s9PkeJHgKCGlRZvYjM5tuZnPM7KZ67c+Y2Qwzm29m4+q1bzGzO8xsNnC4mS0xs5vMbKaZzTWzoYn96n4TN7MHzewuM5tiZmVmdl6iPWJm95nZe2b2HzObuGNbgxpfM7Pfm1kJ8D0z+6KZvWNm75rZy2bWPTHs8hXA1WY2y+KzEhaa2ZOJ/k1P9kPUzPKBA919dpJt/c3slcTfzeTEEPSY2T5mNjXR318mOyKz+Mx0z5vZbDObZ2YXJtpHJv4eZpvZNDPLT7zPm4m/w5nJjoLMLGpmt9f7rC6vt/kZINT5OKSFhH37uB6Z/wC2JP48GRgPGPFfTp4Djkls65z4sw0wD+iSWHfggnqvtQT4bmL528D9ieVLgHsSyw8C/0q8xzDi4+UDnEd8GOoI0IP43BLnJan3NeC+euud+HRyrW8AdySWbwR+WG+/R4GjEst9gYVJXvs44Ml66/Xrfha4OLF8KfBMYvk5YGxi+Yodf58NXvdc4C/11guAHKAMGJlo60B8BOe2QF6ibTBQkljuD8xLLI8Dbkgs5wIlwIDEem9gbtj/X+kR/EPDfUtLOjnxeDex3p74D6g3gKvM7MuJ9j6J9rXEJzx6ssHr7Bj2ewbx+QeSecbj8ysssPgsdABHAf9KtK8ys1d3U+s/6y0XAf80s57Ef+gu3sVzTgSGxeeYAaCDmbV39/q/8fcEynfx/MPr9edh4Df12r+UWH4U+G2S584F7jCzXwPPufubZnYAsNLdpwO4+yaIH20A95jZQcT/fockeb2TgQPrHWEVEP9MFgOrgV676INkEAWEtCQDbnX3P+/UGJ9o5kTgcHffZmavEZ+GE6DC3RvOileZ+LOWXf8/XFlv2Xaxz+5srbd8N/FpTickar1xF8+JAIe5e8VuXnc7n/at2bj7+2Z2CPGB235pZpOBp3ex+9XAJ8SnPI0Ayeo14kdqk5JsyyPeD8lwOgchLWkScKmZtQcws95m1o34b6frE+EwFDgsoPd/Czg3cS6iO/GTzI1RwKdj7F9cr30zkF9v/SXiM5ABkPgNvaGFwKBdvM8U4sM1Q/w7/jcTy1OJf4VEve07MbNewDZ3/z/gduAQYBHQ08xGJvbJT5x0LyB+ZBEDvkZ8/uKGJgHfMrPsxHOHJI48IH7EsdurnSQzKCCkxbj7S8S/InnbzOYCTxD/AfsikGVmC4nP/zs1oBKeJD5M8gLg/4CZwMZGPO9G4F9mNgNYU6/9WeDLO05SA1cBxYmTugtIMnOXu79HfPrK/IbbiIfL181sDvEf3N9LtH8fuCbRPmgXNR8ATDOzWcDPgV+6exVwIfHpZ2cD/yH+2/99wMWJtqHsfLS0w/3E/55mJi59/TOfHq0dBzyf5DmSYTTct7QqO84JWHxu3mnAke6+qoVruBrY7O73N3L/tsB2d3czG0P8hPXZgRa5+3reAM529/Vh1SAtQ+cgpLV5zsw6Ej/Z/IuWDoeEPwLnN2H/EcRPKhuwgfgVTqEws0Li52MUDq2AjiBERCQpnYMQEZGkFBAiIpKUAkJERJJSQIiISFIKCBERSUoBISIiSf0/HuRB1McbAw0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we will fit our model using and [SGDR learning rate schedule](https://github.com/amaiya/ktrain/blob/master/example-02-tuning-learning-rates.ipynb) by invoking the ```fit``` method with the *cycle_len* parameter (along with the *cycle_mult* parameter)." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 25000 samples, validate on 25000 samples\n", "Epoch 1/7\n", "25000/25000 [==============================] - 7s 263us/step - loss: 0.2105 - acc: 0.9461 - val_loss: 0.2481 - val_acc: 0.9187\n", "Epoch 2/7\n", "25000/25000 [==============================] - 7s 261us/step - loss: 0.0458 - acc: 0.9936 - val_loss: 0.2266 - val_acc: 0.9218\n", "Epoch 3/7\n", "25000/25000 [==============================] - 6s 257us/step - loss: 0.0082 - acc: 0.9999 - val_loss: 0.2236 - val_acc: 0.9228\n", "Epoch 4/7\n", "25000/25000 [==============================] - 6s 256us/step - loss: 0.0069 - acc: 0.9999 - val_loss: 0.2169 - val_acc: 0.9227\n", "Epoch 5/7\n", "25000/25000 [==============================] - 6s 259us/step - loss: 0.0029 - acc: 1.0000 - val_loss: 0.2148 - val_acc: 0.9227\n", "Epoch 6/7\n", "25000/25000 [==============================] - 7s 261us/step - loss: 0.0020 - acc: 1.0000 - val_loss: 0.2142 - val_acc: 0.9228\n", "Epoch 7/7\n", "25000/25000 [==============================] - 6s 255us/step - loss: 0.0017 - acc: 1.0000 - val_loss: 0.2141 - val_acc: 0.9227\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(0.001, 3, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### As can be seen, our final model yields a validation accuracy of 92.27%.\n", "\n", "### Making Predictions\n", "\n", "Let's predict the sentiment of new movie reviews (or comments in this case) using our trained model.\n", "\n", "The ```preproc``` object (returned by ```texts_from_folder```) is important here, as it is used to preprocess data in a way our model expects." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "data = [ 'This movie was horrible! The plot was boring. Acting was okay, though.',\n", " 'The film really sucked. I want my money back.',\n", " 'What a beautiful romantic comedy. 10/10 would see again!']" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['neg', 'neg', 'pos']" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As can be seen, our model returns predictions that appear to be correct. The predictor instance can also be used to return \"probabilities\" of our predictions with respect to each class. Let us first print the classes and their order. The class *pos* stands for positive sentiment and *neg* stands for negative sentiment. Then, we will re-run ```predictor.predict``` with *return_proba=True* to see the probabilities." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['neg', 'pos']" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.81179327, 0.18820675],\n", " [0.7463994 , 0.25360066],\n", " [0.26558533, 0.7344147 ]], dtype=float32)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(data, return_proba=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For text classifiers, there is also `predictor.predict_proba`, which is simply calls `predict` with `return_proba=True`.\n", "\n", "Our movie review sentiment predictor can be saved to disk and reloaded/re-used later as part of an application. This is illustrated below:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_moviereview_predictor')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.load_predictor('/tmp/my_moviereview_predictor')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['pos']" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(['Groundhog Day is my favorite movie of all time!'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that both the `load_predictor` and `get_predictor` functions accept an optional `batch_size` argument that is set to 32 by default. The `batch_size` can also be set manually on the `Predictor` instance. That is, the `batch_size` used for inference and predictions can be increased with either of the following:\n", "```python\n", "# you can set the batch_size as an argument to load_predictor (or get_predictor)\n", "predictor = ktrain.load_predictor('/tmp/my_moviereview_predictor', batch_size=128)\n", "\n", "# you can also set the batch_size used for predictions this way\n", "predictor.batch_size = 128\n", "```\n", "Larger batch sizes can potentially speed predictions when `predictor.predict` is supplied with a list of examples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-Label Text Classification: Identifying Toxic Online Comments\n", "\n", "In the previous example, the classes (or categories) were mutually exclusive. By contrast, in multi-label text classification, a document or text snippet can belong to multiple classes. Here, we will classify Wikipedia comments into one or more categories of so-called *toxic comments*. Categories of toxic online behavior include toxic, severe_toxic, obscene, threat, insult, and identity_hate. The dataset can be downloaded from the [Kaggle Toxic Comment Classification Challenge](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data) as a CSV file (i.e., download the file ```train.csv```). We will load the data using the ```texts_from_csv``` function. This function expects one column to contain the texts of documents and one or more other columns to store the labels. Labels can be in any of the following formats:\n", "\n", "```\n", "1. one-hot-encoded or arrays representing classes will have a single one in each row:\n", " Binary Classification (two classes):\n", " text|positive|negative\n", " I like this movie.|1|0\n", " I hated this movie.|0|1\n", " Multiclass Classification (more than two classes): \n", " text|negative|neutral|positive\n", " I hated this movie.|1|0|0 # negative\n", " I loved this movie.|0|0|1 # positive\n", " I saw the movie.|0|1|0 # neutral\n", "2. multi-hot-encoded arrays representing classes:\n", " Multi-label classification will have one or more ones in each row:\n", " text|politics|television|sports\n", " I will vote in 2020.|1|0|0 # politics\n", " I watched the debate on CNN.|1|1|0 # politics and television\n", " Did you watch the game on ESPN?|0|1|1 # sports and television\n", " I play basketball.|0|0|1 # sports \n", "3. labels are in a single column of string or integer values representing classs labels\n", " Example with label_columns=['label'] and text_column='text':\n", " text|label\n", " I like this movie.|positive\n", " I hated this movie.|negative\n", "```\n", "\n", "Since the Toxic Comment Classification Challenge is a multi-label problem, we must use the second format, where labels are already multi-hot-encoded. Luckily, the `train.csv` file for this problem is already multi-hot-encoded, so no extra processing is required. \n", "\n", "Since `val_filepath is None`, 10% of the data will automatically be used as a validation set.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 197340\n", "Nrows: 143613\n", "143613 train sequences\n", "Average train sequence length: 66\n", "15958 test sequences\n", "Average test sequence length: 66\n", "Pad sequences (samples x time)\n", "x_train shape: (143613,150)\n", "x_test shape: (15958,150)\n", "y_train shape: (143613,6)\n", "y_test shape: (15958,6)\n" ] } ], "source": [ "DATA_PATH = 'data/toxic-comments/train.csv'\n", "NUM_WORDS = 50000\n", "MAXLEN = 150\n", "trn, val, preproc = text.texts_from_csv(DATA_PATH,\n", " 'comment_text',\n", " label_columns = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"],\n", " val_filepath=None, # if None, 10% of data will be used for validation\n", " max_features=NUM_WORDS, maxlen=MAXLEN,\n", " ngram_range=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, as before, we load a text classification model and wrap the model and data in Learner object. Instead of using the NBSVM model, we will explicitly request a different model called fasttext using the ```name``` parameter of ```text_classifier```. The fastText architecture was created by [Facebook](https://arxiv.org/abs/1607.01759) in 2016. (You can call the ```print_textmodels``` to show the available text classification models.) " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nbsvm: NBSVM model (http://www.aclweb.org/anthology/P12-2018)\n", "fasttext: a fastText-like model (http://arxiv.org/pdf/1607.01759.pdf)\n", "logreg: logistic regression\n" ] } ], "source": [ "text.print_text_classifiers()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? True\n", "compiling word ID features...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('fasttext', trn, preproc=preproc)\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As before, we use our learning rate finder to find a good learning rate. In this case, a learning rate of 0.0007 appears to be good." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/5\n", "143613/143613 [==============================] - 47s 325us/step - loss: 0.7361 - acc: 0.5322\n", "Epoch 2/5\n", "143613/143613 [==============================] - 46s 323us/step - loss: 0.4683 - acc: 0.7714\n", "Epoch 3/5\n", "143613/143613 [==============================] - 46s 323us/step - loss: 0.0879 - acc: 0.9729\n", "Epoch 4/5\n", "143613/143613 [==============================] - 46s 323us/step - loss: 0.1106 - acc: 0.9686\n", "Epoch 5/5\n", "143613/143613 [==============================] - 46s 323us/step - loss: 0.1636 - acc: 0.9629\n", "\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEOCAYAAABmVAtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xl4VOX5//H3HVYRRITgwg5iERdQIqKiYkVEqGLFKrZaba1b3dpaf+Jad3Br1bpr+22rVUrdFQQVF6ioEDaR1bCobBL2nZDk/v0xk8kkTFZycmYmn9d1zZVznnPOzCcDmXvO9jzm7oiIiABkhB1ARESSh4qCiIjEqCiIiEiMioKIiMSoKIiISIyKgoiIxKgoiIhIjIqCiIjEqCiIiEiMioKIiMTUDztAVbVq1co7duwYdgwRkZQybdq0Ne6eWdF6KVcUOnbsSHZ2dtgxRERSipl9W5n1dPhIRERiVBRERCRGRUFERGJUFEREJEZFQUREYlQUREQkps4WBXdnzoqNYccQEUkqKXefQk3YujOf5yct5tEPvwFg6cjBIScSEUkOdW5PYf3WPA770/hYQQDoOHwMM75bT8fhYzj8T+OZsmQdKzZsB2DTjl3krN7Cjl0FsfUX5W7hzx8sjM27O2/NXM7gxyfRcfgYJsz7ofZ+IRGRGmTuHnaGKsnKyvLq3tE86ZtcLvrblEqv/8Ivs/jNv4pfa9gx7Xht+jJ2FVTuPXtsWE9O6prJOU9PZsmarVzQuz0jzjmiyrlFRPaUmU1z96wK16tLRaHj8DEl5p/+xdFc9e/pNRGrSqbceiq975sQm7/x9B/xxozl3D3kMI7v0qrW84hI+qtsUagzh4++WrYhNt2rQwvm3zOQM444kKUjBzPhhpMB6NG2OS9fdmy1X+OkQzLJ6tCiwvXiCwLAQ+MXkLN6Cz9//kt+889s8vILY8vmrdxEQWFqFW4RSV11Zk8hfi+hMieWX5nyHf+cvJQnfn4UB7duhrvz8PsLWLFhB786oSMtmjTk2YmLeOmL71h8/yAyMgyAbXn5zFu5GXCGPv15lXMWObPHQbwzawUAV5/ShRtP71bt5xIR0eGjUoqKQqumjci+rX+NZHF3Cgqd+vUqv8O1dWc+R939AScd0ooWTRrSv/v+XPHitAq3e+nSY5m1bAOLVm/hz+f33JPYIlIHqSiUMm/lJs54bBKL7h9Evei3+mTx+IRv2LIzn5nfbWDK0nWV3q5BPWNXgeuSWhGpUGWLQp25T+HQA/dJ2g/P607tWmK+sNDpfMtYAF6+7Fh+/vyXCbcrugrqq2UbOKJNc8ySq9iJSOqpM3sKqWzKknWc92zF5yfm3zOQxg3q1UIiEUk1OnyUpsbPWcUVL07jipM6g8Gzny5OuF6y7hWJSDhUFOqIhT9sZnLOGu58Z26J9tMP259nLuylQ0oiAiTJfQpmNtDMFphZjpkNT7D8L2Y2M/pYaGYbEj2PlO2Q/ZtxyQmddmsfP+cHOt08lhtGzwohlYikqsCKgpnVA54EzgC6AxeYWff4ddz99+7e0917An8FXg8qT7r71697c15WW969tm+J9temL6Pj8DFs2JYXUjIRSSVB7in0BnLcfbG75wGjgCHlrH8B8EqAedLaSYdk8uC5PTi8TXO+vOXU3Zb3vPsDfvbMZN0dLSLlCrIotAG+j5tfFm3bjZl1ADoBHwWYp87Yf5/GfPzHflzQux3Hd2kZa5+6dD1dbhlbosdXEZF4ydL30TDgVXdP+GllZpebWbaZZefm5tZytNTUqdXejDjnSF6+rA8nHNyyxLLLX5xGfkFhGVuKSF0WZFFYDrSLm28bbUtkGOUcOnL359w9y92zMjMzazBi3fDv3/ThgaHFXXZPXJjLwbe+F2IiEUlWQRaFqUBXM+tkZg2JfPC/XXolM+sGtACq33ucVOj8Y9rvdu9Cx+FjSvTIKiISWFFw93zgGmA8MA8Y7e5zzOxuMzsrbtVhwChPtRsmUtRrVx1fYv6wP40LKYmIJCPdvFYHxfetBJER4ob0THgNgIikiaS4eU2SU0aGMeuOAbH560fN3G1UOhGpm1QU6qjmTRokPMcgInWbikIdp8IgIvFUFEQ9qopIjIqCACULw8IfNoeYRETCpKIgMTcN7AbAgL9M5IVJicdpEJH0pqIgMVf16xKbvnfMPOau2BRiGhEJg4qClBB/c9ugxyexetOOENOISG1TUZASenVowcQbT4nND3h0YohpRKS2qSjIbtq3bML8ewYCsGHbLr5fty3kRCJSW1QUJKHGDerFpk988OMQk4hIbVJRkDJ9cfPuI7iJSHpTUZAyHdC8cWy64/Ax7NLAPCJpT0VByjXq8j6x6a4amEck7akoSLn6dC45lKfGdxZJbyoKUqH4LjBuGD2LVBuDQ0QqT0VBKuXzm38MwJjZK+l081gVBpE0paIglXJg871KzHe6eWwZa4pIKlNRkEpbOnIwe8XdvyAi6UdFQapkXvROZ0B3OoukoUCLgpkNNLMFZpZjZsPLWOc8M5trZnPM7OUg80jNuuWN2WFHEJEaFlhRMLN6wJPAGUB34AIz615qna7AzcAJ7n4Y8Lug8kjNmXPX6QBM+mYNm3bsCjmNiNSkIPcUegM57r7Y3fOAUcCQUutcBjzp7usB3H11gHmkhuzdqH5s+l+Tl4YXRERqXJBFoQ3wfdz8smhbvEOAQ8zsMzP7wswGIimhqHvth99fGHISEalJYZ9org90BfoBFwDPm9m+pVcys8vNLNvMsnNzc2s5oiTSvmWT2PS0b9eHmEREalKQRWE50C5uvm20Ld4y4G133+XuS4CFRIpECe7+nLtnuXtWZmZmYIGleoY+PTnsCCJSQ4IsClOBrmbWycwaAsOAt0ut8yaRvQTMrBWRw0kaMT5FLL5/UGw6e+m6EJOISE0JrCi4ez5wDTAemAeMdvc5Zna3mZ0VXW08sNbM5gIfAze6+9qgMknNysiw2PS5z3weYhIRqSn1K16l+tx9LDC2VNsdcdMO/CH6kBT02LCeXD9qJgBbd+aXuDJJRFJP2CeaJcUN6dmG3/c/BIDlG7aHnEZE9pSKguyxXh1aADDgLxNDTiIie0pFQfZYz/bFVxGv3Ki9BZFUpqIge6xpo/qcnxW5+vjal2eEnEZE9oSKgtSI+356OADZ367nxc+XhppFRKpPRUFqRP16xf+Vbn9rTohJRGRPqChIjXn9t8cDcG6vtiEnEZHqUlGQGnN0+xb0aLcvP2zaEXYUEakm3WkkNWp7Xj6zvt9AXn4hDevrO4dIqtFfrdSob9dGhug85Lb3Qk4iItWhoiA1asINJ8emJ+esCTGJiFSHioLUqLYtisdZ+PkLX4aYRESqQ0VBalz2bf0BOLVb65CTiEhVqShIjWvVtBEAE+av5s/vLwg5jYhUhYqCBCond0vYEUSkClQUJBBLRw4GYOzsVSEnEZGqUFGQwOUXFIYdQUQqSUVBArfgh81hRxCRSlJRkMB88sd+ANwwela4QUSk0gItCmY20MwWmFmOmQ1PsPwSM8s1s5nRx2+CzCO1q0PLyD0L81dt5vpRGmdBJBUEVhTMrB7wJHAG0B24wMy6J1j1P+7eM/p4Iag8UvvMLDb91swVISYRkcoKck+hN5Dj7ovdPQ8YBQwJ8PUkCS0ZMSg2vWNXQYhJRKQygiwKbYDv4+aXRdtKG2pmX5nZq2bWLsA8EgIzo1XThgCMzv6+grVFJGxhn2h+B+jo7kcCHwD/TLSSmV1uZtlmlp2bm1urAWXPjb7iOACWrd8echIRqUiQRWE5EP/Nv220Lcbd17r7zujsC0CvRE/k7s+5e5a7Z2VmZgYSVoJT1EnecxMXh5xERCoSZFGYCnQ1s05m1hAYBrwdv4KZHRg3exYwL8A8EpKG9TNoVD8jdjWSiCSvwIqCu+cD1wDjiXzYj3b3OWZ2t5mdFV3tOjObY2azgOuAS4LKI+E6/5h2rN+aF3YMEalAoMNxuvtYYGyptjvipm8Gbg4ygySHfZs0ZNOOfNZs2RnrRVVEkk/YJ5qljtgV7f9IN7GJJDcVBakVvzqhIwCf5awNN4iIlEtFQWpF62aNY9Pb8vJDTCIi5VFRkFrz235dAPhw3uqQk4hIWVQUpNYM7dUWgOte0XkFkWSloiC1pktm07AjiEgFVBQkFE9+nBN2BBFJQEVBalWPts0BeGj8gpCTiEgiKgpSq966pm/YEUSkHCoKUut+eVwHmjaqT2Ghhx1FREpRUZBad+iB+7BlZz7LN6grbZFko6Igte6Q/SNXIZ344MchJxGR0lQUpNYd3qZ52BFEpAwqClLrGtWvF5tWlxciyaVSRcHMrjezfSzib2Y23cwGBB1O0tdJh0RG0JusDvJEkkpl9xR+7e6bgAFAC+AiYGRgqSTtjTjnCABen7Es5CQiEq+yRcGiPwcBL7r7nLg2kSo7qHmk19Sxs1eFnERE4lW2KEwzs/eJFIXxZtYMKAwulqQ7s+LvFPkF+q8kkiwqWxQuBYYDx7j7NqAB8KvAUkmd8MvjOgDw/KQlIScRkSKVLQrHAQvcfYOZXQjcBmwMLpbUBWcf1QaAB8bNDzmJiBSpbFF4GthmZj2AG4BFwL8q2sjMBprZAjPLMbPh5aw31MzczLIqmUfSwNHtW4QdQURKqWxRyHd3B4YAT7j7k0Cz8jYws3rAk8AZQHfgAjPrnmC9ZsD1wJdVCS7p5dOFuWFHEBEqXxQ2m9nNRC5FHWNmGUTOK5SnN5Dj7ovdPQ8YRaSolHYP8ACwo5JZJI0UdXlx8d+nhJxERKDyReF8YCeR+xVWAW2BhyrYpg3wfdz8smhbjJkdDbRz9zGVzCFp5m11pS2SVCpVFKKF4N9AczP7CbDD3Ss8p1Ce6N7Gn4mco6ho3cvNLNvMsnNzdZghnTRuUNzlxaqN2lkUCVtlu7k4D5gC/Aw4D/jSzM6tYLPlQLu4+bbRtiLNgMOBT8xsKdAHeDvRyWZ3f87ds9w9KzMzszKRJYVkRG9ZeGXKd+EGEZFKHz66lcg9Che7+y+JnC+4vYJtpgJdzayTmTUEhgFvFy10943u3srdO7p7R+AL4Cx3z67ybyEp7ZkLewHw2IRvQk4iIpUtChnuvjpufm1F27p7PnANMB6YB4x29zlmdreZnVWttJKWBhx2AABDeh4UchIRqV/J9caZ2Xjglej8+cDYijZy97Gl13P3O8pYt18ls0gaqp9hvDVzBY8NOyrsKCJ1WqWKgrvfaGZDgROiTc+5+xvBxZK6Jl/jNYskhcruKeDurwGvBZhF6rAG9YxdBc4zny7iypO7hB1HpM4q97yAmW02s00JHpvNbFNthZT0d9dZhwPw5ozlFawpIkGq6GRxM3ffJ8GjmbvvU1shJf0NOyZy9XLTRpXeeRWRAGiMZkkKGdGbFbK/XR9yEpG6TUVBRERiVBQk6Tyo8RVEQqOiIEnjtsGHAvDUJ4tCTiJSd6koSNL4zYmdAWjVtGHISUTqLhUFSTprtuSFHUGkzlJRkKSiS1JFwqWiIEnlsIMit7+c/+znIScRqZtUFCSpzF+1GYAvl6xj6878kNOI1D0qCpJUZt5xWmx6dPb35awpIkFQUZCkYmZc0Ls9AHe9MzfkNCJ1j4qCJJ0R5xwRdgSROktFQUREYlQUJCld++ODAdiik80itUpFQZJSy70jdzW/9MW3IScRqVsCLQpmNtDMFphZjpkNT7D8SjObbWYzzex/ZtY9yDySOk477AAARr6nzvFEalNgRcHM6gFPAmcA3YELEnzov+zuR7h7T+BB4M9B5ZHU0mbfvcKOIFInBbmn0BvIcffF7p4HjAKGxK/g7vFDeu4NaPR22c2aLTvDjiBSZwRZFNoA8XcfLYu2lWBmV5vZIiJ7CtcFmEdSTJfMvQEYNeW7kJOI1B2hn2h29yfdvQtwE3BbonXM7HIzyzaz7Nzc3NoNKKEZfkZkfIWN23eFnESk7giyKCwH2sXNt422lWUUcHaiBe7+nLtnuXtWZmZmDUaUZNb34FYA7NtE4yuI1JYgi8JUoKuZdTKzhsAw4O34Fcysa9zsYOCbAPNIitmrYT0AHhq/IOQkInVHYEXB3fOBa4DxwDxgtLvPMbO7zeys6GrXmNkcM5sJ/AG4OKg8kto+WbA67AgidUKgI5q4+1hgbKm2O+Kmrw/y9SV9XPJ/U1k6cnDYMUTSXugnmkXK87+bTgk7gkidoqIgSa1tiyax6bW6X0EkcCoKkvTO7HEQoM7xRGqDioIkvTOPPBCAOSs2VbCmiOwpFQVJes0aNwDgt/+eHnISkfSnoiBJr0/n/cKOIFJnqChI0jOz2PSCVZtDTCKS/lQUJCX8pm8nADZsyws5iUh6U1GQlDC0V1sAzn/ui5CTiKQ3FQVJCfvv0zg2Pe3bdSEmEUlvKgqSEvbbu7in1KFPfx5iEpH0pqIgKeO1q44HoGF9/bcVCYr+uiRl9OrQAoC8/MKQk4ikLxUFSUkFhRrOWyQIKgqSkp6duCjsCCJpSUVBUsofBxwCwIPjNBqbSBBUFCSlnNb9gNj0fWPmhphEpPbk5Rdy7SszWLJma+CvpaIgKeVHBzSLTT8/aUmISURqT/a363hn1gpufv2rwF9LRUFSztRb+8emN+/YFWISkdrRtFFk5OSj27cI/LVUFCTltGpafCPbEXe+H2ISkeBt3LaLN2esAKBnu30Df71Ai4KZDTSzBWaWY2bDEyz/g5nNNbOvzGyCmXUIMo+kBzPjT2d2j81rmE5JF1OXrmPhDyV7Ar7mlen8/bPIodIG9YL/Hh/YK5hZPeBJ4AygO3CBmXUvtdoMIMvdjwReBR4MKo+kl4uP6xibfvh9XYkk6eFnz3zOgL9MjM1P/249k75ZE5uvX88SbVajgiw7vYEcd1/s7nnAKGBI/Aru/rG7b4vOfgG0DTCPpJGMjOI/jqlL14eYRKTm7dhVwJwVGznnqckl2ldvCn6vOMii0Ab4Pm5+WbStLJcC7wWYR9LMlFtOBSBn9ZaQk4jUrNzNO7k6wfCzKzZsD/y1k+JEs5ldCGQBD5Wx/HIzyzaz7Nzc3NoNJ0mrdVx32q9M+S7EJCJ7Zub3G3hh0uLY/NyVm1i6dttu6/XqmNpXHy0H2sXNt422lWBm/YFbgbPcPeG+kbs/5+5Z7p6VmZkZSFhJbTe/PjvsCCLV8unCXM5+8jPuHTMv1nbFi9PokeBKo4MzmwaeJ8iiMBXoamadzKwhMAx4O34FMzsKeJZIQVgdYBZJUwe3Lv4jyd2sq5AkPO7OYx9+w7L1u3/DL8+3axPfpXzYQfvs1ha/dxyUwIqCu+cD1wDjgXnAaHefY2Z3m9lZ0dUeApoC/zWzmWb2dhlPJ5LQ+N+dFJs+5r4PQ0widd3Stdv4y4cL6fvAx1XarujGtNJe/jKcQ6KJ09QQdx8LjC3VdkfcdP/dNhKpgnoZxu0/6c4976ofJElN3yTZhRJJcaJZZE9c2rdTbPqtmbudthKpFTt2FQDQrIxv/mXZd68Gu7U1aVhvt7aiHoKDpqIgaeX6UTPDjiB1wORFa+g4fAzzV22KtW2PFoXNO/Or9FwH7rtXifkTDm5J+/2alGjrf+j+XPPjrtVMWzUqCpIWFt8/KDZdqFHZJGDjv14FwP/i7ja+653iQ5hrSnW9snH7LjoOH0PH4WOY+f0GOg4fE9ur3VVqeNnWzRozf1XJri7OObq8W7xqloqCpIX4O5w73zK2nDVF9oy788/PvwXg3jHzKCx0duwqYNb3G2LrLCj1od7jruKOG89+8jOgeK/29re+BuCKkzsz7bb+jJm9ssS2L17amzMOP4DaoqIgaaN3p/3CjiBpKL+gkDve+prv10UuNf3l36eUWN7jrvfpdvu4Em1FndotWLWZ2cs2lvncD4ybz7a8yGGny07sTMumjcgrtedwYtdMzILv86iIioKkjb+c3zM2PearlWzSWAspYceuAtZvzavWtpO+yeXj+YlvcXJ3Hhw3n5Ubd+8aYv3WPL5LcMdwvB827eC7tdv4cN5q/vX5t9zw31nMWbGxRAd1fQ9ulfAcwl3vzCW/oJDTH53ImU/8r8zXePqT4rHGi3pA7dWh+K7lZy7sVW7GIKgoSNpoE3fC7uqXp3OkxlpIeuu25tHt9nEcdc8HACxbv42P5v9QYp2CQmfj9uICv2bLTqZ9G+kE8aK/TeFX/5haojB8u3YrO3YV8PnitTz1ySKOG/HRbq971D0fcNJDu99P0P/Pn/LqtGUAHHv/BE566GOufGkaAJu272Lw4yU/4Bs32P0qoSJ3vjOn3N+9tAbRHlCvOKlzrG1gLR42KqKiIGmtaBznrTvzY5cMSvI4OloMALbszKfvAx/z639kl/i3uvHVWfS46/3YBQRXvjiNoU9PZmd+8Tq/+sdU8vILKSx0Tn7oE7rdPo4J84oLxba8fDbtiJzsHRt3zH5XQfGhmvFzVpGzegt//O+shFnnr9rMwMNKfkhv2l5yb/T4Li1j0y99UbWbzzKih4hO674/Z/c8iDevPqFK29cUc0+tKzWysrI8Ozs77BiSpH7YtINj759Q5vKlIwfXYhpJ5IFx8+l2QDOG9GxDx+FjYu1n9jiId2ZFRhi7tG8nbhl0KMfc9yHrooeWTuu+P51b7c2zExcnfF6Ad6/ty0/+Gvk2f+2PD+avH+WUm+VnvdrywNAjycgwLnzhS/6XEzk0dOXJXXjm00XlbpvIPo3rszWvgIIEV8A1rJ+x2/mCeEtGDAr03IGZTXP3rIrW056CpJX9K+gbZtzXK8tdLsF7+pNFXD9qJuc983mJ9qKCANCj3b7c9ubXsYIA8MHcH8otCECsIAAVFgSA/05bxruzV/Lsp4tiBQGosCBcdmKnhO23De7O1accnHDZzDtOi01/dMPJseme7fZl6cjBtXoyuTwqCpJ2HhvWk8Pb7N6ZGMCVL01nxNh5CZdJ8OI7LZyydF2Z6xUUFtZad+jXvTKDEe/Nr9I2tw4uOYjkxcd14OyeB/HTo9vw+/6JbzJr0rA+T/3iaF676jg6x/V2WlYRCYsOH0na6nP/BFZt2lHm8sX3Dypxf4PUrMJCZ/XmnRzQPLL3tjh3Cz9+5NNKbXt8l5ZMXrQ2yHhVclr3/flgbvEJ8KUjB5c49JVz3xnUjxs/eceuAj6ev5p6GUZms0Zs3L6Lfj9qXeI5Zy/byJ3vzGHU5X1qZexlHT6SOm/EOUfEphsm+KPrfMvYEle1SM265Y3Z9BkxgaVrIl1DV7YgABUWhD6d9+P/LjmGsdedyKL7B3HTwG6Vfu57zz680utC5AP/+V9m7Ta+QdHIf0CJggCRq5LOOOJABhx2AEe1b7FbQQA4om1zXrvq+FopCFWRXGlEatAp3Voz8pwjmH77aSy87wyaJ+h4rMdd77N6c9l7E1I9a7fsZNTUyGi8/R7+pMz1urZuytd3nc7dQw4D4OGf9ajU84+6/DhO6daa7gftQ70M46p+XTg/q91u6z049EguPq5DbP7Svp24sE8H3r22L1/dOaDEur077ceFfdqXaBt73YmxD/yHzz2yxLLW+zTml8dFniud6PCR1ClvzljO7/6ze6d5uiqpZj35cQ4PjV8Qm5944ykl7guY9P9OYb+9G7J3gh5F4w/LlHZi11bcePqPOLLt7qOSld72kuM7cudZh5VoL33I8PY3v+bFLyJdViwdOZhJ3+Ry0d+m8MyFR5O7JY+L+hQXFHfn4fcXcF5WOzq03Lvc3z8ZVfbwUaDjKYgkm7OPasP6bXklOi87oBZGs6pr4gsCsNuNYu1K9QJalr9fksXoqcsYNyfSAd2Llx5b7vpTbjmVZo0bsFeprqe7HdCM+as273YO6bafHMp367bx+AVHAZEuJebdPXC37QHMjBtPr/xhqlSlPQWpk0595BMW5RYPgzj06LY8cl7lDl1Ixcr7tl/RXpm78/fPlnJRnw40rB85dPPWzOV0bd2M7gmGqKyMjdt3sWLDdg49sHrbpwOdaBYpx4Qb+vGH04oHLXlt+jJS7QtSslmxYTuDH59U4k7jFk1Knse5ql+XCp/HzLi0b6dYQQAY0rNNtQsCQPO9GtTpglAVKgpSZ113asnryYtOjGo8hqorLHSOH/kRc1Zs4up/z2DwEQcCcHT7FiXWu3HAj8KIJ1UQaFEws4FmtsDMcsxseILlJ5nZdDPLN7Nzg8wiksj8ewbGpm9+fTaPfriQzreM5evlZXd3LLs7fmRxp3MN6hmNGmTQZt+96Nu1Vaz9lkHddF9ICgjsRLOZ1QOeBE4DlgFTzextd48fYf074BLgj0HlEClP6V4uH/3wGyDSXcLSkYPZsaug3J4wJSL+JsH3oqOStdy7Ib84tgPb8gq4tG8nvY8pIsg9hd5Ajrsvdvc8YBQwJH4Fd1/q7l8BZfcSJRKwsk58jp29km63j+ObHzYnXC4RZR1uW7s1j4b1M7j6lINVEFJIkEWhDfB93PyyaJtISvjtv6cDcNpfJoacJLmM+3olOasjhfLdr1Zo+NM0kxInms3scjPLNrPs3NzcsONIGnr5smP5Td9OXPfjxJ2T/T7BDW91kbtz5UvT6f/nSKG85uUZsWW/OLZ9pa4ukuQWZFFYDsTfd9422lZl7v6cu2e5e1ZmZmaNhBOJd3yXVtz2k+78YcCP+OSP/XZb/saM5RQWOtlL1zHju/W1H7AWFRZ6mZfn/iV6zgVg4KMl96AKHW4a2C12Ger/bjoluJASmCDvaJ4KdDWzTkSKwTDg5wG+nkiN6NgqcRcG8YdJ3rv+RDq0bEKThpE/ocJCT4sra3I37+SY+z4EIudapixZx31j5vL3S46hoNB5fEJxUZi/quS5lu/WRW4GnHFHyT6FJLUEtqfg7vnANcB4YB4w2t3nmNndZnYWgJkdY2bLgJ8Bz5pZ1QY1FQnIU784GoDs2/onXH7GY5Pofsd4znv2c578OIfOt4xlcs4aOg4fw+DHJ5X5vOO+XkXH4WNC7521sNATniAuPYbBec9+zqxlG+l174f0LmdEO4iMViapT91ciFSgvC4byjPwsAP4dd9O9O60H7sKClm1cQcnPhjpA+iiPh24J9qFs7tT6FCvmnsac1e7AO6uAAALjElEQVRs4tY3Z/PGbyNj+q7etAMMWjdrzHMTF3H/2MgAMo/8rAdDe7Ut8TstuHcgjerXI7+gkEc//IYnPq54tLJE3v/9SRyyf7NqbSu1Qx3iidSQr+4cwJF3vl/l7cbNWcW4OatYOnIwXW99r8Syti32YuvOfA770/hYW0V9Aq3bmseM79Zz6qH7l2gfFN0zeWHSYu4dUzyq3NKRg2MFAeCG/85iaK+2fBg3WMyPbhvHG789np8+NbnSv9fUW/uT2awR170yg4NbN+UXx7anZdNGld5ekpv2FEQqYdI3ubw2bRlvzlxR8cqlvHn1CZz95GcVrvenM7vTtXUz3pixnE6tmvDw+wuZfecARmcv4553i+/5nHjjKbRvWdzLaHX3ZKpL3YynpsruKagoiFRR3wc+Ytn67SwZMYhON0dOPt80sBsPjJtPj3b78splx9L9jvEVPMueWXz/IIr+crvU4H0C++3dkI//2I95Kzcx7LkvEq6jopCaVBREQjTt23UMffrzsGMAcP2pXXks7qqhIo0bZLBjV8nOBMr7wP9+3TZaNm0Yu+JKUou6zhYJUa8O+5UYBrIsz13Uq8Zf+4mfH8XZPQ+KzV/z44NZeO8ZLBkxiH/86phY+/9u+jFLRgyq9PO226+JCkIdoD0FkYAsWLWZ0x+dyLMX9eLA5o05/KDmZGQYm3fs4v6x87n/p4ezNa+A4+6fwLGdW/LhvMgJ4KwOLcj+dj0t927IhX06cHyXlpxfxqGcIhNuOJlTH/kUKP62n5dfyIZtebSuYGS5XQWFHHPfhzw27ChOPkQ3h6YrHT4SSTGbd+yicYN6NKi3+w78Zzlr+MULX7JkxCB+2LSTxz/6hgyDl774jteuOp5eHVokeEaRYioKIiISo3MKIiJSZSoKIiISo6IgIiIxKgoiIhKjoiAiIjEqCiIiEqOiICIiMSoKIiISk3I3r5nZRqCod6/mwMa4xUXzZf1sBaypwsuVfv6KlsW3VTStbBXPK1vZ7cpWfi5l210Hd6+4HxN3T6kH8Fyi6fj5cn5mV/e1KrOsrGyJppWt4nllq9y/Z13Otid/B3U5W3mPVDx89E4Z0/HzZf3ck9eqzLKysiWaVraK55Wt7HZlq3gbZauGlDt8tCfMLNsr0fdHGJStepStepSteupCtlTcU9gTz4UdoBzKVj3KVj3KVj1pn61O7SmIiEj56tqegoiIlENFQUREYlQUREQkRkUhysxONLNnzOwFM5scdp54ZpZhZveZ2V/N7OKw88Qzs35mNin63vULO09pZra3mWWb2U/CzhLPzA6NvmevmtlVYeeJZ2Znm9nzZvYfMxsQdp54ZtbZzP5mZq8mQZa9zeyf0ffqF2Hnibcn71NaFAUz+7uZrTazr0u1DzSzBWaWY2bDy3sOd5/k7lcC7wL/TKZswBCgLbALWJZk2RzYAjROwmwANwGjaypXTWVz93nR/2/nASckWbY33f0y4Erg/CTLttjdL62pTHuY8Rzg1eh7dVZQmaqTbY/ep5q4Ay7sB3AScDTwdVxbPWAR0BloCMwCugNHEPngj3+0jttuNNAsmbIBw4Erotu+mmTZMqLb7Q/8O8mynQYMAy4BfpJM2aLbnAW8B/w82bJFt3sEODpJs9XY38EeZLwZ6Bld5+Ug8lQ32568T/VJA+4+0cw6lmruDeS4+2IAMxsFDHH3EUDCQwlm1h7Y6O6bkymbmS0D8qKzBcmULc56oFEyZYseztqbyB/wdjMb6+6FyZAt+jxvA2+b2Rjg5T3NVVPZzMyAkcB77j69JnLVVLagVSUjkT3jtsBMauGoSxWzza3u66TF4aMytAG+j5tfFm0rz6XA/wWWqFhVs70OnG5mfwUmBhmMKmYzs3PM7FngReCJZMrm7re6+++IfOA+XxMFoaayRc/FPB5978YGmKvK2YBrgf7AuWZ2ZZDBqPr71tLMngGOMrObA85WpKyMrwNDzexpaririSpImG1P3qe02FOoKe7+p7AzJOLu24gUrKTj7q8T+eNIWu7+j7AzlObunwCfhBwjIXd/HHg87ByJuPtaIuc6QufuW4FfhZ0jkT15n9J5T2E50C5uvm20LRkoW/UoW/Uo255J5ow1ni2di8JUoKuZdTKzhkROOL4dcqYiylY9ylY9yrZnkjljzWcL+ox5bTyAV4CVFF+yeWm0fRCwkMjZ+VuVTdmUTdlSNWNtZVOHeCIiEpPOh49ERKSKVBRERCRGRUFERGJUFEREJEZFQUREYlQUREQkRkVBAmdmW2rhNc6qZFfaNfma/czs+Gpsd5SZ/S06fYmZBd1nVKWYWcfS3TInWCfTzMbVViapfSoKkjLMrF5Zy9z9bXcfGcBrltc/WD+gykUBuIUk7VuoIu6eC6w0sxobA0KSi4qC1Cozu9HMpprZV2Z2V1z7m2Y2zczmmNnlce1bzOwRM5sFHGdmS83sLjObbmazzaxbdL3YN24z+0e0B9LJZrbYzM6NtmeY2VNmNt/MPjCzsUXLSmX8xMweNbNs4HozO9PMvjSzGWb2oZntH+3C+Erg92Y20yIj92Wa2WvR329qog9OM2sGHOnusxIs62hmH0XfmwkW6codM+tiZl9Ef997E+15WWQUsDFmNsvMvjaz86Ptx0Tfh1lmNsXMmkVfZ1L0PZyeaG/HzOqZ2UNx/1ZXxC1+E0iqkcakBoV5S7kedeMBbIn+HAA8BxiRLyTvAidFl+0X/bkX8DXQMjrvwHlxz7UUuDY6/Vvghej0JcAT0el/AP+NvkZ3Iv3NA5xLpJvqDOAAImNAnJsg7yfAU3HzLSB29/9vgEei03cCf4xb72Wgb3S6PTAvwXOfArwWNx+f+x3g4uj0r4E3o9PvAhdEp68sej9LPe9QIt2DF803JzLoymLgmGjbPkR6Rm4CNI62dQWyo9MdiQ7gAlwO3BadbgRkA52i822A2WH/v9IjmIe6zpbaNCD6mBGdb0rkQ2kicJ2Z/TTa3i7avpbIoEKvlXqeoq66pxEZEjGRNz0yfsJcM9s/2tYX+G+0fZWZfVxO1v/ETbcF/mNmBxL5oF1Sxjb9ge5mVjS/j5k1dff4b/YHArllbH9c3O/zIvBgXPvZ0emXgYcTbDsbeMTMHgDedfdJZnYEsNLdpwK4+yaI7FUAT5hZTyLv7yEJnm8AcGTcnlRzIv8mS4DVwEFl/A6S4lQUpDYZMMLdny3RGBkhrT9wnLtvM7NPiIz5DLDD3UuPNrcz+rOAsv8P74ybtjLWKc/WuOm/An9297ejWe8sY5sMoI+77yjnebdT/LvVGHdfaGZHE+kc7V4zmwC8Ucbqvwd+AHoQyZworxHZIxufYFljIr+HpCGdU5DaNB74tZk1BTCzNmbWmsi30PXRgtAN6BPQ639GZKSsjOjeQ79Kbtec4j7qL45r3ww0i5t/n8ioZQBEv4mXNg84uIzXmUyk62OIHLOfFJ3+gsjhIeKWl2BmBwHb3P0l4CEiY/kuAA40s2Oi6zSLnjhvTmQPohC4iMg4v6WNB64yswbRbQ+J7mFAZM+i3KuUJHWpKEitcff3iRz++NzMZgOvEvlQHQfUN7N5RMYG/iKgCK8R6XJ4LvASMB3YWInt7gT+a2bTgDVx7e8APy060QxcB2RFT8zOJcHIV+4+H2gePeFc2rXAr8zsKyIf1tdH238H/CHafnAZmY8AppjZTOBPwL3ungecD/w1eqL+AyLf8p8CLo62daPkXlGRF4i8T9Ojl6k+S/Fe2SnAmATbSBpQ19lSpxQd4zezlsAU4AR3X1XLGX4PbHb3Fyq5fhNgu7u7mQ0jctJ5SKAhy88zERji7uvDyiDB0TkFqWveNbN9iZwwvqe2C0LU08DPqrB+LyInhg3YQOTKpFCYWSaR8ysqCGlKewoiIhKjcwoiIhKjoiAiIjEqCiIiEqOiICIiMSoKIiISo6IgIiIx/x/PmBUwcQQDLgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we will train our model for 8 epochs using ```autofit``` with a learning rate of 0.0007. Having explicitly specified the number of epochs, ```autofit``` will automatically employ a triangular learning rate policy. Our final ROC-AUC score is **0.98**.\n", "\n", "As shown in [this example notebook](https://github.com/amaiya/ktrain/blob/master/examples/text/toxic_comments-bigru.ipynb) on our GitHub project, even better results can be obtained using a Bidirectional GRU with pretrained word vectors (called ‘bigru’ in ktrain)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.0007...\n", "Train on 143613 samples, validate on 15958 samples\n", "Epoch 1/8\n", "143613/143613 [==============================] - 48s 333us/step - loss: 0.1140 - acc: 0.9630 - val_loss: 0.0530 - val_acc: 0.9812\n", "Epoch 2/8\n", "143613/143613 [==============================] - 47s 330us/step - loss: 0.0625 - acc: 0.9790 - val_loss: 0.0501 - val_acc: 0.9819\n", "Epoch 3/8\n", "143613/143613 [==============================] - 48s 331us/step - loss: 0.0572 - acc: 0.9801 - val_loss: 0.0491 - val_acc: 0.9821\n", "Epoch 4/8\n", "143613/143613 [==============================] - 47s 331us/step - loss: 0.0538 - acc: 0.9806 - val_loss: 0.0481 - val_acc: 0.9823\n", "Epoch 5/8\n", "143613/143613 [==============================] - 47s 329us/step - loss: 0.0517 - acc: 0.9813 - val_loss: 0.0476 - val_acc: 0.9823\n", "Epoch 6/8\n", "143613/143613 [==============================] - 47s 329us/step - loss: 0.0501 - acc: 0.9815 - val_loss: 0.0470 - val_acc: 0.9825\n", "Epoch 7/8\n", "143613/143613 [==============================] - 47s 331us/step - loss: 0.0486 - acc: 0.9820 - val_loss: 0.0468 - val_acc: 0.9824\n", "Epoch 8/8\n", "143613/143613 [==============================] - 47s 330us/step - loss: 0.0471 - acc: 0.9824 - val_loss: 0.0470 - val_acc: 0.9826\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(0.0007, 8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Let's compute for ROC-AUC of our final model for identifying toxic online behavior:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " ROC-AUC score: 0.980092 \n", "\n" ] } ], "source": [ "from sklearn.metrics import roc_auc_score\n", "y_pred = learner.model.predict(x_test, verbose=0)\n", "score = roc_auc_score(y_test, y_pred)\n", "print(\"\\n ROC-AUC score: %.6f \\n\" % (score))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making Predictions\n", "\n", "As before, let's make some predictions about toxic comments using our model by wrapping it in a Predictor instance." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[('toxic', 0.5491581),\n", " ('severe_toxic', 0.02454061),\n", " ('obscene', 0.084347874),\n", " ('threat', 0.4110818),\n", " ('insult', 0.17229997),\n", " ('identity_hate', 0.08519211)]]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# correctly predict a toxic comment that includes a threat\n", "predictor.predict([\"If you don't stop immediately, I will kill you.\"])" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[('toxic', 0.021799222),\n", " ('severe_toxic', 7.991817e-07),\n", " ('obscene', 0.000504758),\n", " ('threat', 5.477591e-05),\n", " ('insult', 0.001496369),\n", " ('identity_hate', 9.472556e-05)]]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# non-toxic comment\n", "predictor.predict([\"Okay - I'm calling it a night. See you tomorrow.\"])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/toxic_detector')" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.load_predictor('/tmp/toxic_detector')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[('toxic', 0.86799675),\n", " ('severe_toxic', 0.008107864),\n", " ('obscene', 0.26740596),\n", " ('threat', 0.006626291),\n", " ('insult', 0.39607796),\n", " ('identity_hate', 0.023489485)]]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# model works correctly and as expected after reloading from disk\n", "predictor.predict([\"You have a really ugly face.\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The `Transformers` API in *ktrain*\n", "\n", "If using transformer models like BERT or DistilBert or RoBERTa, *ktrain* includes an alternative API for text classification, which allows the use of **any** Hugging Face `transformers` model. This API can be used as follows:\n", "\n", "```python\n", "import ktrain\n", "from ktrain import text\n", "MODEL_NAME = 'bert-base-uncased'\n", "t = text.Transformer(MODEL_NAME, maxlen=500, \n", " classes=label_list)\n", "trn = t.preprocess_train(x_train, y_train)\n", "val = t.preprocess_test(x_test, y_test)\n", "model = t.get_classifier()\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)\n", "learner.fit_onecycle(3e-5, 1)\n", "```\n", "\n", "Note that `x_train` and `x_test` are the raw texts here:\n", "```python\n", "x_train = ['I hate this movie.', 'I like this movie.']\n", "```\n", "Similar to `texts_from_array`, the labels are arrays in one of the following forms:\n", "```python\n", "# string labels\n", "y_train = ['negative', 'positive']\n", "# integer labels\n", "y_train = [0, 1]\n", "# multi or one-hot encoded labels\n", "y_train = [[1,0], [0,1]]\n", "```\n", "In the latter two cases, you must supply a `class_names` argument to the `Transformer` constructor, which tells *ktrain* how indices map to class names. In this case, `class_names=['negative', 'positive']` because 0=negative and 1=positive.\n", "\n", "For an example, see [this notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/ArabicHotelReviews-AraBERT.ipynb), which builds and Arabic sentiment analysis model using [AraBERT](https://huggingface.co/aubmindlab/bert-base-arabert).\n", "\n", "\n", "For more information, see our tutorial on [text classification with Hugging Face Transformers](https://github.com/amaiya/ktrain/blob/master/tutorials/tutorial-A3-hugging_face_transformers.ipynb).\n", "\n", "You may be also interested in some of our blog posts on text classification:\n", "- [Text Classification With Hugging Face Transformers in TensorFlow 2 (Without Tears)](https://towardsdatascience.com/text-classification-with-hugging-face-transformers-in-tensorflow-2-without-tears-ee50e4f3e7ed)\n", "- [BERT Text Classification in 3 Lines of Code](https://towardsdatascience.com/bert-text-classification-in-3-lines-of-code-using-keras-264db7e7a358)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }