{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KthJSHkGQR7Z" }, "source": [ "# Bag-of-Words Text Classification\n", "\n", "In this tutorial we will show how to build a simple Bag of Words (BoW) text classifier using PyTorch. The classifier is trained on IMDB movie reviews dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from google_drive_downloader import GoogleDriveDownloader as gdd\n", "from torch.utils.data import DataLoader, Dataset\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from tqdm import tqdm, tqdm_notebook" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cpu')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "j8-WlORVQR7n" }, "outputs": [], "source": [ "DATA_PATH = 'data/imdb_reviews.csv'\n", "if not Path(DATA_PATH).is_file():\n", " gdd.download_file_from_google_drive(\n", " file_id='1zfM5E6HvKIe7f3rEt1V2gBpw5QOSSKQz',\n", " dest_path=DATA_PATH,\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
reviewlabel
22003This movie was excellent from start-to-finish....1
25415Diana Guzman is an angry young woman. Survivin...0
51237First off, I agree with quite a bit that escap...0
32410\"Seed\" is torture porn...no doubt about it. Bu...0
35670Well, basically, the movie blows! It's Blair W...0
\n", "
" ], "text/plain": [ " review label\n", "22003 This movie was excellent from start-to-finish.... 1\n", "25415 Diana Guzman is an angry young woman. Survivin... 0\n", "51237 First off, I agree with quite a bit that escap... 0\n", "32410 \"Seed\" is torture porn...no doubt about it. Bu... 0\n", "35670 Well, basically, the movie blows! It's Blair W... 0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# View some example records\n", "pd.read_csv(DATA_PATH).sample(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bag-of-Words Representation\n", "\n", "![](images/bow_diagram.png)\n", "\n", "So the final bag-of-words vector for `['the', 'gray', 'cat', 'sat', 'on', 'the', 'gray', 'mat']` is `[0, 1, 1, 2, 2, 1, 0, 1]`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "GHCoa8R_QR8W" }, "outputs": [], "source": [ "class Sequences(Dataset):\n", " def __init__(self, path):\n", " df = pd.read_csv(path)\n", " self.vectorizer = CountVectorizer(stop_words='english', max_df=0.99, min_df=0.005)\n", " self.sequences = self.vectorizer.fit_transform(df.review.tolist())\n", " self.labels = df.label.tolist()\n", " self.token2idx = self.vectorizer.vocabulary_\n", " self.idx2token = {idx: token for token, idx in self.token2idx.items()}\n", " \n", " def __getitem__(self, i):\n", " return self.sequences[i, :].toarray(), self.labels[i]\n", " \n", " def __len__(self):\n", " return self.sequences.shape[0]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 3028)\n" ] } ], "source": [ "dataset = Sequences(DATA_PATH)\n", "train_loader = DataLoader(dataset, batch_size=4096)\n", "\n", "print(dataset[5][0].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Definition\n", "\n", "![](images/bow_training_diagram.png)\n", "\n", "Layer 1 affine: $$x_1 = W_1 X + b_1$$\n", "Layer 1 activation: $$h_1 = \\textrm{Relu}(x_1)$$\n", "Layer 2 affine: $$x_2 = W_2 h_1 + b_2$$\n", "output: $$p = \\sigma(x_2)$$\n", "Loss: $$L = −(ylog(p)+(1−y)log(1−p))$$\n", "Gradient: \n", "$$\\frac{\\partial }{\\partial W_1}L(W_1, b_1, W_2, b_2) = \\frac{\\partial L}{\\partial p}\\frac{\\partial p}{\\partial x_2}\\frac{\\partial x_2}{\\partial h_1}\\frac{\\partial h_1}{\\partial x_1}\\frac{\\partial x_1}{\\partial W_1}$$\n", "\n", "Parameter update:\n", "$$W_1 = W_1 - \\alpha \\frac{\\partial L}{\\partial W_1}$$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "-eKgEFZOQR8s" }, "outputs": [], "source": [ "class BagOfWordsClassifier(nn.Module):\n", " def __init__(self, vocab_size, hidden1, hidden2):\n", " super(BagOfWordsClassifier, self).__init__()\n", " self.fc1 = nn.Linear(vocab_size, hidden1)\n", " self.fc2 = nn.Linear(hidden1, hidden2)\n", " self.fc3 = nn.Linear(hidden2, 1)\n", " \n", " def forward(self, inputs):\n", " x = F.relu(self.fc1(inputs.squeeze(1).float()))\n", " x = F.relu(self.fc2(x))\n", " return self.fc3(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 0 }, "colab_type": "code", "id": "M0Q9ze4EQR8t", "outputId": "54fd71a7-c105-4ce4-a016-949dccfa0947" }, "outputs": [ { "data": { "text/plain": [ "BagOfWordsClassifier(\n", " (fc1): Linear(in_features=3028, out_features=128, bias=True)\n", " (fc2): Linear(in_features=128, out_features=64, bias=True)\n", " (fc3): Linear(in_features=64, out_features=1, bias=True)\n", ")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = BagOfWordsClassifier(len(dataset.token2idx), 128, 64)\n", "model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "H2emTYNqE4yC" }, "outputs": [], "source": [ "criterion = nn.BCEWithLogitsLoss()\n", "optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.001)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #1\tTrain Loss: 0.711\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #2\tTrain Loss: 0.676\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #3\tTrain Loss: 0.640\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #4\tTrain Loss: 0.568\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #5\tTrain Loss: 0.484\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #6\tTrain Loss: 0.413\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #7\tTrain Loss: 0.363\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #8\tTrain Loss: 0.331\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #9\tTrain Loss: 0.309\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Epoch #10\tTrain Loss: 0.294\n" ] } ], "source": [ "model.train()\n", "train_losses = []\n", "for epoch in range(10):\n", " progress_bar = tqdm_notebook(train_loader, leave=False)\n", " losses = []\n", " total = 0\n", " for inputs, target in progress_bar:\n", " model.zero_grad()\n", "\n", " output = model(inputs)\n", " loss = criterion(output.squeeze(), target.float())\n", " \n", " loss.backward()\n", " \n", " nn.utils.clip_grad_norm_(model.parameters(), 3)\n", "\n", " optimizer.step()\n", " \n", " progress_bar.set_description(f'Loss: {loss.item():.3f}')\n", " \n", " losses.append(loss.item())\n", " total += 1\n", " \n", " epoch_loss = sum(losses) / total\n", " train_losses.append(epoch_loss)\n", " \n", " tqdm.write(f'Epoch #{epoch + 1}\\tTrain Loss: {epoch_loss:.3f}')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def predict_sentiment(text):\n", " model.eval()\n", " with torch.no_grad():\n", " test_vector = torch.LongTensor(dataset.vectorizer.transform([text]).toarray())\n", "\n", " output = model(test_vector)\n", " prediction = torch.sigmoid(output).item()\n", "\n", " if prediction > 0.5:\n", " print(f'{prediction:0.3}: Positive sentiment')\n", " else:\n", " print(f'{prediction:0.3}: Negative sentiment')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyzing reviews for \"Cool Cat Saves the Kids\"\n", "\n", "![](https://m.media-amazon.com/images/M/MV5BNzE1OTY3OTk5M15BMl5BanBnXkFtZTgwODE0Mjc1NDE@._V1_UY268_CR11,0,182,268_AL_.jpg)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.000917: Negative sentiment\n" ] } ], "source": [ "test_text = \"\"\"\n", "This poor excuse for a movie is terrible. It has been 'so good it's bad' for a\n", "while, and the high ratings are a good form of sarcasm, I have to admit. But\n", "now it has to stop. Technically inept, spoon-feeding mundane messages with the\n", "artistic weight of an eighties' commercial, hypocritical to say the least, it\n", "deserves to fall into oblivion. Mr. Derek, I hope you realize you are like that\n", "weird friend that everybody know is lame, but out of kindness and Christian\n", "duty is treated like he's cool or something. That works if you are a good\n", "decent human being, not if you are a horrible arrogant bully like you are. Yes,\n", "Mr. 'Daddy' Derek will end on the history books of the internet for being a\n", "delusional sour old man who thinks to be a good example for kids, but actually\n", "has a poster of Kim Jong-Un in his closet. Destroy this movie if you all have a\n", "conscience, as I hope IHE and all other youtube channel force-closed by Derek\n", "out of SPITE would destroy him in the courts.This poor excuse for a movie is\n", "terrible. It has been 'so good it's bad' for a while, and the high ratings are\n", "a good form of sarcasm, I have to admit. But now it has to stop. Technically\n", "inept, spoon-feeding mundane messages with the artistic weight of an eighties'\n", "commercial, hypocritical to say the least, it deserves to fall into oblivion.\n", "Mr. Derek, I hope you realize you are like that weird friend that everybody\n", "know is lame, but out of kindness and Christian duty is treated like he's cool\n", "or something. That works if you are a good decent human being, not if you are a\n", "horrible arrogant bully like you are. Yes, Mr. 'Daddy' Derek will end on the\n", "history books of the internet for being a delusional sour old man who thinks to\n", "be a good example for kids, but actually has a poster of Kim Jong-Un in his\n", "closet. Destroy this movie if you all have a conscience, as I hope IHE and all\n", "other youtube channel force-closed by Derek out of SPITE would destroy him in\n", "the courts.\n", "\"\"\"\n", "predict_sentiment(test_text)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.946: Positive sentiment\n" ] } ], "source": [ "test_text = \"\"\"\n", "Cool Cat Saves The Kids is a symbolic masterpiece directed by Derek Savage that\n", "is not only satirical in the way it makes fun of the media and politics, but in\n", "the way in questions as how we humans live life and how society tells us to\n", "live life.\n", "\n", "Before I get into those details, I wanna talk about the special effects in this\n", "film. They are ASTONISHING, and it shocks me that Cool Cat Saves The Kids got\n", "snubbed by the Oscars for Best Special Effects. This film makes 2001 look like\n", "garbage, and the directing in this film makes Stanley Kubrick look like the\n", "worst director ever. You know what other film did that? Birdemic: Shock and\n", "Terror. Both of these films are masterpieces, but if I had to choose my\n", "favorite out of the 2, I would have to go with Cool Cat Saves The Kids. It is\n", "now my 10th favorite film of all time.\n", "\n", "Now, lets get into the symbolism: So you might be asking yourself, Why is Cool\n", "Cat Orange? Well, I can easily explain. Orange is a color. Orange is also a\n", "fruit, and its a very good fruit. You know what else is good? Good behavior.\n", "What behavior does Cool Cat have? He has good behavior. This cannot be a\n", "coincidence, since cool cat has good behavior in the film.\n", "\n", "Now, why is Butch The Bully fat? Well, fat means your wide. You wanna know who\n", "was wide? Hitler. Nuff said this cannot be a coincidence.\n", "\n", "Why does Erik Estrada suspect Butch The Bully to be a bully? Well look at it\n", "this way. What color of a shirt was Butchy wearing when he walks into the area?\n", "I don't know, its looks like dark purple/dark blue. Why rhymes with dark? Mark.\n", "Mark is that guy from the Room. The Room is the best movie of all time. What is\n", "the opposite of best? Worst. This is how Erik knew Butch was a bully.\n", "\n", "and finally, how come Vivica A. Fox isn't having a successful career after\n", "making Kill Bill.\n", "\n", "I actually can't answer that question.\n", "\n", "Well thanks for reading my review.\n", "\"\"\"\n", "predict_sentiment(test_text)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.96: Positive sentiment\n" ] } ], "source": [ "test_text = \"\"\"\n", "Don't let any bullies out there try and shape your judgment on this gem of a\n", "title.\n", "\n", "Some people really don't have anything better to do, except trash a great movie\n", "with annoying 1-star votes and spread lies on the Internet about how \"dumb\"\n", "Cool Cat is.\n", "\n", "I wouldn't be surprised to learn if much of the unwarranted negativity hurled\n", "at this movie is coming from people who haven't even watched this movie for\n", "themselves in the first place. Those people are no worse than the Butch the\n", "Bully, the film's repulsive antagonist.\n", "\n", "As it just so happens, one of the main points of \"Cool Cat Saves the Kids\" is\n", "in addressing the attitudes of mean naysayers who try to demean others who\n", "strive to bring good attitudes and fun vibes into people's lives. The message\n", "to be learned here is that if one is friendly and good to others, the world is\n", "friendly and good to one in return, and that is cool. Conversely, if one is\n", "miserable and leaving 1-star votes on IMDb, one is alone and doesn't have any\n", "friends at all. Ain't that the truth?\n", "\n", "The world has uncovered a great, new, young filmmaking talent in \"Cool Cat\"\n", "creator Derek Savage, and I sure hope that this is only the first of many\n", "amazing films and stories that the world has yet to appreciate.\n", "\n", "If you are a cool person who likes to have lots of fun, I guarantee that this\n", "is a movie with charm that will uplift your spirits and reaffirm your positive\n", "attitudes towards life.\n", "\"\"\"\n", "predict_sentiment(test_text)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.00168: Negative sentiment\n" ] } ], "source": [ "test_text = \"\"\"\n", "What the heck is this ? There is not one redeeming quality about this terrible\n", "and very poorly done \"movie\". I can't even say that it's a \"so bad it's good\n", "movie\".It is undeniably pointless to address all the things wrong here but\n", "unfortunately even the \"life lessons\" about bullies and stuff like this are so\n", "wrong and terrible that no kid should hear them.The costume is also horrible\n", "and the acting...just unbelievable.No effort whatsoever was put into this thing\n", "and it clearly shows,I have no idea what were they thinking or who was it even\n", "meant for. I feel violated after watching this trash and I deeply recommend you\n", "stay as far away as possible.This is certainly one of the worst pieces of c***\n", "I have ever seen.\n", "\"\"\"\n", "predict_sentiment(test_text)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "1_BoW_text_classification.ipynb", "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 1 }