{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text Classification with Hugging Face Transformers in *ktrain*\n", "\n", "As of v0.8.x, *ktrain* now includes an easy-to-use, thin wrapper to the Hugging Face transformers library for text classification." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Data Into Arrays" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of training set: 2257\n", "size of validation set: 1502\n", "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "print('size of training set: %s' % (len(train_b['data'])))\n", "print('size of validation set: %s' % (len(test_b['data'])))\n", "print('classes: %s' % (train_b.target_names))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Preprocess Data and Build a Transformer Model\n", "\n", "For `MODEL_NAME`, *ktrain* supports both the \"official\" built-in models [available here](https://huggingface.co/transformers/pretrained_models.html) and the [community-upoaded models available here](https://huggingface.co/models)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4-tf\n", "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import ktrain\n", "from ktrain import text\n", "MODEL_NAME = 'distilbert-base-uncased'\n", "t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)\n", "trn = t.preprocess_train(x_train, y_train)\n", "val = t.preprocess_test(x_test, y_test)\n", "model = t.get_classifier()\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `x_train` and `x_test` are the raw texts that look like this:\n", "```python\n", "x_train = ['I hate this movie.', 'I like this movie.']\n", "```\n", "The labels are arrays in one of the following forms:\n", "```python\n", "# string labels\n", "y_train = ['negative', 'positive']\n", "# integer labels\n", "y_train = [0, 1] # labels must start from 0 if in integer format\n", "# multi or one-hot encoded labels\n", "y_train = [[1,0], [0,1]]\n", "```\n", "In the latter two cases, you must supply a `class_names` argument to the `Transformer` constructor, which tells *ktrain* how indices map to class names. In this case, `class_names=['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']` because 0=alt.atheism, 1=comp.graphics, etc." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2 [Optional]: Estimate a Good Learning Rate\n", "\n", "Learning rates between `2e-5` and `5e-5` tend to work well with transformer models based on papers from Google. However, we will run our learning-rate-finder for two epochs to estimate the LR on this particular dataset.\n", "\n", "As shown below, our results are consistent Google's findings." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 376 steps\n", "Epoch 1/2\n", "376/376 [==============================] - 73s 194ms/step - loss: 1.0788 - accuracy: 0.5191\n", "Epoch 2/2\n", "115/376 [========>.....................] - ETA: 43s - loss: 1.9950 - accuracy: 0.2482\n", "\n", "done.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True, max_epochs=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train Model\n", "\n", "Train using a [1cycle learning rate schedule](https://arxiv.org/pdf/1803.09820.pdf)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 8e-05...\n", "Train for 377 steps, validate for 251 steps\n", "Epoch 1/4\n", "377/377 [==============================] - 89s 236ms/step - loss: 0.5214 - accuracy: 0.8285 - val_loss: 0.2847 - val_accuracy: 0.9081\n", "Epoch 2/4\n", "377/377 [==============================] - 80s 213ms/step - loss: 0.1524 - accuracy: 0.9513 - val_loss: 0.5775 - val_accuracy: 0.8309\n", "Epoch 3/4\n", "377/377 [==============================] - 81s 215ms/step - loss: 0.1066 - accuracy: 0.9739 - val_loss: 0.2469 - val_accuracy: 0.9387\n", "Epoch 4/4\n", "377/377 [==============================] - 81s 215ms/step - loss: 0.0318 - accuracy: 0.9907 - val_loss: 0.1645 - val_accuracy: 0.9561\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(8e-5, 4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Evaluate/Inspect Model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " alt.atheism 0.92 0.93 0.93 319\n", " comp.graphics 0.97 0.97 0.97 389\n", " sci.med 0.97 0.95 0.96 396\n", "soc.religion.christian 0.96 0.96 0.96 398\n", "\n", " accuracy 0.96 1502\n", " macro avg 0.95 0.96 0.95 1502\n", " weighted avg 0.96 0.96 0.96 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[298, 2, 8, 11],\n", " [ 7, 378, 3, 1],\n", " [ 5, 8, 378, 5],\n", " [ 15, 0, 1, 382]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=t.get_classes())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------\n", "id:521 | loss:7.12 | true:sci.med | pred:comp.graphics)\n", "\n" ] } ], "source": [ "# the one we got most wrong\n", "learner.view_top_losses(n=1, preproc=t)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From: jim.zisfein@factory.com (Jim Zisfein) \n", "Subject: Data of skull\n", "Distribution: world\n", "Organization: Invention Factory's BBS - New York City, NY - 212-274-8298v.32bis\n", "Reply-To: jim.zisfein@factory.com (Jim Zisfein) \n", "Lines: 11\n", "\n", "GT> From: gary@concave.cs.wits.ac.za (Gary Taylor)\n", "GT> Hi, We are trying to develop a image reconstruction simulation for the skull\n", "\n", "You could do high resolution CT (computed tomographic) scanning of\n", "the skull. Many CT scanners have an algorithm to do 3-D\n", "reconstructions in any plane you want. If you did reconstructions\n", "every 2 degrees or so in all planes, you could use the resultant\n", "images to create user-controlled animation.\n", "---\n", " . SLMR 2.1 . E-mail: jim.zisfein@factory.com (Jim Zisfein)\n", " \n", "\n" ] } ], "source": [ "# understandable mistake - this sci.med post talks a lot about computer graphics\n", "print(x_test[521])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 5: Make Predictions on New Data in Deployment" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc=t)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'soc.religion.christian'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=soc.religion.christian\n", " \n", "\n", "\n", " \n", " (probability 0.998, score 7.287)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +7.336\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.049\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " jesus christ is the central figure of christianity.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.explain('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_20newsgroup_predictor')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "reloaded_predictor = ktrain.load_predictor('/tmp/my_20newsgroup_predictor')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'soc.religion.christian'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([8.9553175e-03, 3.1522836e-04, 3.8172584e-04, 9.9034774e-01],\n", " dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict_proba('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.get_classes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional Tips and Tricks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have a **transformers** model that has already been trained/fine-tuned, you can easily wrap it into a **ktrain** `Predictor`. The example below loads the pre-fine-tuned [coronabert model](https://huggingface.co/jakelever/coronabert) into **ktrain** to make predictions:\n", "\n", "```python\n", "# Import ktrain along with a couple things from transformers\n", "import ktrain\n", "from transformers import TFAutoModelForSequenceClassification\n", "\n", "# Load the model and compile it for ktrain/tf.Keras\n", "model = TFAutoModelForSequenceClassification.from_pretrained(\"jakelever/coronabert\")\n", "model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])\n", "\n", "# Pull the categories out of the model (or set class_names manually)\n", "class_names = list(model.config.id2label.values())\n", "\n", "# Set up a ktrain preprocessor (which manages the tokenization) with the labels\n", "preproc = ktrain.text.Transformer('jakelever/coronabert',maxlen=500,class_names=class_names)\n", "preproc.preprocess_train_called = True # needed to suppress warnings about not calling preprocess_train\n", "\n", "# Get the predictor (which takes in the model and tokenizer info)\n", "predictor = ktrain.get_predictor(model, preproc)\n", "\n", "# Make predictions\n", "text = [\"A genomic region associated with protection against severe COVID-19 is inherited from Neandertals.\"]\n", "predictor.predict(text)\n", "\n", "# OUTPUT:\n", "# [[('Clinical Reports', 0.0003284997),\n", "# ('Comment/Editorial', 0.0022700194),\n", "# ('Communication', 0.00060458254),\n", "# ('Contact Tracing', 0.00027690193),\n", "# ('Diagnostics', 0.0003987006),\n", "# ('Drug Targets', 0.0008852846),\n", "# ('Education', 0.00018228142),\n", "# ('Effect on Medical Specialties', 0.00045943243),\n", "# ('Forecasting & Modelling', 0.00047854715),\n", "# ('Health Policy', 0.00042494797),\n", "# ('Healthcare Workers', 6.292213e-05),\n", "# ('Imaging', 0.00021008229),\n", "# ('Immunology', 0.00072542584),\n", "# ('Inequality', 0.0007106358),\n", "# ('Infection Reports', 0.00033797201),\n", "# ('Long Haul', 0.00034338655),\n", "# ('Medical Devices', 0.0002488097),\n", "# ('Meta-analysis', 0.00030506376),\n", "# ('Misinformation', 0.0012771417),\n", "# ('Model Systems & Tools', 0.0020338537),\n", "# ('Molecular Biology', 0.9950799),\n", "# ('News', 0.00034808667),\n", "# ('Non-human', 0.98562455),\n", "# ('Non-medical', 0.0005655724),\n", "# ('Pediatrics', 0.00042545484),\n", "# ('Prevalence', 0.0011711525),\n", "# ('Prevention', 0.00043099752),\n", "# ('Psychology', 0.00045698017),\n", "# ('Recommendations', 0.0004172316),\n", "# ('Review', 0.002200645),\n", "# ('Risk Factors', 0.00014382145),\n", "# ('Surveillance', 0.00081551325),\n", "# ('Therapeutics', 0.0010580326),\n", "# ('Transmission', 0.0031670583),\n", "# ('Vaccines', 0.0011023124)]]\n", "\n", "```\n", "\n", "Finally, to make predictions with a smaller deployment footprint, you can export the model to ONNX format as described in [this example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/ktrain-ONNX-TFLite-examples.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }