{ "cells": [ { "cell_type": "markdown", "id": "78365427-e6e1-4ba3-ad42-69d9c8620a00", "metadata": {}, "source": [ "# DistilBERT Classifier as Feature Extractor Using Embetter" ] }, { "cell_type": "markdown", "id": "2c384b90-6074-4a96-b84d-3f1f29e10b7f", "metadata": {}, "source": [ "In this feature-based approach, we are using the embeddings from a pretrained transormer to train a random forest and logistic regression model in scikit-learn:\n", "\n", "![](figures/feature-extractor.jpeg)" ] }, { "cell_type": "code", "execution_count": 1, "id": "9eea2bca-b66f-4751-8af9-0c1e7631563e", "metadata": {}, "outputs": [], "source": [ "# pip install transformers datasets" ] }, { "cell_type": "code", "execution_count": 2, "id": "33541bc3-0e07-4808-b6b6-4d7578814ab6", "metadata": {}, "outputs": [], "source": [ "# conda install sklearn --yes" ] }, { "cell_type": "markdown", "id": "9933a986-37ae-49a1-9da4-2dd0f54b2d61", "metadata": {}, "source": [ "In addition, we will be using the [embetter](https://github.com/koaning/embetter) scikit-learn library:" ] }, { "cell_type": "code", "execution_count": 3, "id": "a6ab918a-b0a2-4e39-a316-fa14b1fd3207", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch : 1.12.1\n", "transformers: 4.23.1\n", "datasets : 2.6.1\n", "sklearn : 0.0\n", "\n", "conda environment: dl-fundamentals\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark --conda -p torch,transformers,datasets,sklearn" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e58e25b-8ef2-4f03-87b2-9fea4728aef3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] } ], "source": [ "import torch\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "id": "ecf02113-10e8-41a8-b7cb-029d24cb0593", "metadata": { "tags": [] }, "source": [ "# 1 Loading the Dataset" ] }, { "cell_type": "markdown", "id": "1f871d7d-d973-4f34-a014-5961e173280a", "metadata": {}, "source": [ "The IMDB movie review dataset consists of 50k movie reviews with sentiment label (0: negative, 1: positive)." ] }, { "cell_type": "markdown", "id": "31944168-80db-47ed-95d1-c40459e16343", "metadata": {}, "source": [ "## 1a) Load from `datasets` Hub" ] }, { "cell_type": "code", "execution_count": 5, "id": "53160e55-e19f-40b9-bfe3-12e15bc2835b", "metadata": {}, "outputs": [], "source": [ "from datasets import list_datasets, load_dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "4db26996-6616-4872-a894-448c9669c1e4", "metadata": {}, "outputs": [], "source": [ "# list_datasets()" ] }, { "cell_type": "code", "execution_count": 7, "id": "eef7f55e-f6a3-40e7-85ff-270ef11aca70", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset imdb (/home/raschka/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8992eba4dc384426a34944a7f102123f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00
The acting is horrendous... serious amateur hour. Throughout the movie I thought that it was interesting that they found someone who speaks and looks like Michael Madsen, only to find out that it is actually him! A new low even for him!!

The plot is terrible. People who claim that it is original or good have probably never seen a decent movie before. Even by the standard of Hollywood action flicks, this is a terrible movie.

Don't watch it!!! Go for a jog instead - at least you won't feel like killing yourself.\",\n", " 'label': 0}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "imdb_data[\"train\"][99]" ] }, { "cell_type": "markdown", "id": "f7439730-f7fc-4b08-b7e6-b32ff72c9723", "metadata": { "tags": [] }, "source": [ "## 1b) Load from local directory" ] }, { "cell_type": "markdown", "id": "ca5f590b-7e8b-492a-b937-33f3ea49036c", "metadata": {}, "source": [ "The IMDB movie review set can be downloaded from http://ai.stanford.edu/~amaas/data/sentiment/. After downloading the dataset, decompress the files.\n", "\n", "A) If you are working with Linux or MacOS X, open a new terminal windowm cd into the download directory and execute\n", "\n", " tar -zxf aclImdb_v1.tar.gz\n", "\n", "B) If you are working with Windows, download an archiver such as 7Zip to extract the files from the download archive." ] }, { "cell_type": "markdown", "id": "07f3cf11-550d-4544-b9e1-9db656a665a3", "metadata": {}, "source": [ "C) Use the following code to download and unzip the dataset via Python" ] }, { "cell_type": "markdown", "id": "980e46a7-cfa0-420e-820d-c1425de5ca55", "metadata": {}, "source": [ "**Download the movie reviews**" ] }, { "cell_type": "code", "execution_count": 9, "id": "1c883c5d-27ee-4eac-bbd5-c31c70e0aaea", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import tarfile\n", "import time\n", "import urllib.request\n", "\n", "source = \"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n", "target = \"aclImdb_v1.tar.gz\"\n", "\n", "if os.path.exists(target):\n", " os.remove(target)\n", "\n", "\n", "def reporthook(count, block_size, total_size):\n", " global start_time\n", " if count == 0:\n", " start_time = time.time()\n", " return\n", " duration = time.time() - start_time\n", " progress_size = int(count * block_size)\n", " speed = progress_size / (1024.0**2 * duration)\n", " percent = count * block_size * 100.0 / total_size\n", "\n", " sys.stdout.write(\n", " f\"\\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB \"\n", " f\"| {speed:.2f} MB/s | {duration:.2f} sec elapsed\"\n", " )\n", " sys.stdout.flush()\n", "\n", "\n", "if not os.path.isdir(\"aclImdb\") and not os.path.isfile(\"aclImdb_v1.tar.gz\"):\n", " urllib.request.urlretrieve(source, target, reporthook)" ] }, { "cell_type": "code", "execution_count": 10, "id": "be9ad2f8-7cb2-4c4f-937d-7d3f4315fbae", "metadata": {}, "outputs": [], "source": [ "if not os.path.isdir(\"aclImdb\"):\n", "\n", " with tarfile.open(target, \"r:gz\") as tar:\n", " tar.extractall()" ] }, { "cell_type": "markdown", "id": "75b2c486-1a19-4ddb-b073-6958050e525b", "metadata": {}, "source": [ "**Convert them to a pandas DataFrame and save them as CSV**" ] }, { "cell_type": "code", "execution_count": 11, "id": "bd725b35-2044-4f0b-9516-046f8690dc20", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████| 50000/50000 [00:55<00:00, 893.83it/s]\n" ] } ], "source": [ "import os\n", "import sys\n", "\n", "import numpy as np\n", "import pandas as pd\n", "from packaging import version\n", "from tqdm import tqdm\n", "\n", "# change the `basepath` to the directory of the\n", "# unzipped movie dataset\n", "\n", "basepath = \"aclImdb\"\n", "\n", "labels = {\"pos\": 1, \"neg\": 0}\n", "\n", "df = pd.DataFrame()\n", "\n", "with tqdm(total=50000) as pbar:\n", " for s in (\"test\", \"train\"):\n", " for l in (\"pos\", \"neg\"):\n", " path = os.path.join(basepath, s, l)\n", " for file in sorted(os.listdir(path)):\n", " with open(os.path.join(path, file), \"r\", encoding=\"utf-8\") as infile:\n", " txt = infile.read()\n", "\n", " if version.parse(pd.__version__) >= version.parse(\"1.3.2\"):\n", " x = pd.DataFrame(\n", " [[txt, labels[l]]], columns=[\"review\", \"sentiment\"]\n", " )\n", " df = pd.concat([df, x], ignore_index=False)\n", "\n", " else:\n", " df = df.append([[txt, labels[l]]], ignore_index=True)\n", " pbar.update()\n", "df.columns = [\"text\", \"label\"]" ] }, { "cell_type": "code", "execution_count": 12, "id": "f367f045-0da2-495c-8a24-115863925a15", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "np.random.seed(0)\n", "df = df.reindex(np.random.permutation(df.index))" ] }, { "cell_type": "markdown", "id": "e163e017-9674-4378-8454-69b8912a09f4", "metadata": {}, "source": [ "**Basic datasets analysis and sanity checks**" ] }, { "cell_type": "code", "execution_count": 13, "id": "2dd205ce-be31-410c-9d6b-d39895a9c634", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class distribution:\n" ] }, { "data": { "text/plain": [ "array([25000, 25000])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Class distribution:\")\n", "np.bincount(df[\"label\"].values)" ] }, { "cell_type": "code", "execution_count": 14, "id": "93ea4c92-1e3c-497c-ae0e-e114437fd80c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, 173.0, 2470)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text_len = df[\"text\"].apply(lambda x: len(x.split()))\n", "text_len.min(), text_len.median(), text_len.max() " ] }, { "cell_type": "markdown", "id": "9a95465c-35d6-4638-a3c1-1d8dfb3ff60b", "metadata": {}, "source": [ "**Split data into training, validation, and test sets**" ] }, { "cell_type": "code", "execution_count": 15, "id": "cd5f5326-8062-407f-8679-d701d9e2c169", "metadata": {}, "outputs": [], "source": [ "df_shuffled = df.sample(frac=1, random_state=1).reset_index()\n", "\n", "df_train = df_shuffled.iloc[:35_000]\n", "df_val = df_shuffled.iloc[35_000:40_000]\n", "df_test = df_shuffled.iloc[40_000:]\n", "\n", "df_train.to_csv(\"train.csv\", index=False, encoding=\"utf-8\")\n", "df_val.to_csv(\"validation.csv\", index=False, encoding=\"utf-8\")\n", "df_test.to_csv(\"test.csv\", index=False, encoding=\"utf-8\")" ] }, { "cell_type": "code", "execution_count": 16, "id": "d4cc8710-e962-4513-ad70-01ddbe535e0a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indextextlabel
00When we started watching this series on cable,...1
10Steve Biko was a black activist who tried to r...1
20My short comment for this flick is go pick it ...1
30As a serious horror fan, I get that certain ma...0
40Robert Cummings, Laraine Day and Jean Muir sta...1
\n", "
" ], "text/plain": [ " index text label\n", "0 0 When we started watching this series on cable,... 1\n", "1 0 Steve Biko was a black activist who tried to r... 1\n", "2 0 My short comment for this flick is go pick it ... 1\n", "3 0 As a serious horror fan, I get that certain ma... 0\n", "4 0 Robert Cummings, Laraine Day and Jean Muir sta... 1" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "id": "b9185130-6530-44aa-8579-1505227e1be3", "metadata": {}, "source": [ "# 2 Train Model on Embeddings (Extracted Features)" ] }, { "cell_type": "code", "execution_count": 17, "id": "0c7819bb-6778-4c47-a061-e99c9542385c", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.pipeline import make_pipeline \n", "from sklearn.linear_model import LogisticRegression\n", "\n", "from embetter.text import SentenceEncoder\n", "\n", "classifier = make_pipeline(\n", " SentenceEncoder(\"distiluse-base-multilingual-cased-v2\"),\n", " LogisticRegression()\n", ")\n", "\n", "classifier.fit(df_train[\"text\"].values, df_train[\"label\"].values);" ] }, { "cell_type": "code", "execution_count": 18, "id": "d569d831-ea20-4cb1-a454-34e88719a6af", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier.score(df_val[\"text\"].values, df_val[\"label\"].values)" ] }, { "cell_type": "code", "execution_count": 19, "id": "09a31990-004c-423f-ab3e-79720e491c4c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8032" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier.score(df_test[\"text\"].values, df_test[\"label\"].values)" ] }, { "cell_type": "code", "execution_count": null, "id": "b095a33d-6a25-4694-9aab-b9e99fac2ca9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }