{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will apply ktrain to the dataset employed in the **scikit-learn** [*Working with Text Data*](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html) tutorial. As in the tutorial, we will sample 4 newgroups to create a small multiclass text classification dataset. Let's fetch the [20newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) using **scikit-learn**." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR'])\n", "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n", "['/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/comp.graphics/38440'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/comp.graphics/38479'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20737'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20942'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20487']\n", "[1 1 3 3 3]\n", "From: sd345@city.ac.uk (Michael Collier)\n", "Subject: Converting images to HP LaserJet III?\n", "Nntp-Posting-Host: hampton\n", "Organization: The City University\n", "Lines: 14\n", "\n", "Does anyone know of a good way (standard PC application/PD utility) to\n", "convert tif/img/tga files into LaserJet III format. We would also li\n", "1\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "# # inspect\n", "print(train_b.keys())\n", "print(train_b['target_names'])\n", "print(train_b['filenames'][:5])\n", "print(train_b['target'][:5])\n", "print(train_b['data'][0][:300])\n", "print(train_b['target'][0])\n", "#print(set(train_b['target']))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 36393\n", "Nrows: 2257\n", "2257 train sequences\n", "Average train sequence length: 321\n", "x_train shape: (2257,350)\n", "y_train shape: (2257,4)\n", "1502 test sequences\n", "Average test sequence length: 342\n", "x_test shape: (1502,350)\n", "y_test shape: (1502,4)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n", " x_test=x_test, y_test=y_test,\n", " class_names=train_b.target_names,\n", " ngram_range=1, \n", " maxlen=350, \n", " max_features=35000)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "max_features is 27645\n", "maxlen is 350\n", "building document-term matrix... this may take a few moments...\n", "rows: 1-2257\n", "computing log-count ratios...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('nbsvm', train_data=(x_train, y_train), preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=(x_train, y_train), val_data=(x_test, y_test))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 1s 293us/step - loss: 1.4131 - acc: 0.3549\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 1.4120 - acc: 0.3567\n", "Epoch 3/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.4098 - acc: 0.3589\n", "Epoch 4/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 1.4052 - acc: 0.3624\n", "Epoch 5/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.3961 - acc: 0.3753\n", "Epoch 6/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 1.3781 - acc: 0.4072\n", "Epoch 7/1024\n", "2257/2257 [==============================] - 0s 146us/step - loss: 1.3429 - acc: 0.4657\n", "Epoch 8/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.2772 - acc: 0.5782\n", "Epoch 9/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.1628 - acc: 0.7297\n", "Epoch 10/1024\n", "2257/2257 [==============================] - 0s 146us/step - loss: 0.9845 - acc: 0.8666\n", "Epoch 11/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 0.7499 - acc: 0.9526\n", "Epoch 12/1024\n", "2257/2257 [==============================] - 0s 145us/step - loss: 0.5081 - acc: 0.9809\n", "Epoch 13/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 0.3162 - acc: 0.9880\n", "Epoch 14/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 0.1874 - acc: 0.9951\n", "Epoch 15/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.1073 - acc: 0.9973\n", "Epoch 16/1024\n", "2257/2257 [==============================] - 0s 154us/step - loss: 0.0584 - acc: 0.9982\n", "Epoch 17/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 0.0292 - acc: 0.9996\n", "Epoch 18/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0140 - acc: 1.0000\n", "Epoch 19/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 0.0065 - acc: 1.0000\n", "Epoch 20/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0030 - acc: 1.0000\n", "Epoch 21/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0014 - acc: 1.0000\n", "Epoch 22/1024\n", "2257/2257 [==============================] - 0s 145us/step - loss: 6.4536e-04 - acc: 1.0000\n", "Epoch 23/1024\n", "2257/2257 [==============================] - 0s 149us/step - loss: 3.0513e-04 - acc: 1.0000\n", "Epoch 24/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.4576e-04 - acc: 1.0000\n", "Epoch 25/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 6.8675e-05 - acc: 1.0000\n", "Epoch 26/1024\n", "2257/2257 [==============================] - 0s 159us/step - loss: 3.2571e-05 - acc: 1.0000\n", "Epoch 27/1024\n", "2257/2257 [==============================] - 0s 158us/step - loss: 1.5664e-05 - acc: 1.0000\n", "Epoch 28/1024\n", "2257/2257 [==============================] - 0s 154us/step - loss: 7.4376e-06 - acc: 1.0000\n", "Epoch 29/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 3.5614e-06 - acc: 1.0000\n", "Epoch 30/1024\n", "2257/2257 [==============================] - 0s 151us/step - loss: 1.7064e-06 - acc: 1.0000\n", "Epoch 31/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 8.4131e-07 - acc: 1.0000\n", "Epoch 32/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 4.3300e-07 - acc: 1.0000\n", "Epoch 33/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 2.4621e-07 - acc: 1.0000\n", "Epoch 34/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 1.6358e-07 - acc: 1.0000\n", "Epoch 35/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.3281e-07 - acc: 1.0000\n", "Epoch 36/1024\n", "2257/2257 [==============================] - 0s 151us/step - loss: 1.2280e-07 - acc: 1.0000\n", "Epoch 37/1024\n", "2257/2257 [==============================] - 0s 155us/step - loss: 1.1990e-07 - acc: 1.0000\n", "Epoch 38/1024\n", "2257/2257 [==============================] - 0s 160us/step - loss: 1.1939e-07 - acc: 1.0000\n", "Epoch 39/1024\n", "2257/2257 [==============================] - 0s 158us/step - loss: 1.1921e-07 - acc: 1.0000\n", "Epoch 40/1024\n", "1760/2257 [======================>.......] - ETA: 0s - loss: 1.2209e-07 - acc: 1.0000\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.01...\n", "Train on 2257 samples, validate on 1502 samples\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 1s 234us/step - loss: 0.2529 - acc: 0.9313 - val_loss: 0.2315 - val_acc: 0.9301\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0155 - acc: 0.9978 - val_loss: 0.2289 - val_acc: 0.9294\n", "Epoch 3/1024\n", "2257/2257 [==============================] - 0s 221us/step - loss: 0.0085 - acc: 0.9996 - val_loss: 0.2292 - val_acc: 0.9294\n", "Epoch 4/1024\n", "2257/2257 [==============================] - 0s 221us/step - loss: 0.0064 - acc: 0.9996 - val_loss: 0.2285 - val_acc: 0.9288\n", "Epoch 5/1024\n", "2257/2257 [==============================] - 1s 222us/step - loss: 0.0050 - acc: 1.0000 - val_loss: 0.2288 - val_acc: 0.9288\n", "Epoch 6/1024\n", "2257/2257 [==============================] - 0s 217us/step - loss: 0.0041 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9294\n", "Epoch 7/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0035 - acc: 1.0000 - val_loss: 0.2277 - val_acc: 0.9294\n", "Epoch 8/1024\n", "2257/2257 [==============================] - 0s 209us/step - loss: 0.0031 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9308\n", "\n", "Epoch 00008: Reducing Max LR on Plateau: new max lr will be 0.005 (if not early_stopping).\n", "Epoch 9/1024\n", "2257/2257 [==============================] - 0s 208us/step - loss: 0.0028 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9301\n", "Epoch 10/1024\n", "2257/2257 [==============================] - 0s 213us/step - loss: 0.0026 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9301\n", "\n", "Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.0025 (if not early_stopping).\n", "Epoch 11/1024\n", "2257/2257 [==============================] - 0s 215us/step - loss: 0.0025 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9301\n", "Epoch 12/1024\n", "2257/2257 [==============================] - 0s 218us/step - loss: 0.0024 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9314\n", "Epoch 13/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0024 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9314\n", "\n", "Epoch 00013: Reducing Max LR on Plateau: new max lr will be 0.00125 (if not early_stopping).\n", "Epoch 14/1024\n", "2257/2257 [==============================] - 0s 215us/step - loss: 0.0023 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9314\n", "Epoch 15/1024\n", "2257/2257 [==============================] - 0s 218us/step - loss: 0.0023 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9314\n", "\n", "Epoch 00015: Reducing Max LR on Plateau: new max lr will be 0.000625 (if not early_stopping).\n", "Epoch 16/1024\n", "2257/2257 [==============================] - 0s 201us/step - loss: 0.0022 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9314\n", "Restoring model weights from the end of the best epoch\n", "Epoch 00016: early stopping\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(0.01)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.93 0.86 0.89 319\n", " 1 0.95 0.95 0.95 389\n", " 2 0.95 0.92 0.94 396\n", " 3 0.90 0.97 0.93 398\n", "\n", " accuracy 0.93 1502\n", " macro avg 0.93 0.93 0.93 1502\n", "weighted avg 0.93 0.93 0.93 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[274, 4, 8, 33],\n", " [ 4, 371, 10, 4],\n", " [ 10, 13, 366, 7],\n", " [ 6, 4, 2, 386]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['sci.med', 'sci.med', 'sci.med']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(test_b.data[0:3])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 2, 2])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_b.target[:3]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }