{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Naive Bayes from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import re\n", "import glob\n", "from pathlib import Path\n", "from random import shuffle\n", "from math import exp, log\n", "from collections import defaultdict, Counter\n", "from typing import NamedTuple, List, Set, Tuple" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Ensure that we have a `data` directory we use to store downloaded data\n", "!mkdir -p data\n", "data_dir: Path = Path('data')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2020-02-09 12:03:06-- http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz\n", "Resolving nlp.cs.aueb.gr (nlp.cs.aueb.gr)... 195.251.248.252\n", "Connecting to nlp.cs.aueb.gr (nlp.cs.aueb.gr)|195.251.248.252|:80... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1802573 (1.7M) [application/x-gzip]\n", "Saving to: ‘data/enron1.tar.gz’\n", "\n", "enron1.tar.gz 100%[===================>] 1.72M 920KB/s in 1.9s \n", "\n", "2020-02-09 12:03:08 (920 KB/s) - ‘data/enron1.tar.gz’ saved [1802573/1802573]\n", "\n" ] } ], "source": [ "# We're using the \"Enron Spam\" data set\n", "!wget -nc -P data http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "!tar -xzf data/enron1.tar.gz -C data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# The data set has 2 directories: One for `spam` messages, one for `ham` messages\n", "spam_data_path: Path = data_dir / 'enron1' / 'spam'\n", "ham_data_path: Path = data_dir / 'enron1' / 'ham'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Our data container for `spam` and `ham` messages\n", "class Message(NamedTuple):\n", " text: str\n", " is_spam: bool" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['data/enron1/spam/4743.2005-06-25.GP.spam.txt',\n", " 'data/enron1/spam/1309.2004-06-08.GP.spam.txt',\n", " 'data/enron1/spam/0726.2004-03-26.GP.spam.txt',\n", " 'data/enron1/spam/0202.2004-01-13.GP.spam.txt',\n", " 'data/enron1/spam/3988.2005-03-06.GP.spam.txt']" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Globbing for all the `.txt` files in both (`spam` and `ham`) directories\n", "spam_message_paths: List[str] = glob.glob(str(spam_data_path / '*.txt'))\n", "ham_message_paths: List[str] = glob.glob(str(ham_data_path / '*.txt'))\n", "\n", "message_paths: List[str] = spam_message_paths + ham_message_paths\n", "message_paths[:5]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# The list which eventually contains all the parsed Enron `spam` and `ham` messages\n", "messages: List[Message] = []" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Open every file individually, turn it into a `Message` and append it to our `messages` list\n", "for path in message_paths:\n", " with open(path, errors='ignore') as file:\n", " is_spam: bool = True if 'spam' in path else False\n", " # We're only interested in the subject for the time being \n", " text: str = file.readline().replace('Subject:', '').strip()\n", " messages.append(Message(text, is_spam))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Message(text='january production estimate', is_spam=False),\n", " Message(text='re : your code # 5 g 6878', is_spam=True),\n", " Message(text='account # 20367 s tue , 28 jun 2005 11 : 41 : 41 - 0800', is_spam=True),\n", " Message(text='congratulations', is_spam=True),\n", " Message(text='fw : hpl imbalance payback', is_spam=False)]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shuffle(messages)\n", "messages[:5]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5172" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(messages)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Given a string, normalize and extract all words with length greater than 2\n", "def tokenize(text: str) -> Set[str]:\n", " words: List[str] = []\n", " for word in re.findall(r'[A-Za-z0-9\\']+', text):\n", " if len(word) >= 2:\n", " words.append(word.lower())\n", " return set(words)\n", "\n", "assert tokenize('Is this a text? If so, Tokenize this text!...') == {'is', 'this', 'text', 'if', 'so', 'tokenize'}" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'estimate', 'january', 'production'}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenize(messages[0].text)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Split the list of messages into a `train` and `test` set (defaults to 80/20 train/test split)\n", "def train_test_split(messages: List[Message], pct=0.8) -> Tuple[List[Message], List[Message]]:\n", " shuffle(messages)\n", " num_train = int(round(len(messages) * pct, 0))\n", " return messages[:num_train], messages[num_train:]\n", "\n", "assert len(train_test_split(messages)[0]) + len(train_test_split(messages)[1]) == len(messages)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# The Naive Bayes classifier\n", "class NaiveBayes:\n", " def __init__(self, k=1) -> None:\n", " # `k` is the smoothening factor\n", " self._k: int = k\n", " self._num_spam_messages: int = 0\n", " self._num_ham_messages: int = 0\n", " self._num_word_in_spam: Dict[int] = defaultdict(int)\n", " self._num_word_in_ham: Dict[int] = defaultdict(int)\n", " self._spam_words: Set[str] = set()\n", " self._ham_words: Set[str] = set()\n", " self._words: Set[str] = set()\n", "\n", " # Iterate through the given messages and gather the necessary statistics\n", " def train(self, messages: List[Message]) -> None:\n", " msg: Message\n", " token: str\n", " for msg in messages:\n", " tokens: Set[str] = tokenize(msg.text)\n", " self._words.update(tokens)\n", " if msg.is_spam:\n", " self._num_spam_messages += 1\n", " self._spam_words.update(tokens)\n", " for token in tokens:\n", " self._num_word_in_spam[token] += 1\n", " else:\n", " self._num_ham_messages += 1\n", " self._ham_words.update(tokens)\n", " for token in tokens:\n", " self._num_word_in_ham[token] += 1 \n", " \n", " # Probability of `word` being spam\n", " def _p_word_spam(self, word: str) -> float:\n", " return (self._k + self._num_word_in_spam[word]) / ((2 * self._k) + self._num_spam_messages)\n", " \n", " # Probability of `word` being ham\n", " def _p_word_ham(self, word: str) -> float:\n", " return (self._k + self._num_word_in_ham[word]) / ((2 * self._k) + self._num_ham_messages)\n", " \n", " # Given a `text`, how likely is it spam?\n", " def predict(self, text: str) -> float:\n", " text_words: Set[str] = tokenize(text)\n", " log_p_spam: float = 0.0\n", " log_p_ham: float = 0.0\n", "\n", " for word in self._words:\n", " p_spam: float = self._p_word_spam(word)\n", " p_ham: float = self._p_word_ham(word)\n", " if word in text_words:\n", " log_p_spam += log(p_spam)\n", " log_p_ham += log(p_ham)\n", " else:\n", " log_p_spam += log(1 - p_spam)\n", " log_p_ham += log(1 - p_ham)\n", "\n", " p_if_spam: float = exp(log_p_spam)\n", " p_if_ham: float = exp(log_p_ham)\n", " return p_if_spam / (p_if_spam + p_if_ham)\n", "\n", "# Tests\n", "def test_naive_bayes():\n", " messages: List[Message] = [\n", " Message('Spam message', is_spam=True),\n", " Message('Ham message', is_spam=False),\n", " Message('Ham message about Spam', is_spam=False)]\n", " \n", " nb: NaiveBayes = NaiveBayes()\n", " nb.train(messages)\n", " \n", " assert nb._num_spam_messages == 1\n", " assert nb._num_ham_messages == 2\n", " assert nb._spam_words == {'spam', 'message'}\n", " assert nb._ham_words == {'ham', 'message', 'about', 'spam'}\n", " assert nb._num_word_in_spam == {'spam': 1, 'message': 1}\n", " assert nb._num_word_in_ham == {'ham': 2, 'message': 2, 'about': 1, 'spam': 1}\n", " assert nb._words == {'spam', 'message', 'ham', 'about'}\n", "\n", " # Our test message\n", " text: str = 'A spam message'\n", " \n", " # Reminder: The `_words` we iterater over are: {'spam', 'message', 'ham', 'about'}\n", " \n", " # Calculate how spammy the `text` might be\n", " p_if_spam: float = exp(sum([\n", " log( (1 + 1) / ((2 * 1) + 1)), # `spam` (also in `text`)\n", " log( (1 + 1) / ((2 * 1) + 1)), # `message` (also in `text`)\n", " log(1 - ((1 + 0) / ((2 * 1) + 1))), # `ham` (NOT in `text`)\n", " log(1 - ((1 + 0) / ((2 * 1) + 1))), # `about` (NOT in `text`)\n", " ]))\n", " \n", " # Calculate how hammy the `text` might be\n", " p_if_ham: float = exp(sum([\n", " log( (1 + 1) / ((2 * 1) + 2)), # `spam` (also in `text`)\n", " log( (1 + 2) / ((2 * 1) + 2)), # `message` (also in `text`)\n", " log(1 - ((1 + 2) / ((2 * 1) + 2))), # `ham` (NOT in `text`)\n", " log(1 - ((1 + 1) / ((2 * 1) + 2))), # `about` (NOT in `text`)\n", " ]))\n", " \n", " p_spam: float = p_if_spam / (p_if_spam + p_if_ham)\n", " \n", " assert p_spam == nb.predict(text)\n", "\n", "test_naive_bayes()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "train: List[Message]\n", "test: List[Message]\n", "\n", "# Splitting our Enron messages into a `train` and `test` set\n", "train, test = train_test_split(messages)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Spam messages in training data: 1227\n", "Ham messages in training data: 2911\n", "Most spammy words: [('you', 115), ('the', 104), ('your', 104), ('for', 86), ('to', 83), ('re', 81), ('on', 56), ('and', 51), ('get', 48), ('is', 48), ('in', 43), ('with', 40), ('of', 38), ('it', 35), ('at', 35), ('online', 34), ('all', 33), ('from', 33), ('this', 32), ('new', 31)]\n" ] } ], "source": [ "# Train our Naive Bayes classifier with the `train` set\n", "nb: NaiveBayes = NaiveBayes()\n", "nb.train(train)\n", "\n", "print(f'Spam messages in training data: {nb._num_spam_messages}')\n", "print(f'Ham messages in training data: {nb._num_ham_messages}')\n", "print(f'Most spammy words: {Counter(nb._num_word_in_spam).most_common(20)}')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Message(text=\"a witch . i don ' t\", is_spam=True),\n", " Message(text='active and strong', is_spam=True),\n", " Message(text='get great prices on medications', is_spam=True),\n", " Message(text='', is_spam=True),\n", " Message(text='popular software at low low prices . misunderstand developments', is_spam=True)]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Grabbing all the spam messages from our `test` set\n", "spam_messages: List[Message] = [item for item in test if item.is_spam]\n", "spam_messages[:5]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicting likelihood of \"get your hand clock repliacs todday carson\" being spam.\n" ] }, { "data": { "text/plain": [ "0.9884313222593173" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Using our trained Naive Bayes classifier to classify a spam message\n", "message: str = spam_messages[10].text\n", " \n", "print(f'Predicting likelihood of \"{message}\" being spam.')\n", "nb.predict(message)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Message(text='new update for buybacks', is_spam=False),\n", " Message(text='enron and blockbuster to launch entertainment on - demand service', is_spam=False),\n", " Message(text='re : astros web site comments', is_spam=False),\n", " Message(text='re : formosa meter # : 1000', is_spam=False),\n", " Message(text='re : deal extension for 11 / 21 / 2000 for 98 - 439', is_spam=False)]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Grabbing all the ham messages from our `test` set\n", "ham_messages: List[Message] = [item for item in test if not item.is_spam]\n", "ham_messages[:5]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicting likelihood of \"associate & analyst mid - year 2001 prc process\" being spam.\n" ] }, { "data": { "text/plain": [ "5.3089147140900964e-05" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Using our trained Naive Bayes classifier to classify a ham message\n", "message: str = ham_messages[10].text\n", "\n", "print(f'Predicting likelihood of \"{text}\" being spam.')\n", "nb.predict(message)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }