{
"cells": [
{
"cell_type": "markdown",
"id": "78365427-e6e1-4ba3-ad42-69d9c8620a00",
"metadata": {},
"source": [
"# DistilBERT Classifier as Feature Extractor Using Embetter"
]
},
{
"cell_type": "markdown",
"id": "2c384b90-6074-4a96-b84d-3f1f29e10b7f",
"metadata": {},
"source": [
"In this feature-based approach, we are using the embeddings from a pretrained transormer to train a random forest and logistic regression model in scikit-learn:\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9eea2bca-b66f-4751-8af9-0c1e7631563e",
"metadata": {},
"outputs": [],
"source": [
"# pip install transformers datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "33541bc3-0e07-4808-b6b6-4d7578814ab6",
"metadata": {},
"outputs": [],
"source": [
"# conda install sklearn --yes"
]
},
{
"cell_type": "markdown",
"id": "9933a986-37ae-49a1-9da4-2dd0f54b2d61",
"metadata": {},
"source": [
"In addition, we will be using the [embetter](https://github.com/koaning/embetter) scikit-learn library:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a6ab918a-b0a2-4e39-a316-fa14b1fd3207",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch : 1.12.1\n",
"transformers: 4.23.1\n",
"datasets : 2.6.1\n",
"sklearn : 0.0\n",
"\n",
"conda environment: dl-fundamentals\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark --conda -p torch,transformers,datasets,sklearn"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1e58e25b-8ef2-4f03-87b2-9fea4728aef3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda:0\n"
]
}
],
"source": [
"import torch\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"id": "ecf02113-10e8-41a8-b7cb-029d24cb0593",
"metadata": {
"tags": []
},
"source": [
"# 1 Loading the Dataset"
]
},
{
"cell_type": "markdown",
"id": "1f871d7d-d973-4f34-a014-5961e173280a",
"metadata": {},
"source": [
"The IMDB movie review dataset consists of 50k movie reviews with sentiment label (0: negative, 1: positive)."
]
},
{
"cell_type": "markdown",
"id": "31944168-80db-47ed-95d1-c40459e16343",
"metadata": {},
"source": [
"## 1a) Load from `datasets` Hub"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "53160e55-e19f-40b9-bfe3-12e15bc2835b",
"metadata": {},
"outputs": [],
"source": [
"from datasets import list_datasets, load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4db26996-6616-4872-a894-448c9669c1e4",
"metadata": {},
"outputs": [],
"source": [
"# list_datasets()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "eef7f55e-f6a3-40e7-85ff-270ef11aca70",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset imdb (/home/raschka/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8992eba4dc384426a34944a7f102123f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 25000\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 25000\n",
" })\n",
" unsupervised: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 50000\n",
" })\n",
"})\n"
]
}
],
"source": [
"imdb_data = load_dataset(\"imdb\")\n",
"print(imdb_data)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2ed3fec2-a35a-4dc0-b863-e0c5c1799654",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'text': \"This film is terrible. You don't really need to read this review further. If you are planning on watching it, suffice to say - don't (unless you are studying how not to make a good movie).
The acting is horrendous... serious amateur hour. Throughout the movie I thought that it was interesting that they found someone who speaks and looks like Michael Madsen, only to find out that it is actually him! A new low even for him!!
The plot is terrible. People who claim that it is original or good have probably never seen a decent movie before. Even by the standard of Hollywood action flicks, this is a terrible movie.
Don't watch it!!! Go for a jog instead - at least you won't feel like killing yourself.\",\n",
" 'label': 0}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"imdb_data[\"train\"][99]"
]
},
{
"cell_type": "markdown",
"id": "f7439730-f7fc-4b08-b7e6-b32ff72c9723",
"metadata": {
"tags": []
},
"source": [
"## 1b) Load from local directory"
]
},
{
"cell_type": "markdown",
"id": "ca5f590b-7e8b-492a-b937-33f3ea49036c",
"metadata": {},
"source": [
"The IMDB movie review set can be downloaded from http://ai.stanford.edu/~amaas/data/sentiment/. After downloading the dataset, decompress the files.\n",
"\n",
"A) If you are working with Linux or MacOS X, open a new terminal windowm cd into the download directory and execute\n",
"\n",
" tar -zxf aclImdb_v1.tar.gz\n",
"\n",
"B) If you are working with Windows, download an archiver such as 7Zip to extract the files from the download archive."
]
},
{
"cell_type": "markdown",
"id": "07f3cf11-550d-4544-b9e1-9db656a665a3",
"metadata": {},
"source": [
"C) Use the following code to download and unzip the dataset via Python"
]
},
{
"cell_type": "markdown",
"id": "980e46a7-cfa0-420e-820d-c1425de5ca55",
"metadata": {},
"source": [
"**Download the movie reviews**"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1c883c5d-27ee-4eac-bbd5-c31c70e0aaea",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import tarfile\n",
"import time\n",
"import urllib.request\n",
"\n",
"source = \"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
"target = \"aclImdb_v1.tar.gz\"\n",
"\n",
"if os.path.exists(target):\n",
" os.remove(target)\n",
"\n",
"\n",
"def reporthook(count, block_size, total_size):\n",
" global start_time\n",
" if count == 0:\n",
" start_time = time.time()\n",
" return\n",
" duration = time.time() - start_time\n",
" progress_size = int(count * block_size)\n",
" speed = progress_size / (1024.0**2 * duration)\n",
" percent = count * block_size * 100.0 / total_size\n",
"\n",
" sys.stdout.write(\n",
" f\"\\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB \"\n",
" f\"| {speed:.2f} MB/s | {duration:.2f} sec elapsed\"\n",
" )\n",
" sys.stdout.flush()\n",
"\n",
"\n",
"if not os.path.isdir(\"aclImdb\") and not os.path.isfile(\"aclImdb_v1.tar.gz\"):\n",
" urllib.request.urlretrieve(source, target, reporthook)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "be9ad2f8-7cb2-4c4f-937d-7d3f4315fbae",
"metadata": {},
"outputs": [],
"source": [
"if not os.path.isdir(\"aclImdb\"):\n",
"\n",
" with tarfile.open(target, \"r:gz\") as tar:\n",
" tar.extractall()"
]
},
{
"cell_type": "markdown",
"id": "75b2c486-1a19-4ddb-b073-6958050e525b",
"metadata": {},
"source": [
"**Convert them to a pandas DataFrame and save them as CSV**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bd725b35-2044-4f0b-9516-046f8690dc20",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████| 50000/50000 [00:55<00:00, 893.83it/s]\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from packaging import version\n",
"from tqdm import tqdm\n",
"\n",
"# change the `basepath` to the directory of the\n",
"# unzipped movie dataset\n",
"\n",
"basepath = \"aclImdb\"\n",
"\n",
"labels = {\"pos\": 1, \"neg\": 0}\n",
"\n",
"df = pd.DataFrame()\n",
"\n",
"with tqdm(total=50000) as pbar:\n",
" for s in (\"test\", \"train\"):\n",
" for l in (\"pos\", \"neg\"):\n",
" path = os.path.join(basepath, s, l)\n",
" for file in sorted(os.listdir(path)):\n",
" with open(os.path.join(path, file), \"r\", encoding=\"utf-8\") as infile:\n",
" txt = infile.read()\n",
"\n",
" if version.parse(pd.__version__) >= version.parse(\"1.3.2\"):\n",
" x = pd.DataFrame(\n",
" [[txt, labels[l]]], columns=[\"review\", \"sentiment\"]\n",
" )\n",
" df = pd.concat([df, x], ignore_index=False)\n",
"\n",
" else:\n",
" df = df.append([[txt, labels[l]]], ignore_index=True)\n",
" pbar.update()\n",
"df.columns = [\"text\", \"label\"]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f367f045-0da2-495c-8a24-115863925a15",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"np.random.seed(0)\n",
"df = df.reindex(np.random.permutation(df.index))"
]
},
{
"cell_type": "markdown",
"id": "e163e017-9674-4378-8454-69b8912a09f4",
"metadata": {},
"source": [
"**Basic datasets analysis and sanity checks**"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2dd205ce-be31-410c-9d6b-d39895a9c634",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class distribution:\n"
]
},
{
"data": {
"text/plain": [
"array([25000, 25000])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"Class distribution:\")\n",
"np.bincount(df[\"label\"].values)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "93ea4c92-1e3c-497c-ae0e-e114437fd80c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4, 173.0, 2470)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text_len = df[\"text\"].apply(lambda x: len(x.split()))\n",
"text_len.min(), text_len.median(), text_len.max() "
]
},
{
"cell_type": "markdown",
"id": "9a95465c-35d6-4638-a3c1-1d8dfb3ff60b",
"metadata": {},
"source": [
"**Split data into training, validation, and test sets**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cd5f5326-8062-407f-8679-d701d9e2c169",
"metadata": {},
"outputs": [],
"source": [
"df_shuffled = df.sample(frac=1, random_state=1).reset_index()\n",
"\n",
"df_train = df_shuffled.iloc[:35_000]\n",
"df_val = df_shuffled.iloc[35_000:40_000]\n",
"df_test = df_shuffled.iloc[40_000:]\n",
"\n",
"df_train.to_csv(\"train.csv\", index=False, encoding=\"utf-8\")\n",
"df_val.to_csv(\"validation.csv\", index=False, encoding=\"utf-8\")\n",
"df_test.to_csv(\"test.csv\", index=False, encoding=\"utf-8\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d4cc8710-e962-4513-ad70-01ddbe535e0a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
| \n", " | index | \n", "text | \n", "label | \n", "
|---|---|---|---|
| 0 | \n", "0 | \n", "When we started watching this series on cable,... | \n", "1 | \n", "
| 1 | \n", "0 | \n", "Steve Biko was a black activist who tried to r... | \n", "1 | \n", "
| 2 | \n", "0 | \n", "My short comment for this flick is go pick it ... | \n", "1 | \n", "
| 3 | \n", "0 | \n", "As a serious horror fan, I get that certain ma... | \n", "0 | \n", "
| 4 | \n", "0 | \n", "Robert Cummings, Laraine Day and Jean Muir sta... | \n", "1 | \n", "