{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "detected encoding: utf-8\n", "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [ "done." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trn, val, preproc = text.texts_from_folder('data/aclImdb', \n", " maxlen=500, \n", " preprocess_mode='bert',\n", " train_test_names=['train', \n", " 'test'],\n", " classes=['pos', 'neg'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "maxlen is 500\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('bert', trn , preproc=preproc)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, \n", " train_data=trn, \n", " val_data=val, \n", " batch_size=6)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1024\n", " 6492/25000 [======>.......................] - ETA: 19:19 - loss: 0.6908 - acc: 0.6155\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 2e-05...\n", "Train on 25000 samples, validate on 25000 samples\n", "25000/25000 [==============================] - 2304s 92ms/sample - loss: 0.2442 - accuracy: 0.9008 - val_loss: 0.1596 - val_accuracy: 0.9394\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 2e-5 is one of the LRs recommended by Google and is consistent with the plot above.\n", "learner.fit_onecycle(2e-5, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **93.94%** accuracy in a single epoch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make some predictions on new data." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "data = [ 'This movie was horrible! The plot was boring. Acting was okay, though.',\n", " 'The film really sucked. I want my money back.',\n", " 'The plot had too many holes.',\n", " 'What a beautiful romantic comedy. 10/10 would see again!',\n", " ]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "['neg', 'neg', 'neg', 'pos']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To save and reload the the predictor for later use:\n", "```python\n", "predictor.save('/tmp/my_predictor')\n", "reloaded_predictor = ktrain.load_predictor('/tmp/my_predictor')\n", "```\n", "\n", "Please see the [text classification tutorial](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/tutorials/tutorial-04-text-classification.ipynb) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }