{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will apply *ktrain* to the dataset employed in the **scikit-learn** [Working with Text Data](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html) tutorial. As in the tutorial, we will sample 4 newgroups to create a relatively small multiclass text classification dataset. This will provide us an opportunity to see BERT in action on a very small training set. Let's fetch the [20newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) using **scikit-learn**." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of training set: 2257\n", "size of validation set: 1502\n", "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "print('size of training set: %s' % (len(train_b['data'])))\n", "print('size of validation set: %s' % (len(test_b['data'])))\n", "print('classes: %s' % (train_b.target_names))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preprocessing train...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trn, val, preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n", " x_test=x_test, y_test=y_test,\n", " class_names=train_b.target_names,\n", " preprocess_mode='bert',\n", " maxlen=350)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "maxlen is 350\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('bert', train_data=trn, preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, batch_size=6)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 110s 49ms/step - loss: 1.3606 - acc: 0.3833\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 94s 42ms/step - loss: 0.3663 - acc: 0.8861\n", "Epoch 3/1024\n", " 186/2257 [=>............................] - ETA: 1:25 - loss: 1.2041 - acc: 0.4946\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3gVVfrA8e+b3gOkEQgQSuidACKgIIiACiqsimtHEde67rr6s+y6llXXtS7oiq6LXcGKWJAqIDUB6S300BISCGmknt8f9yYGTAMyd3Jz38/z3MeZM+fOvBkv971nzsw5YoxBKaWU5/KyOwCllFL20kSglFIeThOBUkp5OE0ESinl4TQRKKWUh9NEoJRSHs7H7gDOVGRkpImPj7c7DKWUcivJyclHjTFRlW1zu0QQHx9PUlKS3WEopZRbEZG9VW3TS0NKKeXhNBEopZSH00SglFIeThOBUkp5OE0ESinl4TQRKKWUh7MsEYjIOyKSJiIbq6kzRER+EZFNIvKTVbHUVkpaNvmFJXaHoZRSLmVli2A6MLKqjSLSCHgdGGOM6QL8zsJYapSWfZLhLy3mydmb7AxDKaVczrJEYIxZDGRWU+U64AtjzD5n/TSrYqkoLfsk//t5N6Wlp07I88mq/QCs3Xf8lPL9mXmMmbKUuz5aQ0GxthaUUg2PnX0E7YHGIrJIRJJF5MaqKorIJBFJEpGk9PT0czrohGkr+Ps3m3ll/o5Tyj9d7UgEWw9nsz8zr7z83WV7WJ+axbfrD/H+8iofzDtFSanhcNbJc4pTKaVcxc5E4AP0AS4FLgEeF5H2lVU0xkwzxiQaYxKjoiodKqPWjuYUAvDa/B3856ednCwqYfb6gxw4nk9iq8YA3P3RGopKSgFYvSeTfvFN6N+6Ce8u30NNU3tm5Rdx3VsrOO/Z+UxZsOM32zNyCjiUlX9Of4NSStUlO8caSgUyjDG5QK6ILAZ6ANutPGiIvw9Z+UUAPPf9Vt5btof4yGAAnrqiK1+tPcCbi3cx8LkF/Ot3PViXmsVDIzvSKMiX//tiA+tTs2gdFcyM1fv5aXs6S3Yc5c0b+nAiv4jP16TSKNCPlbsdV8T+9eN2Any9uXVga04WlzB92R7++cM2AHY/OxoRsfJPVUqpWrEzEXwNTBERH8AP6A+8bPVBs08W0SoiiL0Zjss/B7NOcrK4lKt6N6dTbBjxEcG8uXgXadkF3PjOKgAu6x5LqbMlMHbqz7/Z5x3vJ9MuOoSUtBwAhneK5sL2UTz+9Sae/nYLT3+7hZgwf46cKCh/z5p9x+njbIEopZSdrLx99GNgOdBBRFJFZKKITBaRyQDGmC3AD8B6YBXwtjGmyltN60JpqSG7oJjR3WJPKc/MLaR9TCgAgX7etHa2EMrENQ6kVUQwky5oA0Cn2DDOa9OEjk1Dy+uUJQGAIR2i6dws/JR9VEwCAOPeWMZTszef+x+llFLnyLIWgTFmQi3qvAC8YFUMp8suKMYYiAzx59JusXy74VD5thaNg8qXv757IPsz8/ho5T4aB/mVX8J5ZHQnHhnd6ZR9Lt1xlOv/uxKAN37fm+JSw6XdYvHyEpY+NBRjYNa6g2xIzeK5cd04mlPArdOT2JeZx3+X7mbyhW2JCvWvPu6TRYT4++ilJKWUJdxuPoJzccLZNxAW4MNTV3QlISaEV+Y5OnSbNw4srxcW4EuXZuE8c2W3Gvd5ftsIHrykA80bBTLqtJZGnDO53DW0XXlZoyA/3ru1Hy/O3c436w5y+3tJfHXXwCr3vys9h4tedDxrN+f+C+hQoRWilFJ1waOGmMjMddwx1CTYjybBftw/vD2PXer4hX/65aDa8vIS7hrajit6Na/1e+Ijg/n3hF786eL2/LL/OFMXppTfpVTRt+sPMf4/y8vXf9rukkctlFIexqMSQUau4zp9RMivl2JuG9yGnf8YTXigr8vjGdvTkTxemLONhEe/Z+rCFA5l5VNcUsr61OPc9dEaMnML+c/1fWjRJJBPVu+vNGEopdS58KhEUPYMQUSw3ynl3l72XHtvGRHEH4f/+ujEC3O2MeDZBbR79HvGTHHcnTS8UzQjuzZlZJem7ErP5d1le2yJVSnVcHlUH0FGWSII8auhputMuqANMWH+BPv78M7Pu2kdGcyREyf5OSWDO4e05aGRHQG4d1gCby3ZzTPfbaGguJTbBrfG38fb5uiVUg2BxySC3IJiPl+TSqCvN0F+9efPDvTz5tp+LQG4vEezKuuFBvjywMXteWnudl6Ys40VuzL45/juxIb/2sldWFzK9iPZhPj7cPB4PgPaRuidRkqpGtWfb0SL/bj5MClpObjz9+I9F7VjXJ845m46zBPfbGbAswv4/r7BrNiVwcer9nE0p7C8Qxzgki4x/HNcD8KDXN//oZRyHx6TCPrGNwGghqGC6jURoXmjQK4/rxVPfON4GG3Uq0vKt/eIC+fei9qRmVvIawtSmLPpCGnZq3jwkg4UFpdyQUIUXjb1hyil6i+PSQRxjYPo37oJ43rH2R3KOfPx9mL9EyPYl5HHil0Z5BaUMKF/C6JDA8rrjO4eyz0frWXtvuNc95bjgbeLO8fw5xEdaB0ZjK+38Mv+47SJDNEWg1IeTmoaTbO+SUxMNElJSXaH4RbyCovp9sSPhPj70KVZGMt2ZpRvC/D14mRRKbHhASz88xACfLXjWamGTESSjTGJlW3zqNtHPU2Qnw+rHhnGmscvZup1vbm8RzMinc9QnCxyPI9wKOskf5q5zs4wlVI285hLQ56q7OG5xsF+/HtCL4wxpGcXsGTHUUZ3i+UPHybz7fpDXNQhlUu7x2rLQCkPpJeGPNy+jDxGv7aEnIJiGgU5blHt1aIx3eLCa36zUspt6KUhVaWWEUEkPTac2we35nheEX/9ehOXT1nK9J931zgbm1KqYdBLQ4oAX28evbQzQzpEIwKvL9zJE99sJrew5JSRU5VSDZNeGlK/UVpquPjln9iZnkuwnzd/vbwz1/RtaXdYSqlzoJeG1Bnx8hKm3ZjIQyM7Eh8ZzEOfb+DJbzZzsqjE7tCUUhbQFoGq1oHj+fzt643M25JGbHgAo7rGctvg1jRrFFjzm5VS9Ya2CNRZa94okLdv6svz47rh4y288/Nuzn9uAat2Z9odmlKqjlg5ef07IpImItVOSC8ifUWkWETGWxWLOnfX9G3Jkr9cxHNXOabv/MOHyezNyLU5KqVUXbCyRTAdGFldBRHxBp4HfrQwDlWHru3Xkg8m9udkUSlXTP2ZzQdP2B2SUuocWZYIjDGLgZquH9wDfA7oZLxuZFBCJN/eOwh/H28mvruaD1bsJa+w2O6wlFJnybY+AhFpDlwJvFGLupNEJElEktLT060PTtWoVUQwL17dg8zcQh77aiMXv7SYt5fsYsbq/RzNKbA7PKXUGbD0riERiQdmG2O6VrJtJvCiMWaFiEx31vuspn3qXUP1S25BMb/sP84TszaxIy0HgBB/H169tifDOsXYHJ1Sqkx1dw3ZmQh2A2WzpEQCecAkY8xX1e1TE0H9VFpq2JuZx56MXJ6avZk9R3O5pEtTft+/FU3D/WkXHWp3iEp5tOoSgW1DTBhjWpctV2gRVJsEVP3l5SW0jgymdWQwPeMaMXVhCp+u3s/3Gw8DMOmCNjwyupPNUSqlKmPl7aMfA8uBDiKSKiITRWSyiEy26piqfmgc7Mdjl3VmyUNDGZwQCcC0xbvo+eSP7MvIszk6pdTp9MliZbncgmL+NGMdP2w6TL/WTejVohHd4xpxafdYu0NTymPUy0tDynME+/vwnxv68O6yPfxt1qbyp5I3HGjLjQNa6XAVStlMWwTKpbYePsGREwV8uGIvP24+AsBVvZrzx4vb06JJkM3RKdVwaYtA1Rsdm4bRsSlckBBJ8t5jvL9iL99tPMQXaw9w55C2/Oni9vh46xBYSrmSJgJlCxEhMb4JifFNSEnL4YlZm3hj0U6O5Rby7FXdEJGad6KUqhP600vZrl10CB/c1p/xfeL4ZPV+ZialApCeXaBzICjlAtpHoOqN0lLDhLdWsHJ3JmUNgmEdY3j7pkovayqlzoD2ESi3UDYz2jPfbmbhtnTiI4KYt+UI499YRmZuIVGh/twyMJ7hnWK0H0GpOqQtAlVvnSwq4fdvryR577FTyhOiQ7hraDvG9mymfQlK1ZK2CJRbCvD15rPJA8gtLOFYbiGRIf5MX7aH53/Yyv2f/sLRnAJuG9zG7jCVcnvavlb1mogQ4u9DiyZBBPp5c+eQtux4ZhTDO0Xz0tztOuS1UnVAE4FyO77eXvzf6E4UFJcyZUGK3eEo5fY0ESi31DYqhKsTWzB92R5+2q6TFSl1LjQRKLf1t8s706JJIDe9s4qBzy1g44Esu0NSyi1pIlBuK8DXm/dv7c+wjtEcOJ7P5VOW8thXG9ibkWt3aEq5Fb19VDUI6dkFTP4gmeS9x/AS2P70KH3WQKkKqrt9VP+lqAYhKtSfKdf1omlYAKUGXpu/w+6QlHIbmghUgxEbHsiKR4ZxeY9mvLYghakLU8jKL7I7LKXqPU0EqsF5YXx3WkUE8cKcbQx78ScOHM+3OySl6jVNBKrBCfD1ZuYdA3h+XDey8gt56LP1lJa6V1+YUq5k5eT174hImohsrGL770VkvYhsEJFlItLDqliU54kOC+Cavi15aGRHlqYcZdx/luntpUpVwcoWwXRgZDXbdwMXGmO6AU8B0yyMRXmoWwa25qYBrdhy6ASX/XspL83dTom2DpQ6hWWJwBizGMisZvsyY0zZsJIrgDirYlGey9tL+PvYrnx8+3l0iAnltfk7ePzrjazaXeVHUymPU1/6CCYC31e1UUQmiUiSiCSlp+twAurM9WrZmB/uH0zn2DA+WrmPq99cTl5hsd1hKVUv2J4IRGQojkTwUFV1jDHTjDGJxpjEqKgo1wWnGhQR4Zkru9ImMhiAy/+9VO8oUgqbE4GIdAfeBsYaYzLsjEV5hl4tG7Pgz0P4YGJ/DhzP56Uft9sdklK2sy0RiEhL4AvgBmOM/mtULjUoIZLr+rXi8zWpDP3XIr29VHk0K28f/RhYDnQQkVQRmSgik0VksrPKX4EI4HUR+UVEdAAh5VJ3X9QOgN1Hc1m2UxukynPpoHPKo+UVFjP8xZ8I9PPm23sHE+DrbXdISllCB51TqgpBfj48N647O9NzeWWeDlSnPJMmAuXxLmgfxfg+cfx36S4ydA5k5YE0ESgF3HFBG4pKDB+u3Gd3KEq5nCYCpYCEmFAGJ0Ty0tztjJmylENZ+nyB8hyaCJRy+vuYLgCsT83isS8rHStRqQZJE4FSTm2iQnj6iq6M6tqUhdvSdLRS5TE0EShVwfXnteIfV3Yj2M+H95bvsTscpVxCE4FSp2kc7MeQjtHMSErl+w2H7A5HKctpIlCqEn8Y0haAOz9cw8JtaTZHo5S1NBEoVYlOsWFseXIkbaKCueV/q3lp7nYdj0g1WJoIlKpCoJ83L4zvDsBr83fw1pJdLN6ezsmiEpsjU6puaSJQqhp9WjVh6UNDAXj2+63c+M4q/jVnm81RKVW3NBEoVYO4xkF8NnlA+fqnq/drq0A1KJoIlKqFxPgmzJw8gGev6kZ2QTFzNh22OySl6owmAqVqqW98E65JbEHzRoF8vuaA3eEoVWc0ESh1Bry8hHF94liyI51th7PtDkepOqGJQKkzdMv58fh6ezEjab/doShVJzQRKHWGGgf7MahdJHM2HcbdZvhTqjKaCJQ6CyO7NCX1WD7/Xbrb7lCUOmdWTl7/joikiUil4/mKw2sikiIi60Wkt1WxKFXXrurdnNaRwcxad9DuUJQ6Z1a2CKYDI6vZPgpIcL4mAW9YGItSdcrH24urejVnw4Es9mXk2R2OUufEskRgjFkMZFZTZSzwnnFYATQSkVir4lGqro3rE4cx8M16bRUo92ZnH0FzoOJtF6nOst8QkUkikiQiSenp6S4JTqmaNGsUSJdmYfy0TT+Tyr25RWexMWaaMSbRGJMYFRVldzhKlRvSIYrkfcfIyi+yOxSlzpqdieAA0KLCepyzTCm3MbxTDCWlhh826gQ2yn3ZmQhmATc67x46D8gyxui/JuVWerZoRIsmgczfopPXKPflY9WOReRjYAgQKSKpwN8AXwBjzH+A74DRQAqQB9xiVSxKWUVEOK91BHO3HKG01ODlJXaHpNQZsywRGGMm1LDdAHdZdXylXKV/mwhmJqey7Ug2nWLD7A5HqTPmFp3FStVn57eNwMdLePHH7XaHotRZqVUiEJH7RCTMeT3/vyKyRkRGWB2cUu6gWaNA7h2WwLwtR9h0MMvucJQ6Y7VtEdxqjDkBjAAaAzcAz1kWlVJu5qbz4wny8+b95XvtDkWpM1bbRFDWAzYaeN8Ys6lCmVIeLzzQl2GdYpiz6TDFJaV2h6PUGaltIkgWkR9xJII5IhIK6KddqQpGd23KsbwiVu2ubmQVpeqf2iaCicDDQF9jTB6O20D1dk+lKriwQxQBvl58v1HnM1bupbaJYACwzRhzXESuBx4DtFdMqQqC/HwY2iGaHzYdprRUJ6xR7qO2ieANIE9EegB/AnYC71kWlVJuamTXpqRnF5C875jdoShVa7VNBMXOB8DGAlOMMVOBUOvCUso9XdQxGj8fL77+RYfNUu6jtokgW0T+D8dto9+KiBfO4SKUUr8KDfDlsm6xfL32IEV695ByE7VNBNcABTieJziMY6TQFyyLSik3NqJLU7ILivk8OdXuUJSqlVolAueX/4dAuIhcBpw0xmgfgVKVGNIhirZRwTqxvXIbtR1i4mpgFfA74GpgpYiMtzIwpdxVgK831/ZtyY60HD5cuRdH95pS9VdtLw09iuMZgpuMMTcC/YDHrQtLKff2u8Q4AB79ciMfrNBhJ1T9VttE4GWMqTjzRsYZvFcpj9MoyI+ZkwcA8N5ybRWo+q22X+Y/iMgcEblZRG4GvsUxsYxSqgp945vwwvju7EjLYdnODLvDUapKte0sfhCYBnR3vqYZYx6yMjClGoLLezQjMsSfR77cwOGsk3aHo1Slan15xxjzuTHmAefrSyuDUqqhCPD15rmrunE46yQPf7He7nCUqlS1iUBEskXkRCWvbBE5UdPORWSkiGwTkRQRebiS7S1FZKGIrBWR9SIy+lz+GKXqo+GdY/jjxe1ZtC2d5L06Mqmqf6pNBMaYUGNMWCWvUGNMtZOziog3MBUYBXQGJohI59OqPQbMMMb0Aq4FXj/7P0Wp+uvGAa0I9ffh09X77Q5Fqd+w8s6ffkCKMWaXMaYQ+ATHWEUVGaAsoYQDBy2MRynbBPn5cHHnGL7fcJis/CK7w1HqFFYmguZAxZ8/qc6yip4ArheRVBx3Id1jYTxK2erWQa3JLijm/eV77A5FqVPY/SzABGC6MSYO5zSYzgHtTiEik0QkSUSS0tPTXR6kUnWha/NwBrSJ4LPkVH2uQNUrViaCA0CLCutxzrKKJgIzAIwxy4EAIPL0HRljphljEo0xiVFRURaFq5T1xveJY09GHl+u1WGqVf1hZSJYDSSISGsR8cPRGTzrtDr7gGEAItIJRyLQn/yqwbqqd3OC/bx5YMY69mfm2R2OUoCFicAYUwzcDcwBtuC4O2iTiDwpImOc1f4E3C4i64CPgZuNtplVAyYi3Da4DQCfr9FhqlX9IO72vZuYmGiSkpLsDkOpc3LNm8vJzC1k7gMX2h2K8hAikmyMSaxsm92dxUp5pEu7x7IjLYftR7LtDkUpTQRK2WFk16aIwOz1h+wORSlNBErZITo0gP6tm/Dl2lSKdW5jZTNNBErZZOKgNuzPzGfKwhS7Q1EeThOBUja5uHMMF3WM5oMV+7RVoGyliUApG13TtwVHcwpYvEMfn1HVs/IOT00EStloaIdomgT78VmyPlOgqmaM4YIXFjLVosuImgiUspGfjxdjezZj3uY0thw6oWMQqUqlZxewPzOfQF9vS/aviUApm43rHUdhSSmjXl3Ca/O141j91tbDjudNOsaGWrJ/TQRK2axLs1/neHpl/nZtFajf2FaWCJpWOx/YWdNEoJTNRITXJvQCwBiY+G4SWw7VOBOs8iBbD2cTFepPk2A/S/aviUCpemBMj2bs/Mdogvy8WbA1jZv/t4oivaVUOW07coKOTa25LASaCJSqN7y9hFeu6UmwnzdHThTwxKxNPDDjF1LScuwOTdmopNSw40gOHWI0ESjlEUZ0acqmJ0cyqF0kH67cxxdrDvD3bzbZHZay0Z6MXAqKS+kYa03/AGgiUKpeevTSTsRHBBHo683SlKPsTNdWgafaeqiso1hbBEp5lE6xYSx6cCiLHhxCiJ8PT83ebHdIyibbjmTjJdAuOsSyY2giUKoeiwkL4OaB8Szenk7aiZN2h6NscDgrn+jQAAIsepgMNBEoVe9d0as5pQY+XLmPvMJiu8NRLnY0p5CIEGtuGy2jiUCpeq5tVAjdmofz6vwd/P7tlXaHo1wsI6eAiBB/S49haSIQkZEisk1EUkTk4SrqXC0im0Vkk4h8ZGU8SrmrR0Z3AmDtvuPMWL3f5miUKx3NKSTSogfJyliWCETEG5gKjAI6AxNEpPNpdRKA/wMGGmO6APdbFY9S7mxA2wh2PDOKge0ieOzrjczdfMTukJQLFBaXkpZ9kpjwAEuPY2WLoB+QYozZZYwpBD4Bxp5W53ZgqjHmGIAxJs3CeJRya77eXjw5tiuFxaXc/l4SyXuPkZlbaHdYykJbD5+gqMTQrXm4pcexMhE0Byq2YVOdZRW1B9qLyM8iskJERloYj1Jur21UCEsfGoq3lzDujWX0fmouN/9vFRk5BXaHpiywbv9xAHq0aGTpcezuLPYBEoAhwATgLRH5zV8sIpNEJElEktLTdSYn5dniGgfx4CUdytcXbUvny7UHbIxIWWXt/uNEhvjTzI0vDR0AWlRYj3OWVZQKzDLGFBljdgPbcSSGUxhjphljEo0xiVFRUZYFrJS7uPn8eK7r35KXr+lB1+ZhfLP+kN0hqTqWlV/E9xsOc0FCJCJi6bGsTASrgQQRaS0ifsC1wKzT6nyFozWAiETiuFS0y8KYlGoQAny9+ceV3biyVxyXdW/Guv3HGfnKYh2xtAHZeugE+UUljOnZzPJjWZYIjDHFwN3AHGALMMMYs0lEnhSRMc5qc4AMEdkMLAQeNMZkWBWTUg3RZd1jAceY9e8v32tzNKqu7D6aCzj6hazmY+XOjTHfAd+dVvbXCssGeMD5UkqdhbjGQSz5y1Ae+nw9z36/hYs7x9CiSZDdYalztOngCQJ8vWjWKNDyY9ndWayUqgMtmgTx7FXdMAYe/Gyd3eGoOjBvyxEubB+Ft5e1/QOgiUCpBqNVRDB3DmnLil2ZOmy1m8stKOZQ1km6x1l722gZTQRKNSC/79+K0AAfbn83ib0ZuXaHo85S6rF8AJdd4tNEoFQD0jQ8gDsuaMOuo7lc+MIiXpm3nWK9k8jtpB7LAyCusfX9A6CJQKkG5w9D2nFVb8dD/K/M28G101aQlq1zGbiTjBzH0CFRFo86WkYTgVINjJeX8NLVPVn7+MW0aBJI0t5j9HtmPln5RXaHpmrpeL4jETQK8nXJ8TQRKNVANQ724/PJ55ev/33WJhx3bKv67lheET5eQoi/pXf4l9NEoFQDFh0WwNKHhnLnkLZ8sfYAM5NT7Q5J1cLxvCIaBflaPrREGU0ESjVwcY2DeHBEB3q0aMSUBSmUlGqroL7Lyi8kPNA1l4VAE4FSHsHLS5h8QRv2ZebR9pHvOJSVjzGGEye136A+2pmW65Inisu45gKUUsp2I7o0LV++4/1kjuUVsj/Tcb/6kr8M1WEp6omjOQVsO5LNlb1Pn77FOtoiUMpDeHsJqx4ZxmOXdmLjgSyOZP06mc0jX26wMTJV0QHnw2TtXDDYXBltESjlQaLDArhtcBuu6NWcQF9vAP73827+9eN2dhzJJiEm1OYIVVq2I0FHh7nmGQLQFoFSHikyxJ9gfx+C/X2Y0K8lvt7Ca9qRbLvM3EJufy8JgJgwa2clq0gTgVIeLiLEn4mD2vDNuoO8Mm+73eF4HGMMby/ZxdbDJ/jHd1vKyyOC/VwWg14aUkrx0MgO7ErP4b9Ld3PT+fFEumhoA0+y52guAb7eND1t/uEDx/N5+ttfE8Dobk25b1h7fLxd9ztdWwRKKUSEv4zsSF5hCe8t22N3OA3OweP5jHh5MaNeXcyhrPzTtp06DtQr1/SiQ1PX9tVoIlBKAdAuOoThnWJ4f8Ve8gtL7A6nQVmfepzCklKO5RXx55nrKK3QF3PguGOk0cgQP1Y9Mgw/H9d/LWsiUEqVm3RBG47lFfHZGh2K4mylZZ/8zZhOu5zzDz92aSd+Tsngu42HyreVtQiW/OUiol3YQVyRpYlAREaKyDYRSRGRh6upN05EjIgkWhmPUqp6feMb06NFI16bv4OMnIKa36BOsfFAFv2emc87P+8pLysoLmHzwRNEh/pz68DWtIoIYnqF7QeO5xMR7Eegn7frA3ayLBGIiDcwFRgFdAYmiEjnSuqFAvcBK62KRSlVOyLCP67sSkZOAa8v2ml3OG5j3uYj/P7tFYyd+jMAT83ezI+bDpOSls2Ilxcze/0hujYPx8tLuOG8ViTtPcbGA1mA4wEyVw4nURkrWwT9gBRjzC5jTCHwCTC2knpPAc8DOnOGUvVAl2bhjOsdx/sr9rL54An2Z+bx9pJdTF2YYndo9daXvxzg55QMSkoNY3s2A2DS+8kMf2kxezMcfQAD2kQA8LvEFgT6evPoVxvZfiSbg8fzadbInktCZay8fbQ5sL/CeirQv2IFEekNtDDGfCsiD1a1IxGZBEwCaNmypQWhKqUquueiBGYmpzL6tSWnlA9oG0Hvlo1tiqr+Ss8uIDrUn2v7teQPQ9oypEMUf/x0Xfn2V6/tyWXdHQkiPNCXWwbG8/qinYx4eTEAgxOibIm7jG3PEYiIF/AScHNNdY0x04BpAImJifroo1IWaxkRxJTrenH3R2tPKX9j0U7eulG78ioqLillx5FshneK4YGL2wNwRc/meInQvFEgHZqGEhpw6pDSfxnZkdhGgTz+1UvO9LMAABD1SURBVEaABt0iOAC0qLAe5ywrEwp0BRY5J19oCswSkTHGmCQL41JK1cJl3ZsxpEM0KWk5+HoLP246wqvzd/Dolxt4+oquLps0pb5bsSuTY3lFDO8cU14mIoztWf3ooTec14qv1h4gee8xGge57iniyljZR7AaSBCR1iLiB1wLzCrbaIzJMsZEGmPijTHxwApAk4BS9UiIvw89WzSiS7NwJl3QhjZRwXy4ch+z1h20O7R6Y/WeTAD6t25yxu997qpudGkWxsB2kXUd1hmxLBEYY4qBu4E5wBZghjFmk4g8KSJjrDquUsoawf4+zPvjhbSPCWHqwpRTHoryVCdOFvHq/B0ANDqLX/UJMaF8e+/g3ww74WqWPkdgjPnOGNPeGNPWGPOMs+yvxphZldQdoq0Bpeo3Ly/hnosS2H4kh7eX7rI7HNuVzR3Qs0UjmyM5N/pksVLqjFzWPZbhnaJ5dd6O34yb42nK5g549NJONkdybjQRKKXOiIjw6KWdKS41DHxuAQeOe24ySC+bRCbUvUdr1USglDpjrSOD+c8NfSg18G/nNXJPlLzX0VHs7sN2ayJQSp2VoR2iufn8eGYmp7LbOaiaJ8nKL2JGUiqX92hGsL97T+2iiUApddb+MLQtvt7Cq3U4s1lxSSmfrNrHyaL6NRR2cUkpnyenknbCMRrOC3O2UlJquHFAK5sjO3eaCJRSZy06NICbz2/N1+sOsvXwCYwxvxmCubaMMexKz2HB1jQe/mIDf/lsfR1He26+WHuAP81cR79/zGdneg7r9mfRtXkYfePP/PmB+sa92zNKKdvdcUEbPl61j9GvLqHUQHxEEDPuGFCrsfVX7c5kT0YuwzpG8/jXG/luw+HybbPWHeTeYe1oF+3a2bqqMnfzkfLlYS/+BMD15zWMsc+0RaCUOieNg/14f2I/BjkHTtuTkcdLc2u+VDR1YQpXv7mcv3y2nj5PzzslCYT6+xAe6Ms9H/9SLy4R5RUWs3h7OjefH899wxLKy9tGhdgYVd3RFoFS6px1j2vEe7f2IyuviCkLd/DWkt1c0qUpQztGV/meN3/67XwHu58dzYpdmSTEhLA+9Ti3Tk/ir19v5Plx3W0b2yj1WB6Dnl8IwCVdmpIY35jOzcLw9RYGtLF3aIi6oolAKVVnwoN8+dOIDizZcZQ7P0zm72O6cE3fyi+fNG8cRAuB2fcM4rX5KUSE+CEiDGjrGLf/oo4x3DmkLW8s2sk36w5xTd8WPDGmS53Fmnosj/s++YXtR7K5f3h7gv286dOqMQkxv16Kyswt5AHncNKdYsPo17oJ3l7CJV2a1lkc9YEmAqVUnQrw9eZ/t/Rl0nvJPPLlRqJC/bmoY8wpdYwx7M/MY3yfOESE+4YnVLqvey5qx7zNR9iRlsP0ZXvo2jyc8X3izjnG5L3HGPfGsvL1p2ZvLl/e/vSo8gnkv1x7gFV7Mrn+vJY8fUW3cz5ufaV9BEqpOhcbHsh7t/ajbVQwj3yxkZyC4lO2H88rIqegmBZNgqrdT5CfD9/fN5gV/zeM2PAA/jxzHd+c48iny1KOlieBzrFhPDyqI20igwn0dcwZ/EaFKTr3ZuQSGuDDU2O7ntMx6ztNBEopSzQO9uP5cd05kn2S295dzbKUo+Xbnvt+KwAta0gEAD7eXjQND+B/t/QlxN+Hx7/eyK70nLOO693lewD45/jufHffYCZf2JYFfx7ClqdGcmm3WF6et728/2JfZh6tIoIa/NwLmgiUUpbp1bIxN57XihW7Mrnu7ZWMnbKUrLwiPk1yzGIbH1FzIijTsWkYX911PoXFpYx4eTFr9x07q5i2Hc5mZJemXJ3Y4jfbXr22J5d0ieHFudvZevgE+zLyapWs3J0mAqWUpf52eRc+mzyAAW0iWJeaxe3vO0abb94okHbRZ3b7ZbvoUN69tR/FpYab3llFknNSmNoyxnD4xElaNAmsdLuPtxdPje2Kv7cXI19Zwq6juTVevmoINBEopSzl5SUkxjfh3Vv7cf15LVm12/HlfcOAVmd1yaVvfBN+enAI/r7e/O7N5WeUDPIKSzhZVEpENYPERYcF8OkdA/B3dhjH1uLBOHeniUAp5RJ+Po5f2x/d3p9RXZsyonNMzW+qQquIYD6ddB4Rwf48MGMda/Ydq9XQFkdzHMNGRwRXP5tY52ZhzL5nEN2ahzMooWE8K1AdTQRKKZcREc5vG8kb1/ehzTk+ldsmKoRXr+3Jvsw8rnp9GY98uYG8wuLf3KFU0ZuLHbOqRdZi/oCEmFC+uWdQvRniwkqaCJRSbmtgu0geHtURgI9X7afzX+fQ9+l5rNiVwfYj2ZScNq/yyl0ZAHSJDXN5rPWZpYlAREaKyDYRSRGRhyvZ/oCIbBaR9SIyX0TcfzxXpZRLTb6wLVufGsnN58fjJZBfVMK101Yw4uXF3PXhGkpKDQXFJUxZsIOd6bn8rk9crQbE8ySWJQIR8QamAqOAzsAEEel8WrW1QKIxpjvwGfBPq+JRSjVcAb7ePDGmC5ufHMkP9w8myM/xcNgPmw7z4Mx1LNqWzr9+dAyE1zRck8DprBxioh+QYozZBSAinwBjgfJnuY0xCyvUXwFcb2E8SqkGLsDXm45Nw1jxyDB8vIT/LtnNi3O388XaAwA8dmknrup97kNUNDRWJoLmwP4K66lA/2rqTwS+r2yDiEwCJgG0bNkwxv9WSlknLMAXgHuGJVBq4GXnDGoTB7Vu8E8Jn416MeiciFwPJAIXVrbdGDMNmAaQmJh4dtMfKaU80n3DEzAYMnMLNQlUwcpEcACo+Ax3nLPsFCIyHHgUuNAYU2BhPEopD3X/8PZ2h1CvWXnX0GogQURai4gfcC0wq2IFEekFvAmMMcakWRiLUkqpKliWCIwxxcDdwBxgCzDDGLNJRJ4UkTHOai8AIcBMEflFRGZVsTullFIWsbSPwBjzHfDdaWV/rbA83MrjK6WUqpk+WayUUh5OE4FSSnk4TQRKKeXhNBEopZSH00SglFIeTmozmUN9IiLpwF7najiQVWHz6euRwFFc5/TjW/3+muqf7fbKymtT1pDPd23qVlfnTM51ZeWV1XPl+a5vn+2a6uhn+7fvb2WMiaq0tjHGbV/AtBrWk+yMx+r311T/bLdXVl6bsoZ8vmtTt7o6Z3Kuqzi3lZ1/l53v+vbZrsvz7emfbWOM218a+qaGdVc71+Of6ftrqn+22ysrr22ZK7nyfNembnV1zuRcV1buSee6tvXr6nx7+mfb/S4NnQkRSTLGJNodh6fQ8+1aer5dp6Gfa3dvEdRkmt0BeBg9366l59t1GvS5btAtAqWUUjVr6C0CpZRSNdBEoJRSHk4TgVJKeTiPTQQiMlhE/iMib4vIMrvjaehExEtEnhGRf4vITXbH05CJyBARWeL8fA+xOx5PICLBIpIkIpfZHcvZcMtEICLviEiaiGw8rXykiGwTkRQRebi6fRhjlhhjJgOzgXetjNfd1cX5BsbimK60CEi1KlZ3V0fn2gA5QAB6rqtVR+cb4CFghjVRWs8t7xoSkQtwfNDfM8Z0dZZ5A9uBi3F8+FcDEwBv4NnTdnGrcU6NKSIzgInGmGwXhe926uJ8O1/HjDFvishnxpjxrorfndTRuT5qjCkVkRjgJWPM710Vv7upo/PdA4jAkXiPGmNmuyb6umPpDGVWMcYsFpH404r7ASnGmF0AIvIJMNYY8yxQaXNNRFoCWZoEqlcX51tEUoFC52qJddG6t7r6bDsdA/ytiLOhqKPP9hAgGOgM5IvId8aYUivjrmtumQiq0BzYX2E9Fehfw3smAv+zLKKG7UzP9xfAv0VkMLDYysAaoDM61yJyFXAJ0AiYYm1oDdIZnW9jzKMAInIzztaYpdFZoCElgjNmjPmb3TF4CmNMHo7EqyxmjPkCR+JVLmSMmW53DGfLLTuLq3AAaFFhPc5Zpqyh59t19Fy7lsed74aUCFYDCSLSWkT8gGuBWTbH1JDp+XYdPdeu5XHn2y0TgYh8DCwHOohIqohMNMYUA3cDc4AtwAxjzCY742wo9Hy7jp5r19Lz7eCWt48qpZSqO27ZIlBKKVV3NBEopZSH00SglFIeThOBUkp5OE0ESinl4TQRKKWUh9NEoCwnIjkuOMaYWg4XXJfHHCIi55/F+3qJyH+dyzeLSL0YD0hE4k8fjrmSOlEi8oOrYlKuoYlAuQ3n8MCVMsbMMsY8Z8ExqxuPawhwxokAeAR47awCspkxJh04JCID7Y5F1R1NBMqlRORBEVktIutF5O8Vyr8SkWQR2SQikyqU54jIiyKyDhggIntE5O8iskZENohIR2e98l/WIjJdRF4TkWUisktExjvLvUTkdRHZKiJzReS7sm2nxbhIRF4RkSTgPhG5XERWishaEZknIjHOoYsnA38UkV/EMeNdlIh87vz7Vlf2ZSkioUB3Y8y6SrbFi8gC57mZ7xwmHRFpKyIrnH/v05W1sMQxQ9a3IrJORDaKyDXO8r7O87BORFaJSKjzOEuc53BNZa0aEfEWkRcq/L+6o8LmrwCd46AhMcboS1+WvoAc539HANMAwfEjZDZwgXNbE+d/A4GNQIRz3QBXV9jXHuAe5/IfgLedyzcDU5zL04GZzmN0xjG2PMB44DtneVMc4/WPryTeRcDrFdYb8+tT+LcBLzqXnwD+XKHeR8Ag53JLYEsl+x4KfF5hvWLc3wA3OZdvBb5yLs8GJjiXJ5edz9P2Ow54q8J6OOAH7AL6OsvCcIw4HAQEOMsSgCTncjyw0bk8CXjMuewPJAGtnevNgQ12f670VXcvjx6GWrncCOdrrXM9BMcX0WLgXhG50lnewlmegWMSm89P20/ZEMvJwFVVHOsr4xgXfrM4ZuoCGATMdJYfFpGF1cT6aYXlOOBTEYnF8eW6u4r3DAc6i0jZepiIhBhjKv6CjwXSq3j/gAp/z/vAPyuUX+Fc/gj4VyXv3QC8KCLPA7ONMUtEpBtwyBizGsAYcwIcrQdgioj0xHF+21eyvxFA9wotpnAc/092A2lAsyr+BuWGNBEoVxLgWWPMm6cUOmZ4Gg4MMMbkicgiHNP+AZw0xpw+o1mB878lVP0ZLqiwLFXUqU5uheV/45jycZYz1ieqeI8XcJ4x5mQ1+83n17+tzhhjtotIb2A08LSIzAe+rKL6H4EjOKZY9AIqi1dwtLzmVLItAMffoRoI7SNQrjQHuFVEQgBEpLmIROP4tXnMmQQ6AudZdPyfgXHOvoIYHJ29tRHOr+PR31ShPBsIrbD+I3BP2YrzF/fptgDtqjjOMhxDHoPjGvwS5/IKHJd+qLD9FCLSDMgzxnwAvAD0BrYBsSLS11kn1Nn5HY6jpVAK3IBjLt7TzQHuFBFf53vbO1sS4GhBVHt3kXIvmgiUyxhjfsRxaWO5iGwAPsPxRfoD4CMiW4DncHzxWeFzHNMObgY+ANYAWbV43xPATBFJBo5WKP8GuLKssxi4F0h0dq5uxnE9/xTGmK1AuLPT+HT3ALeIyHocX9D3OcvvBx5wlrerIuZuwCoR+QX4G/C0MaYQuAbHFKHrgLk4fs2/DtzkLOvIqa2fMm/jOE9rnLeUvsmvra+hwLeVvEe5KR2GWnmUsmv2IhIBrAIGGmMOuziGPwLZxpi3a1k/CMg3xhgRuRZHx/FYS4OsPp7FOCZzP2ZXDKpuaR+B8jSzRaQRjk7fp1ydBJzeAH53BvX74OjcFeA4jjuKbCEiUTj6SzQJNCDaIlBKKQ+nfQRKKeXhNBEopZSH00SglFIeThOBUkp5OE0ESinl4TQRKKWUh/t/GT/yitWaWNsAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 2e-05...\n", "Epoch 1/5\n", "2257/2257 [==============================] - 104s 46ms/step - loss: 0.5075 - acc: 0.8019\n", "Epoch 2/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0892 - acc: 0.9703\n", "Epoch 3/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0302 - acc: 0.9925\n", "Epoch 4/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0156 - acc: 0.9951\n", "Epoch 5/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0040 - acc: 0.9987\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(2e-5, 5)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.95 0.92 0.93 319\n", " 1 0.96 0.98 0.97 389\n", " 2 0.98 0.95 0.97 396\n", " 3 0.94 0.98 0.96 398\n", "\n", " accuracy 0.96 1502\n", " macro avg 0.96 0.96 0.96 1502\n", "weighted avg 0.96 0.96 0.96 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[292, 6, 3, 18],\n", " [ 4, 383, 1, 1],\n", " [ 9, 7, 376, 4],\n", " [ 2, 1, 3, 392]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(val_data=val)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "['sci.med', 'sci.med', 'sci.med']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(test_b.data[0:3])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 2, 2])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_b.target[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }