{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import ktrain\n", "from ktrain import text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-Label Text Classification: Identifying Toxic Online Comments\n", "\n", "Here, we will classify Wikipedia comments into one or more categories of so-called *toxic comments*. Categories of toxic online behavior include toxic, severe_toxic, obscene, threat, insult, and identity_hate. The dataset can be downloaded from the [Kaggle Toxic Comment Classification Challenge](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data) as a CSV file (i.e., download the file ```train.csv```). We will load the data using the ```texts_from_csv``` method, which assumes the label_columns are already one-hot-encoded in the spreadsheet. Since *val_filepath* is None, 10% of the data will automatically be used as a validation set.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Word Counts: 197516\n", "Nrows: 143613\n", "143613 train sequences\n", "Average train sequence length: 66\n", "x_train shape: (143613,150)\n", "y_train shape: (143613,6)\n", "15958 test sequences\n", "Average test sequence length: 66\n", "x_test shape: (15958,150)\n", "y_test shape: (15958,6)\n" ] } ], "source": [ "DATA_PATH = 'data/toxic-comments/train.csv'\n", "NUM_WORDS = 50000\n", "MAXLEN = 150\n", "(x_train, y_train), (x_test, y_test), preproc = text.texts_from_csv(DATA_PATH,\n", " 'comment_text',\n", " label_columns = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"],\n", " val_filepath=None, # if None, 10% of data will be used for validation\n", " max_features=NUM_WORDS, maxlen=MAXLEN,\n", " ngram_range=1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fasttext: a fastText-like model (http://arxiv.org/pdf/1607.01759.pdf)\n", "logreg: logistic regression using a trainable Embedding layer\n", "nbsvm: NBSVM model (http://www.aclweb.org/anthology/P12-2018)\n", "bigru: Bidirectional GRU with pretrained word vectors\n", "bert: Bidirectional Encoder Representations from Transformers (https://arxiv.org/abs/1810.04805)\n" ] } ], "source": [ "text.print_text_classifiers()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We weill employ a Bidirectional GRU with pretrained word vectors. The following code cell loads a BIGRU model and defines a ```Learner``` object based on that model. The file ```crawl-300d-2M.vec ``` contains 2 million word vectors trained by Facebook and will be automatically downloaded for use with this model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? True\n", "compiling word ID features...\n", "max_features is 49325\n", "processing pretrained word vectors...\n", "done.\n" ] } ], "source": [ "model = text.text_classifier('bigru', (x_train, y_train), preproc=preproc)\n", "\n", "learner = ktrain.get_learner(model, train_data=(x_train, y_train), val_data=(x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As before, we use our learning rate finder to find a good learning rate. In this case, a learning rate of 0.0007 appears to be good." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/1\n", "100352/143613 [===================>..........] - ETA: 10:26 - loss: 0.3649 - acc: 0.7886\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEOCAYAAABmVAtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xl8XHW9//HXJ5N9adKk6Za2pCul0AJtWgplUxELKChwEVQuCFi9XtQfer0ielVwAcXlXgEXREFRRIQrt0Kh7AplawoUukK6QJvSNi1d0jZ7Pr8/ZhrSkLWdMyczeT8fjzw658x35nxOks4753zP+X7N3REREQFIC7sAERHpPxQKIiLSRqEgIiJtFAoiItJGoSAiIm0UCiIi0kahICIibRQKIiLSRqEgIiJtFAoiItImPewC+mrIkCFeXl4edhkiIkllyZIl29y9tKd2SRcK5eXlVFZWhl2GiEhSMbM3e9NOp49ERKSNQkFERNooFEREpI1CQURE2igURESkjUJBRETaJN0lqQervqmFJW/u4OjRReRnvbvbzS2tbNxRx6addYwsysGB7Iw0DCOSZuRnpUeXzcIrXkQkQQZMKNz8RBU3P1lFJM2YNKyAzbvqSDNjd30TTS3dz1NdkJ3OkPws0tOMjEgaGREjPZJGbmaE4YOyyc6IkGYwKCeDVndGDc6lvCSP3MwImelpFGSnU5CVgaXB9j2NDM7NYF9jC3samml1Z1hBNoPzMhP0nRAR6dqACYVPzh7Drromdtc3sWlnHXMmDKEgO52i3EzKS3IZWpDN9r2NpBnsa2whzYyW1lZqG5rZtLOOXXXNNLe00tTiNLe20tzi1NQ28Fr1LlpbnVaHPQ3NB11fXmaEUYNzGTskj0nDCxhVlEN+djqDczM5rCSXotwMcjPTaW2NBlhamo5cRCT+BkwojCjM4bsfPSrQbbg7La3Oq9W72FPfTHNrKw1Nreyub6K2vpnmVqc4N5Md+xoZlJNBQXY6hrFhxz7eemcfW3fX8/rWWh5evrnT9y/ISmdPYzMGDM7NZHBeJsV5mQzKTmfSsAIOH15AaUEW00YdeIpMRKS39MkRR2ZGesSYPmbwIb1PXWMLNbUN7G2MHqVsrW3gnb2NbNpZR1FuBobxzr5GduxtZPveRjbuqOPJ1TW0xI4izGBIfhal+VkcOXIQU0cVMm1UEYNzMyjKyaQwNyMeuysiKUih0A/lZEYYU5ILwBEjBvXqNfsam9s6zF/ZsJMtu+up3lnPE6u28tclGw9oW5iTwQcmD+XCWWOYPqaI9IguQhORKHPvvpO1v6moqHANiNd77s6mXfUs3bCTvQ3NbN/bSNXWPSxctpnahmayM9KYPa6EUyaVcszoIo4cWUhmukJCJNWY2RJ3r+ipnY4UUpyZUVaUQ1lRzgHrrzvnSB5dsYWX39rJoyu28NTqGiB6Oe60UUVUHDaYGbGvolxdGSUyUOhIQXB3Nu+u5+W3dlK5fgdL3nyHZZt2t/VRzCwfzAUVozlj6gh1YIskqd4eKSgUpFN1jS28smEnL6zbzv0vV7N++z4y09M4fcowzj56JO+bPJQM9UWIJI1+EQpmNhf4HyAC3ObuN3R4/mfA+2KLucBQdy/q7j0VConn7ix5cwfzl27igVff5p290RvwPnL0SC6oGM1RZYVhlygiPQg9FMwsArwOfBDYCCwGLnL3FV20/wJwrLtf1t37KhTC1dzSyj/fqOFvL2/ikeWbaWhuZVZ5MZ9/33hOmVSq4UBE+qn+0NE8C6hy97Wxgu4GzgE6DQXgIuDbAdYjcZAeSeP9k4fx/snD2LWvib8u2cBvn1nHpbcvZmpZIf/+vvGcPmW47rgWSVJBnhQuAza0W94YW/ceZnYYMBZ4IsB6JM4KczO44qRx/OOr7+OH502ltr6Jz/3xJc78+dM888a2sMsTkYPQX3oKLwTudfeWzp40s3lmVmlmlTU1NQkuTXqSmZ7Gx2eO4fGvnMr/XHgMexqa+dRvX+Czd1ay4Z19YZcnIn0QZChUA6PbLY+KrevMhcCfu3ojd7/V3SvcvaK0tDSOJUo8RdKMc44p47Evn8JXP3Q4/3i9hlNufJKfP/5G20B+ItK/BRkKi4GJZjbWzDKJfvDP79jIzCYDg4HnAqxFEig7I8K/v28Cj3/lVD48bSQ/ffR1Pn7rcyzdsDPs0kSkB4GFgrs3A1cCC4GVwD3uvtzMrjOzs9s1vRC425PthgnpUVlRDv9z4THceP40Xt+yh4/9YhHfmb+cusZOzxKKSD+gm9ckIXbXN/Hjhav5w3NvMro4h59ecAwzy4vDLktkwOjtJan9paNZUtyg7AyuO+co7p43mzQzLrz1eX751Br1NYj0MwoFSajZ40p44AsnMvfI4fzw4VV85g+V7NrXFHZZIhKjUJCEK8jO4OZPHMu1Zx/JP9+o4WO/WMSamj1hlyUiKBQkJGbGJSeUc9dnZrOrromP3ryI59ZsD7sskQFPoSChmllezPwvnMjwwmwuvf1F3QktEjKFgoSurCiHP8+bzdgheVz++8U8uWpr2CWJDFgKBekXhuRncddnZjNpWAHz7qxk4fLNYZckMiApFKTfKM7L5I9XHMeRIwu58q6XeHjZ22GXJDLgKBSkXynMyeD3l81ialkhV971Mouq1McgkkgKBel3CnMyuP3TsxhXmsfn/riE1Ztrwy5JZMBQKEi/tD8YcjIifPr2F9m8qz7skkQGBIWC9FtlRTnc/umZ7K5v5tLbX2R3ve58FgmaQkH6tSNHFvKrT82gause/uOepbRorCSRQCkUpN87ceIQvn7mETyyYgvfe7CrKb5FJB7Swy5ApDcuP3Es1Tvq+N2idUwtK+Tc6aPCLkkkJelIQZLGNWdOZmb5YK79+wq27WkIuxyRlKRQkKSRHknj+nOnsq+xmavve41kmyBKJBkoFCSpTBhawH9+aDKPrdzC/a9Uh12OSMpRKEjS+fSccmaVF3P1fa/x6sadYZcjklIUCpJ00iNp3PLJ6RRkZ/D5P71Ere5fEImbQEPBzOaa2WozqzKzq7toc4GZrTCz5WZ2V5D1SOooLcji1xdPp3pnHTc9URV2OSIpI7BQMLMIcAtwBjAFuMjMpnRoMxH4OjDH3Y8E/l9Q9UjqmXFYMRfMGM3vnlnHqs27wy5HJCUEeaQwC6hy97Xu3gjcDZzToc1ngFvcfQeAu2t2FemTr50xmaLcDK76y1Iam1vDLkck6QUZCmXAhnbLG2Pr2psETDKzRWb2vJnNDbAeSUHFeZnccO40Vr69m9ueWRt2OSJJL+yO5nRgInAqcBHwGzMr6tjIzOaZWaWZVdbU1CS4ROnvTpsyjA8dOYybHq/i7V11YZcjktSCDIVqYHS75VGxde1tBOa7e5O7rwNeJxoSB3D3W929wt0rSktLAytYktc3z5pCqzvfe3Bl2KWIJLUgQ2ExMNHMxppZJnAhML9Dm/uJHiVgZkOInk7SOQDps9HFufzbqeN58NW3WbpB9y6IHKzAQsHdm4ErgYXASuAed19uZteZ2dmxZguB7Wa2AngS+Kq7bw+qJkltV5w0jsG5GfzssdfDLkUkaQU6Sqq7LwAWdFj3rXaPHfhy7EvkkORnpfPZU8Zzw0OrWPLmDmYcNjjskkSSTtgdzSJxdfHswxicm8EtT+qGNpGDoVCQlJKXlc5lc8byxKqtLKveFXY5IklHoSAp519PKKcgK11HCyIHQaEgKacwJ4NLTijn4eWbeWNLbdjliCQVhYKkpMtOHEt2eoRfPLUm7FJEkopCQVJScV4mnzxuDPOXbuLN7XvDLkckaSgUJGV95uRxRMz4pY4WRHpNoSApa9igbM6bUcbfXq7WRDwivaRQkJR2/oxRNDS38tBrm8MuRSQpKBQkpU0fM5jJwwu45akqmls034JITxQKktLMjKs+OIk3t+/j4eU6WhDpiUJBUt4HjxhGeUkutz29LuxSRPo9hYKkvLQ049NzxvLKhp0sefOdsMsR6dcUCjIgnD9jFIOy03W0INIDhYIMCHlZ6Vx8/GE8vHwzb23fF3Y5Iv2WQkEGjItnl5Nmxl8q3wq7FJF+S6EgA8bwwmxOGF/C/KWbiM7vJCIdKRRkQDn76JFseKeOl97SPM4inVEoyIAy96jh5GREuHfJxrBLEemXAg0FM5trZqvNrMrMru7k+UvNrMbMXol9XRFkPSIF2RmccdRwHnh1Ew3NLWGXI9LvBBYKZhYBbgHOAKYAF5nZlE6a/sXdj4l93RZUPSL7feTokdTWN7OoalvYpYj0O0EeKcwCqtx9rbs3AncD5wS4PZFemTNhCIOy03nwVQ17IdJRkKFQBmxot7wxtq6j88zsVTO718xGB1iPCACZ6Wl8cMpwHl2xmSYNkidygLA7mv8OlLv7NOBR4PedNTKzeWZWaWaVNTU1CS1QUtPpRw5jd30zL67TsBci7QUZCtVA+7/8R8XWtXH37e7eEFu8DZjR2Ru5+63uXuHuFaWlpYEUKwPLyRNLycmI8PAynUISaS/IUFgMTDSzsWaWCVwIzG/fwMxGtFs8G1gZYD0ibXIyI5x6eCkPL99MS6tuZBPZL7BQcPdm4EpgIdEP+3vcfbmZXWdmZ8eafdHMlpvZUuCLwKVB1SPS0RlTR1BT28CSN3eEXYpIv5Ee5Ju7+wJgQYd132r3+OvA14OsQaQr7588lKz0NOYvrWbW2OKwyxHpF8LuaBYJTX5WOmdOHcH/vaIb2UT2UyjIgPbhaSOorW/mhbW6CkkEFAoywM2ZMIScjAiPrdwSdiki/YJCQQa07IwIJ00cwmMrtmg4bREUCiKcNmUYm3bVs+Lt3WGXIhI6hYIMeO+fPBQzeHSFTiGJKBRkwBuSn8Wxo4t4crWGUBFRKIgQ7XBeVr2L2vqmsEsRCZVCQQSYPa6EllanUnc3ywCnUBABZhw2mIyI8fza7WGXIhIqhYII0UtTjx5VpJvYZMBTKIjEHDeumNeqd7G3oTnsUkRCo1AQidnfr6BTSDKQKRREYmaNLSYvM8KTq7eGXYpIaBQKIjFZ6RFmji3mefUryACmUBBpZ/a4Eqq27mFrbX3YpYiEQqEg0s7x40oAqFyv+xVkYOpVKJjZl8xskEX91sxeMrPTgy5OJNEOH15AJM1YsUmD48nA1NsjhcvcfTdwOjAYuBi4IbCqREKSnRHhyJGDeHG9+hVkYOptKFjs3zOBO919ebt1Iinl+HElvPzWDuoaNUWnDDy9DYUlZvYI0VBYaGYFQGtPLzKzuWa22syqzOzqbtqdZ2ZuZhW9rEckMMePL6GpxVmicZBkAOptKFwOXA3MdPd9QAbw6e5eYGYR4BbgDGAKcJGZTemkXQHwJeCFPtQtEpiZ5cWkpxnPrtkWdikiCdfbUDgeWO3uO83sU8A3gV09vGYWUOXua929EbgbOKeTdt8FfgjoGkDpF/Ky0jl6dBHP6c5mGYB6Gwq/BPaZ2dHAV4A1wB96eE0ZsKHd8sbYujZmNh0Y7e4P9rIOkYQ4flwJr27cxR6NgyQDTG9Dodmjs5qfA9zs7rcABYeyYTNLA35KNGR6ajvPzCrNrLKmRrNjSfBOGB8dB2nxOl2FJANLb0Oh1sy+TvRS1AdjH+gZPbymGhjdbnlUbN1+BcBRwFNmth6YDczvrLPZ3W919wp3rygtLe1lySIHb3psfoXFujRVBpjehsLHgQai9ytsJvoBf2MPr1kMTDSzsWaWCVwIzN//pLvvcvch7l7u7uXA88DZ7l7Z150QibfsjAjlJXm8sXVP2KWIJFSvQiEWBH8CCs3sw0C9u3fbp+DuzcCVwEJgJXCPuy83s+vM7OxDrFskcBOG5rNGoSADTHpvGpnZBUSPDJ4ietPaTWb2VXe/t7vXufsCYEGHdd/qou2pvalFJFEmDM3nkRVbaGxuJTNdw4TJwNCrUAC+QfQeha0AZlYKPAZ0GwoiyWzC0HxaWp312/cyadghXVchkjR6++dP2v5AiNneh9eKJKXxpfkAVOkUkgwgvT1SeNjMFgJ/ji1/nA6nhURSzfjSfMwUCjKw9CoU3P2rZnYeMCe26lZ3/1twZYmELyczQllRjkJBBpTeHing7vcB9wVYi0i/M2FovkJBBpRuQ8HMagHv7CnA3X1QIFWJ9BMTSvN5fu12WludtDSNFi+pr9tQcHddciED2oSh+dQ3tVK9s47RxblhlyMSOF1BJNKNCUOjVyC9sbU25EpEEkOhINKN/aGgfgUZKBQKIt0oys1kSH6WQkEGDIWCSA8mDs3XwHgyYCgURHowcVg+VVv2EJ1SRCS1KRREejBxaD61Dc1srW0IuxSRwCkURHowPtbZ/PoWXYEkqU+hINKDI0cUAvDqxl0hVyISPIWCSA8KczMYNiiLtTV7wy5FJHAKBZFeGDskj7XbdAWSpD6FgkgvjCvNZ902HSlI6lMoiPTCuCF57NzXxDt7G8MuRSRQgYaCmc01s9VmVmVmV3fy/OfM7DUze8XMnjGzKUHWI3KwNAubDBSBhYKZRYBbgDOAKcBFnXzo3+XuU939GOBHwE+DqkfkUEweER0weNXm3SFXIhKsII8UZgFV7r7W3RuBu4Fz2jdw9/b/w/LofO4GkdANH5RNcV4mKzYpFCS1BRkKZcCGdssbY+sOYGb/bmZriB4pfDHAekQOmplRVpTD3Ys39NxYJImF3tHs7re4+3jga8A3O2tjZvPMrNLMKmtqahJboEhMeiQ681ptfVPIlYgEJ8hQqAZGt1seFVvXlbuBj3b2hLvf6u4V7l5RWloaxxJFeu/yE8cCsPJtDXchqSvIUFgMTDSzsWaWCVwIzG/fwMwmtls8C3gjwHpEDknFYcWAOpsltXU7R/OhcPdmM7sSWAhEgN+5+3Izuw6odPf5wJVmdhrQBOwALgmqHpFDNWxQFlnpaWx4Z1/YpYgEJrBQAHD3BcCCDuu+1e7xl4Lcvkg8mRmji3NZt02hIKkr9I5mkWQyaVg+a2t0A5ukLoWCSB+UFeWwcWcdra26pUZSk0JBpA9GDc6lsbmVbXs1C5ukJoWCSB+MKckFYL36FSRFKRRE+uCI4YMAWPm2LkuV1KRQEOmDYYOyGJyboTGQJGUpFET6wMzYsa+Jv1RuYLeGu5AUpFAQOUgLl20OuwSRuFMoiPTRsms/BMBburNZUpBCQaSP8rOiAwHc9ERVyJWIxJ9CQeQgjCmOXpq6dMPOkCsRiS+FgshB+MUnpwPw++fWh1qHSLwpFEQOwlFlhaQZpJmFXYpIXCkURA7SzPJillXvCrsMkbhSKIgcpKNHF7F2215aNDiepBCFgshBGpybSWNzK4uqtoVdikjcKBREDtJx46LTc67QOEiSQhQKIgfpmFFFZKWnsa1Ww2hL6lAoiByktDSjvCSP9dv3hl2KSNwoFEQOwaThBbxWvQt3dTZLagg0FMxsrpmtNrMqM7u6k+e/bGYrzOxVM3vczA4Lsh6ReJs9rpgtuxtYt01HC5IaAgsFM4sAtwBnAFOAi8xsSodmLwMV7j4NuBf4UVD1iAThhPFDAHhu7faQKxGJjyCPFGYBVe6+1t0bgbuBc9o3cPcn3X3/UJPPA6MCrEck7spLchlRmM2zVQoFSQ1BhkIZsKHd8sbYuq5cDjwUYD0icWdmnDhhCE+/UaOb2CQl9IuOZjP7FFAB3NjF8/PMrNLMKmtqahJbnEgPZo8rYXd9M+u27Qm7FJFDFmQoVAOj2y2Piq07gJmdBnwDONvdO73g291vdfcKd68oLS0NpFiRgzVl5CAAXt2ocZAk+QUZCouBiWY21swygQuB+e0bmNmxwK+JBsLWAGsRCcykYQUU5mTw3Br1K0jyCywU3L0ZuBJYCKwE7nH35WZ2nZmdHWt2I5AP/NXMXjGz+V28nUi/FUkzTpw4hKder9H9ChKIllbnqdVbE9JvlR7km7v7AmBBh3Xfavf4tCC3L5Iop0wq5cFX32ZNzR4mDC0IuxxJMY+v3MK8O5fwjTOP4DMnjwt0W/2io1kk2U0tKwRgxdu1IVciqWjV5ujv1f5BGIOkUBCJg/Gl+WREjOWb1Nks8ffTR18H3v3jI0gKBZE4yExP49gxg/n1P9Zy35KNbHhnX88vEumlzEgaeZkRLAHTvwbapyAykORlRgD4yl+XArDmB2cSSdMcznJomlpacZxLThibkO3pSEEkTv5z7uQDlu9dsoFr/76c8qsfpKG5JaSqJNnd8NAqmlqccaX5CdmeQkEkTo4YMYj3Tx7atvy1+17j9kXrAbjsjsUhVSXJzN357TPrAA743QqSQkEkjn536UzW33DWe04bLarazjNvaC5n6Zt9je8eYRbnZSZkm+pTEAnAS9/8IIvWbOPlt3YwtCCb7y9YyRV/WMxJE0v52tzDdS+D9Mo7exsB+NF50xK2TYWCSAAKczM4c+oIzpw6AoA3ttZyT+VGHl2xhaqte3jsy6cA0OpORkQH7NK5nfuaACjKzUjYNhUKIgnwLxWjuadyIwDrtu1l/DULuKBiFPdUbmR8aR5raqIzty26+v2UFeWEWar0I5t21QEwOEGnjkB9CiIJMbO8mAe+cCILvnhS27r9IbE/EADm3PAE6zW1p8S8Fht5d1hBdsK2qVAQSZCjygqZMnIQa39wZtu6ScPee5nhf8TucxBJi12wMGpw4o4edfpIJMHS0owHvnAiP35kNb/85AyWvLmDGYcNJi0NDv/mwyzduJPmllbS1dcw4O1taCYvM9IWDomg3zqREBxVVsgdn55FTmaEEycOISczQlZ6hFnlxTS1OBO+8dB7hkl2d3751Bo27axj6Yad1DXqhrhUt6e+mfzsxP7triMFkX7k+vOm8oGf/AOA7z6wgrlHDeelt3Zw+6L11NRGJyb84cOr2trPPXI4N5w3lcKcjISMiyOJtaehmfysxH5MW7JNClJRUeGVlZVhlyESmOfWbOei3zzf59ddc+Zk5p08PoCKJCzlVz9IXmaE5dfNPeT3MrMl7l7RUzudPhLpZ44fX8K8TiZSmTOhhF99ajpnTR3B1zqMswTwgwWrKL/6QZ5ctVUzwKWA/eNlDSnISuh2daQg0k9d/NsXALjlk9Op3lHHESMGvadNS6tz/8vV3PHsel6rfncuh/Omj+InFxydsFolvlpanfHXRCet/PG/HM35M0Yd8nv29khBoSCSIrbsrufcXzxL9c66tnVLv3U6jlOUm7ibn+TQ/fP1Gv71dy8CMP/KOUwbVXTI79kvTh+Z2VwzW21mVWZ2dSfPn2xmL5lZs5mdH2QtIqlu2KBsFl39/gP+qjz6ukc45rpHeXXjTlpbXaeVksSdz78JQGlBVkJmW2svsFAwswhwC3AGMAW4yMymdGj2FnApcFdQdYgMNDeeP43Hv3LKAevOvnkR465ZwKwfPK5gSAKLqqIj6j7/9Q8k/KqyIK91mgVUuftaADO7GzgHWLG/gbuvjz3XGmAdIgOKmTG+NJ/1N5zFc2u2s2lnXdtscDW1DYz9evRc9c8+fjQfO/bQz1VLfG3dXc++xhauOm1SKDP3BRkKZcCGdssbgeMC3J6IdHD8+BIg2nG5q66J7y9Y2fbcVX9ZypjiXNZv20fNngYumzOWzHRdkBi2d/ZFh8ue2MkQKImQFDevmdk8YB7AmDFjQq5GJPlcMHM0AJ85eRy765uY9p1HADjvl8+1tbnhoVVMHJrPoJwMfntJBXVNLSyq2s72PQ1c/9AqjhtbzK8+NSOhI3YORFf8PnohTaJvWtsvyK1WA6PbLY+Kreszd78VuBWiVx8demkiA9eg7AzW33AWdz63nv/6v+UHPPfG1j0AHHPdo+953Qvr3uHY7767/vhxJfzxiuOIpBl7G5o575fPsmpzLSV5mbxwzQc0dtNB2rgjevXYsWMO/YqjgxFkKCwGJprZWKJhcCHwiQC3JyJ98KnZhxFJS+Ojx44kJyPCb59Zx/ceXNlp26/NnXzA8BoAz63d3nYtfXvb9zYy4RsPse76MzX0Rh+trYmG8rnHllGQnbiJddoLLBTcvdnMrgQWAhHgd+6+3MyuAyrdfb6ZzQT+BgwGPmJm17r7kUHVJCLvMjM+cdy7p2OvOGkcx48vYeyQPKZ8ayFD8rN44ZoPtHV2fvbkcZz586fJTE/jex89irNvXvSe93zgCyfy4ZueAeC+l6rZXdfEp2Yfpr6KHtQ3tTD5vx5uW07kUNkd6eY1ETkoexuaOfLbCzllUim/vngG2RkRAGrrm5ga67MAuPSEcr5ztv7W686zVdv4xG3RO9hL8jJ58IsnMbwwvhPr9PbmtaToaBaR/icvK531N5z1nvUF2RncevEM5t25BIA7nl3PsWOKOOeYskSXmBRaWp3fLVpHdkYaL//X6eRkRkKtR0cKIhKYt7bv4+Qbn2xbfuSqk5k4NH/A9jXsbWjmgVc3MWVEIZ+9s5JNu+rbnjtmdBH3//ucwLatIwURCd2Yklz+c+7h/Ojh1QCc/rN/AnD0qEJGFOZw/blTB8wlrm/vquP465/o8vnrzukfp9h0pCAigWtqaWXiNx56z/rDhxXwm3+tYExJbghVHZrP/2kJjyzfwhUnjeOSEw5jRGG0c3j/h39ZUQ7/UjGK/37sjU5fP2XEIMoG5/DhaSM4a+qIwC/h1SipItIvLavexYdveobyklzWb993wHPLr/0QeVnpvL6llkHZGd12ttY3tZCdEQl8Pmt3f8/prk076zjhhgP/6n/syyczYWgBs3/wOJt319OZ0cU5PPGVU8kI4R4OhYKI9HuX37GYx1dt7bbNFSeO5ZsfPnAszd8/u55vz4/eeGcGf/v8HKp31LHgtbe55qwjKCs6tEs6n1y9lX+srmHzrnoeW7mFz586nqs+OIntextpaG5lzg1dnwbab1B2Orvrm/nsKeM4YfwQVr69m3OnlzG0IL5XFfWWQkFEksKehmZufHgV971UzZ6G5i7bvXDNB/jvx97gzy++1ev3fvEbH6AkL4va+qZO55RoamklI5LGvsZm/rG6hn+8XsPfXq6mobnnMTpPO2IYv/nXGbjDeb96lpff2tn23DfPOoIrTnrv7HlhUiiISFKpa2zhV/9Yw6mHl3LM6CIp3WWSAAAJ6klEQVTe2dvIgmWb+a/7l3Xa/vsfO4qyohyWb9rNjQujHdn5WendBgvAoqvfzxtbarn09sU91lScl8ktn5je6ZzZHe/Yrt5Zx+V3LOarHzqckyaW9rsb9hQKIpISmltamX3942zb00hRbgY/Om8aH5wy7IAP5P2fY/vXtbY635q/jD8+3/ujCoDhg7KZflgRZ00dyVnTRhzwXF1jC46Tm5mcF20qFERkQHN3nly9lRljilm/fS+vbNjJ8MJsPhu7qe6Mo4bz/Y9N5caFqynNz+TLpx8ecsXBUiiIiEibfjFHs4iIJBeFgoiItFEoiIhIG4WCiIi0USiIiEgbhYKIiLRRKIiISBuFgoiItEm6m9fMrAZ4EygEdrV7qv1yZ4+HANsOcfMdt9nXNl09192+dFxOpn3raV1vHh/qvvVmv3pq15t9668/s+7a6fex631pv5yMv48d1xUCRe5e2uOW3T0pv4Bbu1ru7DFQGe9t9rVNV891ty/JvG89revl40Pat97sVzz2rb/+zLprp9/Hrvelw/4k3e9jT/vW3Vcynz76ezfLXT2O9zb72qar57rbl47LybRvPa3rL/vVU7ve7Ft//Zl1106/j93X//cu1h+KRP0+dlzX631IutNHB8vMKr0X434kI+1b8knV/QLtW7JL5iOFvro17AICpH1LPqm6X6B9S2oD5khBRER6NpCOFEREpAcKBRERaaNQEBGRNgoFwMxOMrNfmdltZvZs2PXEk5mlmdn3zewmM7sk7HriycxONbOnYz+7U8OuJ57MLM/MKs3sw2HXEk9mdkTs53Wvmf1b2PXEk5l91Mx+Y2Z/MbPTw67nYCV9KJjZ78xsq5kt67B+rpmtNrMqM7u6u/dw96fd/XPAA8Dvg6y3L+Kxb8A5wCigCdgYVK19Fad9c2APkE0/2bc47RfA14B7gqny4MTp/9rK2P+1C4A5QdbbF3Hat/vd/TPA54CPB1lvkJL+6iMzO5noB8Mf3P2o2LoI8DrwQaIfFouBi4AIcH2Ht7jM3bfGXncPcLm71yao/G7FY99iXzvc/ddmdq+7n5+o+rsTp33b5u6tZjYM+Km7fzJR9XclTvt1NFBCNOy2ufsDiam+e/H6v2ZmZwP/Btzp7nclqv7uxPlz5CfAn9z9pQSVH1fpYRdwqNz9n2ZW3mH1LKDK3dcCmNndwDnufj3Q6eG4mY0BdvWXQID47JuZbQQaY4stwVXbN/H6ucXsALKCqLOv4vQzOxXIA6YAdWa2wN1bg6y7N+L1M3P3+cB8M3sQ6BehEKefmwE3AA8layBACoRCF8qADe2WNwLH9fCay4HbA6sofvq6b/8L3GRmJwH/DLKwOOjTvpnZucCHgCLg5mBLOyR92i93/waAmV1K7Ggo0OoOTV9/ZqcC5xIN8QWBVnbo+vp/7QvAaUChmU1w918FWVxQUjUU+szdvx12DUFw931EAy/luPv/Eg29lOTud4RdQ7y5+1PAUyGXEQh3/znw87DrOFRJ39HchWpgdLvlUbF1qUD7lnxSdb9A+5ZyUjUUFgMTzWysmWUCFwLzQ64pXrRvySdV9wu0bykn6UPBzP4MPAccbmYbzexyd28GrgQWAiuBe9x9eZh1HgztW/LtW6ruF2jfSNJ966ukvyRVRETiJ+mPFEREJH4UCiIi0kahICIibRQKIiLSRqEgIiJtFAoiItJGoSCBM7M9CdjG2b0ckjqe2zzVzE44iNcda2a/jT2+1Mz6xbhNZlbecejoTtqUmtnDiapJEk+hIEkjNpRxp9x9vrvfEMA2uxsf7FSgz6EAXEOSjpHj7jXA22bWb+ZCkPhSKEhCmdlXzWyxmb1qZte2W3+/mS0xs+VmNq/d+j1m9hMzWwocb2brzexaM3vJzF4zs8mxdm1/cZvZHWb2czN71szWmtn5sfVpZvYLM1tlZo+a2YL9z3Wo8Skz+28zqwS+ZGYfMbMXzOxlM3vMzIbFhln+HHCVmb1i0dn7Ss3svtj+Le7sg9PMCoBp7r60k+fKzeyJ2Pfm8dhw7pjZeDN7Pra/3+vsyMuiM7U9aGZLzWyZmX08tn5m7Puw1MxeNLOC2Haejn0PX+rsaMfMImZ2Y7uf1WfbPX0/EPrcFRIQd9eXvgL9AvbE/j0duBUwon+QPACcHHuuOPZvDrAMKIktO3BBu/daD3wh9vjzwG2xx5cCN8ce3wH8NbaNKUTHxAc4n+hwzWnAcKLzMJzfSb1PAb9otzyYd+/+vwL4Sezxd4D/aNfuLuDE2OMxwMpO3vt9wH3tltvX/Xfgktjjy4D7Y48fAC6KPf7c/u9nh/c9D/hNu+VCIBNYC8yMrRtEdGTkXCA7tm4iUBl7XA4siz2eB3wz9jgLqATGxpbLgNfC/r3SVzBfGjpbEun02NfLseV8oh9K/wS+aGYfi60fHVu/nejEQPd1eJ/9w2UvITo2f2fu9+g8BCssOjMbwInAX2PrN5vZk93U+pd2j0cBfzGzEUQ/aNd18ZrTgClmtn95kJnlu3v7v+xHADVdvP74dvtzJ/Cjdus/Gnt8F/DjTl77GvATM/sh8IC7P21mU4G33X0xgLvvhuhRBXCzmR1D9Ps7qZP3Ox2Y1u5IqpDoz2QdsBUY2cU+SJJTKEgiGXC9u//6gJXRiVdOA453931m9hTRqSgB6t2944xxDbF/W+j6d7ih3WProk139rZ7fBPR6T7nx2r9ThevSQNmu3t9N+9bx7v7Fjfu/rqZTQfOBL5nZo8Df+ui+VXAFqLTfqYBndVrRI/IFnbyXDbR/ZAUpD4FSaSFwGVmlg9gZmVmNpToX6E7YoEwGZgd0PYXAefF+haGEe0o7o1C3h1H/5J262uBgnbLjxCdfQuA2F/iHa0EJnSxnWeJDs8M0XP2T8ceP0/09BDtnj+AmY0E9rn7H4EbgenAamCEmc2MtSmIdZwXEj2CaAUuJjrncEcLgX8zs4zYayfFjjAgemTR7VVKkrwUCpIw7v4I0dMfz5nZa8C9RD9UHwbSzWwl0Tlunw+ohPuITqm4Avgj8BKwqxev+w7wVzNbAmxrt/7vwMf2dzQDXwQqYh2zK4ie/z+Au68iOl1jQcfniAbKp83sVaIf1l+Krf9/wJdj6yd0UfNU4EUzewX4NvA9d28EPk50OtalwKNE/8r/BXBJbN1kDjwq2u82ot+nl2KXqf6ad4/K3gc82MlrJAVo6GwZUPaf4zezEuBFYI67b05wDVcBte5+Wy/b5wJ17u5mdiHRTudzAi2y+3r+SXQC+x1h1SDBUZ+CDDQPmFkR0Q7j7yY6EGJ+CfxLH9rPINoxbMBOolcmhcLMSon2rygQUpSOFEREpI36FEREpI1CQURE2igURESkjUJBRETaKBRERKSNQkFERNr8f+257ojQLAcWAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find()\n", "learner.lr_plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we will train our model for 8 epochs using ```autofit``` with a learning rate of 0.001 for two epochs." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# define a custom callback for ROC-AUC\n", "from tensorflow.keras.callbacks import Callback\n", "from sklearn.metrics import roc_auc_score\n", "class RocAucEvaluation(Callback):\n", " def __init__(self, validation_data=(), interval=1):\n", " super(Callback, self).__init__()\n", "\n", " self.interval = interval\n", " self.X_val, self.y_val = validation_data\n", "\n", " def on_epoch_end(self, epoch, logs={}):\n", " if epoch % self.interval == 0:\n", " y_pred = self.model.predict(self.X_val, verbose=0)\n", " score = roc_auc_score(self.y_val, y_pred)\n", " print(\"\\n ROC-AUC - epoch: %d - score: %.6f \\n\" % (epoch+1, score))\n", "RocAuc = RocAucEvaluation(validation_data=(x_test, y_test), interval=1)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.001...\n", "Train on 143613 samples, validate on 15958 samples\n", "Epoch 1/2\n", "143613/143613 [==============================] - 2157s 15ms/step - loss: 0.0668 - acc: 0.9787 - val_loss: 0.0410 - val_acc: 0.9843\n", "\n", " ROC-AUC - epoch: 1 - score: 0.986431 \n", "\n", "Epoch 2/2\n", "143613/143613 [==============================] - 2142s 15ms/step - loss: 0.0398 - acc: 0.9846 - val_loss: 0.0392 - val_acc: 0.9849\n", "\n", " ROC-AUC - epoch: 2 - score: 0.989871 \n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train\n", "learner.autofit(0.001, 2, callbacks=[RocAuc])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our final ROC-AUC score is **0.9899**." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }