{ "cells": [ { "cell_type": "markdown", "source": [ "# Predicting Movie Genres from Scripts with Naive Bayes\n" ], "metadata": { "collapsed": false }, "id": "bfbabb8e41b2e21d" }, { "cell_type": "markdown", "source": [ "## Imports" ], "metadata": { "collapsed": false }, "id": "a93e0ae149b40d37" }, { "cell_type": "code", "execution_count": 148, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-07-21T16:56:50.314173800Z", "start_time": "2024-07-21T16:56:50.293081300Z" } }, "outputs": [], "source": [ "import nltk\n", "import statistics\n", "import pandas as pd\n", "import numpy as np\n", "import string\n", "from nltk.corpus import stopwords\n", "from nltk.tokenize import word_tokenize\n", "import psycopg2\n", "import warnings\n", "from sklearn.model_selection import train_test_split\n", "import math\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "source": [ "## Data\n", "The methodology for constructing this database can be found in the 'Building a Database' notebook on [github.](https://github.com/mocboch/Movie-Script-Data-Analysis/blob/master/Building%20a%20Database.ipynb)" ], "metadata": { "collapsed": false }, "id": "87949522dbe1f66e" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "conn = psycopg2.connect(dbname='bechdel_test', user='postgres', password='guest')\n", "cur = conn.cursor()\n", "\n", "cur.execute('SELECT * FROM imsdb_scripts JOIN bechdel_ratings ON imsdb_scripts.imdb_id = bechdel_ratings.imdb_id JOIN tmdb_data ON tmdb_data.imdb_id = imsdb_scripts.imdb_id;')\n", "data = pd.DataFrame(cur.fetchall())\n", "df = data.copy()\n", "df.set_index(0, inplace=True)\n", "\n", "cur.execute('SELECT genre.imdb_id, genre FROM genre JOIN imsdb_scripts ON imsdb_scripts.imdb_id = genre.imdb_id;')\n", "genre = pd.DataFrame(cur.fetchall())\n", "cur.close()\n", "conn.close()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:40:37.610414800Z", "start_time": "2024-07-21T00:40:36.902696700Z" } }, "id": "71592ed565d416ac" }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "for genre_ in genre[1].unique():\n", " df[genre_] = pd.Series()\n", "for row in genre.iterrows():\n", " df[row[1][1]][row[1][0]] = 1\n", "df.rename(columns={0:'imdb_id',\n", " 1:'script_date',\n", " 2:'script',\n", " 3:'bechdel_id',\n", " 5:'title',\n", " 6:'release_year',\n", " 7:'bechdel_rating',\n", " 11:'language',\n", " 13:'popularity',\n", " 14:'vote_average',\n", " 15:'vote_count',\n", " 16:'overview'\n", " }, \n", " inplace=True)\n", "df.drop(columns=[4, 8, 9, 10, 12], inplace=True)\n", "df.fillna(0, inplace=True)\n", "df.replace('none', np.nan, inplace=True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:40:38.595796100Z", "start_time": "2024-07-21T00:40:38.377646300Z" } }, "id": "b1634fe99060fec9" }, { "cell_type": "markdown", "source": [ "## Cleaning the Text\n", "This function will clean and tokenize each script, eliminating stop words and punctuation." ], "metadata": { "collapsed": false }, "id": "fb7bd04932339043" }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "def clean_text(text: str) -> list[str]:\n", " text = word_tokenize(text.lower())\n", " ls = list(string.punctuation) + stopwords.words('english') + ['...', '--', '\\'\\'', '``']\n", " i = 0\n", " while i < len(text):\n", " if text[i] in ls:\n", " text.remove(text[i])\n", " else:\n", " i += 1\n", " return text" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:40:39.607310700Z", "start_time": "2024-07-21T00:40:39.593141700Z" } }, "id": "7e5ae395034f5af4" }, { "cell_type": "markdown", "source": [ "A couple of leftover nas remain in the dataset, otherwise we can go ahead and run the function on the dataset." ], "metadata": { "collapsed": false }, "id": "c44a88a7bc6959db" }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "df = df.dropna(subset='script')\n", "df['clean_text'] = [clean_text(text) for text in df['script']]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:52.715479900Z", "start_time": "2024-07-21T00:40:40.714210300Z" } }, "id": "cb345e387cab14ed" }, { "cell_type": "markdown", "source": [ "## The UpdateWeights Function\n", "This function updates the weights of the naive bayes classifier for a single row of data." ], "metadata": { "collapsed": false }, "id": "2f174eef2875f5d7" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "genres = list(df.columns[11:-1])\n", "\n", "def UpdateWeights(row: pd.Series, \n", " weights: dict[str: dict[str, int]], \n", " total_words_per_genre: dict[str: int],\n", " genres: list[str]=genres) -> dict[str: dict[str, int]]:\n", " genre_list = []\n", " for genre in genres:\n", " if row[genre] == 1:\n", " total_words_per_genre[genre] += len(row['clean_text'])\n", " genre_list.append(genre)\n", " \n", " for token in row['clean_text']:\n", " \n", " if token in weights:\n", " for genre in genre_list:\n", " weights[token][genre] += 1\n", " else: \n", " weights[token] = dict.fromkeys(genres, 0)\n", " for genre in genre_list:\n", " weights[token][genre] = 1\n", "\n", " \n", "\n", " \n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:52.845325100Z", "start_time": "2024-07-21T00:52:52.833577200Z" } }, "id": "33f83c4dc36e0c03" }, { "cell_type": "markdown", "source": [ "A couple of duplicates remain in the dataset:" ], "metadata": { "collapsed": false }, "id": "62b80a9c88d4711d" }, { "cell_type": "code", "execution_count": 77, "outputs": [], "source": [ "x = df.duplicated(subset='script')\n", "df = df.drop(list(x[x==True].index))" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T02:10:37.205136700Z", "start_time": "2024-07-21T02:10:37.131414900Z" } }, "id": "8ce12dd0d9829924" }, { "cell_type": "markdown", "source": [ "## Splitting Off a Test Set" ], "metadata": { "collapsed": false }, "id": "52ebd0a28f4318ae" }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(df['clean_text'], df.loc[:,'Drama':'History'], test_size=0.2, random_state=42)\n", "train_df = y_train.join(X_train)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:52.906396400Z", "start_time": "2024-07-21T00:52:52.899837600Z" } }, "id": "895321b94195deeb" }, { "cell_type": "markdown", "source": [ "## The NaiveBayes Function\n", "This function initiates the weights variable and updates it for each row in the dataframe." ], "metadata": { "collapsed": false }, "id": "392ea9cd0c7488b8" }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "def NaiveBayes(df: pd.DataFrame) -> dict[str: dict[str, int]]:\n", " total_words_per_genre = dict.fromkeys(genres, 0)\n", " weights = {}\n", " for i in list(df.index):\n", " UpdateWeights(df.loc[i], weights, total_words_per_genre)\n", " \n", " for word in weights:\n", " for genre in weights[word]:\n", " weights[word][genre] /= total_words_per_genre[genre]\n", " return weights\n", " " ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:52.923653800Z", "start_time": "2024-07-21T00:52:52.906396400Z" } }, "id": "fb80f986256f551d" }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "weights = NaiveBayes(train_df)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:55.126525400Z", "start_time": "2024-07-21T00:52:52.914147700Z" } }, "id": "a6a5c8953907d78c" }, { "cell_type": "markdown", "source": [ "## The LogWeights Function\n", "This function returns the natural logarithm of each weight, or -10,000 if the weight is 0." ], "metadata": { "collapsed": false }, "id": "d2bb0092f5f77545" }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "def LogWeights(weights: dict[str: dict[str: float]]):\n", " for word in weights.keys():\n", " for genre in weights[word]:\n", " if weights[word][genre] == 0:\n", " weights[word][genre] = -10000\n", " else:\n", " weights[word][genre] = math.log(weights[word][genre])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:55.141207Z", "start_time": "2024-07-21T00:52:55.129517Z" } }, "id": "24218d5584304bda" }, { "cell_type": "code", "execution_count": 14, "outputs": [], "source": [ "LogWeights(weights)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T00:52:55.408535700Z", "start_time": "2024-07-21T00:52:55.133048700Z" } }, "id": "33acddf4cb72bfa2" }, { "cell_type": "markdown", "source": [ "## The Feature Function and Score Functions\n", "These functions return the feature function and prediction scores for a script. The n highest scoring genres will be considered the model's predictions, where n is the amount of genres listed for the movie." ], "metadata": { "collapsed": false }, "id": "5cb6011ff30bd287" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "def FeatureFunction(tokens: list[str]) -> list[tuple[str, int]]:\n", " return [(token, tokens.count(token)) for token in set(tokens)]" ], "metadata": { "collapsed": false }, "id": "69e73e6aaf9ddca3" }, { "cell_type": "code", "execution_count": 36, "outputs": [], "source": [ "def Score(script: list[str], weights: dict[str: dict[str: float]]=weights, genres: list[str]=genres) -> dict[str: int]:\n", " score = dict.fromkeys(genres, 0)\n", " \n", " \n", " for word, count in FeatureFunction(script): \n", " for genre in score:\n", " if word in weights: score[genre] += weights[word][genre] * count\n", " return score\n", " " ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-21T01:28:58.844985500Z", "start_time": "2024-07-21T01:28:58.813798600Z" } }, "id": "d08313e152c68912" }, { "cell_type": "markdown", "source": [ "For the first entry, index #349903, Crime and Thriller are listed as the movies genres. The Score function scores those two genres highest by an order of magnitude!" ], "metadata": { "collapsed": false }, "id": "43c144490bd36337" }, { "cell_type": "code", "execution_count": 78, "outputs": [ { "data": { "text/plain": " Drama Romance Adventure Fantasy Family Mystery Crime Thriller \\\n0 \n349903 0 0 0 0 0 0 1 1 \n43014 1 0 0 0 0 0 0 0 \n86510 1 0 0 0 0 0 0 0 \n114369 0 0 0 0 0 1 1 1 \n758758 1 0 1 0 0 0 0 0 \n... ... ... ... ... ... ... ... ... \n100405 0 1 0 0 0 0 0 0 \n110632 1 0 0 0 0 0 1 1 \n448157 0 0 0 1 0 0 0 0 \n1441326 1 0 0 0 0 0 0 1 \n109830 1 1 0 0 0 0 0 0 \n\n War Comedy Music Western Horror Science Fiction Action \\\n0 \n349903 0 0 0 0 0 0 0 \n43014 0 0 0 0 0 0 0 \n86510 1 0 0 0 0 0 0 \n114369 0 0 0 0 0 0 0 \n758758 0 0 0 0 0 0 0 \n... ... ... ... ... ... ... ... \n100405 0 1 0 0 0 0 0 \n110632 0 0 0 0 0 0 0 \n448157 0 0 0 0 0 0 1 \n1441326 0 0 0 0 0 0 0 \n109830 0 1 0 0 0 0 0 \n\n Animation History clean_text \n0 \n349903 0 0 [ocean, 's, twelve, written, george, nolfi, ro... \n43014 0 0 [sunset, boulevard, charles, brackett, billy, ... \n86510 0 0 [fire, screenplay, clayton, frohman, ron, shel... \n114369 0 0 [seven, andrew, kevin, walker, january, 27,199... \n758758 0 0 [wild, written, sean, penn, based, book, jon, ... \n... ... ... ... \n100405 0 0 [p, r, e, w, n, jonathan, lawton, stephen, met... \n110632 0 0 [natural, born, killers, written, quentin, tar... \n448157 0 0 [hancock, written, vincent, ngo, vince, gillig... \n1441326 0 0 [martha, marcy, may, marlene, written, sean, d... \n109830 0 0 [forrest, gump, screenplay, eric, roth, based,... \n\n[330 rows x 18 columns]", "text/html": "
| \n | Drama | \nRomance | \nAdventure | \nFantasy | \nFamily | \nMystery | \nCrime | \nThriller | \nWar | \nComedy | \nMusic | \nWestern | \nHorror | \nScience Fiction | \nAction | \nAnimation | \nHistory | \nclean_text | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n\n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n |
| 349903 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[ocean, 's, twelve, written, george, nolfi, ro... | \n
| 43014 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[sunset, boulevard, charles, brackett, billy, ... | \n
| 86510 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[fire, screenplay, clayton, frohman, ron, shel... | \n
| 114369 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[seven, andrew, kevin, walker, january, 27,199... | \n
| 758758 | \n1 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[wild, written, sean, penn, based, book, jon, ... | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 100405 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[p, r, e, w, n, jonathan, lawton, stephen, met... | \n
| 110632 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[natural, born, killers, written, quentin, tar... | \n
| 448157 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n[hancock, written, vincent, ngo, vince, gillig... | \n
| 1441326 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[martha, marcy, may, marlene, written, sean, d... | \n
| 109830 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[forrest, gump, screenplay, eric, roth, based,... | \n
330 rows × 18 columns
\n| \n | Drama | \nRomance | \nAdventure | \nFantasy | \nFamily | \nMystery | \nCrime | \nThriller | \nWar | \nComedy | \nMusic | \nWestern | \nHorror | \nScience Fiction | \nAction | \nAnimation | \nHistory | \nclean_text | \ngenres_listed | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n\n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n |
| 1126590 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[big, eyes, written, scott, alexander, larry, ... | \n1 | \n
| 1655420 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[week, marilyn, written, adrian, hodges, 1, ex... | \n2 | \n
| 1365050 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[beasts, nation, written, cary, joji, fukunaga... | \n2 | \n
| 1067774 | \n0 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[monte, carlo, written, ron, bass, based, nove... | \n3 | \n
| 164052 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n1 | \n0 | \n0 | \n[hollow, man, written, andrew, w., marlowe, re... | \n3 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 1201167 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[funny, people, written, judd, apatow, april, ... | \n2 | \n
| 1027718 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[wall, street, money, never, sleeps, written, ... | \n2 | \n
| 162346 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[ghost, world, daniel, clowes, terry, zwigoff,... | \n2 | \n
| 824747 | \n1 | \n0 | \n0 | \n0 | \n0 | \n1 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[changeling, true, story, written, j., michael... | \n3 | \n
| 2473602 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n1 | \n0 | \n0 | \n0 | \n0 | \n0 | \n0 | \n[get, written, steven, baigelman, jez, butterw... | \n2 | \n
83 rows × 19 columns
\n