{ "cells": [ { "cell_type": "markdown", "source": [ "# Building a classifier with DistilBert\n", "In this notebook is code to create the movie_overview_classification model. The model accepts an overview of a movie and returns a prediction regarding whether the movie will a pass the Bechdel test. It only achieves accuracy (measured via f-score) of .77, but it will be implemented as part of a larger ensemble algorithm." ], "metadata": { "collapsed": false }, "id": "2cd645ab5faa1775" }, { "cell_type": "markdown", "source": [ "## Imports and Data" ], "metadata": { "collapsed": false }, "id": "93a1ebe96b0c37cb" }, { "cell_type": "code", "execution_count": 33, "outputs": [], "source": [ "import pandas as pd\n", "import warnings\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", "warnings.filterwarnings(\"ignore\")\n", "import datasets\n", "from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DataCollatorWithPadding\n", "from datasets import load_metric\n", " \n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T18:26:24.518079800Z", "start_time": "2024-07-23T18:26:24.514056300Z" } }, "id": "61b223d9fc28303d" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "import BechdelDataImporter as data\n", "df = data.NoScripts()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:18.413548400Z", "start_time": "2024-07-23T16:20:18.183977700Z" } }, "id": "f39f162e2864f877" }, { "cell_type": "markdown", "source": [ "## Text Cleaning\n", "First, instantiate a tokenizer, data collator, and model:" ], "metadata": { "collapsed": false }, "id": "d4e49ecdb60e42c0" }, { "cell_type": "code", "execution_count": 3, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-07-23T16:20:20.917109Z", "start_time": "2024-07-23T16:20:19.948053500Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": "'\\npredicted_class_id = logits.argmax().item()\\nmodel.config.id2label[predicted_class_id]\\n'" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n", "model = DistilBertForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "markdown", "source": [ "- Drop duplicate and na rows in the data\n", "- Tokenize the overviews\n", "- Change the labels from a 4-category rating to a pass-fail rating" ], "metadata": { "collapsed": false }, "id": "4758f34ceb7509f9" }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "df['overview_tokenized'] = pd.Series()\n", "df['label'] = pd.Series()\n", "df = df.drop_duplicates(subset=['overview']).dropna(subset=['overview'])\n", "for i in df.index:\n", " df['overview_tokenized'][i] = tokenizer(df.loc[i, 'overview'], return_tensors=\"pt\")\n", " if df['bechdel_rating'][i] == 3: df['label'][i] = 1\n", " else: df['label'][i] = 0" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:32.787565800Z", "start_time": "2024-07-23T16:20:21.777045300Z" } }, "id": "9557e1ea9bd67821" }, { "cell_type": "markdown", "source": [ "Split off a test set:" ], "metadata": { "collapsed": false }, "id": "682080881db3f13" }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(df[['overview', 'overview_tokenized']], df['label'], test_size=0.2, random_state=42)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:40.513637Z", "start_time": "2024-07-23T16:20:40.487424500Z" } }, "id": "eacbec3d3fbde8c6" }, { "cell_type": "markdown", "source": [ "## Connecting to Hugging Face" ], "metadata": { "collapsed": false }, "id": "ee02108df36daa41" }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "data": { "text/plain": "VBox(children=(HTML(value='
datasets.Dataset: \n", " X['input_ids'] = pd.Series()\n", " X['attention_mask'] = pd.Series()\n", " for i in X.index:\n", " X['input_ids'][i], X['attention_mask'][i] = X.loc[i, 'overview_tokenized'].input_ids.tolist()[0], X.loc[i, 'overview_tokenized'].attention_mask.tolist()[0]\n", " \n", " \n", " return datasets.Dataset.from_pandas(X.join(y).drop(columns=['overview_tokenized']).rename(columns={'overview':'text'}))\n", "\n", "train_df, test_df = processing(X_train, y_train), processing(X_test, y_test)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:23:08.144192500Z", "start_time": "2024-07-23T16:23:05.853498400Z" } }, "id": "86ed4759f5c016d" }, { "cell_type": "markdown", "source": [ "## Training the Model" ], "metadata": { "collapsed": false }, "id": "2e85a93d14c1870e" }, { "cell_type": "code", "execution_count": 15, "outputs": [], "source": [ "from transformers import TrainingArguments, Trainer\n", " \n", "repo_name = \"movie_overview_classification\"\n", " \n", "training_args = TrainingArguments(\n", " output_dir=repo_name,\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=2,\n", " weight_decay=0.01,\n", " save_strategy=\"epoch\",\n", " push_to_hub=True\n", ")\n", " \n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_df,\n", " eval_dataset=test_df,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics\n", ")\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:23:20.952560300Z", "start_time": "2024-07-23T16:23:17.185501700Z" } }, "id": "a50a809b77860b4" }, { "cell_type": "code", "execution_count": 16, "outputs": [ { "data": { "text/plain": "", "text/html": "\n
\n \n \n [ 2/1010 : < :, Epoch 0.00/2]\n
\n \n \n \n \n \n \n \n \n \n
StepTraining Loss

" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "TrainOutput(global_step=1010, training_loss=0.5062790998137824, metrics={'train_runtime': 5485.4763, 'train_samples_per_second': 2.944, 'train_steps_per_second': 0.184, 'total_flos': 545236330318872.0, 'train_loss': 0.5062790998137824, 'epoch': 2.0})" }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T18:01:30.248929900Z", "start_time": "2024-07-23T16:23:44.701626400Z" } }, "id": "9361a3c143040d03" }, { "cell_type": "code", "execution_count": 17, "outputs": [ { "data": { "text/plain": "", "text/html": "\n

\n \n \n [ 1/127 : < :]\n
\n " }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "{'eval_loss': 0.5222412347793579,\n 'eval_accuracy': 0.7439326399207529,\n 'eval_f1': 0.7701200533570476,\n 'eval_runtime': 272.8653,\n 'eval_samples_per_second': 7.399,\n 'eval_steps_per_second': 0.465,\n 'epoch': 2.0}" }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T18:06:03.246574600Z", "start_time": "2024-07-23T18:01:30.243814700Z" } }, "id": "f183c0cdcc65ce21" }, { "cell_type": "markdown", "source": [ "Pushing the model to Hugging Face hub" ], "metadata": { "collapsed": false }, "id": "f6f0609cd15e5cc2" }, { "cell_type": "code", "execution_count": 18, "outputs": [ { "data": { "text/plain": "events.out.tfevents.1721757963.Marks_Laptop.61328.1: 0%| | 0.00/457 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
textinput_idsattention_masklabel__index_level_0__preds
0A young and devoted morning television produce...[101, 1037, 2402, 1998, 7422, 2851, 2547, 3135...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...16240[{'label': 'LABEL_1', 'score': 0.8961271643638...
1Don Birnam, a long-time alcoholic, has been so...[101, 2123, 12170, 12789, 2213, 1010, 1037, 21...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...0574[{'label': 'LABEL_0', 'score': 0.7674303054809...
2One peaceful day on Earth, two remnants of Fri...[101, 2028, 9379, 2154, 2006, 3011, 1010, 2048...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...08304[{'label': 'LABEL_0', 'score': 0.6576490402221...
3Dominic Toretto and his crew battle the most s...[101, 11282, 9538, 9284, 1998, 2010, 3626, 264...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...19737[{'label': 'LABEL_0', 'score': 0.8074591755867...
4The Martins family are optimistic dreamers, qu...[101, 1996, 19953, 2155, 2024, 21931, 24726, 2...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...110031[{'label': 'LABEL_1', 'score': 0.8910151124000...
.....................
2014Seven short films - each one focused on the pl...[101, 2698, 2460, 3152, 1011, 2169, 2028, 4208...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...14951[{'label': 'LABEL_1', 'score': 0.6277556419372...
2015After an unprecedented series of natural disas...[101, 2044, 2019, 15741, 2186, 1997, 3019, 186...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...18786[{'label': 'LABEL_1', 'score': 0.6020154356956...
2016Girl Lost tackles the issue of underage prosti...[101, 2611, 2439, 10455, 1996, 3277, 1997, 210...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...18693[{'label': 'LABEL_1', 'score': 0.9668087363243...
2017Loosely based on the true-life tale of Ron Woo...[101, 11853, 2241, 2006, 1996, 2995, 1011, 216...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...17366[{'label': 'LABEL_1', 'score': 0.9210842251777...
2018A young black pianist becomes embroiled in the...[101, 1037, 2402, 2304, 9066, 4150, 7861, 1261...[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...02045[{'label': 'LABEL_1', 'score': 0.9472063183784...
\n

2019 rows × 6 columns

\n" }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T18:20:38.981650800Z", "start_time": "2024-07-23T18:20:38.968907100Z" } }, "id": "de058528ac4dc23c" }, { "cell_type": "code", "execution_count": 27, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "732fc0611fabe4fa" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "21821b4e974dfb78" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }