{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KthJSHkGQR7Z" }, "source": [ "# Bag-of-Words Text Classification\n", "\n", "In this tutorial we will show how to build a simple Bag of Words (BoW) text classifier using PyTorch. The classifier is trained on IMDB movie reviews dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from google_drive_downloader import GoogleDriveDownloader as gdd\n", "from torch.utils.data import DataLoader, Dataset\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from tqdm import tqdm, tqdm_notebook" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cpu')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "j8-WlORVQR7n" }, "outputs": [], "source": [ "DATA_PATH = 'data/imdb_reviews.csv'\n", "if not Path(DATA_PATH).is_file():\n", " gdd.download_file_from_google_drive(\n", " file_id='1zfM5E6HvKIe7f3rEt1V2gBpw5QOSSKQz',\n", " dest_path=DATA_PATH,\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | review | \n", "label | \n", "
---|---|---|
22003 | \n", "This movie was excellent from start-to-finish.... | \n", "1 | \n", "
25415 | \n", "Diana Guzman is an angry young woman. Survivin... | \n", "0 | \n", "
51237 | \n", "First off, I agree with quite a bit that escap... | \n", "0 | \n", "
32410 | \n", "\"Seed\" is torture porn...no doubt about it. Bu... | \n", "0 | \n", "
35670 | \n", "Well, basically, the movie blows! It's Blair W... | \n", "0 | \n", "