{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will apply ktrain to the dataset employed in the **scikit-learn** [*Working with Text Data*](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html) tutorial. As in the tutorial, we will sample 4 newgroups to create a small multiclass text classification dataset. Let's fetch the [20newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) using **scikit-learn**." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR'])\n", "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n", "['/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/comp.graphics/38440'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/comp.graphics/38479'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20737'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20942'\n", " '/home/amaiya/scikit_learn_data/20news_home/20news-bydate-train/soc.religion.christian/20487']\n", "[1 1 3 3 3]\n", "From: sd345@city.ac.uk (Michael Collier)\n", "Subject: Converting images to HP LaserJet III?\n", "Nntp-Posting-Host: hampton\n", "Organization: The City University\n", "Lines: 14\n", "\n", "Does anyone know of a good way (standard PC application/PD utility) to\n", "convert tif/img/tga files into LaserJet III format. We would also li\n", "1\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "# # inspect\n", "print(train_b.keys())\n", "print(train_b['target_names'])\n", "print(train_b['filenames'][:5])\n", "print(train_b['target'][:5])\n", "print(train_b['data'][0][:300])\n", "print(train_b['target'][0])\n", "#print(set(train_b['target']))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 36393\n", "Nrows: 2257\n", "2257 train sequences\n", "Average train sequence length: 321\n", "x_train shape: (2257,350)\n", "y_train shape: (2257,4)\n", "1502 test sequences\n", "Average test sequence length: 342\n", "x_test shape: (1502,350)\n", "y_test shape: (1502,4)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n", " x_test=x_test, y_test=y_test,\n", " class_names=train_b.target_names,\n", " ngram_range=1, \n", " maxlen=350, \n", " max_features=35000)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "compiling word ID features...\n", "max_features is 27645\n", "maxlen is 350\n", "building document-term matrix... this may take a few moments...\n", "rows: 1-2257\n", "computing log-count ratios...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('nbsvm', train_data=(x_train, y_train), preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=(x_train, y_train), val_data=(x_test, y_test))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 1s 293us/step - loss: 1.4131 - acc: 0.3549\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 1.4120 - acc: 0.3567\n", "Epoch 3/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.4098 - acc: 0.3589\n", "Epoch 4/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 1.4052 - acc: 0.3624\n", "Epoch 5/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.3961 - acc: 0.3753\n", "Epoch 6/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 1.3781 - acc: 0.4072\n", "Epoch 7/1024\n", "2257/2257 [==============================] - 0s 146us/step - loss: 1.3429 - acc: 0.4657\n", "Epoch 8/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.2772 - acc: 0.5782\n", "Epoch 9/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 1.1628 - acc: 0.7297\n", "Epoch 10/1024\n", "2257/2257 [==============================] - 0s 146us/step - loss: 0.9845 - acc: 0.8666\n", "Epoch 11/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 0.7499 - acc: 0.9526\n", "Epoch 12/1024\n", "2257/2257 [==============================] - 0s 145us/step - loss: 0.5081 - acc: 0.9809\n", "Epoch 13/1024\n", "2257/2257 [==============================] - 0s 147us/step - loss: 0.3162 - acc: 0.9880\n", "Epoch 14/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 0.1874 - acc: 0.9951\n", "Epoch 15/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.1073 - acc: 0.9973\n", "Epoch 16/1024\n", "2257/2257 [==============================] - 0s 154us/step - loss: 0.0584 - acc: 0.9982\n", "Epoch 17/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 0.0292 - acc: 0.9996\n", "Epoch 18/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0140 - acc: 1.0000\n", "Epoch 19/1024\n", "2257/2257 [==============================] - 0s 152us/step - loss: 0.0065 - acc: 1.0000\n", "Epoch 20/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0030 - acc: 1.0000\n", "Epoch 21/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 0.0014 - acc: 1.0000\n", "Epoch 22/1024\n", "2257/2257 [==============================] - 0s 145us/step - loss: 6.4536e-04 - acc: 1.0000\n", "Epoch 23/1024\n", "2257/2257 [==============================] - 0s 149us/step - loss: 3.0513e-04 - acc: 1.0000\n", "Epoch 24/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.4576e-04 - acc: 1.0000\n", "Epoch 25/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 6.8675e-05 - acc: 1.0000\n", "Epoch 26/1024\n", "2257/2257 [==============================] - 0s 159us/step - loss: 3.2571e-05 - acc: 1.0000\n", "Epoch 27/1024\n", "2257/2257 [==============================] - 0s 158us/step - loss: 1.5664e-05 - acc: 1.0000\n", "Epoch 28/1024\n", "2257/2257 [==============================] - 0s 154us/step - loss: 7.4376e-06 - acc: 1.0000\n", "Epoch 29/1024\n", "2257/2257 [==============================] - 0s 150us/step - loss: 3.5614e-06 - acc: 1.0000\n", "Epoch 30/1024\n", "2257/2257 [==============================] - 0s 151us/step - loss: 1.7064e-06 - acc: 1.0000\n", "Epoch 31/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 8.4131e-07 - acc: 1.0000\n", "Epoch 32/1024\n", "2257/2257 [==============================] - 0s 156us/step - loss: 4.3300e-07 - acc: 1.0000\n", "Epoch 33/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 2.4621e-07 - acc: 1.0000\n", "Epoch 34/1024\n", "2257/2257 [==============================] - 0s 153us/step - loss: 1.6358e-07 - acc: 1.0000\n", "Epoch 35/1024\n", "2257/2257 [==============================] - 0s 148us/step - loss: 1.3281e-07 - acc: 1.0000\n", "Epoch 36/1024\n", "2257/2257 [==============================] - 0s 151us/step - loss: 1.2280e-07 - acc: 1.0000\n", "Epoch 37/1024\n", "2257/2257 [==============================] - 0s 155us/step - loss: 1.1990e-07 - acc: 1.0000\n", "Epoch 38/1024\n", "2257/2257 [==============================] - 0s 160us/step - loss: 1.1939e-07 - acc: 1.0000\n", "Epoch 39/1024\n", "2257/2257 [==============================] - 0s 158us/step - loss: 1.1921e-07 - acc: 1.0000\n", "Epoch 40/1024\n", "1760/2257 [======================>.......] - ETA: 0s - loss: 1.2209e-07 - acc: 1.0000\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de3xV9Znv8c+TGwkkJJCEW0IIN8EoIBgQvExhipc6U6mtbWFsq7WW4+nYsdNOZ9rTc6y1PVOnjnNxqrXUWse2aq22HlQUtV6wKkpEUAkXwz1cwz1cQkjynD/2BjcxCQlmZe3L9/165ZW1fmvttZ/FJvlmrd9a62fujoiIpK60sAsQEZFwKQhERFKcgkBEJMUpCEREUpyCQEQkxSkIRERSXEbYBXRVUVGRl5eXh12GiEhCeeutt3a5e3FbyxIuCMrLy6mqqgq7DBGRhGJmG9tbplNDIiIpTkEgIpLiFAQiIilOQSAikuIUBCIiKU5BICKS4lIqCA40HGPj7kNhlyEiEldSJgg27znM+Fue5WO3v8T/W7aFmp0H21zP3Vlbd5CjTc0ntbe0OKu313O4saknyhUR6TGB3VBmZvcBfw3sdPezO1hvMvA6MNvdHw2qnqWb9p6YvunhZSemPz2xhAMNx1i1vZ7avUdOes2YgXms3lHPnClDeWzpFhqbWk5a/so/zmBIQQ7paXairbnFT5oXEYl3FtQIZWb2F8BB4IH2gsDM0oHngAbgvs4EQWVlpZ/uncX3vLyW255edVqv7Z2VTnlhH6q3HfjQstxeGRhQfzRytHDNtGGUF/XBgFueqGbaiELu+NwEhhTknPS65hZn8brdjBmUR1Fur9OqS0SkM8zsLXevbHNZkENVmlk58GQHQfAN4BgwObpeoEHQ2oqt+/nhk9Vkpqex62Aj3770DP5y7EAAGo4182rNLvr3yeLJd7bxrUvOoHdW5ADK3dm85wh/9/DbLNu8D4D8nEz2HzlGXnYG9Q1tnz46d1g/Jg4toLRfDs+v3Mmfa3adtLxicF9+ctV4zi7J75b9ExE5Li6DwMxKgAeBGcB9dBAEZjYXmAtQVlZ27saN7T4yI3QtLc6SDXt4YPFG+mZncM7QAv749hYWr9vT6W2cObgvj94wjT69Eu5RUCISp+I1CH4P3OHui83sfkI4IuhJ1VsP0OLO1367lHEl+Vx69iAuHFVE/z5ZQOQo4zdvbOL/PP4eEOmfuPaCci49a9CJdURETle8BsF64HivahFwGJjr7o93tM1EDYKueOqdbfztg0tPajt/ZKSfYXB+TjuvEhFpX0dBENq5B3cffnw65oigwxBIFX81fjBDCs7n5TV1bNpzmD8s3cJra3cz7ccvUDG4Lz/7wiSGFfYJu0wRSRJBXj76EDAdKDKzWuD7QCaAu98T1Psmi4ll/ZhY1g+AG2eM4o7n1vDUO9uo3naAj9/xMk98/ULOHNw35CpFJBkEemooCKlwaqg9LS3OovfruPZXSwBY/N2PMyg/O+SqRCQRdHRqKGXuLE4GaWnG9DEDuPdLkc9y6o//xOrt9SFXJSKJTkGQgGZWDOS6CyJdLJf+xyL+7dnVIVckIolMQZCgbv5kBb/68mQA7nyhhiUbOn+fgohILAVBApsxZgAvfOtjAFz9izfYtPtwyBWJSCJSECS4EcW5/Ofsc2hsbuHif3+ZI43Np36RiEgMBUESmHVOCbd9ehxHm1q48u5XeWHVjrBLEpEEoiBIErOnlHHDx0ayans9191fxe+rNoddkogkCAVBEvnHS8fw39dNYUh+Nj94opqanbq0VEROTUGQRNLSjI+dUcxDc6eSnmZ86Zdvsv/IsbDLEpE4pyBIQsMK+/CTq8azdX8D/7pQ9xiISMcUBEnq0rMGcUnFQH69eCPv1O4LuxwRiWMKgiT2z58eB8D1/52az2YSkc5RECSxotxeTB3Rn531R08MqSki0pqCIMnde81kcntlcM9La8MuRUTilIIgyeX2yuD6i4bzzIrtvFu7P+xyRCQOKQhSwHUXDicnM527X6oh0cafEJHgKQhSQN/sTOZMKePp97bz8pq6sMsRkTijIEgR37h4NL2z0rnt6VU0NrWEXY6IxJHAgsDM7jOznWb2XjvLrzazd8zsXTN7zcwmBFWLRI4K/nP2RFZtr+f6B3Q5qYh8IMgjgvuByzpYvh74mLuPA34IzAuwFgEurhjI2EF5LFpTx+Y9GrtARCICCwJ3XwS0O2yWu7/m7nujs4uB0qBqkQ/cd+1ksjLSuPXJ6rBLEZE4ES99BF8Bng67iFQwpCCHuReN4PmVO6jZeTDsckQkDoQeBGY2g0gQ/FMH68w1syozq6qr01UvH9WXzh9GZloaD76xKexSRCQOhBoEZjYeuBeY5e6721vP3ee5e6W7VxYXF/dcgUlqQF42M8YW88Q7W2lu0X0FIqkutCAwszLgD8AX3X1NWHWkqlnnlFBXf5TF69rNXxFJERlBbdjMHgKmA0VmVgt8H8gEcPd7gJuBQuBuMwNocvfKoOqRk/3l2AHk9srg8be3cMGoorDLEZEQBRYE7j7nFMuvB64P6v2lY9mZ6VxSMZAF727jR1eeTa+M9LBLEpGQhN5ZLOG5fNxgDjU281z1jrBLEZEQKQhS2IyxAxiQ14v5y7aGXYqIhEhBkMLS04zLxw3m5TV1HGlsDrscEQmJgiDFXVIxkKNNLbywamfYpYhISBQEKe68EYUM7NuLP75dG3YpIhISBUGKS08zPjWxhBdX17Fl35GwyxGRECgIhL+ZUkZzi6vTWCRFKQiEYYV9OG94f36zeGPYpYhICBQEAsAlZw1iy74jbNXpIZGUoyAQAM4b3h9AYxqLpCAFgQBw1pC+9O+TxVsb9556ZRFJKgoCAcDMmFTWT0EgkoIUBHLCRaOLWL/rEKu2Hwi7FBHpQQoCOeGyswcBsEj9BCIpRUEgJwzsm82YgXl6GqlIilEQyEmuOGcISzbsZfOew2GXIiI9REEgJ/nUxBLM4KE3NbC9SKpQEMhJSgpymHnmQB6p2kzDMT2aWiQVKAjkQ/5mShm7Djbyuga2F0kJgQWBmd1nZjvN7L12lpuZ3WlmNWb2jplNCqoW6ZppIwvJykjjlTW7wi5FRHpAkEcE9wOXdbD8E8Do6Ndc4GcB1iJdkJ2ZztQRhby4eifuHnY5IhKwwILA3RcBezpYZRbwgEcsBgrMbHBQ9UjXXHzmANbvOsTauoNhlyIiAQuzj6AE2BwzXxtt+xAzm2tmVWZWVVenm516wvQxAwB4ba36CUSSXUJ0Frv7PHevdPfK4uLisMtJCaX9cijK7cXyzfvDLkVEAhZmEGwBhsbMl0bbJA6YGRNK81leuy/sUkQkYGEGwXzgS9Grh6YC+919W4j1SCsThhawtu4g9Q3Hwi5FRAKUEdSGzewhYDpQZGa1wPeBTAB3vwdYAFwO1ACHgS8HVYucnglDC3CHd2v3c/6oorDLEZGABBYE7j7nFMsd+Nug3l8+unNKCwBYummvgkAkiSVEZ7GEI793JqMG5LJ0k/oJRJKZgkA6NKmsgKWb9urGMpEkpiCQDk0q68e+w8dYt+tQ2KWISEAUBNKhScP6AbBUYxmLJC0FgXRoVHEu+TmZLNnQ0dNCRCSRKQikQ2lpxuTy/ry5XkEgkqwUBHJKU0f0Z8Puw2xQP4FIUlIQyCnNPHMgAC+v0QP/RJKRgkBOqbyoDyUFObyuJ5GKJCUFgXTKtJGFLF6/m5YW3U8gkmwUBNIp548sZN/hY6zaXh92KSLSzRQE0inTRhYC8NpajWMskmwUBNIpg/NzKC/szeJ16icQSTYKAum0aSOLeGPdHpqaW8IuRUS6kYJAOm3ayELqjzaxYuuBsEsRkW6kIJBOmzqiPwCv6/SQSFJREEinDcjLZvSAXJ6v3hF2KSLSjRQE0iWXnT2ItzbtZf9hjWMskiwUBNIlF40uxl2nh0SSSaBBYGaXmdlqM6sxs++0sbzMzF40s7fN7B0zuzzIeuSjO2doAb2z0vXcIZEkElgQmFk6cBfwCaACmGNmFa1W+9/AI+4+EZgN3B1UPdI9sjLSuHBUES+u2qnHTYgkiSCPCKYANe6+zt0bgYeBWa3WcaBvdDof2BpgPdJNZowdwPYDDazeocdNiCSDIIOgBNgcM18bbYt1C/AFM6sFFgBfb2tDZjbXzKrMrKquTqckwnbhqCIAqjRqmUhSCLuzeA5wv7uXApcDvzazD9Xk7vPcvdLdK4uLi3u8SDlZab8cBvXN5s0NGsdYJBkEGQRbgKEx86XRtlhfAR4BcPfXgWygKMCapBuYGZXl/XREIJIkggyCJcBoMxtuZllEOoPnt1pnE/BxADM7k0gQ6NxPAphQWsC2/Q3sPng07FJE5CMKLAjcvQm4EVgIrCRyddAKM7vVzK6IrvYt4Ktmthx4CLjW3XUpSgKoGBLp41+5TR3GIokuI8iNu/sCIp3AsW03x0xXAxcEWYMEo2JwJAhWbN3PhaN1Nk8kkXXqiMDMbjKzvhbxSzNbamaXBF2cxK9+fbIYVtibJeowFkl4nT01dJ27HwAuAfoBXwRuC6wqSQjTRhSyZMMe3VgmkuA6GwQW/X458Gt3XxHTJilqcnl/9h85xpqd6icQSWSdDYK3zOxZIkGw0MzyAA1TleIml0fGJ1iyXpeRiiSyzgbBV4DvAJPd/TCQCXw5sKokIQztn8OAvF5UbVQ/gUgi62wQTANWu/s+M/sCkYfF7Q+uLEkEZsaksn4s3aQgEElknQ2CnwGHzWwCkWv/1wIPBFaVJIxzh/Vj854j1NXrxjKRRNXZIGiK3ug1C/ipu98F5AVXliSKScMKAHRUIJLAOhsE9Wb2XSKXjT4VfTBcZnBlSaI4a0g+memmIBBJYJ0Ngs8DR4ncT7CdyAPkbg+sKkkY2ZnpnDUkn7c37gu7FBE5TZ0Kgugv/98C+Wb210CDu6uPQIDI8JXvbd1Ps24sE0lInX3ExOeAN4HPAp8D3jCzq4IsTBLHhKH5HG5spmbnwbBLEZHT0NmHzn2PyD0EOwHMrBh4Hng0qMIkcYwvjXQYL9+8jzGDdA2BSKLpbB9B2vEQiNrdhddKkhte2Ie8Xhksr1U/gUgi6uwRwTNmtpDImAEQ6Txe0MH6kkLS0ozxQ/MVBCIJqrOdxd8G5gHjo1/z3P2fgixMEsv40gJWbaun4Vhz2KWISBd1emAad38MeCzAWiSBTSgtoKnFWb55H+eNKAy7HBHpgg6PCMys3swOtPFVb2YHeqpIiX/TRhZiBovX6UmkIommwyMCd9clINIp+TmZjB6QqzuMRRJQoFf+mNllZrbazGrM7DvtrPM5M6s2sxVm9mCQ9Uiwzh3Wj7c37dWIZSIJJrAgMLN04C7gE0AFMMfMKlqtMxr4LnCBu58FfCOoeiR4k8r6caChibV1urFMJJEEeUQwBahx93Xu3gg8TOTppbG+Ctzl7nsBWt2rIAnm3GH9ADSgvUiCCTIISoDNMfO10bZYZwBnmNmrZrbYzC5ra0NmNtfMqsysqq6uLqBy5aMaXtSHQX2z+XONPiORRBL23cEZwGhgOjAH+IWZFbReyd3nuXulu1cWFxf3cInSWWbGjLHFvLy6TvcTiCSQIINgCzA0Zr402harFpjv7sfcfT2whkgwSIL65PghHGps5pn3toddioh0UpBBsAQYbWbDzSwLmA3Mb7XO40SOBjCzIiKnitYFWJMEbOqIQsr69+axpbVhlyIinRRYELh7E3AjsBBYCTzi7ivM7FYzuyK62kJgt5lVAy8C33b33UHVJMFLSzNmnjmQN9fv0ekhkQTR6UdMnA53X0Crh9O5+80x0w58M/olSWLqiP7c9+p63tuyn8ry/mGXIyKnEHZnsSSh45eRVm3UZaQiiUBBIN2uMLcXI4r6ULVBzx0SSQQKAgnEtJGFvFqzm6bmlrBLEZFTUBBIIKYM78+RY82s2l4fdikicgoKAgnExKGRfoJlmzVqmUi8UxBIIIb2z6F/nyyWKwhE4p6CQAJhZkwo1TjGIolAQSCBGV9awPs7D3LwaFPYpYhIBxQEEphzhhbgDu9t2R92KSLSAQWBBGZ8aT4Af35/V8iViEhHFAQSmMLcXowryeelNRpvSCSeKQgkUBdXDGTF1gPsOdQYdiki0g4FgQTqwtFFuMOrNTo9JBKvFAQSqPEl+eRlZ/DyGg1fKRKvFAQSqIz0NC6pGMTT727jcKMuIxWJRwoCCdxV55ZyqLGZF1ap01gkHikIJHBThvenf58snqveEXYpItIGBYEELj3N+PjYAbywaifH9FhqkbgTaBCY2WVmttrMaszsOx2s9xkzczOrDLIeCc/FFQOpb2jijXUarEYk3gQWBGaWDtwFfAKoAOaYWUUb6+UBNwFvBFWLhO+i0cVkZ6bxbPX2sEsRkVaCPCKYAtS4+zp3bwQeBma1sd4PgX8BGgKsRUKWk5XORaOLea56B+4edjkiEiPIICgBNsfM10bbTjCzScBQd38qwDokTlx85kC27W9g5TaNWiYST0LrLDazNODfgG91Yt25ZlZlZlV1dboxKVFNH1uMGbp6SCTOBBkEW4ChMfOl0bbj8oCzgZfMbAMwFZjfVoexu89z90p3rywuLg6wZAnSgLxsJpQW8MIqBYFIPAkyCJYAo81suJllAbOB+ccXuvt+dy9y93J3LwcWA1e4e1WANUnIZp45gOW1+9l5QF1CIvEisCBw9ybgRmAhsBJ4xN1XmNmtZnZFUO8r8W1mxUAAnlupowKReJER5MbdfQGwoFXbze2sOz3IWiQ+jBmYx6gBufxx6RauPm9Y2OWICLqzWHqYmfGZSaVUbdzL5j2Hwy5HRFAQSAgujp4eekVDWIrEBQWB9LiRxX0Y2LeXnkYqEicUBNLjzIzLxw1m0Zo6Dh7VGAUiYVMQSCguPWsQjc0tLNLIZSKhUxBIKCqH9aOgdybPrtBD6ETCpiCQUESGsBzI8yt3cqSxOexyRFKagkBCc+XEUg4ebeKJ5VvDLkUkpSkIJDRTR/SnYnBf7n9tQ9iliKQ0BYGExsz49KQSqrcdoGanHk0tEhYFgYTqUxNL6J2Vzs9fXhd2KSIpS0EgoSrK7cWsc0p4fNkW9h5qDLsckZSkIJDQfWnaMI41O79evDHsUkRSkoJAQnfm4L5MLu/H/OVbaWnReMYiPU1BIHFhzpQyanYe5Ml3t4VdikjKURBIXPjUOSWMHZTHvz+3hqbmlrDLEUkpCgKJC2lpxjdmnsH6XYd4fqWeSirSkxQEEjdmnjmAotwsHn97S9iliKQUBYHEjYz0ND4zqZRnq7ezYuv+sMsRSRmBBoGZXWZmq82sxsy+08byb5pZtZm9Y2Z/MjMNYpvivjZjFLm9MrjzT++HXYpIyggsCMwsHbgL+ARQAcwxs4pWq70NVLr7eOBR4CdB1SOJIT8nky9OG8az1TvYsOtQ2OWIpIQgjwimADXuvs7dG4GHgVmxK7j7i+5+fATzxUBpgPVIgrhmWjmZaWncvnB12KWIpIQgg6AE2BwzXxtta89XgKfbWmBmc82sysyq6uo0olWyG9A3mxumj+Spd7fx3hb1FYgELS46i83sC0AlcHtby919nrtXuntlcXFxzxYnobjugnLyczK59clq3HW3sUiQggyCLcDQmPnSaNtJzGwm8D3gCnc/GmA9kkAKemfxD5eO4c31e3RfgUjAggyCJcBoMxtuZlnAbGB+7ApmNhH4OZEQ0E+7nGT25KGMKO7DbU+v1N3GIgEKLAjcvQm4EVgIrAQecfcVZnarmV0RXe12IBf4vZktM7P57WxOUlBmehrfungMa+t0t7FIkDKC3Li7LwAWtGq7OWZ6ZpDvL4nv0rMGUlKQw78+u5rpY4rJzkwPuySRpBMXncUi7clIT+PmT1ZQs/MgP3lGl5OKBEFBIHHv0rMGcfV5Zdz36nr+sLQ27HJEko6CQBLCdy8/kyH52XzzkeW8VrMr7HJEkoqCQBJCbq8M5n2pEoCbfreMfYc1vrFId1EQSMI4uySfJ79+IXsONfLPC1aGXY5I0lAQSEI5uySfr140gkeqanli+dawyxFJCgoCSTjfmDmaymH9+Pajy1m9vT7sckQSnoJAEk52Zjp3Xz2J3F6ZfO23b3HoaFPYJYkkNAWBJKQBfbO5c845rN91iJseXkZjkx5BIXK6FASSsM4fWcQPrjiL51fu4PoHqti+vyHskkQSkoJAEtoXp5Xz40+P4411u5l115/5l2dWcaSxOeyyRBKKgkAS3pwpZTx6w/nsOHCUn720livvfpUdB3R0INJZCgJJCuNK81nyvZlcPm4Q63Yd4sq7XmX55n1hlyWSEBQEkjSK83px99Xn8ru5UzEzPvvz15m3aK06kkVOQUEgSWdiWT+e+PqFXDiqiH9esIrL/mMRL63WeAYi7VEQSFLq3yeL+66dzK++PBmAa3+1hKvvXcyiNXUaA1mkFUu0H4rKykqvqqoKuwxJII1NLTzw+gbueXkduw4eZVxJPnOmlHHVuaVkZehvIUkNZvaWu1e2uUxBIKmisamFPyyt5b5X17Nmx0FK++XwmUmlTB9TzPjSAtLTLOwSRQKjIBCJ4e68tLqOn720liUb9+AO/XpnctHoYj52RjEXnVHEgLzssMsU6VYdBUGgYxab2WXAfwLpwL3uflur5b2AB4Bzgd3A5919Q5A1iZgZM8YOYMbYAew51Mgr79fx8uo6Fr1fx/zoE01HFPWhYkhfxg7KY/TAPEYU9aGssDe9MjRmsiSfwI4IzCwdWANcDNQCS4A57l4ds87XgPHufoOZzQaudPfPd7RdHRFIUFpanOptB3jl/V28vWkvK7YeYMu+IyeWpxkMKcihpCCH/JxM+uZk0jc7k745GdHvmfTNzji5PSeT3KwM0nTaSUIW1hHBFKDG3ddFi3gYmAVUx6wzC7glOv0o8FMzM0+081WSFNLSjLNL8jm7JP9E28GjTazdeZD1uw6xftchNuw+xLZ9DWzac5j6hiYOHDlG/SmefmoGuVkZZKQbGelpZKZFvmekG5lpaSfaM9IMA9LMwCLBYxhpaZHvZpGjmUj7B9MQbbPIa4+vd3ydzupqVHVh06ex7eDqTmQzxg7gkxOGdPt2gwyCEmBzzHwtcF5767h7k5ntBwqBkwalNbO5wFyAsrKyoOoV+ZDcXhlMGFrAhKEF7a7T3OIcbGjiQMMx9h85xoGGYxw40hT9fowDDU3UNxyjqdlpamnhWLPT1NzCsZbI96Zm51iL09zSgju4Q4s7LQ7uLXhzZN6BFgeOL8Oj60b6PTzaFjvfWV39y6srf6t1fdtdWLfLW09sZwzKC2S7gfYRdBd3nwfMg8ipoZDLETlJepqR3zuT/N6ZDA27GJHTEORF1FvgpJ+L0mhbm+uYWQaQT6TTWEREekiQQbAEGG1mw80sC5gNzG+1znzgmuj0VcAL6h8QEelZgZ0aip7zvxFYSOTy0fvcfYWZ3QpUuft84JfAr82sBthDJCxERKQHBdpH4O4LgAWt2m6OmW4APhtkDSIi0jE9aEVEJMUpCEREUpyCQEQkxSkIRERSXMI9fdTM6oCNYdfRjiJa3RWdoJJlPyB59kX7EV8ScT+GuXtxWwsSLgjimZlVtfdQp0SSLPsBybMv2o/4kiz7cZxODYmIpDgFgYhIilMQdK95YRfQTZJlPyB59kX7EV+SZT8A9RGIiKQ8HRGIiKQ4BYGISIpTEIiIpDgFQQ8xszQz+79m9l9mds2pXxG/zKyPmVWZ2V+HXcvpMrNPmdkvzOx3ZnZJ2PV0RfTf/7+j9V8ddj2nK5E/g9YS/WdCQdAJZnafme00s/datV9mZqvNrMbMvnOKzcwiMkrbMSLjN/e4btoPgH8CHgmmylPrjv1w98fd/avADcDng6y3M7q4T58GHo3Wf0WPF9uBruxHvH0GsU7j/1ioPxMfla4a6gQz+wvgIPCAu58dbUsH1gAXE/nFvgSYQ2QQnh+32sR10a+97v5zM3vU3a/qqfqP66b9mAAUAtnALnd/smeq/0B37Ie774y+7g7gt+6+tIfKb1MX92kW8LS7LzOzB939b0Iq+0O6sh/uXh1dHhefQawufh4lhPwz8VElxOD1YXP3RWZW3qp5ClDj7usAzOxhYJa7/xj40OGhmdUCjdHZ5uCqbV837cd0oA9QARwxswXu3hJk3a11034YcBuRX6ih/wLqyj4R+SVUCiwjzo7qu7IfZraSOPoMYnXx88gl5J+Jj0pBcPpKgM0x87XAeR2s/wfgv8zsImBRkIV1UZf2w92/B2Bm1xL56yde/sN39fP4OjATyDezUe5+T5DFnab29ulO4Kdm9lfAE2EU1kXt7UcifAax2twPd78R4vJnotMUBD3E3Q8DXwm7ju7i7veHXcNH4e53EvmFmnDc/RDw5bDr+KgS+TNoSyL/TMTVYWWC2QIMjZkvjbYlGu1H/EqWfdJ+xDkFwelbAow2s+FmlgXMBuaHXNPp0H7Er2TZJ+1HnFMQdIKZPQS8Dowxs1oz+4q7NwE3AguBlcAj7r4izDpPRfsRv5Jln7QfiUmXj4qIpDgdEYiIpDgFgYhIilMQiIikOAWBiEiKUxCIiKQ4BYGISIpTEEjgzOxgD7zHFZ18hHZ3vud0Mzv/NF430cx+GZ2+1sx+2v3VdZ2Zlbd+7HIb6xSb2TM9VZP0DAWBJIzoY4Db5O7z3f22AN6zo+dxTQe6HATA/yJBn7Hj7nXANjO7IOxapPsoCKRHmdm3zWyJmb1jZj+IaX/czN4ysxVmNjem/aCZ3WFmy4FpZrbBzH5gZkvN7F0zGxtd78Rf1mZ2v5ndaWavmdk6M7sq2p5mZneb2Soze87MFhxf1qrGl8zsP8ysCrjJzD5pZm+Y2dtm9ryZDYw+ovgG4O/NbJmZXRT9a/mx6P4taeuXpZnlAePdfXkby8rN7IXov82fzKws2j7SzBZH9/dHbR1hWWSErKfMbLmZvWdmn4+2T47+Oyw3szfNLC/6Pq9E/w2XtnVUY2bpZnZ7zGf1P2IWPw4k7Mho0gZ315e+Av0CDka/XwLMA4zIHyFPAn8RXdY/+j0HeGj7FbkAAAORSURBVA8ojM478LmYbW0Avh6d/hpwb3T6WuCn0en7gd9H36OCyDPkAa4CFkTbBwF7gavaqPcl4O6Y+X58cBf+9cAd0elbgH+IWe9B4MLodBmwso1tzwAei5mPrfsJ4Jro9HXA49HpJ4kM5AKR8DnYxnY/A/wiZj4fyALWAZOjbX2JPHG4N5AdbRsNVEWny4H3otNzgf8dne4FVAHDo/MlwLth/7/SV/d96THU0pMuiX69HZ3PJfKLaBHwd2Z2ZbR9aLR9N5FBfB5rtZ0/RL+/RWTYxrY87pHnwleb2cBo24XA76Pt283sxQ5q/V3MdCnwOzMbTOSX6/p2XjMTqDCz4/N9zSzX3WP/gh8M1LXz+mkx+/Nr4Ccx7Z+KTj8I/Gsbr30XuMPM/gV40t1fMbNxwDZ3XwLg7gcgcvRAZDyDc4j8+57RxvYuAcbHHDHlE/lM1gM7gSHt7IMkIAWB9CQDfuzuPz+pMTLq2UxgmrsfNrOXiAz7B9Dg7q1HdDsa/d5M+/+Hj8ZMWzvrdORQzPR/Af/m7vOjtd7SzmvSgKnu3tDBdo/wwb51G3dfY2aTgMuBH5nZn4A/trP63wM7iAw7mga0Va8ROfJa2MaybCL7IUlCfQTSkxYC15lZLoCZlZjZACJ/be6NhsBYYGpA7/8q8JloX8FAIp29nZHPB8+dvyamvR7Ii5l/lsioWwBE/+JubSUwqp33eY3Io40hcg7+lej0YiKnfohZfhIzGwIcdvffALcDk4DVwGAzmxxdJy/a+Z1P5EihBfgikXGdW1sI/E8zy4y+9ozokQREjiA6vLpIEouCQHqMuz9L5NTG62b2LvAokV+kzwAZ9sEYtosDKuExIsMLVgO/AZYC+zvxuluA35vZW8CumPYngCuPdxYDfwdURjtXq4mczz+Ju68iMjRjXutlRELky2b2DpFf0DdF278BfDPaPqqdmscBb5rZMuD7wI/cvRH4PJEhUpcDzxH5a/5u4Jpo21hOPvo57l4i/05Lo5eU/pwPjr5mAE+18RpJUHoMtaSU4+fszawQeBO4wN2393ANfw/Uu/u9nVy/N3DE3d3MZhPpOJ4VaJEd17MImOXue8OqQbqX+ggk1TxpZgVEOn1/2NMhEPUz4LNdWP9cIp27BuwjckVRKMysmEh/iUIgieiIQEQkxamPQEQkxSkIRERSnIJARCTFKQhERFKcgkBEJMUpCEREUtz/BxsZespaqIpnAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.01...\n", "Train on 2257 samples, validate on 1502 samples\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 1s 234us/step - loss: 0.2529 - acc: 0.9313 - val_loss: 0.2315 - val_acc: 0.9301\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0155 - acc: 0.9978 - val_loss: 0.2289 - val_acc: 0.9294\n", "Epoch 3/1024\n", "2257/2257 [==============================] - 0s 221us/step - loss: 0.0085 - acc: 0.9996 - val_loss: 0.2292 - val_acc: 0.9294\n", "Epoch 4/1024\n", "2257/2257 [==============================] - 0s 221us/step - loss: 0.0064 - acc: 0.9996 - val_loss: 0.2285 - val_acc: 0.9288\n", "Epoch 5/1024\n", "2257/2257 [==============================] - 1s 222us/step - loss: 0.0050 - acc: 1.0000 - val_loss: 0.2288 - val_acc: 0.9288\n", "Epoch 6/1024\n", "2257/2257 [==============================] - 0s 217us/step - loss: 0.0041 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9294\n", "Epoch 7/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0035 - acc: 1.0000 - val_loss: 0.2277 - val_acc: 0.9294\n", "Epoch 8/1024\n", "2257/2257 [==============================] - 0s 209us/step - loss: 0.0031 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9308\n", "\n", "Epoch 00008: Reducing Max LR on Plateau: new max lr will be 0.005 (if not early_stopping).\n", "Epoch 9/1024\n", "2257/2257 [==============================] - 0s 208us/step - loss: 0.0028 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9301\n", "Epoch 10/1024\n", "2257/2257 [==============================] - 0s 213us/step - loss: 0.0026 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9301\n", "\n", "Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.0025 (if not early_stopping).\n", "Epoch 11/1024\n", "2257/2257 [==============================] - 0s 215us/step - loss: 0.0025 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9301\n", "Epoch 12/1024\n", "2257/2257 [==============================] - 0s 218us/step - loss: 0.0024 - acc: 1.0000 - val_loss: 0.2274 - val_acc: 0.9314\n", "Epoch 13/1024\n", "2257/2257 [==============================] - 0s 214us/step - loss: 0.0024 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9314\n", "\n", "Epoch 00013: Reducing Max LR on Plateau: new max lr will be 0.00125 (if not early_stopping).\n", "Epoch 14/1024\n", "2257/2257 [==============================] - 0s 215us/step - loss: 0.0023 - acc: 1.0000 - val_loss: 0.2275 - val_acc: 0.9314\n", "Epoch 15/1024\n", "2257/2257 [==============================] - 0s 218us/step - loss: 0.0023 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9314\n", "\n", "Epoch 00015: Reducing Max LR on Plateau: new max lr will be 0.000625 (if not early_stopping).\n", "Epoch 16/1024\n", "2257/2257 [==============================] - 0s 201us/step - loss: 0.0022 - acc: 1.0000 - val_loss: 0.2276 - val_acc: 0.9314\n", "Restoring model weights from the end of the best epoch\n", "Epoch 00016: early stopping\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(0.01)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.93 0.86 0.89 319\n", " 1 0.95 0.95 0.95 389\n", " 2 0.95 0.92 0.94 396\n", " 3 0.90 0.97 0.93 398\n", "\n", " accuracy 0.93 1502\n", " macro avg 0.93 0.93 0.93 1502\n", "weighted avg 0.93 0.93 0.93 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[274, 4, 8, 33],\n", " [ 4, 371, 10, 4],\n", " [ 10, 13, 366, 7],\n", " [ 6, 4, 2, 386]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['sci.med', 'sci.med', 'sci.med']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(test_b.data[0:3])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 2, 2])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_b.target[:3]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }