{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4-tf\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the Data Into Arrays" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of training set: 2257\n", "size of validation set: 1502\n", "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "print('size of training set: %s' % (len(train_b['data'])))\n", "print('size of validation set: %s' % (len(test_b['data'])))\n", "print('classes: %s' % (train_b.target_names))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Preprocess Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trn, val, preproc = text.texts_from_array(x_train=x_train, y_train=y_train,\n", " x_test=x_test, y_test=y_test,\n", " class_names=train_b.target_names,\n", " preprocess_mode='distilbert',\n", " maxlen=350)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Build a Model and Wrap in Learner" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fasttext: a fastText-like model [http://arxiv.org/pdf/1607.01759.pdf]\n", "logreg: logistic regression using a trainable Embedding layer\n", "nbsvm: NBSVM model [http://www.aclweb.org/anthology/P12-2018]\n", "bigru: Bidirectional GRU with pretrained word vectors [https://arxiv.org/abs/1712.09405]\n", "standard_gru: simple 2-layer GRU with randomly initialized embeddings\n", "bert: Bidirectional Encoder Representations from Transformers (BERT) [https://arxiv.org/abs/1810.04805]\n", "distilbert: distilled, smaller, and faster BERT from Hugging Face [https://arxiv.org/abs/1910.01108]\n" ] } ], "source": [ "text.print_text_classifiers()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "maxlen is 350\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('distilbert', train_data=trn, preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train Model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 3e-05...\n", "Train for 377 steps, validate for 251 steps\n", "Epoch 1/4\n", "377/377 [==============================] - 69s 183ms/step - loss: 0.6874 - accuracy: 0.7882 - val_loss: 0.3342 - val_accuracy: 0.8948\n", "Epoch 2/4\n", "377/377 [==============================] - 58s 154ms/step - loss: 0.1288 - accuracy: 0.9659 - val_loss: 0.2021 - val_accuracy: 0.9341\n", "Epoch 3/4\n", "377/377 [==============================] - 60s 159ms/step - loss: 0.0538 - accuracy: 0.9849 - val_loss: 0.1343 - val_accuracy: 0.9621\n", "Epoch 4/4\n", "377/377 [==============================] - 61s 161ms/step - loss: 0.0176 - accuracy: 0.9973 - val_loss: 0.1415 - val_accuracy: 0.9634\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(3e-5, 4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict on New Data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "p = ktrain.get_predictor(model, preproc)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'comp.graphics'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(\"There is a problem with my computer monitor's resolution. Everything is blurry.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }