{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will apply *ktrain* to the dataset employed in the **scikit-learn** [Working with Text Data](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html) tutorial. As in the tutorial, we will sample 4 newgroups to create a relatively small multiclass text classification dataset. This will provide us an opportunity to see BERT in action on a very small training set. Let's fetch the [20newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) using **scikit-learn**." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of training set: 2257\n", "size of validation set: 1502\n", "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "print('size of training set: %s' % (len(train_b['data'])))\n", "print('size of validation set: %s' % (len(test_b['data'])))\n", "print('classes: %s' % (train_b.target_names))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preprocessing train...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trn, val, preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n", " x_test=x_test, y_test=y_test,\n", " class_names=train_b.target_names,\n", " preprocess_mode='bert',\n", " maxlen=350)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "maxlen is 350\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('bert', train_data=trn, preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, batch_size=6)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", "2257/2257 [==============================] - 110s 49ms/step - loss: 1.3606 - acc: 0.3833\n", "Epoch 2/1024\n", "2257/2257 [==============================] - 94s 42ms/step - loss: 0.3663 - acc: 0.8861\n", "Epoch 3/1024\n", " 186/2257 [=>............................] - ETA: 1:25 - loss: 1.2041 - acc: 0.4946\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 2e-05...\n", "Epoch 1/5\n", "2257/2257 [==============================] - 104s 46ms/step - loss: 0.5075 - acc: 0.8019\n", "Epoch 2/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0892 - acc: 0.9703\n", "Epoch 3/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0302 - acc: 0.9925\n", "Epoch 4/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0156 - acc: 0.9951\n", "Epoch 5/5\n", "2257/2257 [==============================] - 95s 42ms/step - loss: 0.0040 - acc: 0.9987\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(2e-5, 5)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.95 0.92 0.93 319\n", " 1 0.96 0.98 0.97 389\n", " 2 0.98 0.95 0.97 396\n", " 3 0.94 0.98 0.96 398\n", "\n", " accuracy 0.96 1502\n", " macro avg 0.96 0.96 0.96 1502\n", "weighted avg 0.96 0.96 0.96 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[292, 6, 3, 18],\n", " [ 4, 383, 1, 1],\n", " [ 9, 7, 376, 4],\n", " [ 2, 1, 3, 392]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(val_data=val)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "['sci.med', 'sci.med', 'sci.med']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(test_b.data[0:3])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 2, 2])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_b.target[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }