{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification of Spam\n", "\n", "We'll now be exploring the field of Natural Language Processing (NLP), which concerns itself with interpreting, predicting, and classifying the written word. As a first foray into this field, we'll construct a simple spam classifier.\n", "\n", "Our goal in this project is to classify text messages as either spam or not spam (also playfully known as \"ham\" messages). We'll be using a collection of English SMS messages from [Kaggle](https://www.kaggle.com/uciml/sms-spam-collection-dataset/home) as our dataset. First thing to notice is that this dataset is NOT homogeneous! First off, the number of spam/ham messages is unbalanced: only (425+322)=747 spam messages, compared to (3375+450+1002)=4827 ham messages. The other thing is that this data doesn't all come from the same source:\n", "\n", "* [Grumbletext](www.grumbletext.co.uk): UK forum for people complaining about spam messages. (425 spam messages)\n", "* [NUS SMS Corpus](www.comp.nus.edu.sg/~rpnlpir/downloads/corpora/smsCorpus/): Dataset of legitimate messages collected by the Dept. of Computer Science at the National University of Singapore. Vast majority of messages are from students attending the university. (3375 ham messages)\n", "* [Caroline Tag's PhD thesis](http://etheses.bham.ac.uk/253/1/Tagg09PhD.pdf): Dataset collected by Caroline Tag during her doctoral studies on the linguistic aspects of texting. (450 ham messages)\n", "* [SMS Spam Corpus v.0.1 Big](http://www.esp.uem.es/jmgomez/smsspamcorpus/): Dataset (itself composed of other datasets) collected by researchers at the Universidad Europea de Madrid. (1002 ham and 322 spam messages)\n", "\n", "We'll need to take special care to make sure that the unbalanced nature of the dataset, and the various sources and nationalities of the senders of these messages does not affect the classification. As such, a simple classification accuracy measure will not suffice, as our classifier could still get 4827/(4827+747) = 87% accuracy by just specifying everything as ham and misclassifying every spam sample." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EDA and data preprocessing\n", "Ok, let's dive in. First, we'll download our dataset and extract the CSV file therein (this requires installing the [Kaggle API](https://github.com/Kaggle/kaggle-api)):" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: Your Kaggle API key is readable by otherusers on this system! To fix this, you can run'chmod 600 /home/ecotner/.kaggle/kaggle.json'\n", "Archive: sms-spam-collection-dataset.zip\n", " inflating: Data/spam.csv \n" ] } ], "source": [ "%%bash\n", "kaggle datasets download uciml/sms-spam-collection-dataset --quiet\n", "mkdir -p Data\n", "unzip sms-spam-collection-dataset.zip -d Data\n", "rm sms-spam-collection-dataset.zip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's upload this using pandas and take a look at the data." ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
v1v2Unnamed: 2Unnamed: 3Unnamed: 4
0hamGo until jurong point, crazy.. Available only ...NaNNaNNaN
1hamOk lar... Joking wif u oni...NaNNaNNaN
2spamFree entry in 2 a wkly comp to win FA Cup fina...NaNNaNNaN
3hamU dun say so early hor... U c already then say...NaNNaNNaN
4hamNah I don't think he goes to usf, he lives aro...NaNNaNNaN
5spamFreeMsg Hey there darling it's been 3 week's n...NaNNaNNaN
6hamEven my brother is not like to speak with me. ...NaNNaNNaN
7hamAs per your request 'Melle Melle (Oru Minnamin...NaNNaNNaN
8spamWINNER!! As a valued network customer you have...NaNNaNNaN
9spamHad your mobile 11 months or more? U R entitle...NaNNaNNaN
\n", "
" ], "text/plain": [ " v1 v2 Unnamed: 2 \\\n", "0 ham Go until jurong point, crazy.. Available only ... NaN \n", "1 ham Ok lar... Joking wif u oni... NaN \n", "2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n", "3 ham U dun say so early hor... U c already then say... NaN \n", "4 ham Nah I don't think he goes to usf, he lives aro... NaN \n", "5 spam FreeMsg Hey there darling it's been 3 week's n... NaN \n", "6 ham Even my brother is not like to speak with me. ... NaN \n", "7 ham As per your request 'Melle Melle (Oru Minnamin... NaN \n", "8 spam WINNER!! As a valued network customer you have... NaN \n", "9 spam Had your mobile 11 months or more? U R entitle... NaN \n", "\n", " Unnamed: 3 Unnamed: 4 \n", "0 NaN NaN \n", "1 NaN NaN \n", "2 NaN NaN \n", "3 NaN NaN \n", "4 NaN NaN \n", "5 NaN NaN \n", "6 NaN NaN \n", "7 NaN NaN \n", "8 NaN NaN \n", "9 NaN NaN " ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "raw_data = pd.read_csv(\"./Data/spam.csv\", encoding=\"latin\") # Need to use latin encoding since UTF throws error\n", "raw_data.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, looks like the class (spam/ham) is under column \"v1\", and the actual text of the message is in \"v2\". What's in the other 3 unnamed columns? Is there even anything in them? Also, you can already tell from this snippet that the use of texting slang is pervasive throughout this corpus. Let's look take a look at the \"unnamed\" columns." ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
v1v2Unnamed: 2Unnamed: 3Unnamed: 4
95spamYour free ringtone is waiting to be collected....PO Box 5249MK17 92H. 450Ppw 16\"NaN
281ham\\Wen u miss someonethe person is definitely special for u..... B...why to miss themjust Keep-in-touch\\\" gdeve..\"
444ham\\HEY HEY WERETHE MONKEESPEOPLE SAY WE MONKEYAR...HOWU DOIN? FOUNDURSELF A JOBYET SAUSAGE?LOVE ...NaNNaN
671spamSMS. ac sun0819 posts HELLO:\\You seem coolwanted to say hi. HI!!!\\\" Stop? Send STOP to ...NaNNaN
710hamHeight of Confidence: All the Aeronautics prof...this wont even start........ Datz confidence..\"NaNNaN
\n", "
" ], "text/plain": [ " v1 v2 \\\n", "95 spam Your free ringtone is waiting to be collected.... \n", "281 ham \\Wen u miss someone \n", "444 ham \\HEY HEY WERETHE MONKEESPEOPLE SAY WE MONKEYAR... \n", "671 spam SMS. ac sun0819 posts HELLO:\\You seem cool \n", "710 ham Height of Confidence: All the Aeronautics prof... \n", "\n", " Unnamed: 2 Unnamed: 3 \\\n", "95 PO Box 5249 MK17 92H. 450Ppw 16\" \n", "281 the person is definitely special for u..... B... why to miss them \n", "444 HOWU DOIN? FOUNDURSELF A JOBYET SAUSAGE?LOVE ... NaN \n", "671 wanted to say hi. HI!!!\\\" Stop? Send STOP to ... NaN \n", "710 this wont even start........ Datz confidence..\" NaN \n", "\n", " Unnamed: 4 \n", "95 NaN \n", "281 just Keep-in-touch\\\" gdeve..\" \n", "444 NaN \n", "671 NaN \n", "710 NaN " ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_data[pd.notna(raw_data[\"Unnamed: 2\"]) | pd.notna(raw_data[\"Unnamed: 3\"]) | pd.notna(raw_data[\"Unnamed: 3\"])].head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks like the extra columns contain other messages or metadata? Some of it may be relevant to the classification (e.g. it looks like spam messages have some kind of PO box), but it seems like only a very small fraction of the dataset (50/5572=0.9%) has these, so I'm just going to drop these extra columns." ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "raw_data = raw_data[[\"v1\", \"v2\"]].rename(columns={\"v1\": \"y\", \"v2\": \"msg\"})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's take a look at some summary statistics for the dataset - what is the class balance, number of messages, etc. We'll make use of the python package NLTK (Natural Language ToolKit) to simplify the analysis. First, we'll have to do some tokenization. I'm going to convert all messages to lowercase, then split tokens based on non-alphanumeric characters, discarding the punctuation itself. Also, I'll convert the \"spam\"/\"ham\" labels to binary identifiers (where \"spam\" --> 1)." ] }, { "cell_type": "code", "execution_count": 198, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ymsg
00[go, until, jurong, point, crazy, available, o...
10[ok, lar, joking, wif, u, oni]
21[free, entry, in, 2, a, wkly, comp, to, win, f...
30[u, dun, say, so, early, hor, u, c, already, t...
40[nah, i, don, t, think, he, goes, to, usf, he,...
\n", "
" ], "text/plain": [ " y msg\n", "0 0 [go, until, jurong, point, crazy, available, o...\n", "1 0 [ok, lar, joking, wif, u, oni]\n", "2 1 [free, entry, in, 2, a, wkly, comp, to, win, f...\n", "3 0 [u, dun, say, so, early, hor, u, c, already, t...\n", "4 0 [nah, i, don, t, think, he, goes, to, usf, he,..." ] }, "execution_count": 198, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import nltk\n", "from nltk.tokenize import RegexpTokenizer\n", "tokenizer = RegexpTokenizer(r\"\\w+\")\n", "\n", "data = []\n", "for _, row in raw_data.iterrows():\n", " y, msg = row\n", " y = 0 if (y==\"ham\") else 1\n", " msg = msg.lower()\n", " data.append([y, tokenizer.tokenize(msg)])\n", "data = pd.DataFrame(data, columns=[\"y\", \"msg\"])\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, before we go any further and start to make any more assumptions about how to process the messages, we should split this into training, validation and test sets. Then we'll blind the test set. Even though we haven't got to the typical training procedure yet, any more decisions regarding how to process the data will cross-contaminate the test set so we need to keep it separate. Since we only have a couple thousand data points, I'll use a 80/10/10 train/val/test split to get good statistics when evaluating the validation/test sets." ] }, { "cell_type": "code", "execution_count": 310, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set:\n", "Size: 4458, # spam examples: 601 (13.5%), # ham examples: 3857 (86.5%)\n", "\n", "Validation set:\n", "Size: 557, # spam examples: 66 (11.8%), # ham examples: 491 (88.2%)\n", "\n", "Test set:\n", "Size: 557, # spam examples: 80 (14.4%), # ham examples: 477 (85.6%)\n", "\n" ] } ], "source": [ "import numpy as np\n", "np.random.seed(0)\n", "shuffled_data = data.sample(len(data), random_state=0)\n", "\n", "# Determine the split\n", "M = len(shuffled_data)\n", "train_idx = int(M*0.80)+1\n", "val_idx = int(M*(0.80+0.10))+1\n", "X_train = shuffled_data[\"msg\"].values[:train_idx]\n", "y_train = shuffled_data[\"y\"].values[:train_idx]\n", "X_val = shuffled_data[\"msg\"].values[train_idx:val_idx]\n", "y_val = shuffled_data[\"y\"].values[train_idx:val_idx]\n", "X_test = shuffled_data[\"msg\"].values[val_idx:]\n", "y_test = shuffled_data[\"y\"].values[val_idx:]\n", "\n", "# Check split stats\n", "for name, X, y in [(\"Training\", X_train, y_train), (\"Validation\", X_val, y_val), (\"Test\", X_test, y_test)]:\n", " print(name + \" set:\")\n", " M = len(X); nspam = np.sum(y); nham = M - nspam\n", " print(\"Size: {}, # spam examples: {} ({:.1f}%), # ham examples: {} ({:.1f}%)\\n\".format(M, nspam, 100*nspam/M, nham, 100*nham/M))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's compute those summary stats. We'll gather frequency data on all the tokens in the training set." ] }, { "cell_type": "code", "execution_count": 311, "metadata": {}, "outputs": [], "source": [ "# Concatenate all tokens together\n", "all_tokens = []\n", "all_spam_tokens = []\n", "all_ham_tokens = []\n", "for x, y in zip(X_train, y_train):\n", " if (y==1): all_spam_tokens.extend(x)\n", " elif (y==0): all_ham_tokens.extend(x)\n", " all_tokens.extend(x)\n", "\n", "# Get frequency distribution of words\n", "from nltk.probability import FreqDist\n", "freq_dist = FreqDist(token for token in all_tokens)\n", "spam_freq_dist = FreqDist(token for token in all_spam_tokens)\n", "ham_freq_dist = FreqDist(token for token in all_ham_tokens)" ] }, { "cell_type": "code", "execution_count": 312, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total # of tokens with frequency of at least...\n", "1: 7790, 2: 3793, 3: 2535, 4: 1969, 5: 1645, 10: 914, 50: 220, 100: 118, \n", "\n", "Most frequent tokens (both spam/ham):\n", " i to you a the u and in is me my it for your of call s that have on \n", "2388 1800 1761 1161 1067 942 783 716 712 648 614 590 561 555 482 473 461 459 444 441 \n", "\n", "Most freqent spam tokens:\n", " to a call å you your the 2 free for now or txt u is on ur 4 have and \n", " 552 288 274 244 238 220 176 169 168 168 161 150 133 132 128 119 117 115 103 103 \n", "\n", "Most frequent ham tokens:\n", " i you to the a u and in me my is it that of for s so but can have \n", "2337 1523 1248 891 873 810 680 652 624 607 584 561 442 405 393 385 365 359 357 341 \n", "\n", "LEAST frequent tokens:\n", " mega asda counts toyota camry olayiwola mileage landing kane shud \n", " 1 1 1 1 1 1 1 1 1 1 \n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "print(\"Total # of tokens with frequency of at least...\")\n", "freq_count = {}\n", "for token in freq_dist:\n", " f = freq_dist[token]\n", " freq_count[f] = freq_count.get(f, 0) + 1\n", "freq_counts = [sum([freq_count[f] for f in freq_count if f >= n]) for n in [1,2,3,4,5,10,50,100]]\n", "print(\"\".join([str(n)+\": {}, \" for n in [1,2,3,4,5,10,50,100]]).format(*freq_counts), end=\"\\n\\n\")\n", "print(\"Most frequent tokens (both spam/ham):\")\n", "freq_dist.tabulate(20)\n", "print(\"\\nMost freqent spam tokens:\")\n", "spam_freq_dist.tabulate(20)\n", "print(\"\\nMost frequent ham tokens:\")\n", "ham_freq_dist.tabulate(20)\n", "print(\"\\nLEAST frequent tokens:\")\n", "FreqDist(dict(freq_dist.most_common()[-10:])).tabulate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking at the number of tokens with a given count, we can see that the vast majority of distinct tokens only appear a handful of times throughout the entire corpus. This means that we can safely ignore tokens that have counts less than a small number (say, 5?) because they are not distinctive enough to discriminate against in classification.\n", "\n", "Looking at the most frequent tokens, we can see that \"stop words\" such as \"i\", \"to\", \"you\", etc. appear quite often; however, the most frequent spam tokens are not exactly the same as the most frequent ham tokens! For example, the token \"i\" is the most frequent ham token, yet it doesn't even appear on the top 20 list of spam tokens. Likewise, the token \"å\" is one of the top 5 spam tokens, yet doesn't appear in the top 20 ham tokens.\n", "\n", "But just because a token is frequent throughout the corpus doesn't mean it appear in most messages (i.e. it has a high \"document frequency\"). Let's take a look at the document frequency and token frequency of the union of the top 10 tokens in either set:" ] }, { "cell_type": "code", "execution_count": 313, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'me' - spam document frequency: 3.7%, ham document frequency: 13.7%\n", "\tspam token frequency: 0.2%, ham token frequency: 1.1%\n", "\n", "'the' - spam document frequency: 24.0%, ham document frequency: 17.8%\n", "\tspam token frequency: 1.1%, ham token frequency: 1.6%\n", "\n", "'for' - spam document frequency: 24.5%, ham document frequency: 9.1%\n", "\tspam token frequency: 1.1%, ham token frequency: 0.7%\n", "\n", "'i' - spam document frequency: 6.0%, ham document frequency: 41.9%\n", "\tspam token frequency: 0.3%, ham token frequency: 4.1%\n", "\n", "'in' - spam document frequency: 9.3%, ham document frequency: 15.3%\n", "\tspam token frequency: 0.4%, ham token frequency: 1.1%\n", "\n", "'u' - spam document frequency: 16.3%, ham document frequency: 14.6%\n", "\tspam token frequency: 0.9%, ham token frequency: 1.4%\n", "\n", "'your' - spam document frequency: 31.1%, ham document frequency: 7.6%\n", "\tspam token frequency: 1.4%, ham token frequency: 0.6%\n", "\n", "'free' - spam document frequency: 21.8%, ham document frequency: 1.3%\n", "\tspam token frequency: 1.1%, ham token frequency: 0.1%\n", "\n", "'my' - spam document frequency: 1.2%, ham document frequency: 12.7%\n", "\tspam token frequency: 0.0%, ham token frequency: 1.1%\n", "\n", "'call' - spam document frequency: 42.3%, ham document frequency: 4.8%\n", "\tspam token frequency: 1.8%, ham token frequency: 0.4%\n", "\n", "'2' - spam document frequency: 21.6%, ham document frequency: 5.7%\n", "\tspam token frequency: 1.1%, ham token frequency: 0.5%\n", "\n", "'you' - spam document frequency: 32.4%, ham document frequency: 27.6%\n", "\tspam token frequency: 1.5%, ham token frequency: 2.7%\n", "\n", "'a' - spam document frequency: 36.6%, ham document frequency: 18.6%\n", "\tspam token frequency: 1.9%, ham token frequency: 1.5%\n", "\n", "'å' - spam document frequency: 32.3%, ham document frequency: 0.1%\n", "\tspam token frequency: 1.6%, ham token frequency: 0.0%\n", "\n", "'and' - spam document frequency: 15.0%, ham document frequency: 14.1%\n", "\tspam token frequency: 0.7%, ham token frequency: 1.2%\n", "\n", "'to' - spam document frequency: 61.9%, ham document frequency: 25.5%\n", "\tspam token frequency: 3.6%, ham token frequency: 2.2%\n", "\n" ] } ], "source": [ "_ = [dist.most_common(10) for dist in [spam_freq_dist, ham_freq_dist]]\n", "for t in {item[0] for sublist in _ for item in sublist}:\n", " spam_doc_freq = sum([(t in x) for x in X_train[y_train==1]])/len(X_train[y_train==1])\n", " ham_doc_freq = sum([(t in x) for x in X_train[y_train!=1]])/len(X_train[y_train!=1])\n", " print(\"'{}' - spam document frequency: {:.1f}%, ham document frequency: {:.1f}%\".format(t, 100*spam_doc_freq, 100*ham_doc_freq))\n", " print(\"\\tspam token frequency: {:.1f}%, ham token frequency: {:.1f}%\".format(100*spam_freq_dist.freq(t), 100*ham_freq_dist.freq(t)), end=\"\\n\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that just because a token appears frequently within a class doesn't mean that it is helpful in discriminating classes. Take a look at \"u\": even though it's in the top most frequent tokens in both classes, its document frequency in both is about 15%; it will not be useful to discriminate between spam and ham. However, there are a number of tokens that appear primarily in one class and not the other, such as \"me\", \"i\", \"free\", \"my\", \"call\", and \"å\". We can find the most discrimnative tokens by looking at their \"discriminative ratio\" (I just made that up), defined by\n", "\n", "$$\\text{dr}(t,D) = \\ln\\left(\\frac{\\text{df}(t,D_\\text{spam})}{\\text{df}(t,D_\\text{ham})}\\right),$$\n", "\n", "where $t$ is the token, $D_i$ is a set of documents, and $\\text{df}(t,D) = \\{d \\in D: t \\in d\\}/|D|$ is the document frequency. Tokens with large $|\\text{dr}(t,D)|$ should have good discriminative power since the document frequency within one class is significantly larger, and the sign of this metric should specify which class the token is biased towards." ] }, { "cell_type": "code", "execution_count": 603, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:8: RuntimeWarning: divide by zero encountered in log\n", " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Top spam features:\n", "('å', 6.028295894068239)\n", "('uk', 5.810293742254448)\n", "('www', 5.496636183399406)\n", "('nokia', 4.474009801709218)\n", "('txt', 4.392746837630453)\n", "('mobile', 3.7367519227018002)\n", "('cash', 3.521597761421069)\n", "('won', 3.2630439618923117)\n", "('stop', 3.007672732915791)\n", "('free', 2.802421714149846)\n", "('reply', 2.5941614033319977)\n", "('1', 2.3460649987452924)\n", "('text', 2.2468155546817834)\n", "('call', 2.181448533082571)\n", "('our', 2.12992497780842)\n", "('new', 1.8590500236730203)\n", "('or', 1.6686030594351484)\n", "('from', 1.625435172491515)\n", "('4', 1.6067073176665483)\n", "('your', 1.4031832841877871)\n", "('now', 1.4022916211773053)\n", "('only', 1.3890463944272846)\n", "('ur', 1.352488798693487)\n", "('send', 1.3414676174441496)\n", "('2', 1.3329569277762412)\n", "('for', 0.9886963869858907)\n", "('with', 0.9196517167887156)\n", "('to', 0.8863179808939132)\n", "('this', 0.8620479114156929)\n", "('on', 0.7266507374275272)\n", "('have', 0.6851157284663115)\n", "('a', 0.6762080009771575)\n", "('get', 0.6758803469768475)\n", "('t', 0.5706811997234378)\n", "('just', 0.48577103066519994)\n", "('is', 0.43262356112681216)\n", "('no', 0.4076163603821683)\n", "('of', 0.2988023554296916)\n", "('the', 0.29798569552313564)\n", "('are', 0.2835136629146011)\n", "('s', 0.20226599064172146)\n", "('you', 0.16225891233517728)\n", "('be', 0.15826233265069203)\n", "('u', 0.10896325084553317)\n", "('and', 0.06359369256363905)\n", "('we', 0.01464682095954321)\n", "\n", "Top ham features:\n", "('lt', -inf)\n", "('gt', -inf)\n", "('my', -2.3853552331248133)\n", "('but', -2.189832164472323)\n", "('ll', -2.062923312608294)\n", "('how', -1.9782494355591895)\n", "('ok', -1.9695913728160748)\n", "('i', -1.9451402769519104)\n", "('me', -1.317108071510026)\n", "('when', -1.2812634786920782)\n", "('that', -1.2583791848584907)\n", "('it', -1.0919463560395792)\n", "('so', -1.0225111045409898)\n", "('do', -0.9518575628688977)\n", "('m', -0.9409376727548754)\n", "('what', -0.903876787655024)\n", "('not', -0.8882208905824712)\n", "('at', -0.7619888004395601)\n", "('can', -0.7217979198563149)\n", "('if', -0.5776097925755947)\n", "('know', -0.544178595100713)\n", "('up', -0.5251150563134485)\n", "('in', -0.49402446924341725)\n", "('will', -0.1181126688863976)\n" ] } ], "source": [ "dr = []\n", "_ = [dist.most_common(50) for dist in [spam_freq_dist, ham_freq_dist]]\n", "for t in {item[0] for sublist in _ for item in sublist}:\n", " spam_doc_freq = sum([(t in x) for x in X_train[y_train==1]])/len(X_train[y_train==1])\n", " ham_doc_freq = sum([(t in x) for x in X_train[y_train!=1]])/len(X_train[y_train!=1])\n", " #idf = -np.log(sum([(t in x) for x in X_train])/len(X_train))\n", " try:\n", " dr.append((t, np.log(spam_doc_freq/ham_doc_freq)))\n", " except ZeroDivisionError:\n", " pass\n", "dr.sort(key=lambda e: abs(e[1]), reverse=True)\n", "print(\"Top spam features:\", *filter(lambda e: e[1]>0, dr), sep=\"\\n\", end=\"\\n\\n\")\n", "print(\"Top ham features:\", *filter(lambda e: e[1]<0, dr), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The presence of the above tokens should then be the top 20 most discriminative features when trying to classify spam messages. We can see that 'å' is clearly the top feature, with a heavy positive bias, as predicted. It appears that most of the features with strong negative bias (\"my\", \"i\", \"me\") are first-person pronouns - obviously whoever is sending spam messages doesn't like to talk about themselves that much! It also appears as though there are two tokens (\"lt\" and \"gt\") which only appear in ham messages, and are thus completely correlated with the ham class. If you look at the messages which contain these tokens, it's obvious that they're some kind of transcription error; I'm pretty sure that they are corrupted versions of '<' and '>' (i.e. the less/greater than symbols '<' = '<' and '>' = '>'). The uncorrupted versions are entirely within the spam class, so I don't think it would be fair to use these tokens as classification features. On that note, it also appears that 'å' is also a corrupted version of the symbol for GBP (£), but since it's essentially a stand-in, I think it'll be fine to keep it." ] }, { "cell_type": "code", "execution_count": 604, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Occurrences of lt and gt:\n", "[['ham'\n", " 'Great! I hope you like your man well endowed. I am <#> inches...']\n", " ['ham'\n", " 'A gram usually runs like <#> , a half eighth is smarter though and gets you almost a whole second gram for <#>']\n", " ['ham'\n", " 'Do you know what Mallika Sherawat did yesterday? Find out now @ <URL>']\n", " ['ham' 'Does not operate after <#> or what']\n", " ['ham'\n", " \"Turns out my friends are staying for the whole show and won't be back til ~ <#> , so feel free to go ahead and smoke that $ <#> worth\"]]\n", "\n", "Occurrences of '<' and '>':\n", "[['spam'\n", " 'SIX chances to win CASH! From 100 to 20,000 pounds txt> CSH11 and send to 87575. Cost 150p/day, 6days, 16+ TsandCs apply Reply HL 4 info']\n", " ['spam'\n", " 'XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap. xxxmobilemovieclub.com?n=QJKGIGHJJGCBL']\n", " ['spam'\n", " 'TheMob> Check out our newest selection of content, Games, Tones, Gossip, babes and sport, Keep your mobile fit and funky text WAP to 82468']\n", " ['spam'\n", " 'Please CALL 08712404000 immediately as there is an urgent message waiting for you.']\n", " ['spam'\n", " 'RT-KIng Pro Video Club>> Need help? info@ringtoneking.co.uk or call 08701237397 You must be 16+ Club credits redeemable at www.ringtoneking.co.uk! Enjoy!']]\n", "\n", "Occurrences of 'å':\n", "[['spam'\n", " \"FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, å£1.50 to rcv\"]\n", " ['spam'\n", " 'WINNER!! As a valued network customer you have been selected to receivea å£900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.']\n", " ['spam'\n", " 'URGENT! You have won a 1 week FREE membership in our å£100,000 Prize Jackpot! Txt the word: CLAIM to No: 81010 T&C www.dbuk.net LCCLTD POBOX 4403LDNW1A7RW18']\n", " ['ham' 'Fine if thatåÕs the way u feel. ThatåÕs the way its gota b']\n", " ['spam'\n", " 'Thanks for your subscription to Ringtone UK your mobile will be charged å£5/month Please confirm by replying YES or NO. If you reply NO you will not be charged']]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: UserWarning: This pattern has match groups. To actually get the groups, use str.extract.\n", " \n" ] } ], "source": [ "print(\"Occurrences of lt and gt:\")\n", "print(raw_data[raw_data.msg.str.contains(r\"\\W(gt|lt)\\W\")].values[:5], end=\"\\n\\n\")\n", "print(\"Occurrences of '<' and '>':\")\n", "print(raw_data[raw_data.msg.str.contains(r\"<|>\")].values[:5], end=\"\\n\\n\")\n", "print(\"Occurrences of 'å':\")\n", "print(raw_data[raw_data.msg.str.contains(r\"å\")].values[:5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need to make sure that there is some subset of these discriminative tokens $T_\\text{disc}$ such that every document $d \\in D$ contains at least one element of $T_\\text{disc} = \\{t | t \\in d \\, \\forall t \\in T_\\text{disc}, d \\in D \\}$." ] }, { "cell_type": "code", "execution_count": 861, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of tokens which completely covers entire dataset: |T_disc| = 68\n", "Top 20 tokens: å, uk, www, nokia, txt, mobile, cash, won, stop, free, reply, my, 1, text, but, call, our, ll, how, ok\n" ] } ], "source": [ "# Iterate over all tokens\n", "T_disc = []\n", "X_all = list(X_train) + list(X_val)\n", "for t, ratio in dr:\n", " # Remove 'lt' and 'gt' from consideration because that's probably cheating\n", " if np.isinf(ratio): continue\n", " # Test to see which document the token is in, and remove it from consideration\n", " for i in reversed(range(len(X_all))):\n", " if t in X_all[i]:\n", " del X_all[i]\n", " T_disc.append(t)\n", " if len(X_all) == 0:\n", " break\n", "print(\"Number of tokens which completely covers entire dataset: |T_disc| =\", len(T_disc))\n", "print(\"Top 20 tokens: \", end=\"\")\n", "print(*T_disc[:20], sep=\", \")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logistic regression and assessment of performance metrics\n", "As a simple first attempt at classification, let's just take a subset of the top discriminative tokens (except for 'lt' and 'gt') and use those as features in a linear classifier. First, we'll need a function to convert our documents to (sparse) feature vectors." ] }, { "cell_type": "code", "execution_count": 862, "metadata": {}, "outputs": [], "source": [ "import scipy as sp\n", "from scipy.sparse import dok_matrix\n", "\n", "token_to_idx_dict = {T_disc[i]: i+1 for i in range(len(T_disc))} # Reserve 0 index for unknown token\n", "X_train_sparse = dok_matrix((len(X_train), len(token_to_idx_dict)+1), dtype=np.float32)\n", "X_val_sparse = dok_matrix((len(X_val), len(token_to_idx_dict)+1), dtype=np.float32)\n", "X_test_sparse = dok_matrix((len(X_test), len(token_to_idx_dict)+1), dtype=np.float32)\n", "for X, S in [(X_train, X_train_sparse), (X_val, X_val_sparse), (X_test, X_test_sparse)]:\n", " for i, x in enumerate(X):\n", " for t in x:\n", " S[i, token_to_idx_dict.get(t, 0)] = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll import a logistic regression classifier from scikit-learn and fit it to that. Since there is a class imbalance in the dataset, we'll rebalance it by assigning extra weight to the spam class (the ratio of spam:ham is about 6:1)." ] }, { "cell_type": "code", "execution_count": 863, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LogisticRegression(C=100.0, class_weight={0: 1, 1: 6}, dual=False,\n", " fit_intercept=True, intercept_scaling=1, max_iter=100,\n", " multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,\n", " solver='liblinear', tol=0.0001, verbose=0, warm_start=False)" ] }, "execution_count": 863, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "model = {}\n", "model[1] = LogisticRegression(penalty=\"l2\", C=100.0, solver=\"liblinear\", class_weight={0:1, 1:6})\n", "model[1].fit(X_train_sparse, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see how our classifier performs. First, let's look at some standard performance metrics across all the classes such as ROC, precision/recall, and $F_1$ score." ] }, { "cell_type": "code", "execution_count": 864, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Training statistics:\n", " precision recall f1-score support\n", "\n", " ham 0.99 0.96 0.98 3857\n", " spam 0.80 0.94 0.86 601\n", "\n", "avg / total 0.96 0.96 0.96 4458\n", "\n", "Training classification accuracy: 0.960\n", "\n", "Validation statistics:\n", " precision recall f1-score support\n", "\n", " ham 0.99 0.97 0.98 491\n", " spam 0.79 0.94 0.86 66\n", "\n", "avg / total 0.97 0.96 0.97 557\n", "\n", "Validation classification accuracy: 0.964\n" ] } ], "source": [ "from sklearn.metrics import classification_report, roc_curve, precision_recall_curve, accuracy_score\n", "\n", "# Get training/validation predictions\n", "y_train_pred = model[1].predict_proba(X_train_sparse)\n", "y_val_pred = model[1].predict_proba(X_val_sparse)\n", "# Plot ROC curves\n", "def roc_plot(y_train, y_train_pred, y_val, y_val_pred):\n", " fpr, tpr, _ = roc_curve(y_train, y_train_pred[:,1])\n", " plt.plot(fpr, tpr, label=\"Training\")\n", " fpr, tpr, _ = roc_curve(y_val, y_val_pred[:,1])\n", " plt.plot(fpr, tpr, label=\"Validation\")\n", " plt.ylim(ymin=0.8, ymax=1)\n", " plt.xlabel(\"False positive rate\")\n", " plt.ylabel(\"True positive rate\")\n", " plt.title(\"ROC\")\n", " plt.legend(loc=4)\n", "# Plot precision/recall curves\n", "def pr_plot(y_train, y_train_pred, y_val, y_val_pred):\n", " precision, recall, _ = precision_recall_curve(y_train, y_train_pred[:,1])\n", " plt.plot(precision, recall, label=\"Training\")\n", " precision, recall, _ = precision_recall_curve(y_val, y_val_pred[:,1])\n", " plt.plot(precision, recall, label=\"Validation\")\n", " plt.ylim(ymin=0.6, ymax=1)\n", " plt.xlabel(\"Precision\")\n", " plt.ylabel(\"Recall\")\n", " plt.title(\"Precision/recall\")\n", " plt.legend(loc=3)\n", "def f1_plot(y_train, y_train_pred, y_val, y_val_pred):\n", " p, r, thresh = precision_recall_curve(y_train, y_train_pred[:,1]); p=p[:-1]; r=r[:-1]\n", " plt.plot(thresh, 2*p*r/(p+r), label=\"Training\")\n", " p, r, thresh = precision_recall_curve(y_val, y_val_pred[:,1]); p=p[:-1]; r=r[:-1]\n", " plt.plot(thresh, 2*p*r/(p+r), label=\"Validation\")\n", " #plt.ylim(ymin=0.6, ymax=1)\n", " plt.xlabel(\"Threshold\")\n", " plt.ylabel(\"F1\")\n", " plt.title(\"$F_1$ score\")\n", " plt.legend(loc=3)\n", "\n", "plt.subplot(131)\n", "roc_plot(y_train, y_train_pred, y_val, y_val_pred)\n", "plt.subplot(132)\n", "pr_plot(y_train, y_train_pred, y_val, y_val_pred)\n", "plt.subplot(133)\n", "f1_plot(y_train, y_train_pred, y_val, y_val_pred)\n", "plt.subplots_adjust(right=2, wspace=0.25)\n", "plt.show()\n", "\n", "# Show classification reports\n", "print(\"Training statistics:\")\n", "print(classification_report(y_train, np.argmax(y_train_pred, axis=1), target_names=[\"ham\", \"spam\"]))\n", "print(\"Training classification accuracy: {:.3f}\".format(accuracy_score(y_train, np.argmax(y_train_pred, axis=1))))\n", "print(\"\\nValidation statistics:\")\n", "print(classification_report(y_val, np.argmax(y_val_pred, axis=1), target_names=[\"ham\", \"spam\"]))\n", "print(\"Validation classification accuracy: {:.3f}\".format(accuracy_score(y_val, np.argmax(y_val_pred, axis=1))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This looks pretty good at a first glance; 96% classification accuracy doesn't look too bad. But since there is a class imbalance, accuracy isn't everything. Also, we have to consider the risk of making misclassifications. Specifically, if we want to make sure we're catching all the spam (true positives), we want to make sure we have a high recall (true positive rate). However, we REALLY don't want to misclassify a ham message as spam (false positive), since if a legitimate message is accidentally deleted, it could have grave consequences, whereas if a spam message gets through the filter, it is simply a minor nusiance. As it is now, with a threshold of 0.5, our classifier will catch 91% of the spam, and incorrectly identifies ham messages 2% of the time, which doesn't sound bad. But the fraction of messages which it classifies as spam that are actually ham is as high as 20%!" ] }, { "cell_type": "code", "execution_count": 865, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fraction of ham messages incorrectly identified as spam (FPR): 0.0326\n", "Fraction of spam messages correctly identified (TPR): 0.9394\n", "Fraction of messages identified as spam which are actually ham (FDR): 0.2051\n" ] } ], "source": [ "thresh = 0.5\n", "print(\"Fraction of ham messages incorrectly identified as spam (FPR): {:.4f}\".format(np.mean(y_val_pred[y_val==0][:,1] > thresh)))\n", "print(\"Fraction of spam messages correctly identified (TPR): {:.4f}\".format(np.mean(y_val_pred[y_val==1][:,1] > thresh)))\n", "print(\"Fraction of messages identified as spam which are actually ham (FDR): {:.4f}\".format(np.mean(y_val[y_val_pred[:,1]>thresh] == 0)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our primary goal is to minimize the number of false positives. Since the ratio of ham:spam is already relatively large (about 6:1), we don't expect to be making a lot of positive identifications anyway. Therefore, the false positive rate ($FPR=FP/N$) is always going to be relatively small, regardless of performance. The false discovery rate ($FDR=FP/(TP+FP)$ or 1-precision), on the other hand, will be relatively large since the denominator $TP+FP$ is smaller than the total number of negative examples $N=TN+FP$. So minimizing the FDR should be our primary goal. This is equivalent to maximizing the precision ($TP/(TP+FP)$), but large relative changes in small FDR may not amount to significant changes in precision (like changing FDR from 0.01 -> 0.03 only changes precision from 0.99 -> 0.97). However, in the end, we really don't want to end up throwing a large fraction of legitimate messages away, so FPR will be useful to us as a constraint (e.g. don't want to discard any more than 1/1000 ham messages).\n", "\n", "Our secondary goal is to positively identify as much actual spam as possible. This is simply measured by the recall or true positive rate $TPR=TP/T=TP/(TP+FN)$.\n", "\n", "Ideally, we would be able to combine these two metrics (1/FDR and TPR) into a single one so that we may measure performance at a glance, but since $1/FDR \\in [1, \\infty)$ and $TPR \\in [0,1]$, this does not seem feasible. However, what we can do is try to maximize recall, while maintaining a constraint on the false positive rate. To evaluate our model using this metric, we can plot FPR vs TPR vs different thresholds (the typical ROC curve) on a semilog plot, decide on a maximum allowable cutoff for the FPR (which determines the classification threshold), and then check the TPR at this threshold. Let's see how our current model is evaluated using this metric:" ] }, { "cell_type": "code", "execution_count": 792, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fpr, tpr, _ = roc_curve(y_train, y_train_pred[:,1])\n", "plt.semilogx(fpr, tpr, label=\"Training\")\n", "fpr, tpr, _ = roc_curve(y_val, y_val_pred[:,1])\n", "plt.semilogx(fpr, tpr, label=\"Validation\")\n", "plt.ylim(ymin=0.6, ymax=1)\n", "plt.xlabel(\"False positive rate\")\n", "plt.ylabel(\"True positive rate\")\n", "plt.title(\"ROC\")\n", "plt.legend(loc=4)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I think a false positive rate of 1 in 100 (0.01) is probably a pretty good constraint (also our validation set only has ~500 samples in it so it's impossible to measure an FPR less than ~1/500). We'll use that from now on. Looks like this corresponds to a recall of roughly 0.80-0.85. Let's write a function to do the evaluation for us:" ] }, { "cell_type": "code", "execution_count": 793, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training - FPR=9.852e-03, TPR=0.875, thresh=0.824\n", "Validation - FPR=8.147e-03, TPR=0.879, thresh=0.880\n" ] } ], "source": [ "def evaluate_model(model, X, y, FPR_max=0.01):\n", " # Find y_pred\n", " y_pred = model.predict_proba(X)\n", " \n", " # Get FPR, TPR, and thresholds from sklearn's roc_curve function\n", " fpr, tpr, thresh = roc_curve(y, y_pred[:,1])\n", " \n", " # Find the threshold with the largest FPR below the max\n", " for i in reversed(range(len(fpr))):\n", " if fpr[i] <= FPR_max: break\n", " return fpr[i], tpr[i], thresh[i]\n", "print(\"Training - FPR={:.3e}, TPR={:.3f}, thresh={:.3f}\".format(*evaluate_model(model[1], X_train_sparse, y_train)))\n", "print(\"Validation - FPR={:.3e}, TPR={:.3f}, thresh={:.3f}\".format(*evaluate_model(model[1], X_val_sparse, y_val)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our metric down we should start tuning hyperparameters. I'll just roll the dice on a random set of hyperparameters a number of times, evaluate them using k-fold cross-validation, then pick the best model." ] }, { "cell_type": "code", "execution_count": 830, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model 0: FPR=9.174e-03, TPR=0.849±0.069, thresh=0.921\n", "Model 2: FPR=9.335e-03, TPR=0.859±0.070, thresh=0.532\n", "Model 17: FPR=9.259e-03, TPR=0.860±0.041, thresh=0.774\n", "Model 87: FPR=5.807e-03, TPR=0.866±0.052, thresh=0.633\n", "Model 94: FPR=8.226e-03, TPR=0.873±0.056, thresh=0.679\n" ] } ], "source": [ "from sklearn.model_selection import KFold\n", "kf = KFold(n_splits=5, shuffle=True, random_state=None)\n", "\n", "best_hyperparameters = None\n", "best_tpr = 0\n", "X = sp.sparse.vstack([X_train_sparse, X_val_sparse]).todok()\n", "y = np.concatenate([y_train, y_val], axis=0)\n", "\n", "for _ in range(100):\n", " # Pick hyperparameters\n", " C = 10**np.random.uniform(-2,2)\n", " spam_weight = np.random.uniform(0.5, 10)\n", " penalty = np.random.choice([\"l1\", \"l2\"])\n", " \n", " # Instantiate, train, and evaluate model using k-fold CV\n", " tpr_list = []\n", " for train_idx, val_idx in kf.split(X):\n", " temp_model = LogisticRegression(penalty=penalty, C=C, solver=\"liblinear\", class_weight={0:1, 1:spam_weight})\n", " temp_model.fit(X[train_idx], y[train_idx])\n", " fpr, tpr, thresh = evaluate_model(temp_model, X[val_idx], y[val_idx])\n", " tpr_list.append(tpr)\n", " if np.mean(tpr_list) > best_tpr:\n", " best_tpr = np.mean(tpr_list)\n", " best_hyperparameters = (C, spam_weight, penalty)\n", " print(\"Model {}: FPR={:.3e}, TPR={:.3f}\\u00B1{:.3f}, thresh={:.3f}\".format(_, fpr, best_tpr, 2*np.std(tpr_list), thresh), end=\"\\n\")" ] }, { "cell_type": "code", "execution_count": 870, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C=9.16e-01, spam_weight=3.00, penalty=l1\n", "Test set: FPR=0.4864, TPR=0.988, threshold=0.013\n", "Test set: FPR=0.0964, TPR=0.938, threshold=0.104\n", "Test set: FPR=0.0042, TPR=0.825, threshold=0.857\n", "Test set: FPR=0.0000, TPR=0.750, threshold=0.964\n" ] } ], "source": [ "C, spam_weight, penalty = best_hyperparameters\n", "model[2] = LogisticRegression(penalty=penalty, C=C, solver=\"liblinear\", class_weight={0:1, 1:spam_weight})\n", "model[2].fit(X_train_sparse, y_train)\n", "print(\"C={:.2e}, spam_weight={:.2f}, penalty={}\".format(*best_hyperparameters))\n", "for FPR_max in [1/2, 1/10, 1/100, 1/1000]:\n", " print(\"Test set: FPR={:.4f}, TPR={:.3f}, threshold={:.3f}\".format(*evaluate_model(model[2], X_test_sparse, y_test, FPR_max=FPR_max)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have a recall of 82.5% and a false positive rate of <1% on the test set (which has been completely blinded up until this point). Looking around for other studies/implementations to benchmark against, I found [this study](https://arxiv.org/abs/cs/0009009), which compared several spam filtering techniques. For comparison, our constraint of max FPR of 1/100 would be roughly comparable to one of their models with parameter $\\lambda = 100$ (this comparison falls apart for $\\lambda\\sim 1$). The value of $\\lambda$ can be understood as the relative cost of misclassifying ham as spam vs misclassifying spam as ham - it is $\\lambda$ times worse to make a false positive identification than a false negative. As we can see from their Table 1, our logistic regression model outperforms all the models they considered with regards to spam recall (TPR) for equivalent values of $\\lambda$." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }