{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text Classification with NLTK\n", "> In this post, we will expand on our NLP foundation and explore different ways to improve our text classification with NLTK and Scikit-learn. In details, we will build SMS spam filters.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Machine_Learning, Natural_Language_Processing]\n", "- image: " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Required Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import nltk\n", "import sklearn\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Version Check" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python: 3.7.6 (default, Jan 8 2020, 20:23:39) [MSC v.1916 64 bit (AMD64)]\n", "NLTK: 3.4.5\n", "Scikit-learn: 0.22.1\n", "Pandas: 1.0.1\n", "NumPy: 1.18.1\n" ] } ], "source": [ "print('Python: {}'.format(sys.version))\n", "print('NLTK: {}'.format(nltk.__version__))\n", "print('Scikit-learn: {}'.format(sklearn.__version__))\n", "print('Pandas: {}'.format(pd.__version__))\n", "print('NumPy: {}'.format(np.__version__))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the dataset\n", "Now that we have ensured that our libraries are installed correctly, let's load the data set as a Pandas DataFrame. Furthermore, let's extract some useful information such as the column information and class distributions.\n", "\n", "The data set we will be using comes from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection). It contains over 5000 SMS labeled messages that have been collected for mobile phone spam research. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
0hamGo until jurong point, crazy.. Available only ...
1hamOk lar... Joking wif u oni...
2spamFree entry in 2 a wkly comp to win FA Cup fina...
3hamU dun say so early hor... U c already then say...
4hamNah I don't think he goes to usf, he lives aro...
\n", "
" ], "text/plain": [ " 0 1\n", "0 ham Go until jurong point, crazy.. Available only ...\n", "1 ham Ok lar... Joking wif u oni...\n", "2 spam Free entry in 2 a wkly comp to win FA Cup fina...\n", "3 ham U dun say so early hor... U c already then say...\n", "4 ham Nah I don't think he goes to usf, he lives aro..." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the dataset of SMS messages\n", "sms = pd.read_table('./dataset/SMSSpamCollection', header=None, encoding='utf-8')\n", "sms.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 5572 entries, 0 to 5571\n", "Data columns (total 2 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 0 5572 non-null object\n", " 1 1 5572 non-null object\n", "dtypes: object(2)\n", "memory usage: 87.2+ KB\n" ] } ], "source": [ "sms.info()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ham 4825\n", "spam 747\n", "Name: 0, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check class distribution\n", "sms[0].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the data summary, we can find that the SPAM message is defined as `spam` and non-SPAM message is defined as `ham`. And there are 747 spam messages in dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data-preprocess\n", "From the label, label is defined with string type. To recognize it in model, It needs to convert it with binary values. This kind of process is called **one-hot encoding**. There are several ways to apply one-hot encoding:\n", "\n", "- use `pd.get_dummies`\n", "- use `LabelEncoder` in `sklearn.preprocessing`\n", "\n", "In this time, we use `LabelEncoder`," ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 1 0 0 1 0 0 1 1]\n", "0 ham\n", "1 ham\n", "2 spam\n", "3 ham\n", "4 ham\n", "5 spam\n", "6 ham\n", "7 ham\n", "8 spam\n", "9 spam\n", "Name: 0, dtype: object\n" ] } ], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "\n", "enc = LabelEncoder()\n", "label = enc.fit_transform(sms[0])\n", "print(label[:10])\n", "print(sms[0][:10])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 Go until jurong point, crazy.. Available only ...\n", "1 Ok lar... Joking wif u oni...\n", "2 Free entry in 2 a wkly comp to win FA Cup fina...\n", "3 U dun say so early hor... U c already then say...\n", "4 Nah I don't think he goes to usf, he lives aro...\n", "5 FreeMsg Hey there darling it's been 3 week's n...\n", "6 Even my brother is not like to speak with me. ...\n", "7 As per your request 'Melle Melle (Oru Minnamin...\n", "8 WINNER!! As a valued network customer you have...\n", "9 Had your mobile 11 months or more? U R entitle...\n", "Name: 1, dtype: object" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = sms[1]\n", "text[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, it is time to text preprocessing. From the previous post, we've learned several text preprocess. But before apply those techniques, we need to formalize the text that need to remove special characters or numbers like phone numbers and so on. To do this, we can use **regular expression**(regex for short) for finding the pattern-matching. Here is some common regex form described in wikipedia.\n", "\n", "- **^** Matches the starting position within the string. In line-based tools, it matches the starting position of any line.\n", "\n", "- **.** Matches any single character (many applications exclude newlines, and exactly which characters are considered newlines is flavor-, character-encoding-, and platform-specific, but it is safe to assume that the line feed character is included). Within POSIX bracket expressions, the dot character matches a literal dot. For example, a.c matches \"abc\", etc., but [a.c] matches only \"a\", \".\", or \"c\".\n", "\n", "- **[ ]** A bracket expression. Matches a single character that is contained within the brackets. For example, [abc] matches \"a\", \"b\", or \"c\". [a-z] specifies a range which matches any lowercase letter from \"a\" to \"z\". These forms can be mixed: [abcx-z] matches \"a\", \"b\", \"c\", \"x\", \"y\", or \"z\", as does [a-cx-z]. The - character is treated as a literal character if it is the last or the first (after the ^, if present) character within the brackets: [abc-], [-abc]. Note that backslash escapes are not allowed. The ] character can be included in a bracket expression if it is the first (after the ^) character: []abc].\n", "\n", "- **[^ ]** Matches a single character that is not contained within the brackets. For example, [^abc] matches any character other than \"a\", \"b\", or \"c\". [^a-z] matches any single character that is not a lowercase letter from \"a\" to \"z\". Likewise, literal characters and ranges can be mixed.\n", "\n", "- **\\$** Matches the ending position of the string or the position just before a string-ending newline. In line-based tools, it matches the ending position of any line.\n", "\n", "- **( )** Defines a marked subexpression. The string matched within the parentheses can be recalled later (see the next entry, \\n). A marked subexpression is also called a block or capturing group. BRE mode requires ( ).\n", "\n", "- **\\\\n** Matches what the nth marked subexpression matched, where n is a digit from 1 to 9. This construct is vaguely defined in the POSIX.2 standard. Some tools allow referencing more than nine capturing groups.\n", "\n", "- **\\*** Matches the preceding element zero or more times. For example, abc matches \"ac\", \"abc\", \"abbbc\", etc. [xyz] matches \"\", \"x\", \"y\", \"z\", \"zx\", \"zyx\", \"xyzzy\", and so on. (ab)* matches \"\", \"ab\", \"abab\", \"ababab\", and so on.\n", "\n", "- **{m,n}** Matches the preceding element at least m and not more than n times. For example, a{3,5} matches only \"aaa\", \"aaaa\", and \"aaaaa\". This is not found in a few older instances of regexes. BRE mode requires {m,n}.\n", "\n", "If you want to test your regex form, test it [here](https://regexr.com/)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Use regular expression\n", "\n", "# Replace email addresses with 'email'\n", "processed = text.str.replace(r'^.+@[^\\.].*\\.[a-z]{2,}$', 'emailaddress')\n", "\n", "# Replace URLs with 'webaddress'\n", "processed = processed.str.replace(r'^http\\://[a-zA-Z0-9\\-\\.]+\\.[a-zA-Z]{2,3}(/\\S*)?$', 'webaddress')\n", "\n", "# Replace money symbols with 'moneysymb' (£ can by typed with ALT key + 156)\n", "processed = processed.str.replace(r'£|\\$', 'moneysymb')\n", " \n", "# Replace 10 digit phone numbers (formats include paranthesis, spaces, no spaces, dashes) with 'phonenumber'\n", "processed = processed.str.replace(r'^\\(?[\\d]{3}\\)?[\\s-]?[\\d]{3}[\\s-]?[\\d]{4}$', 'phonenumbr')\n", " \n", "# Replace numbers with 'numbr'\n", "processed = processed.str.replace(r'\\d+(\\.\\d+)?', 'numbr')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And it is required to remove useless characters like whitespace, punctuation and so on." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Remove punctuation\n", "processed = processed.str.replace(r'[^\\w\\d\\s]', ' ')\n", "\n", "# Replace whitespace between terms with a single space\n", "processed = processed.str.replace(r'\\s+', ' ')\n", "\n", "# Remove leading and trailing whitespace\n", "processed = processed.str.replace(r'^\\s+|\\s+?$', '')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After that, we will use all lower case sentence." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 go until jurong point crazy available only in ...\n", "1 ok lar joking wif u oni\n", "2 free entry in numbr a wkly comp to win fa cup ...\n", "3 u dun say so early hor u c already then say\n", "4 nah i don t think he goes to usf he lives arou...\n", " ... \n", "5567 this is the numbrnd time we have tried numbr c...\n", "5568 will ü b going to esplanade fr home\n", "5569 pity was in mood for that so any other suggest...\n", "5570 the guy did some bitching but i acted like i d...\n", "5571 rofl its true to its name\n", "Name: 1, Length: 5572, dtype: object" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processed = processed.str.lower()\n", "processed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, in the previous post, we learned about stopword removing for text preprocessing. we can apply this." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from nltk.corpus import stopwords\n", "\n", "stop_words = set(stopwords.words('english'))\n", "\n", "processed = processed.apply(lambda x: ' '.join(term for term in x.split() if term not in stop_words))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also, using PorterStemmer, we can extract stem of each word." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "ps = nltk.PorterStemmer()\n", "\n", "processed = processed.apply(lambda x: ' '.join(ps.stem(term) for term in x.split()))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 go jurong point crazi avail bugi n great world...\n", "1 ok lar joke wif u oni\n", "2 free entri numbr wkli comp win fa cup final tk...\n", "3 u dun say earli hor u c alreadi say\n", "4 nah think goe usf live around though\n", " ... \n", "5567 numbrnd time tri numbr contact u u moneysymbnu...\n", "5568 ü b go esplanad fr home\n", "5569 piti mood suggest\n", "5570 guy bitch act like interest buy someth els nex...\n", "5571 rofl true name\n", "Name: 1, Length: 5572, dtype: object" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, you can see processed message is quite different from original one, since stop word removing, stemming and regular expression is applied." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature extraction\n", "After preprocessing, we need to extract feature from text message. To do this, it will be necessary to tokenize each word. In this case, we will use the 1500 most common words as features." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of words: 6579\n", "Most common words: [('numbr', 2648), ('u', 1207), ('call', 674), ('go', 456), ('get', 451), ('ur', 391), ('gt', 318), ('lt', 316), ('come', 304), ('moneysymbnumbr', 303), ('ok', 293), ('free', 284), ('day', 276), ('know', 275), ('love', 266)]\n" ] } ], "source": [ "from nltk.tokenize import word_tokenize\n", "\n", "all_words = []\n", "\n", "for message in processed:\n", " words = word_tokenize(message)\n", " for w in words:\n", " all_words.append(w)\n", " \n", "all_words = nltk.FreqDist(all_words)\n", "\n", "# Print the result\n", "print('Number of words: {}'.format(len(all_words)))\n", "print('Most common words: {}'.format(all_words.most_common(15)))" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "# use the 1500 most common words as features\n", "word_features = [x[0] for x in all_words.most_common(1500)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we created the feature list, now we need to find the what features are in messages." ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "def find_features(message):\n", " words = word_tokenize(message)\n", " features = {}\n", " for word in word_features:\n", " features[word] = (word in words)\n", "\n", " return features" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go\n", "got\n", "n\n", "great\n", "wat\n", "e\n", "world\n", "point\n", "avail\n", "crazi\n", "bugi\n", "la\n", "cine\n" ] } ], "source": [ "features = find_features(processed[0])\n", "for key, value in features.items():\n", " if value == True:\n", " print(key)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('numbr', False),\n", " ('u', False),\n", " ('call', False),\n", " ('go', True),\n", " ('get', False),\n", " ('ur', False),\n", " ('gt', False),\n", " ('lt', False),\n", " ('come', False),\n", " ('moneysymbnumbr', False)]" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(features.items())[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we made an one simple data that we can use it as an training set. We can apply same apporach in other dataset. Then, we need to split into training set and test set" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "messages = list(zip(processed, label))\n", "\n", "np.random.seed(1)\n", "np.random.shuffle(messages)\n", "\n", "# Call find_features function for each SMS message\n", "feature_set = [(find_features(text), label) for (text, label) in messages]" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "training, test = train_test_split(feature_set, test_size=0.25, random_state=1)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4179\n", "1393\n" ] } ], "source": [ "print(len(training))\n", "print(len(test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Scikit-learn Classifier with NLTK\n", "Now, we build the training and test set, we can build machine learning model in scikit-learn. We are using the following alogithms and see the performance of each ones,\n", "\n", "- KNearestNeighbors\n", "- Random Forest\n", "- Decision Tree\n", "- Logistic Regression\n", "- Naive Bayes\n", "- Support Vector Machine" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "K Nearest Neighbors model Accuracy: 0.9454414931801867\n", "Decision Tree model Accuracy: 0.95908111988514\n", "Random Forest model Accuracy: 0.9813352476669059\n", "Logistic Regression model Accuracy: 0.9834888729361091\n", "SGD Classifier model Accuracy: 0.9806173725771715\n", "Naive Bayes model Accuracy: 0.9856424982053122\n", "Support Vector Classifier model Accuracy: 0.9820531227566404\n" ] } ], "source": [ "from nltk.classify.scikitlearn import SklearnClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression, SGDClassifier\n", "from sklearn.naive_bayes import MultinomialNB\n", "from sklearn.svm import SVC\n", "from sklearn.metrics import classification_report, accuracy_score, confusion_matrix\n", "\n", "names = ['K Nearest Neighbors', 'Decision Tree', 'Random Forest', 'Logistic Regression', 'SGD Classifier',\n", " 'Naive Bayes', 'Support Vector Classifier']\n", "\n", "classifiers = [\n", " KNeighborsClassifier(),\n", " DecisionTreeClassifier(),\n", " RandomForestClassifier(),\n", " LogisticRegression(),\n", " SGDClassifier(max_iter=100),\n", " MultinomialNB(),\n", " SVC(kernel='linear')\n", "]\n", "\n", "models = zip(names, classifiers)\n", "\n", "for name, model in models:\n", " nltk_model = SklearnClassifier(model)\n", " nltk_model.train(training)\n", " accuracy = nltk.classify.accuracy(nltk_model, test)\n", " print(\"{} model Accuracy: {}\".format(name, accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the result, most of models can get almost 95~98% accuracy. But we can also enhance our model to voting the best model from the result, the one of ensemble approach. To do this, we need to use `VotingClassifier` from `sklearn.ensemble`. You can find the details of Voting Classifier [here](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html)." ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Voting Classifier model Accuracy: 0.9842067480258435\n" ] } ], "source": [ "from sklearn.ensemble import VotingClassifier\n", "\n", "# Since VotingClassifier can accept list type of models\n", "models = list(zip(names, classifiers))\n", "\n", "nltk_ensemble = SklearnClassifier(VotingClassifier(estimators=models, voting='hard', n_jobs=-1))\n", "nltk_ensemble.train(training)\n", "accuracy = nltk.classify.accuracy(nltk_ensemble, test)\n", "print(\"Voting Classifier model Accuracy: {}\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are done. We can generate the confusion matrix, one of the metrics to check classification performance." ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "text_features, labels = zip(*test)\n", "prediction = nltk_ensemble.classify_many(text_features)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.99 1.00 0.99 1199\n", " 1 0.98 0.91 0.94 194\n", "\n", " accuracy 0.98 1393\n", " macro avg 0.98 0.95 0.97 1393\n", "weighted avg 0.98 0.98 0.98 1393\n", "\n" ] } ], "source": [ "print(classification_report(labels, prediction))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also we can see the confusion matrix as an DataFrame format (more fancy I guess)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predicted
hamspam
actualham11954
spam18176
\n", "
" ], "text/plain": [ " predicted \n", " ham spam\n", "actual ham 1195 4\n", " spam 18 176" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame( confusion_matrix(labels, prediction),\n", " index=[['actual', 'actual'], ['ham', 'spam']],\n", " columns = [['predicted', 'predicted'], ['ham', 'spam']])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "From this post, we made an SMS spam filter from given SMS dataset. In order to do this, we need preprocess text(seen from previous post like toknization, stemming, stop words removing and so on) and feature extraction to make dataset. NLTK is great tool to do it and it helps to train the model with `SklearnClassifier` wrapper. After that, we finally made SMS spam filter with Voting Method(one of ensemble approach) that has almost 98% accuracy." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }