{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Predicting Movie Genres from Scripts with Naive Bayes\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "bfbabb8e41b2e21d"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Imports"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a93e0ae149b40d37"
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-07-21T16:56:50.314173800Z",
     "start_time": "2024-07-21T16:56:50.293081300Z"
    }
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "import statistics\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import string\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.tokenize import word_tokenize\n",
    "import psycopg2\n",
    "import warnings\n",
    "from sklearn.model_selection import train_test_split\n",
    "import math\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Data\n",
    "The methodology for constructing this database can be found in the 'Building a Database' notebook on [github.](https://github.com/mocboch/Movie-Script-Data-Analysis/blob/master/Building%20a%20Database.ipynb)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "87949522dbe1f66e"
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "conn = psycopg2.connect(dbname='bechdel_test', user='postgres', password='guest')\n",
    "cur = conn.cursor()\n",
    "\n",
    "cur.execute('SELECT * FROM imsdb_scripts JOIN bechdel_ratings ON imsdb_scripts.imdb_id = bechdel_ratings.imdb_id JOIN tmdb_data ON tmdb_data.imdb_id = imsdb_scripts.imdb_id;')\n",
    "data = pd.DataFrame(cur.fetchall())\n",
    "df = data.copy()\n",
    "df.set_index(0, inplace=True)\n",
    "\n",
    "cur.execute('SELECT genre.imdb_id, genre FROM genre JOIN imsdb_scripts ON imsdb_scripts.imdb_id = genre.imdb_id;')\n",
    "genre = pd.DataFrame(cur.fetchall())\n",
    "cur.close()\n",
    "conn.close()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:40:37.610414800Z",
     "start_time": "2024-07-21T00:40:36.902696700Z"
    }
   },
   "id": "71592ed565d416ac"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "for genre_ in genre[1].unique():\n",
    "    df[genre_] = pd.Series()\n",
    "for row in genre.iterrows():\n",
    "    df[row[1][1]][row[1][0]] = 1\n",
    "df.rename(columns={0:'imdb_id',\n",
    "                        1:'script_date',\n",
    "                        2:'script',\n",
    "                        3:'bechdel_id',\n",
    "                        5:'title',\n",
    "                        6:'release_year',\n",
    "                        7:'bechdel_rating',\n",
    "                        11:'language',\n",
    "                        13:'popularity',\n",
    "                        14:'vote_average',\n",
    "                        15:'vote_count',\n",
    "                        16:'overview'\n",
    "                        }, \n",
    "               inplace=True)\n",
    "df.drop(columns=[4, 8, 9, 10, 12], inplace=True)\n",
    "df.fillna(0, inplace=True)\n",
    "df.replace('none', np.nan, inplace=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:40:38.595796100Z",
     "start_time": "2024-07-21T00:40:38.377646300Z"
    }
   },
   "id": "b1634fe99060fec9"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Cleaning the Text\n",
    "This function will clean and tokenize each script, eliminating stop words and punctuation."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "fb7bd04932339043"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "def clean_text(text: str) -> list[str]:\n",
    "    text = word_tokenize(text.lower())\n",
    "    ls = list(string.punctuation) + stopwords.words('english') + ['...', '--', '\\'\\'', '``']\n",
    "    i = 0\n",
    "    while i < len(text):\n",
    "        if text[i] in ls:\n",
    "            text.remove(text[i])\n",
    "        else:\n",
    "            i += 1\n",
    "    return text"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:40:39.607310700Z",
     "start_time": "2024-07-21T00:40:39.593141700Z"
    }
   },
   "id": "7e5ae395034f5af4"
  },
  {
   "cell_type": "markdown",
   "source": [
    "A couple of leftover nas remain in the dataset, otherwise we can go ahead and run the function on the dataset."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c44a88a7bc6959db"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "df = df.dropna(subset='script')\n",
    "df['clean_text'] = [clean_text(text) for text in df['script']]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:52.715479900Z",
     "start_time": "2024-07-21T00:40:40.714210300Z"
    }
   },
   "id": "cb345e387cab14ed"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## The UpdateWeights Function\n",
    "This function updates the weights of the naive bayes classifier for a single row of data."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2f174eef2875f5d7"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "genres = list(df.columns[11:-1])\n",
    "\n",
    "def UpdateWeights(row: pd.Series, \n",
    "               weights: dict[str: dict[str, int]], \n",
    "               total_words_per_genre: dict[str: int],\n",
    "               genres: list[str]=genres) -> dict[str: dict[str, int]]:\n",
    "    genre_list = []\n",
    "    for genre in genres:\n",
    "        if row[genre] == 1:\n",
    "            total_words_per_genre[genre] += len(row['clean_text'])\n",
    "            genre_list.append(genre)\n",
    "        \n",
    "    for token in row['clean_text']:\n",
    "       \n",
    "        if token in weights:\n",
    "            for genre in genre_list:\n",
    "                weights[token][genre] += 1\n",
    "        else: \n",
    "            weights[token] = dict.fromkeys(genres, 0)\n",
    "            for genre in genre_list:\n",
    "                weights[token][genre] = 1\n",
    "\n",
    "        \n",
    "\n",
    "            \n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:52.845325100Z",
     "start_time": "2024-07-21T00:52:52.833577200Z"
    }
   },
   "id": "33f83c4dc36e0c03"
  },
  {
   "cell_type": "markdown",
   "source": [
    "A couple of duplicates remain in the dataset:"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "62b80a9c88d4711d"
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "outputs": [],
   "source": [
    "x = df.duplicated(subset='script')\n",
    "df = df.drop(list(x[x==True].index))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T02:10:37.205136700Z",
     "start_time": "2024-07-21T02:10:37.131414900Z"
    }
   },
   "id": "8ce12dd0d9829924"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Splitting Off a Test Set"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "52ebd0a28f4318ae"
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(df['clean_text'], df.loc[:,'Drama':'History'], test_size=0.2, random_state=42)\n",
    "train_df = y_train.join(X_train)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:52.906396400Z",
     "start_time": "2024-07-21T00:52:52.899837600Z"
    }
   },
   "id": "895321b94195deeb"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## The NaiveBayes Function\n",
    "This function initiates the weights variable and updates it for each row in the dataframe."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "392ea9cd0c7488b8"
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "def NaiveBayes(df: pd.DataFrame) -> dict[str: dict[str, int]]:\n",
    "    total_words_per_genre = dict.fromkeys(genres, 0)\n",
    "    weights = {}\n",
    "    for i in list(df.index):\n",
    "        UpdateWeights(df.loc[i], weights, total_words_per_genre)\n",
    "    \n",
    "    for word in weights:\n",
    "        for genre in weights[word]:\n",
    "            weights[word][genre] /= total_words_per_genre[genre]\n",
    "    return weights\n",
    "        "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:52.923653800Z",
     "start_time": "2024-07-21T00:52:52.906396400Z"
    }
   },
   "id": "fb80f986256f551d"
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "weights = NaiveBayes(train_df)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:55.126525400Z",
     "start_time": "2024-07-21T00:52:52.914147700Z"
    }
   },
   "id": "a6a5c8953907d78c"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## The LogWeights Function\n",
    "This function returns the natural logarithm of each weight, or -10,000 if the weight is 0."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d2bb0092f5f77545"
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "def LogWeights(weights: dict[str: dict[str: float]]):\n",
    "    for word in weights.keys():\n",
    "        for genre in weights[word]:\n",
    "            if weights[word][genre] == 0:\n",
    "                weights[word][genre] = -10000\n",
    "            else:\n",
    "                weights[word][genre] = math.log(weights[word][genre])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:55.141207Z",
     "start_time": "2024-07-21T00:52:55.129517Z"
    }
   },
   "id": "24218d5584304bda"
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "LogWeights(weights)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T00:52:55.408535700Z",
     "start_time": "2024-07-21T00:52:55.133048700Z"
    }
   },
   "id": "33acddf4cb72bfa2"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## The Feature Function and Score Functions\n",
    "These functions return the feature function and prediction scores for a script. The n highest scoring genres will be considered the model's predictions, where n is the amount of genres listed for the movie."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "5cb6011ff30bd287"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def FeatureFunction(tokens: list[str]) -> list[tuple[str, int]]:\n",
    "    return [(token, tokens.count(token)) for token in set(tokens)]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "69e73e6aaf9ddca3"
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "outputs": [],
   "source": [
    "def Score(script: list[str], weights: dict[str: dict[str: float]]=weights, genres: list[str]=genres) -> dict[str: int]:\n",
    "    score = dict.fromkeys(genres, 0)\n",
    "    \n",
    "    \n",
    "    for word, count in FeatureFunction(script):      \n",
    "        for genre in score:\n",
    "            if word in weights: score[genre] += weights[word][genre] * count\n",
    "    return score\n",
    "        "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T01:28:58.844985500Z",
     "start_time": "2024-07-21T01:28:58.813798600Z"
    }
   },
   "id": "d08313e152c68912"
  },
  {
   "cell_type": "markdown",
   "source": [
    "For the first entry, index #349903, Crime and Thriller are listed as the movies genres. The Score function scores those two genres highest by an order of magnitude!"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "43c144490bd36337"
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "outputs": [
    {
     "data": {
      "text/plain": "         Drama  Romance  Adventure  Fantasy  Family  Mystery  Crime  Thriller  \\\n0                                                                               \n349903       0        0          0        0       0        0      1         1   \n43014        1        0          0        0       0        0      0         0   \n86510        1        0          0        0       0        0      0         0   \n114369       0        0          0        0       0        1      1         1   \n758758       1        0          1        0       0        0      0         0   \n...        ...      ...        ...      ...     ...      ...    ...       ...   \n100405       0        1          0        0       0        0      0         0   \n110632       1        0          0        0       0        0      1         1   \n448157       0        0          0        1       0        0      0         0   \n1441326      1        0          0        0       0        0      0         1   \n109830       1        1          0        0       0        0      0         0   \n\n         War  Comedy  Music  Western  Horror  Science Fiction  Action  \\\n0                                                                       \n349903     0       0      0        0       0                0       0   \n43014      0       0      0        0       0                0       0   \n86510      1       0      0        0       0                0       0   \n114369     0       0      0        0       0                0       0   \n758758     0       0      0        0       0                0       0   \n...      ...     ...    ...      ...     ...              ...     ...   \n100405     0       1      0        0       0                0       0   \n110632     0       0      0        0       0                0       0   \n448157     0       0      0        0       0                0       1   \n1441326    0       0      0        0       0                0       0   \n109830     0       1      0        0       0                0       0   \n\n         Animation  History                                         clean_text  \n0                                                                               \n349903           0        0  [ocean, 's, twelve, written, george, nolfi, ro...  \n43014            0        0  [sunset, boulevard, charles, brackett, billy, ...  \n86510            0        0  [fire, screenplay, clayton, frohman, ron, shel...  \n114369           0        0  [seven, andrew, kevin, walker, january, 27,199...  \n758758           0        0  [wild, written, sean, penn, based, book, jon, ...  \n...            ...      ...                                                ...  \n100405           0        0  [p, r, e, w, n, jonathan, lawton, stephen, met...  \n110632           0        0  [natural, born, killers, written, quentin, tar...  \n448157           0        0  [hancock, written, vincent, ngo, vince, gillig...  \n1441326          0        0  [martha, marcy, may, marlene, written, sean, d...  \n109830           0        0  [forrest, gump, screenplay, eric, roth, based,...  \n\n[330 rows x 18 columns]",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>Drama</th>\n      <th>Romance</th>\n      <th>Adventure</th>\n      <th>Fantasy</th>\n      <th>Family</th>\n      <th>Mystery</th>\n      <th>Crime</th>\n      <th>Thriller</th>\n      <th>War</th>\n      <th>Comedy</th>\n      <th>Music</th>\n      <th>Western</th>\n      <th>Horror</th>\n      <th>Science Fiction</th>\n      <th>Action</th>\n      <th>Animation</th>\n      <th>History</th>\n      <th>clean_text</th>\n    </tr>\n    <tr>\n      <th>0</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>349903</th>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[ocean, 's, twelve, written, george, nolfi, ro...</td>\n    </tr>\n    <tr>\n      <th>43014</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[sunset, boulevard, charles, brackett, billy, ...</td>\n    </tr>\n    <tr>\n      <th>86510</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[fire, screenplay, clayton, frohman, ron, shel...</td>\n    </tr>\n    <tr>\n      <th>114369</th>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[seven, andrew, kevin, walker, january, 27,199...</td>\n    </tr>\n    <tr>\n      <th>758758</th>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[wild, written, sean, penn, based, book, jon, ...</td>\n    </tr>\n    <tr>\n      <th>...</th>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n    </tr>\n    <tr>\n      <th>100405</th>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[p, r, e, w, n, jonathan, lawton, stephen, met...</td>\n    </tr>\n    <tr>\n      <th>110632</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[natural, born, killers, written, quentin, tar...</td>\n    </tr>\n    <tr>\n      <th>448157</th>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[hancock, written, vincent, ngo, vince, gillig...</td>\n    </tr>\n    <tr>\n      <th>1441326</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[martha, marcy, may, marlene, written, sean, d...</td>\n    </tr>\n    <tr>\n      <th>109830</th>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[forrest, gump, screenplay, eric, roth, based,...</td>\n    </tr>\n  </tbody>\n</table>\n<p>330 rows × 18 columns</p>\n</div>"
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T02:16:04.062220300Z",
     "start_time": "2024-07-21T02:16:04.045331100Z"
    }
   },
   "id": "cb70fef545eba14a"
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "outputs": [
    {
     "data": {
      "text/plain": "Crime             -1.169829e+05\nThriller          -1.187586e+05\nDrama             -6.394165e+06\nComedy            -9.610512e+06\nRomance           -1.141861e+07\nAction            -1.331699e+07\nAdventure         -1.500677e+07\nScience Fiction   -1.607495e+07\nMystery           -1.653374e+07\nFantasy           -1.806365e+07\nHorror            -1.845203e+07\nHistory           -2.054818e+07\nFamily            -2.386636e+07\nMusic             -2.577168e+07\nAnimation         -2.709325e+07\nWestern           -4.132910e+07\nWar               -4.240746e+07\ndtype: float64"
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.Series(Score(train_df['clean_text'][349903])).sort_values(ascending=False)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T01:23:16.498180300Z",
     "start_time": "2024-07-21T01:23:15.670678300Z"
    }
   },
   "id": "e703dc7ae4b57aab"
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "outputs": [
    {
     "data": {
      "text/plain": "         Drama  Romance  Adventure  Fantasy  Family  Mystery  Crime  Thriller  \\\n0                                                                               \n1126590      1        0          0        0       0        0      0         0   \n1655420      1        1          0        0       0        0      0         0   \n1365050      1        0          0        0       0        0      0         0   \n1067774      0        1          1        0       0        0      0         0   \n164052       0        0          0        0       0        0      0         1   \n...        ...      ...        ...      ...     ...      ...    ...       ...   \n1201167      1        0          0        0       0        0      0         0   \n1027718      1        0          0        0       0        0      1         0   \n162346       1        0          0        0       0        0      0         0   \n824747       1        0          0        0       0        1      1         0   \n2473602      1        0          0        0       0        0      0         0   \n\n         War  Comedy  Music  Western  Horror  Science Fiction  Action  \\\n0                                                                       \n1126590    0       0      0        0       0                0       0   \n1655420    0       0      0        0       0                0       0   \n1365050    1       0      0        0       0                0       0   \n1067774    0       1      0        0       0                0       0   \n164052     0       0      0        0       0                1       1   \n...      ...     ...    ...      ...     ...              ...     ...   \n1201167    0       1      0        0       0                0       0   \n1027718    0       0      0        0       0                0       0   \n162346     0       1      0        0       0                0       0   \n824747     0       0      0        0       0                0       0   \n2473602    0       0      1        0       0                0       0   \n\n         Animation  History  \\\n0                             \n1126590          0        0   \n1655420          0        0   \n1365050          0        0   \n1067774          0        0   \n164052           0        0   \n...            ...      ...   \n1201167          0        0   \n1027718          0        0   \n162346           0        0   \n824747           0        0   \n2473602          0        0   \n\n                                                clean_text genres_listed  \n0                                                                         \n1126590  [big, eyes, written, scott, alexander, larry, ...             1  \n1655420  [week, marilyn, written, adrian, hodges, 1, ex...             2  \n1365050  [beasts, nation, written, cary, joji, fukunaga...             2  \n1067774  [monte, carlo, written, ron, bass, based, nove...             3  \n164052   [hollow, man, written, andrew, w., marlowe, re...             3  \n...                                                    ...           ...  \n1201167  [funny, people, written, judd, apatow, april, ...             2  \n1027718  [wall, street, money, never, sleeps, written, ...             2  \n162346   [ghost, world, daniel, clowes, terry, zwigoff,...             2  \n824747   [changeling, true, story, written, j., michael...             3  \n2473602  [get, written, steven, baigelman, jez, butterw...             2  \n\n[83 rows x 19 columns]",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>Drama</th>\n      <th>Romance</th>\n      <th>Adventure</th>\n      <th>Fantasy</th>\n      <th>Family</th>\n      <th>Mystery</th>\n      <th>Crime</th>\n      <th>Thriller</th>\n      <th>War</th>\n      <th>Comedy</th>\n      <th>Music</th>\n      <th>Western</th>\n      <th>Horror</th>\n      <th>Science Fiction</th>\n      <th>Action</th>\n      <th>Animation</th>\n      <th>History</th>\n      <th>clean_text</th>\n      <th>genres_listed</th>\n    </tr>\n    <tr>\n      <th>0</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>1126590</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[big, eyes, written, scott, alexander, larry, ...</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1655420</th>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[week, marilyn, written, adrian, hodges, 1, ex...</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>1365050</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[beasts, nation, written, cary, joji, fukunaga...</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>1067774</th>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[monte, carlo, written, ron, bass, based, nove...</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>164052</th>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[hollow, man, written, andrew, w., marlowe, re...</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>...</th>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n      <td>...</td>\n    </tr>\n    <tr>\n      <th>1201167</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[funny, people, written, judd, apatow, april, ...</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>1027718</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[wall, street, money, never, sleeps, written, ...</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>162346</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[ghost, world, daniel, clowes, terry, zwigoff,...</td>\n      <td>2</td>\n    </tr>\n    <tr>\n      <th>824747</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[changeling, true, story, written, j., michael...</td>\n      <td>3</td>\n    </tr>\n    <tr>\n      <th>2473602</th>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>[get, written, steven, baigelman, jez, butterw...</td>\n      <td>2</td>\n    </tr>\n  </tbody>\n</table>\n<p>83 rows × 19 columns</p>\n</div>"
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df = y_test.join(X_test)\n",
    "test_df['genres_listed'] =  pd.Series()\n",
    "for i in test_df.index:\n",
    "    test_df['genres_listed'][i] = sum(test_df.loc[i][:'History'])\n",
    "test_df"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T01:46:07.118036200Z",
     "start_time": "2024-07-21T01:46:07.067951900Z"
    }
   },
   "id": "1f3f14a651d70387"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Calculating Accuracy\n",
    "We can define accuracy as how many of the model's first n predictions are correct over n. Informally, this accuracy represents the percentage of correct genres the model is able to identify."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8339aebf6d0484a9"
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "outputs": [],
   "source": [
    "def PredictionAccuracy(test_df: pd.DataFrame) -> tuple[int, float]:    \n",
    "    total_score = 0\n",
    "    for i in test_df.index:\n",
    "        score = 0\n",
    "        num_genres = test_df.loc[i]['genres_listed']\n",
    "    \n",
    "        preds = list(pd.Series(Score(test_df.loc[i]['clean_text'])).sort_values(ascending=False).index)\n",
    "    \n",
    "        for genre in preds[:num_genres]:\n",
    "            if df.loc[i][genre] == 1:\n",
    "                score += 1\n",
    "    \n",
    "        score /= num_genres\n",
    "        total_score += score\n",
    "    return total_score / len(test_df)\n",
    "        \n",
    "        \n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T02:26:01.649181300Z",
     "start_time": "2024-07-21T02:26:01.625009100Z"
    }
   },
   "id": "de34a4e2542460b9"
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "outputs": [
    {
     "data": {
      "text/plain": "0.40843373493975893"
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PredictionAccuracy(test_df)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T02:21:42.893046800Z",
     "start_time": "2024-07-21T02:20:42.879695300Z"
    }
   },
   "id": "199b63bebe2f52b3"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Precision, Recall, and F-score\n",
    "Calculating the precision, recall, and F score can give a more complete picture of the model's accuracy.\n",
    "\n",
    "These scores will all be the same with the previously described mode of making predictions, because it takes into account the correct amount of labels to be predicted. Each false positive is accompanied by a false negative. This approach can be useful for tuning a larger generative model, which is how I ultimately plan to use this code. In order to calculate the accuracy in a more granular way, however, we can define a prediction threshold, either as a discrete quantity or a function of the predicted probabilities for the entire set of classes."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "47aa47ad044cc643"
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "outputs": [],
   "source": [
    "def Precision_Recall_F(test_df: pd.DataFrame, threshold_function, thresh_func_args: tuple) -> tuple[float, float, float]:\n",
    "    total_score = 0\n",
    "    true_positives = 0\n",
    "    true_negatives = 0\n",
    "    false_positives = 0\n",
    "    false_negatives = 0\n",
    "    for i in test_df.index:\n",
    "        num_genres = test_df.loc[i]['genres_listed']\n",
    "    \n",
    "        p = pd.Series(Score(test_df.loc[i]['clean_text'])).sort_values(ascending=False)\n",
    "        preds = threshold_function(p, thresh_func_args)\n",
    "        \n",
    "        #print(preds)\n",
    "        \n",
    "        \n",
    "    \n",
    "        for genre in list(test_df.loc[:,'Drama':'History'].columns):\n",
    "            if genre in preds: pred = True #Positive prediction\n",
    "            else: pred = False #Negative prediction\n",
    "            \n",
    "            if test_df.loc[i, genre] == 0: obs = False #Negative observed value\n",
    "            else: obs = True #Positive observed Value\n",
    "            \n",
    "            match (pred, obs):\n",
    "                case (True, True):\n",
    "                   true_positives += 1\n",
    "                case (True, False):\n",
    "                    false_positives += 1\n",
    "                case (False, False):\n",
    "                    true_negatives += 1\n",
    "                case (False, True):\n",
    "                    false_negatives += 1\n",
    "        '''        \n",
    "        print('preds: ', [genre for genre in preds])\n",
    "        print('row: ', test_df.loc[i])\n",
    "        print('TP: ', true_positives)\n",
    "        print('TN: ', true_negatives)\n",
    "        print('FP: ', false_positives)\n",
    "        print('FN: ', false_negatives)\n",
    "        print('-------')\n",
    "        '''\n",
    "        \n",
    "    precision = true_positives / (true_positives + false_positives)\n",
    "    recall = true_positives / (true_positives + false_negatives)\n",
    "    f = 2 * ((precision * recall) / (precision + recall))\n",
    "\n",
    "        \n",
    "    return (precision, recall, f)\n",
    "   "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:29:52.004189600Z",
     "start_time": "2024-07-21T17:29:52.000583Z"
    }
   },
   "id": "1e1e45abd0649f27"
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "outputs": [],
   "source": [
    "def thresh_stdev(p, args=(1,)):\n",
    "    '''Threshold is defined at the given Z score for each row's predicted probabilities'''\n",
    "    return list(p[p > statistics.mean(p) + (statistics.stdev(p) * args[0])].index)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:29:52.173915500Z",
     "start_time": "2024-07-21T17:29:52.168700400Z"
    }
   },
   "id": "de4fc9f25169a23e"
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "outputs": [],
   "source": [
    "def correct_number_of_preds(p, args=(0,)):\n",
    "    '''Model will make the correct number of predictions, plus the given amount of extra predictions'''\n",
    "    global num_genres\n",
    "    return list(p.index)[:num_genres + args[0]]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:29:52.344684200Z",
     "start_time": "2024-07-21T17:29:52.338936500Z"
    }
   },
   "id": "70026921b106e28a"
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "outputs": [],
   "source": [
    "def thresh_constant(p, args=(-3000000,)):\n",
    "    '''Threshold is a given constant'''\n",
    "    return list([p[p > args[0]]].index)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:29:52.509987800Z",
     "start_time": "2024-07-21T17:29:52.502530100Z"
    }
   },
   "id": "bfa03cc79e243ade"
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "outputs": [],
   "source": [
    "def thresh_linear_wrt_mean(p, args=(10,)):\n",
    "    '''Threshold is a given constant multiplied by the mean of each row's predicted probabilities'''\n",
    "    return list(p[p > statistics.mean(p) / args[0]].index)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:29:52.937577500Z",
     "start_time": "2024-07-21T17:29:52.919854800Z"
    }
   },
   "id": "27b8ab751f833d69"
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "outputs": [],
   "source": [
    "(precision, recall, f) = Precision_Recall_F(test_df, correct_number_of_preds, (0,))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:47:03.164427100Z",
     "start_time": "2024-07-21T17:46:15.204319Z"
    }
   },
   "id": "64659fde5944528a"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Discussion\n",
    "Ideally, I would have done this type of hyperparameter tuning before making predictions on the test set. As it is, it would be hard to pick an ideal model without overfitting the available data.\n",
    "\n",
    "However, this data is very limiting to begin with. The model leans heavily in favor of predicting certain categories, predicting drama and thriller significantly more often than the other classes. Looking at the training data, these categories are heavily overrepresented. "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "eabcbc69317f6519"
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drama 160\n",
      "Romance 53\n",
      "Adventure 59\n",
      "Fantasy 40\n",
      "Family 19\n",
      "Mystery 45\n",
      "Crime 70\n",
      "Thriller 118\n",
      "War 5\n",
      "Comedy 90\n",
      "Music 7\n",
      "Western 4\n",
      "Horror 46\n",
      "Science Fiction 61\n",
      "Action 81\n",
      "Animation 15\n",
      "History 14\n"
     ]
    }
   ],
   "source": [
    "for genre in train_df.columns[:17]:\n",
    "    print(genre, len(train_df[train_df[genre] == 1]))\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-21T17:43:52.258013100Z",
     "start_time": "2024-07-21T17:43:52.244831100Z"
    }
   },
   "id": "b91fae83e77efcb9"
  },
  {
   "cell_type": "markdown",
   "source": [
    "With more data this model can likely be made more accurate, and with the addition of some fresh validation and testing data, a 'final' model can be tuned. More than likely, this simple 'bag-of-words' classifier will only be useful as part of a larger ensemble or GAN if at all."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f751fc881733f2ee"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}