{ "cells": [ { "cell_type": "markdown", "source": [ "# Building a classifier with DistilBert\n", "In this notebook is code to create the movie_overview_classification model. The model accepts an overview of a movie and returns a prediction regarding whether the movie will a pass the Bechdel test. It only achieves accuracy (measured via f-score) of .77, but it will be implemented as part of a larger ensemble algorithm." ], "metadata": { "collapsed": false }, "id": "2cd645ab5faa1775" }, { "cell_type": "markdown", "source": [ "## Imports and Data" ], "metadata": { "collapsed": false }, "id": "93a1ebe96b0c37cb" }, { "cell_type": "code", "execution_count": 33, "outputs": [], "source": [ "import pandas as pd\n", "import warnings\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", "warnings.filterwarnings(\"ignore\")\n", "import datasets\n", "from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DataCollatorWithPadding\n", "from datasets import load_metric\n", " \n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T18:26:24.518079800Z", "start_time": "2024-07-23T18:26:24.514056300Z" } }, "id": "61b223d9fc28303d" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "import BechdelDataImporter as data\n", "df = data.NoScripts()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:18.413548400Z", "start_time": "2024-07-23T16:20:18.183977700Z" } }, "id": "f39f162e2864f877" }, { "cell_type": "markdown", "source": [ "## Text Cleaning\n", "First, instantiate a tokenizer, data collator, and model:" ], "metadata": { "collapsed": false }, "id": "d4e49ecdb60e42c0" }, { "cell_type": "code", "execution_count": 3, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-07-23T16:20:20.917109Z", "start_time": "2024-07-23T16:20:19.948053500Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/plain": "'\\npredicted_class_id = logits.argmax().item()\\nmodel.config.id2label[predicted_class_id]\\n'" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n", "model = DistilBertForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "markdown", "source": [ "- Drop duplicate and na rows in the data\n", "- Tokenize the overviews\n", "- Change the labels from a 4-category rating to a pass-fail rating" ], "metadata": { "collapsed": false }, "id": "4758f34ceb7509f9" }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "df['overview_tokenized'] = pd.Series()\n", "df['label'] = pd.Series()\n", "df = df.drop_duplicates(subset=['overview']).dropna(subset=['overview'])\n", "for i in df.index:\n", " df['overview_tokenized'][i] = tokenizer(df.loc[i, 'overview'], return_tensors=\"pt\")\n", " if df['bechdel_rating'][i] == 3: df['label'][i] = 1\n", " else: df['label'][i] = 0" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:32.787565800Z", "start_time": "2024-07-23T16:20:21.777045300Z" } }, "id": "9557e1ea9bd67821" }, { "cell_type": "markdown", "source": [ "Split off a test set:" ], "metadata": { "collapsed": false }, "id": "682080881db3f13" }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(df[['overview', 'overview_tokenized']], df['label'], test_size=0.2, random_state=42)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T16:20:40.513637Z", "start_time": "2024-07-23T16:20:40.487424500Z" } }, "id": "eacbec3d3fbde8c6" }, { "cell_type": "markdown", "source": [ "## Connecting to Hugging Face" ], "metadata": { "collapsed": false }, "id": "ee02108df36daa41" }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "data": { "text/plain": "VBox(children=(HTML(value='
| Step | \nTraining Loss | \n
|---|
"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "TrainOutput(global_step=1010, training_loss=0.5062790998137824, metrics={'train_runtime': 5485.4763, 'train_samples_per_second': 2.944, 'train_steps_per_second': 0.184, 'total_flos': 545236330318872.0, 'train_loss': 0.5062790998137824, 'epoch': 2.0})"
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-07-23T18:01:30.248929900Z",
"start_time": "2024-07-23T16:23:44.701626400Z"
}
},
"id": "9361a3c143040d03"
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"data": {
"text/plain": " 2019 rows × 6 columns\n \n
\n\n \n \n \n \n text \n input_ids \n attention_mask \n label \n __index_level_0__ \n preds \n \n \n 0 \n A young and devoted morning television produce... \n [101, 1037, 2402, 1998, 7422, 2851, 2547, 3135... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 6240 \n [{'label': 'LABEL_1', 'score': 0.8961271643638... \n \n \n 1 \n Don Birnam, a long-time alcoholic, has been so... \n [101, 2123, 12170, 12789, 2213, 1010, 1037, 21... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 0 \n 574 \n [{'label': 'LABEL_0', 'score': 0.7674303054809... \n \n \n 2 \n One peaceful day on Earth, two remnants of Fri... \n [101, 2028, 9379, 2154, 2006, 3011, 1010, 2048... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 0 \n 8304 \n [{'label': 'LABEL_0', 'score': 0.6576490402221... \n \n \n 3 \n Dominic Toretto and his crew battle the most s... \n [101, 11282, 9538, 9284, 1998, 2010, 3626, 264... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 9737 \n [{'label': 'LABEL_0', 'score': 0.8074591755867... \n \n \n 4 \n The Martins family are optimistic dreamers, qu... \n [101, 1996, 19953, 2155, 2024, 21931, 24726, 2... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 10031 \n [{'label': 'LABEL_1', 'score': 0.8910151124000... \n \n \n ... \n ... \n ... \n ... \n ... \n ... \n ... \n \n \n 2014 \n Seven short films - each one focused on the pl... \n [101, 2698, 2460, 3152, 1011, 2169, 2028, 4208... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 4951 \n [{'label': 'LABEL_1', 'score': 0.6277556419372... \n \n \n 2015 \n After an unprecedented series of natural disas... \n [101, 2044, 2019, 15741, 2186, 1997, 3019, 186... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 8786 \n [{'label': 'LABEL_1', 'score': 0.6020154356956... \n \n \n 2016 \n Girl Lost tackles the issue of underage prosti... \n [101, 2611, 2439, 10455, 1996, 3277, 1997, 210... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 8693 \n [{'label': 'LABEL_1', 'score': 0.9668087363243... \n \n \n 2017 \n Loosely based on the true-life tale of Ron Woo... \n [101, 11853, 2241, 2006, 1996, 2995, 1011, 216... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 1 \n 7366 \n [{'label': 'LABEL_1', 'score': 0.9210842251777... \n \n \n \n2018 \n A young black pianist becomes embroiled in the... \n [101, 1037, 2402, 2304, 9066, 4150, 7861, 1261... \n [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n 0 \n 2045 \n [{'label': 'LABEL_1', 'score': 0.9472063183784... \n