{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text Classification with Hugging Face Transformers in *ktrain*\n", "\n", "As of v0.8.x, *ktrain* now includes an easy-to-use, thin wrapper to the Hugging Face transformers library for text classification." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Data Into Arrays" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size of training set: 2257\n", "size of validation set: 1502\n", "classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']\n" ] } ], "source": [ "categories = ['alt.atheism', 'soc.religion.christian',\n", " 'comp.graphics', 'sci.med']\n", "from sklearn.datasets import fetch_20newsgroups\n", "train_b = fetch_20newsgroups(subset='train',\n", " categories=categories, shuffle=True, random_state=42)\n", "test_b = fetch_20newsgroups(subset='test',\n", " categories=categories, shuffle=True, random_state=42)\n", "\n", "print('size of training set: %s' % (len(train_b['data'])))\n", "print('size of validation set: %s' % (len(test_b['data'])))\n", "print('classes: %s' % (train_b.target_names))\n", "\n", "x_train = train_b.data\n", "y_train = train_b.target\n", "x_test = test_b.data\n", "y_test = test_b.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Preprocess Data and Build a Transformer Model\n", "\n", "For `MODEL_NAME`, *ktrain* supports both the \"official\" built-in models [available here](https://huggingface.co/transformers/pretrained_models.html) and the [community-upoaded models available here](https://huggingface.co/models)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "using Keras version: 2.2.4-tf\n", "preprocessing train...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing test...\n", "language: en\n" ] }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import ktrain\n", "from ktrain import text\n", "MODEL_NAME = 'distilbert-base-uncased'\n", "t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)\n", "trn = t.preprocess_train(x_train, y_train)\n", "val = t.preprocess_test(x_test, y_test)\n", "model = t.get_classifier()\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `x_train` and `x_test` are the raw texts that look like this:\n", "```python\n", "x_train = ['I hate this movie.', 'I like this movie.']\n", "```\n", "The labels are arrays in one of the following forms:\n", "```python\n", "# string labels\n", "y_train = ['negative', 'positive']\n", "# integer labels\n", "y_train = [0, 1] # labels must start from 0 if in integer format\n", "# multi or one-hot encoded labels\n", "y_train = [[1,0], [0,1]]\n", "```\n", "In the latter two cases, you must supply a `class_names` argument to the `Transformer` constructor, which tells *ktrain* how indices map to class names. In this case, `class_names=['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']` because 0=alt.atheism, 1=comp.graphics, etc." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2 [Optional]: Estimate a Good Learning Rate\n", "\n", "Learning rates between `2e-5` and `5e-5` tend to work well with transformer models based on papers from Google. However, we will run our learning-rate-finder for two epochs to estimate the LR on this particular dataset.\n", "\n", "As shown below, our results are consistent Google's findings." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 376 steps\n", "Epoch 1/2\n", "376/376 [==============================] - 73s 194ms/step - loss: 1.0788 - accuracy: 0.5191\n", "Epoch 2/2\n", "115/376 [========>.....................] - ETA: 43s - loss: 1.9950 - accuracy: 0.2482\n", "\n", "done.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEMCAYAAADJQLEhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU5dn/8c+VfU8ICSGQhH1flbCoqKiouFIVF1yqVqVardraPtYuaref1ta2T12LPJa64a4gqKAFBReEsENADHtIIPu+z9y/P2bAmCYhgTlzZrner1dezJxz5sx1J2S+Oec+577FGINSSqngFWJ3AUoppeylQaCUUkFOg0AppYKcBoFSSgU5DQKllApyGgRKKRXkLAsCEXleRIpEZGsH6xNF5D0R2SQi20TkZqtqUUop1TErjwjmAzM6WX8nkGuMGQdMAx4XkQgL61FKKdWOMKt2bIxZKSL9O9sEiBcRAeKAMqDlWPtNSUkx/ft3tlullFJtrVu3rsQYk9reOsuCoAueBBYBBUA8cLUxxtnehiIyB5gDkJWVRU5OjteKVEqpQCAi+zpaZ2dn8fnARqAPMB54UkQS2tvQGDPXGJNtjMlOTW030JRSSh0nO4PgZuBt45IH7AGG21iPUkoFJTuDYD9wDoCIpAHDgN021qOUUkHJsj4CEVmA62qgFBHJBx4CwgGMMc8Cvwfmi8gWQID7jTElVtWjlFKqfVZeNTT7GOsLgPOsen+llFJdo3cWK6VUkNMgUEopP/Bx7mG+OVxtyb41CJRSyg/86OX1vLX+oCX71iBQSik/4DCGsBCxZN8aBEop5eOMMTichlANAqWUCk4OpwHQIFBKqWDVokGglFLBzWlcQaB9BEopFaT0iEAppYKcw6FBoJRSQc2hp4aUUiq4fXvVkDUf2RoESinl477tI7Bm/xoESinl477tI9AjAqWUCkraR6CUUkHO4XQCetWQUkoFLb2PQCmlglyL3keglFLBTYeYUEqpIKenhpRSKsjpMNRKKRXkNAiUUirIHQmCML2hTCmlgpPfDjEhIs+LSJGIbO1km2kislFEtonIp1bVopRS/uzbG8r874hgPjCjo5UikgQ8DVxqjBkFXGlhLUop5bccrhzwv8tHjTErgbJONrkWeNsYs9+9fZFVtSillD8L5CEmhgI9ROQTEVknIt+3sRallPJZVt9HEGbJXrv+3hOAc4Bo4EsRWW2M2dl2QxGZA8wByMrK8mqRSillt0C+fDQfWGqMqTXGlAArgXHtbWiMmWuMyTbGZKempnq1SKWUstu3l48GXhAsBKaKSJiIxACTge021qOUUj7pyKmhEPGzU0MisgCYBqSISD7wEBAOYIx51hizXUQ+BDYDTmCeMabDS02VUipYHT0iCPWzIDDGzO7CNn8G/mxVDUopFQgCuY9AKaVUF+gQE0opFeSOXj5qUR+BBoFSSvm4ozeUWdRHoEGglFI+zm+HmFBKKeUZR44IrLp8VINAKaV8XEsA31CmlFKqC5xOgwiEaBAopVRwanEay44GQINAKaV8nsNpLOsfAA0CpZTyeXpEoJRSQc7hNJYNLwEaBEop5fM0CJRSKsi1OI1lE9eDBoFSSvm8xhYH4RYNLwEaBEop5fP2lNTSr2eMZfvXIFBKKR/mdBp2HqpmeO8Ey95Dg0AppXzYwYp6apscDOsdb9l7aBAopZQPW7+/HIBRffSIQCmlgtLyHUWkxEUwuk+iZe+hQaCUUj7KGMNn35RwxpBUywacAw0CpZTyWYeqGiitbWJcZpKl76NBoJRSPiq3oAqAkRb2DwCEWbp3pZRSXVbf5OCVNftZsaOIhmYHuYWuIBiRrkGglFJB4a31+fx+cS4Aw3vHU9fk4Pun9CMu0tqPasv2LiLPAxcDRcaY0Z1sNxH4ErjGGPOmVfUopZSvO1hRD8BLt0zmtME92VVcy8CUWMvf18o+gvnAjM42EJFQ4E/AMgvrUEopv1BQUU9WcgxTh6QgIgzuFWfp1UJHWBYExpiVQNkxNvsx8BZQZFUdSinlLwoq6umTFOX197Wtj0BE+gKXAWcBE+2qoz1NLU52HKqiZ1wk+WV1lNU2ERkewulDUgkP/W52Vjc0Ex8V3uV9G2MQC6ecU0r5jz0ltcRHhREVHkpcZBgFFQ1MHpjs9Trs7Cz+O3C/McZ5rA9GEZkDzAHIysryyJvXNbXwr8/34nQaLhiTTmZyNGv2lPFR7mHe2XCQ6oaW/3pNSlwE4zN70NDsoKSmkaLqRspqmxiWFk9Di4PE6HB6xUdRVd9MZHgI6/aVExUeSq/4SDKTYyisrKe4upGBKXFU1jczLjOR80b1prHZQa+EKNITo+idEIWIUFnXjMHgcLqCIzk2AnANRxsWEmLpJBVKKWvVNbXw8KJtvJ6Tf3RZVnIMByvqyUiK9no9Yoyxbuci/YHF7XUWi8ge4MinWQpQB8wxxrzb2T6zs7NNTk5Ot2tpdjgpr2vioYXbCA0RlmwppHXTQwScBkTge+P7Mm1YKgUVDQxMjaVvUjSFlQ28kXOAfaV1REeEkhQTTnpiNA3NDvaU1JKeGMX2wirqmx2kxkfS4jBk9++BMVBU3ciOQ1U4ndDkcJKWEElKXCSrd5fS0Oz8Tp3R4aHERYVRVtuEw/ltgQlRYfSIjaCwsgEMTB6YzJlDU6lqaMHhdLJhfwWFlQ0M6RV3NDzSEiKpbmhh68FKGpodDEmLp7axhYr6ZmIjw5gyMJnJA5KJjQgjLiqMQalxRIWHdvt7q5TqmtKaRvaW1nHHS+soqm7kuslZJMdGEBYSQm5hJZFhoTx4yUhS4iI9/t4iss4Yk93uOruCoM12893bHfOqoeMNgv9sP8wt/3a9LiYilLomB/fPGM5lJ/Xlg62FbMmv5MIx6YzNSKRXgnfO0ZXWNLLzcA1xkWEU1zRwsKKBvMPVFFY2MCAllvioMMJDQwgR4UB5HeV1zcRHhRETHspb6/Mpr2sGXCE2sk8CWckxbDpQSYvTSY+YCIqqG3Eaw5QBPYkIC+GbohqSosNJigmnvK6JnL3ltLQKm9AQIbtfDzJ6xBAdEcLYjCQq65qJCAuhqLqB2kYH8VFhDOsdT0xEKCPSE0iNiyQsVO9LVOpYFqzZzy/f2YIxrrML/7xhAhP6ee80kC1BICILgGm4/to/DDwEhAMYY55ts+18LA6CvKIaFqzZT3piFLeePpCy2qajp1v8kdNpqKxvPvoXfHRE9/+SL6ttYl9pLXVNDirqmtlysJKVO4upqGuitLaJxpZvj1ZCQ4To8FBqm1q+cyQVFxnG2IxEhvWOp8VhKK1tZPqINMZlJjGgZ6xXrnhQylcZYyirbeI/O4p4eNE2EqLCSYgO4yfTh3LBmHSv1mLbEYEVjjcIVPfUNrZwuKqBHjERlNU1kZ4YRUyE65RVSU3j0eDYVVzDtoOVbC+sJjRESIgO43BVI+A6nZXdP5lJA5KZ2D+ZmIhQwkNDGNwrzubWKWWdyrpmnlj+DZ/llbDzcDVHDrqH9Irj5Vsne+2MQ1saBMpyDqfBGEOICOv2l7OnpJb1+8pZs7eM3cW139n21EE9mT0pixmje//XVVhK+TOn0zDr2S/YcKCCkzKTGNwrjhaH4cZT+zO6b6KtF3loEChbFVc3sm5fGS1OQ355Pc+t3E1pbROxEaF876S+XDs5i1EWjrWulNUKKuqZu3I3H+Ue5mBFPY/NGstV2Zl2l/UdGgTKpzidhmW5h1i0qYAVO4qpb3Zw0dh0Hrl8DAnduCdDKTut3FnMuxsPcrC8njV7yxDg3JFpTBnYkxtP6e9z/WMaBMpnVdY1M++z3TyxPI+UuAieuX4CE/t7/4Yapbpj4caD3PPqRhKjw+mbFM30kWnMOjmDrJ4xdpfWIQ0C5fM2HajgzlfWk19ez3kj05g6JIVrJ2XppanKp9Q0tnDbv3P4cncpE/v34MVbJvvNvTedBYEOQ618wrjMJJbeewY/e2MTH2w9xLLcw+wvrePXF4+0uzSlANck8tc99xX1zQ5+NG0Qd58zxG9C4Fg0CJTPiI0M45nrJ+BwGh5etI15n+3hzGGpnD4k1e7SVJB7c10+Dy7cSn2zgzvPGsTPzx9ud0kepUGgfE5oiPCLC4azencpt8zP4YoJfbnjzME+ff5VBZaSmkY27K9g2bZDbM6v5OvD1QzvHc+T154ckPfBaB+B8lklNY3c9kIOG/ZXkJUcw9J7zziuO6iV6o7//fgb/vc/O3EaCA8VRqYn0DMukkcuH0OaTTeDeYL2ESi/lBIXyTs/Oo0v8kq4dt5XPLdqN3efM8TuslQAO1hRz5MrvmHqkFTuPnswo/okBsUfHxoEyuedOjiFi8am8/ePd9I7McrnbtRRgaGkppHLnvocgN9dOor+Xpgi0ldoECi/8NgVY6msa+Z/3txMXGQYF3p5wC4VuCrrmqltauGHL66jvK6JF34wOahCADQIlJ+IjQxj3o3ZzH5uNXe+sp67zhrMfecNs7ss5aeMMazfX8Fb6/N55av9AMRGhPLnWeM4ZVBPm6vzPg0C5TeiwkN5+dbJ3Pf6Jp5ckcfM8X0D8goOZR1jDC+t3scLX+7jm6IaRGDSgGR6xkbw/VP6B2UIgAaB8jMxEWH8/nujWb6jiHmrdvPoFWPtLkn5kWW5h/nNwm2MzUjk4UtGcs6INDKT9bJkDQLld1LiIrkyO4PX1+bz03OH2ja+u/IvWw9Wcv9bmxmaFsfbd5yqw5e0ot8J5ZdunTqQZqeTF1fvs7sU5QfW7CnjunlfERsRxrzvT9QQaEO/G8ov9U+J5Zzhaby0eh8NzQ67y1E+yOk0vLepgAv/dxVX/fNL4iLDeHXOFL1DvR0aBMpv3Xr6AMrrmnn6k112l6J8TG1jC79bnMuPF2zgm6Jqfn3RCD6893TtD+iA9hEovzV5QDKXndSXf/znG84dkcaYDJ3lTEF+eR2znvmSQ1UNTB6QzN+vGU96YrTdZfk0PSJQfktE+O3MUSTFhPObhVv1FJGirqmFe1/dSHVDMy/fOpkFt03REOgCDQLl1xKiwnnksjFsPFDB3z7eaXc5yib1TQ7mf76HWc98yfr95Tw2axynDU7xuekifZUGgfJ7F4xJZ9aEDJ7/bA/F1Y12l6O8bMP+cmY9+wUPv5dLRV0Tf7t6PBeN1SFIukP7CFRAuGPaIN5cl89ra/dz19k6QmkwKKys549LtrN4cyFxkWHM+34200em2V2WX7LsiEBEnheRIhHZ2sH660Rks4hsEZEvRGScVbWowDcoNY6pg1N45av9tDicdpejvOCeVzeyeHMh10zMZPnPztQQOAFWnhqaD8zoZP0e4ExjzBjg98BcC2tRQeD6Kf0oqGxg+Y4iu0tRFtt0oII1e8r4zcUjefSKsfSK17vLT4RlQWCMWQmUdbL+C2NMufvpaiDDqlpUcJg+ohe94iN5de0Bu0tRFiqrbWLeZ3uIDAvhymz92PAEX+ksvgX4wO4ilH8LCw1h9qQslu8o4q5X1lPfpJeTBppfvbOFk3//Ee9tKmDWhAwSosLtLikg2N5ZLCJn4QqCqZ1sMweYA5CVleWlypQ/uvX0Aby4eh+LNxcyMDWOn5471O6SlIe8v6WQl7/azxUnZ3D9lCzGZiTZXVLA6NIRgYjcIyIJ4vJ/IrJeRM470TcXkbHAPGCmMaa0o+2MMXONMdnGmOzU1NQTfVsVwOKjwln9wDlcNCaduSt3UVhZb3dJygO+PlTN3Qs2MC4ziT9eNpqTsnoQqvcIeExXTw39wBhTBZwH9ABuAB49kTcWkSzgbeAGY4zeCaQ8JiIshAcuHI7TCXNX7ra7HHWCjDH8ZuFW4qPCmH/TRKLCA38yeW/rahAcid4LgReNMdtaLWv/BSILgC+BYSKSLyK3iMjtInK7e5MHgZ7A0yKyUURyjqN+pdqV0SOG6SN7sXBjAZX1zXaXo07Awo0FrNlTxv/MGE6P2Ai7ywlIXe0jWCciy4ABwAMiEg90erG2MWb2MdbfCtzaxfdXqttumTqAj3IPc9/rG5l340S7y1HHoaHZwf97fzvjMpO4OjvT7nICVlePCG4BfgFMNMbUAeHAzZZVpZQHTOiXzM2nDeDTncXUNrbYXY46Dm+tz6eoupH7ZwzTcYMs1NUgOAX42hhTISLXA78GKq0rSynPmDY0lWaH4fO8ErtLUd3kcBqeW7mbcRmJnDIwOCeV95auBsEzQJ17GIj7gF3AC5ZVpZSHZPdPJikmnPc2F9pdiuqmD7ceYm9pHT88cxAiejRgpa4GQYsxxgAzgSeNMU8B8daVpZRnRISFcMnYPizbdkhHJvUjxhieXJHHwNRYzh/V2+5yAl5Xg6BaRB7AddnoEhEJwdVPoJTPu/m0/jichqdW5NldiuqipdsOsb2wijvOHKT3C3hBV4PgaqAR1/0Eh3CNC/Rny6pSyoMGpsZx6bg+vLUuX4ed8AMLNx7k7gUbGdIrjpnj+9pdTlDoUhC4P/xfBhJF5GKgwRijfQTKb1w9MZPqxhbe36J9Bb5s44EK7nt9EydlJfHG7acQEeYrw6EFtq4OMXEVsAa4ErgK+EpEZllZmFKeNGlAMv17xvBajo5M6qvqmlr48YL1pCVEMfeGbJJi9OYxb+lq3P4K1z0ENxpjvg9MAn5jXVlKeZaIcGV2Jmv2lLGnpNbuclQ7nvlkFwfK6vnb1eNJjNEuSG/qahCEGGNaz/ZR2o3XKuUTZk3IIDREeF2PCnxOfnkdc1fuZub4PkwakGx3OUGnqx/mH4rIUhG5SURuApYA71tXllKel5YQxVnDUnlzXb5OZ+ljHvlgByJw/4zhdpcSlLraWfxzXFNJjnV/zTXG3G9lYUpZ4arsTIqrG/lM7zT2GV/tLmXJ5kJuP3MQfZKi7S4nKHX59I4x5i1jzE/dX+9YWZRSVjljqGs+i5v+tZZV3xTbXI1yOA2/W5xLn8QofnjGILvLCVqdBoGIVItIVTtf1SJS5a0ilfKUqPBQxme6Zrb6+Rubcd0wr+zypw93sK2gil9cOILoCJ1nwC6dBoExJt4Yk9DOV7wxJsFbRSrlSc9eP4GrszM5VNXAloM6dqJdNudXMHflbmZPyuKSsel2lxPU9MofFXR6J0bxiwuGExoifJR72O5yglJji4MH3t5CcmwEv7xwuA4qZzMNAhWUesRGkN2vBx9sPaSnh2zwwhf72FZQxZ+uGEt8lN4zYDcNAhW0Zo7vS15RjZ4e8rKqhmae+iSPM4amcu7INLvLUWgQqCB20dh0IsNCeHNdvt2lBA1jDL9dlEtFXTP/c/4wu8tRbhoEKmglRodz/qjeLNxYoKOSeskb6/J5a30+d589mNF9E+0uR7lpEKigdsMp/aisb+b5z/fYXUrAK6pq4A+Lc5nUP5l7pw+1uxzVigaBCmoT+yczdXAKr609oJ3GFvvbxztpaHby6BVjdCJ6H6NBoILe+aN7s7+sjl3FOiqpVbYVVPJGTj7XTMpkYGqc3eWoNjQIVNA7e3gvAJbv0HsKrNDY4uCnr22iR2yEnhLyURoEKuj1TYpmeO94lu8oOvbGqtveWX+Qrw9X8+jlY0iO1clmfJFlQSAiz4tIkYhs7WC9iMg/RCRPRDaLyMlW1aLUsUwfkcbaveXkFugQWp72wdZDZCXHHD3yUr7HyiOC+cCMTtZfAAxxf80BnrGwFqU6dcvUASRFh/OXZV/bXUpA2XGois/zSrhgdG8dRsKHWRYExpiVQFknm8wEXjAuq4EkEdGRp5QtesRGcGV2Jp/uLKa0ptHucgLG3z7aSXxUGD88U4eY9mV29hH0BVrPGZjvXvZfRGSOiOSISE5xsY4hr6wxc3wfHE7D+1sP2V1KQKhqaGbFjmIuOylD+wZ8nF90Fhtj5hpjso0x2ampqXaXowLU8N7xDE2LY+GGg3aXEhBW7iymyeHkorG97S5FHYOdQXAQyGz1PMO9TClbiAhXTsgkZ1856/Z1dlZTdcXneSXER4UxPrOH3aWoY7AzCBYB33dfPTQFqDTGFNpYj1JcNyWLHjHhPP/ZXrtL8WvvbjjIgjUHmDKwJ6F6F7HPC7NqxyKyAJgGpIhIPvAQEA5gjHkWeB+4EMgD6oCbrapFqa6KiQjj4rF9eGPdAWobW4iNtOxXJKAt3ebqZ/m5jjDqFyz7X26MmX2M9Qa406r3V+p4XTw2nRdX7+Pj7YeZOb7d6xfUMew8XM25I9MYmhZvdymqC/yis1gpb5rYP5m0hEgWb9YzlcejscXB3tI6hqbpmEL+QoNAqTZCQoSLxvTh06+Lqaxvtrscv5NXVIPDafRowI9oECjVjkvGpdPkcOrk9sfh052ue30mD+hpcyWqqzQIlGrH+MwkMnpE896mArtL8StvrsvnsQ+/ZnjveHonRtldjuoiDQKl2iEiXDy2D5/nlVBW22R3OX5jhXsE119dNMLmSlR3aBAo1YGLx6bT4jR8qENOdNnGAxVcPDad04foCAD+RINAqQ6M6pPAwJRYFm/W00NdsbekloMV9YzPTLK7FNVNGgRKdcB1eiid1btLKapusLscn/fsp7uICAvhknF97C5FdZMGgVKduHhcH5wGlurpoU7ll9fx5rp8Zk/MJC1BO4n9jQaBUp0YmhZP/54xLMs9jOtmeNWel1bvx4DOO+CnNAiUOoazh6ex6psS7nt9k92l+KQWh5N3Nxxk2tBU+iRF212OOg4aBEodw21nDKBPYhQLNxVQVKV9BW29u7GAQ1UNzJ6UZXcp6jhpECh1DOmJ0bx82xQcTsMb6/LtLsen7Cut5XfvbWNEegLnjNDJ6f2VBoFSXTAgJZZTB/VkwZr9OJ3aVwBQ29jC9L9+SlVDCz8+e7BOTu/HNAiU6qJrJ2eRX17PqrwSu0vxCR9vP0yzw3DnWYO4YLROR+nPNAiU6qLzRvYmJS6Ceat2212K7RxOw9yVu8lMjua+c4fp0YCf0yBQqosiwkK4/cxBrPqmhDV7gntO4093FrGtoIr7zh1GiE5F6fc0CJTqhuun9CMpJpznP9tjdym2+nDrIeIjw7hwTLrdpSgP0CBQqhuiwkOZdXIGH28/HLST1ry/pZB3NxRw7qg0IsL0IyQQ6E9RqW66YIxrVNJPvi6yuxSvyy2o4q5X1jMmI5EHLx5pdznKQzQIlOqmkzKTSImLZNm24Ju97OlP8oiNCOP5myaSFBNhdznKQzQIlOqmkBDh3JFpLNlSyOrdpXaX4zXVDc0s23aYKyZkkBgdbnc5yoM0CJQ6Dt8b7xpq+Y6X1tHU4rS5Gu/4ePthmhxOHWY6AGkQKHUcJg/syT9vmEB5XTPLdwTHKaIlmwvpkxjFSTrxTMCxNAhEZIaIfC0ieSLyi3bWZ4nIChHZICKbReRCK+tRypOmj0gjLSGSuSt309jisLscS1XWN7NyZwkXjknX+wYCkGVBICKhwFPABcBIYLaItL3M4NfA68aYk4BrgKetqkcpTwsNES47KYP1+ys447EV7Cmptbsky3yc6zotdLGeFgpIVh4RTALyjDG7jTFNwKvAzDbbGCDB/TgR0MlhlV+548xB/Pz8YZTXNfOvzwP3JrNluYfomxTNuIxEu0tRFrAyCPoCB1o9z3cva+1h4HoRyQfeB37c3o5EZI6I5IhITnFxsRW1KnVcEmPCufOswZw7Io0lmwtpaA7MU0QHK+oZmhanYwoFKLs7i2cD840xGcCFwIsi8l81GWPmGmOyjTHZqampXi9SqWO54ZR+lNY2MXdlYA5IV1bTRHJspN1lKItYGQQHgcxWzzPcy1q7BXgdwBjzJRAFpFhYk1KWmDKwJ+cM78ULX+6j2RFYl5MaYyipbSIlTm8gC1RWBsFaYIiIDBCRCFydwYvabLMfOAdAREbgCgI996P80uxJWZTUNLJiR2ANPVHT2EJTi5OeGgQBy7IgMMa0AHcBS4HtuK4O2iYivxORS92b3QfcJiKbgAXATcYYnf5J+aVpw1LpFR/J/C/24gigWcxKa5oA6KmnhgJWmJU7N8a8j6sTuPWyB1s9zgVOs7IGpbwlLDSEG6b04/GPdvKHJbk8dMkou0vyiNJaVxAk6xFBwLK7s1ipgHLX2YO5dFwfXv5qPyU1jXaX4xGl7nak6BFBwNIgUMqDRIR7pg+h2eHk31/stbscjzhc7QqC1HgNgkClQaCUhw1KjeP8kb159tNdATFnwYGyOiLCQuilQRCwNAiUssCfrhhL36RonlieZ3cpJ2x/aR1ZyTE6xlAA0yBQygKJMeFcP6Uf6/aVs+ob/74iel+ZKwhU4NIgUMoi103ux+Becfzqna1+ezmpMYYDGgQBT4NAKYtER4Tyk+lD2V9W57c3ma3fX0FNYwtj+upgc4FMg0ApC503Ko30xCjm++EVRMYY5n+xl4iwEM4blWZ3OcpCGgRKWSg8NITrp/Tjs7wSFm48iNOPThEt2lTAe5sKuP2MgcRH6RzFgUyDQCmLXT3RNfbiPa9u5JU1+22upmvqmlr4y7KvGZmewL3Th9pdjrKYBoFSFkuJi+TBi12T8y3e7B9zLz376W4OlNXz4CUj9bLRIGDpWENKKZcfTB1AeV0TT63Io6Cinj5J0XaX1Kklmws4bXBPpgzsaXcpygv0iEApLzlyiujlr/bZXEnn9pXWsqu4lnNHaAdxsNAgUMpLMnrEcObQVN5Z79udxhsPVAAwWY8GgoYGgVJeNHN8XwoqG1jpw3cb5xZUEREawuBecXaXorxEg0ApL5oxujf9esbwhyXbffaoILewiqG94wgP1Y+HYKE/aaW8KCrcdbdxXlENX+4utbucdn1zuIahafF2l6G8SINAKS+bMbo3STHhvPKV791T0NTi5HB1A5k9dGyhYKJBoJSXRYWHcsXJGSzddug7s5g5nIZnP91FuXtqSDscqmzAGOjbw7cvb1WepUGglA2uys6kxWn4YEvh0WUbD1Tw6Ac7uPWFHNvqyi+vAyBDgyCoaBAoZYOhaXEM7hXHq2sPUN/kAOBwVQMA6/aVU9XQbEtd+RX1AGQk6amhYKJBoJQNRIR7zhlCbmEVT3/imsWssLLh6PpFG+0ZimJXUQ1hIULvxChb3l/ZQ4NAKZtcMlX+pfkAAA2OSURBVK4PUwen8MTyPLYXVlFYUU9UeAjjMpP445Lt7Cmp9XpNn+4sJrt/DyLC9KMhmOhPWykbXXFyBgA/fHEdhZUN9EmM5p/XT6DF6eSBtzdT09hyQvsvqKgnr6i6S9vmFdWw41A104b1OqH3VP7H0iAQkRki8rWI5InILzrY5ioRyRWRbSLyipX1KOVrZo7vw73Th7C/rI4lWwpJT4qid2IU54/qzerdZTy4cOsx99HicLY7Feb2wiou+scqrnjmS2qPESj1TQ4eeHsz8ZFhR8NJBQ/LgkBEQoGngAuAkcBsERnZZpshwAPAacaYUcC9VtWjlC8SEW46tT/R4aEAXHaS60P4kcvHMG1YKgs3FnCgrK7Tffzo5fVc8L8rKattYsWOIu58eT1bD1Zy7XOrEREq65tZcIx5EN7ekM/aveX8/nujSY2P9EzjlN+w8ohgEpBnjNltjGkCXgVmttnmNuApY0w5gDHGPyd2VeoEJMVEsOJn03j+pmxmTXAFQXxUOI9ePhaH03D6YytYvuNwu69dvbuUZbmH2Xm4hnMe/4Sb569lyZZCLn7iM8JDQ3j7jlOZPCCZeav20NTiPPq6hmYHX+4qZcXXrl+5PcW1RIWHMHN8H+sbrHyOlUHQFzjQ6nm+e1lrQ4GhIvK5iKwWkRnt7UhE5ohIjojkFBf77mBdSh2v3olRnD087b+W/cQ9O9jjy3ZizHdP/9Q3ObjthRyykmN4bc4UesREHF0XHxnGX68aT/+UWG6fNohDVQ3MXbkLp9Pw1e5SpjzyH2Y/t5qb/7WWdzbkk19eT0aPGER0EppgZPfENGHAEGAakAGsFJExxpiK1hsZY+YCcwGys7N9c6QupSxwz/QhpMRH8Kt3trJ2bzmTBiQfXffFrhKqG1p4+rqTmTywJ//+wSTueXUDv5s5mtF9E49ud+aQVAalxvKXZTtZsOYAByvq6RkbwV+vGsezn+7ibx99Q0xEKH19fLIcZR0rjwgOApmtnme4l7WWDywyxjQbY/YAO3EFg1LK7fKTMkiMDufOV9bzy3e28OFW193Iy3cUERsRejQcMpNjePtHp30nBABCQoT5N0/i/hnDOei+YWzG6N5cfnIGPztvGPvL6thxqNrnZ01T1rEyCNYCQ0RkgIhEANcAi9ps8y6uowFEJAXXqaLdFtaklN+Jjgjl71ePp7i6kVe+2s/tL62nxeFk+Y4ipg5JITIs9Jj7yEyO4Y5pg1jzy3OYOjiFm08bAMC5I9O40t0vofMPBC/LgsAY0wLcBSwFtgOvG2O2icjvRORS92ZLgVIRyQVWAD83xvjm2LxK2eis4b24bnLW0edPLM+jsLKBc4Z3bzrJXglRvHTr5KMf+iLCY7PG8vaPTv3O/lVwkbYdUL4uOzvb5OTYNyiXUnZpdjg5VNnAjc+vYbf7ruM1vzqHXvE6HIQ6NhFZZ4zJbm+d3lmslJ8IDw0hMzmGn5zrupJobEaihoDyCLuvGlJKddNFY9JZuu0QZ+lQEMpDNAiU8jMhIcKT155sdxkqgOipIaWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSCnQaCUUkFOg0AppYKcBoFSSgU5vxtrSESKgX3up4lAZTuPU4ASD7xd632eyLYdrWtvedtlHbXRl9vb0frutrft8yOPPdXejmo6nu26+jM+nvaC93/Gnmpve8v88f90ILS3nzEmtd01xhi//QLmdvA4x9P7P5FtO1rX3vK2yzppo8+2t6ttO1Z7O2qzp9rbnTYfT3u72z5f+hl7qr3dbKPP/p8OxPa2/vL3U0PvdfDYiv2fyLYdrWtvedtlHbXRl9vb0frutrftczvbfDztbW95sLW3vWX++H86ENt7lN+dGuoKEckxHQy3Goi0vYEv2Nqs7fUufz8i6MhcuwvwMm1v4Au2Nmt7vSggjwiUUkp1XaAeESillOoiDQKllApyGgRKKRXkNAiUUirIBV0QiMjpIvKsiMwTkS/srsdqIhIiIn8UkSdE5Ea767GaiEwTkVXun/E0u+vxBhGJFZEcEbnY7lq8QURGuH++b4rIHXbXYzUR+Z6IPCcir4nIeVa8h18FgYg8LyJFIrK1zfIZIvK1iOSJyC8624cxZpUx5nZgMfBvK+s9UZ5oLzATyACagXyravUED7XXADVAFMHRXoD7gdetqdKzPPQ7vN39O3wVcJqV9Z4oD7X3XWPMbcDtwNWW1OlPl4+KyBm4fslfMMaMdi8LBXYC5+L6xV8LzAZCgUfa7OIHxpgi9+teB24xxlR7qfxu80R73V/lxph/isibxphZ3qq/uzzU3hJjjFNE0oC/GmOu81b93eWh9o4DeuIKvhJjzGLvVH98PPU7LCKXAncALxpjXvFW/d3l4c+sx4GXjTHrPV1nmKd3aCVjzEoR6d9m8SQgzxizG0BEXgVmGmMeAdo9VBaRLKDSl0MAPNNeEckHmtxPHdZVe+I89fN1KwcirajTUzz0850GxAIjgXoRed8Y47Sy7hPhqZ+xMWYRsEhElgA+GwQe+hkL8CjwgRUhAH4WBB3oCxxo9TwfmHyM19wC/MuyiqzV3fa+DTwhIqcDK60szCLdaq+IXA6cDyQBT1pbmiW61V5jzK8AROQm3EdDllZnje7+jKcBl+MK+vctrcwa3f0d/jEwHUgUkcHGmGc9XVAgBEG3GWMesrsGbzHG1OEKvqBgjHkbV/gFFWPMfLtr8BZjzCfAJzaX4TXGmH8A/7DyPfyqs7gDB4HMVs8z3MsClbZX2xtogq3NPtfeQAiCtcAQERkgIhHANcAim2uykrZX2xtogq3NPtdevwoCEVkAfAkME5F8EbnFGNMC3AUsBbYDrxtjttlZp6doe7W9BFB7Ifja7C/t9avLR5VSSnmeXx0RKKWU8jwNAqWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSCnQaAsJyI1XniPS7s4ZLMn33OaiJx6HK87SUT+z/34JhHxiTGRRKR/2+GS29kmVUQ+9FZNyjs0CJTfcA/f2y5jzCJjzKMWvGdn43FNA7odBMAvsXjsGKsYY4qBQhHx6XkAVPdoECivEpGfi8haEdksIr9ttfxdEVknIttEZE6r5TUi8riIbAJOEZG9IvJbEVkvIltEZLh7u6N/WYvIfBH5h4h8ISK7RWSWe3mIiDwtIjtE5CMRef/IujY1fiIifxeRHOAeEblERL4SkQ0i8rGIpLmHFr4d+ImIbBTXzHepIvKWu31r2/uwFJF4YKwxZlM76/qLyHL39+Y/7uHSEZFBIrLa3d4/tHeEJa5ZypaIyCYR2SoiV7uXT3R/HzaJyBoRiXe/zyr393B9e0c1IhIqIn9u9bP6YavV7wI+O8+DOg7GGP3SL0u/gBr3v+cBcwHB9UfIYuAM97pk97/RwFagp/u5Aa5qta+9wI/dj38EzHM/vgl40v14PvCG+z1G4hr7HWAWrmGLQ4DeuOYsmNVOvZ8AT7d63oNv78K/FXjc/fhh4GettnsFmOp+nAVsb2ffZwFvtXreuu73gBvdj38AvOt+vBiY7X58+5HvZ5v9XgE81+p5IhAB7AYmupcl4BpxOAaIci8bAuS4H/cHtrofzwF+7X4cCeQAA9zP+wJb7P5/pV+e+wrKYaiVbc5zf21wP4/D9UG0ErhbRC5zL890Ly/FNZnOW232c2SY6XW4xqVvz7vGNTZ/rrhmKwOYCrzhXn5IRFZ0UutrrR5nAK+JSDquD9c9HbxmOjBSRI48TxCROGNM67/g04HiDl5/Sqv2vAg81mr599yPXwH+0s5rtwCPi8ifgMXGmFUiMgYoNMasBTDGVIHr6AF4UkTG4/r+Dm1nf+cBY1sdMSXi+pnsAYqAPh20QfkhDQLlTQI8Yoz553cWuiYamQ6cYoypE5FPcE29CNBgjGk7s1qj+18HHf8fbmz1WDrYpjO1rR4/gWvay0XuWh/u4DUhwBRjTEMn+63n27Z5jDFmp4icDFwI/EFE/gO808HmPwEO45rmMgRor17BdeS1tJ11UbjaoQKE9hEob1oK/EBE4gBEpK+I9ML112a5OwSGA1Msev/PgSvcfQVpuDp7uyKRb8eLv7HV8mogvtXzZbhmkwLA/Rd3W9uBwR28zxe4hiQG1zn4Ve7Hq3Gd+qHV+u8QkT5AnTHmJeDPwMnA10C6iEx0bxPv7vxOxHWk4ARuwDVXbltLgTtEJNz92qHuIwlwHUF0enWR8i8aBMprjDHLcJ3a+FJEtgBv4vog/RAIE5HtuOZmXW1RCW/hmhYwF3gJWA9UduF1DwNviMg6oKTV8veAy450FgN3A9nuztVcXOfzv8MYswPXlIPxbdfhCpGbRWQzrg/oe9zL7wV+6l4+uIOaxwBrRGQj8BDwB2NME3A1rqlKNwEf4fpr/mngRvey4Xz36OeIebi+T+vdl5T+k2+Pvs4ClrTzGuWndBhqFVSOnLMXkZ7AGuA0Y8whL9fwE6DaGDOvi9vHAPXGGCMi1+DqOJ5paZGd17MS12Tr5XbVoDxL+whUsFksIkm4On1/7+0QcHsGuLIb20/A1bkrQAWuK4psISKpuPpLNAQCiB4RKKVUkNM+AqWUCnIaBEopFeQ0CJRSKshpECilVJDTIFBKqSD3/wGPjKpB0E20pAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True, max_epochs=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Train Model\n", "\n", "Train using a [1cycle learning rate schedule](https://arxiv.org/pdf/1803.09820.pdf)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 8e-05...\n", "Train for 377 steps, validate for 251 steps\n", "Epoch 1/4\n", "377/377 [==============================] - 89s 236ms/step - loss: 0.5214 - accuracy: 0.8285 - val_loss: 0.2847 - val_accuracy: 0.9081\n", "Epoch 2/4\n", "377/377 [==============================] - 80s 213ms/step - loss: 0.1524 - accuracy: 0.9513 - val_loss: 0.5775 - val_accuracy: 0.8309\n", "Epoch 3/4\n", "377/377 [==============================] - 81s 215ms/step - loss: 0.1066 - accuracy: 0.9739 - val_loss: 0.2469 - val_accuracy: 0.9387\n", "Epoch 4/4\n", "377/377 [==============================] - 81s 215ms/step - loss: 0.0318 - accuracy: 0.9907 - val_loss: 0.1645 - val_accuracy: 0.9561\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(8e-5, 4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Evaluate/Inspect Model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " alt.atheism 0.92 0.93 0.93 319\n", " comp.graphics 0.97 0.97 0.97 389\n", " sci.med 0.97 0.95 0.96 396\n", "soc.religion.christian 0.96 0.96 0.96 398\n", "\n", " accuracy 0.96 1502\n", " macro avg 0.95 0.96 0.95 1502\n", " weighted avg 0.96 0.96 0.96 1502\n", "\n" ] }, { "data": { "text/plain": [ "array([[298, 2, 8, 11],\n", " [ 7, 378, 3, 1],\n", " [ 5, 8, 378, 5],\n", " [ 15, 0, 1, 382]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=t.get_classes())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------\n", "id:521 | loss:7.12 | true:sci.med | pred:comp.graphics)\n", "\n" ] } ], "source": [ "# the one we got most wrong\n", "learner.view_top_losses(n=1, preproc=t)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From: jim.zisfein@factory.com (Jim Zisfein) \n", "Subject: Data of skull\n", "Distribution: world\n", "Organization: Invention Factory's BBS - New York City, NY - 212-274-8298v.32bis\n", "Reply-To: jim.zisfein@factory.com (Jim Zisfein) \n", "Lines: 11\n", "\n", "GT> From: gary@concave.cs.wits.ac.za (Gary Taylor)\n", "GT> Hi, We are trying to develop a image reconstruction simulation for the skull\n", "\n", "You could do high resolution CT (computed tomographic) scanning of\n", "the skull. Many CT scanners have an algorithm to do 3-D\n", "reconstructions in any plane you want. If you did reconstructions\n", "every 2 degrees or so in all planes, you could use the resultant\n", "images to create user-controlled animation.\n", "---\n", " . SLMR 2.1 . E-mail: jim.zisfein@factory.com (Jim Zisfein)\n", " \n", "\n" ] } ], "source": [ "# understandable mistake - this sci.med post talks a lot about computer graphics\n", "print(x_test[521])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 5: Make Predictions on New Data in Deployment" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc=t)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'soc.religion.christian'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=soc.religion.christian\n", " \n", "\n", "\n", " \n", " (probability 0.998, score 7.287)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +7.336\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.049\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " jesus christ is the central figure of christianity.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.explain('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_20newsgroup_predictor')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "reloaded_predictor = ktrain.load_predictor('/tmp/my_20newsgroup_predictor')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'soc.religion.christian'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([8.9553175e-03, 3.1522836e-04, 3.8172584e-04, 9.9034774e-01],\n", " dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.predict_proba('Jesus Christ is the central figure of Christianity.')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reloaded_predictor.get_classes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional Tips and Tricks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have a **transformers** model that has already been trained/fine-tuned, you can easily wrap it into a **ktrain** `Predictor`. The example below loads the pre-fine-tuned [coronabert model](https://huggingface.co/jakelever/coronabert) into **ktrain** to make predictions:\n", "\n", "```python\n", "# Import ktrain along with a couple things from transformers\n", "import ktrain\n", "from transformers import TFAutoModelForSequenceClassification\n", "\n", "# Load the model and compile it for ktrain/tf.Keras\n", "model = TFAutoModelForSequenceClassification.from_pretrained(\"jakelever/coronabert\")\n", "model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])\n", "\n", "# Pull the categories out of the model (or set class_names manually)\n", "class_names = list(model.config.id2label.values())\n", "\n", "# Set up a ktrain preprocessor (which manages the tokenization) with the labels\n", "preproc = ktrain.text.Transformer('jakelever/coronabert',maxlen=500,class_names=class_names)\n", "preproc.preprocess_train_called = True # needed to suppress warnings about not calling preprocess_train\n", "\n", "# Get the predictor (which takes in the model and tokenizer info)\n", "predictor = ktrain.get_predictor(model, preproc)\n", "\n", "# Make predictions\n", "text = [\"A genomic region associated with protection against severe COVID-19 is inherited from Neandertals.\"]\n", "predictor.predict(text)\n", "\n", "# OUTPUT:\n", "# [[('Clinical Reports', 0.0003284997),\n", "# ('Comment/Editorial', 0.0022700194),\n", "# ('Communication', 0.00060458254),\n", "# ('Contact Tracing', 0.00027690193),\n", "# ('Diagnostics', 0.0003987006),\n", "# ('Drug Targets', 0.0008852846),\n", "# ('Education', 0.00018228142),\n", "# ('Effect on Medical Specialties', 0.00045943243),\n", "# ('Forecasting & Modelling', 0.00047854715),\n", "# ('Health Policy', 0.00042494797),\n", "# ('Healthcare Workers', 6.292213e-05),\n", "# ('Imaging', 0.00021008229),\n", "# ('Immunology', 0.00072542584),\n", "# ('Inequality', 0.0007106358),\n", "# ('Infection Reports', 0.00033797201),\n", "# ('Long Haul', 0.00034338655),\n", "# ('Medical Devices', 0.0002488097),\n", "# ('Meta-analysis', 0.00030506376),\n", "# ('Misinformation', 0.0012771417),\n", "# ('Model Systems & Tools', 0.0020338537),\n", "# ('Molecular Biology', 0.9950799),\n", "# ('News', 0.00034808667),\n", "# ('Non-human', 0.98562455),\n", "# ('Non-medical', 0.0005655724),\n", "# ('Pediatrics', 0.00042545484),\n", "# ('Prevalence', 0.0011711525),\n", "# ('Prevention', 0.00043099752),\n", "# ('Psychology', 0.00045698017),\n", "# ('Recommendations', 0.0004172316),\n", "# ('Review', 0.002200645),\n", "# ('Risk Factors', 0.00014382145),\n", "# ('Surveillance', 0.00081551325),\n", "# ('Therapeutics', 0.0010580326),\n", "# ('Transmission', 0.0031670583),\n", "# ('Vaccines', 0.0011023124)]]\n", "\n", "```\n", "\n", "Finally, to make predictions with a smaller deployment footprint, you can export the model to ONNX format as described in [this example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/ktrain-ONNX-TFLite-examples.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }