{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "fastai_collab.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyN1v8T4yqV4jCEYOlmXCsEI", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/bipinKrishnan/fastai_course/blob/master/fastai_collab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "code", "metadata": { "id": "12UvxQcEEo4o" }, "source": [ "!pip install fastai --upgrade" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "neBLmF00EsVi" }, "source": [ "import fastai.collab as c\n", "import fastai.tabular.all as t\n", "import torch\n", "\n", "import pandas as pd\n", "import os\n", "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"" ], "execution_count": 25, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "pZHSdM3ZE3XT", "outputId": "e5ea3cdf-5399-4219-9a42-d0b5eb451d02", "colab": { "base_uri": "https://localhost:8080/", "height": 54 } }, "source": [ "path = c.untar_data(c.URLs.ML_100k)\n", "path.ls()" ], "execution_count": 26, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#23) [Path('/root/.fastai/data/ml-100k/ub.base'),Path('/root/.fastai/data/ml-100k/u5.base'),Path('/root/.fastai/data/ml-100k/u5.test'),Path('/root/.fastai/data/ml-100k/u.info'),Path('/root/.fastai/data/ml-100k/u2.base'),Path('/root/.fastai/data/ml-100k/u3.base'),Path('/root/.fastai/data/ml-100k/ua.test'),Path('/root/.fastai/data/ml-100k/u4.test'),Path('/root/.fastai/data/ml-100k/mku.sh'),Path('/root/.fastai/data/ml-100k/u.genre')...]" ] }, "metadata": { "tags": [] }, "execution_count": 26 } ] }, { "cell_type": "code", "metadata": { "id": "ldPlun7NFBqw", "outputId": "99e6eadd-d2dd-4e69-9d49-da0d5ab40f1b", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "df = pd.read_csv(path/'u.data', delimiter='\\t', names=['user', 'movie', 'rating', 'timestamp'])\n", "df.head()" ], "execution_count": 27, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>user</th>\n", " <th>movie</th>\n", " <th>rating</th>\n", " <th>timestamp</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>196</td>\n", " <td>242</td>\n", " <td>3</td>\n", " <td>881250949</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>186</td>\n", " <td>302</td>\n", " <td>3</td>\n", " <td>891717742</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>22</td>\n", " <td>377</td>\n", " <td>1</td>\n", " <td>878887116</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>244</td>\n", " <td>51</td>\n", " <td>2</td>\n", " <td>880606923</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>166</td>\n", " <td>346</td>\n", " <td>1</td>\n", " <td>886397596</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " user movie rating timestamp\n", "0 196 242 3 881250949\n", "1 186 302 3 891717742\n", "2 22 377 1 878887116\n", "3 244 51 2 880606923\n", "4 166 346 1 886397596" ] }, "metadata": { "tags": [] }, "execution_count": 27 } ] }, { "cell_type": "code", "metadata": { "id": "SnFX7W8NF3K2", "outputId": "f79c3180-7c71-42db-c167-66f8183e9910", "colab": { "base_uri": "https://localhost:8080/", "height": 225 } }, "source": [ "df.info(), df.shape" ], "execution_count": 28, "outputs": [ { "output_type": "stream", "text": [ "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 100000 entries, 0 to 99999\n", "Data columns (total 4 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 user 100000 non-null int64\n", " 1 movie 100000 non-null int64\n", " 2 rating 100000 non-null int64\n", " 3 timestamp 100000 non-null int64\n", "dtypes: int64(4)\n", "memory usage: 3.1 MB\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "(None, (100000, 4))" ] }, "metadata": { "tags": [] }, "execution_count": 28 } ] }, { "cell_type": "code", "metadata": { "id": "LHL4rnnPGf2E", "outputId": "a323bdaa-2c89-49ac-8816-f749451c6294", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "movies = pd.read_csv(path/'u.item', delimiter='|', names=['movie', 'title'], encoding='latin-1', usecols=(0, 1))\n", "movies.head()" ], "execution_count": 29, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>movie</th>\n", " <th>title</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>Toy Story (1995)</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>2</td>\n", " <td>GoldenEye (1995)</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>3</td>\n", " <td>Four Rooms (1995)</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>4</td>\n", " <td>Get Shorty (1995)</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>5</td>\n", " <td>Copycat (1995)</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " movie title\n", "0 1 Toy Story (1995)\n", "1 2 GoldenEye (1995)\n", "2 3 Four Rooms (1995)\n", "3 4 Get Shorty (1995)\n", "4 5 Copycat (1995)" ] }, "metadata": { "tags": [] }, "execution_count": 29 } ] }, { "cell_type": "code", "metadata": { "id": "pL6bEndiIzyT", "outputId": "026ecd26-0e1c-4854-b968-0b1ff5d21773", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "df = df.merge(movies)\n", "df.head()" ], "execution_count": 30, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>user</th>\n", " <th>movie</th>\n", " <th>rating</th>\n", " <th>timestamp</th>\n", " <th>title</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>196</td>\n", " <td>242</td>\n", " <td>3</td>\n", " <td>881250949</td>\n", " <td>Kolya (1996)</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>63</td>\n", " <td>242</td>\n", " <td>3</td>\n", " <td>875747190</td>\n", " <td>Kolya (1996)</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>226</td>\n", " <td>242</td>\n", " <td>5</td>\n", " <td>883888671</td>\n", " <td>Kolya (1996)</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>154</td>\n", " <td>242</td>\n", " <td>3</td>\n", " <td>879138235</td>\n", " <td>Kolya (1996)</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>306</td>\n", " <td>242</td>\n", " <td>5</td>\n", " <td>876503793</td>\n", " <td>Kolya (1996)</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " user movie rating timestamp title\n", "0 196 242 3 881250949 Kolya (1996)\n", "1 63 242 3 875747190 Kolya (1996)\n", "2 226 242 5 883888671 Kolya (1996)\n", "3 154 242 3 879138235 Kolya (1996)\n", "4 306 242 5 876503793 Kolya (1996)" ] }, "metadata": { "tags": [] }, "execution_count": 30 } ] }, { "cell_type": "code", "metadata": { "id": "YXShxk83Jmnm", "outputId": "19e3c40e-aa0a-4cb9-8f28-446988ea5ee7", "colab": { "base_uri": "https://localhost:8080/", "height": 363 } }, "source": [ "dls = c.CollabDataLoaders.from_df(\n", " df,\n", " user_name='user',\n", " item_name='title',\n", " rating_name='rating',\n", " bs=32\n", ")\n", "\n", "dls.show_batch()" ], "execution_count": 31, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>user</th>\n", " <th>title</th>\n", " <th>rating</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>303</td>\n", " <td>Shall We Dance? (1937)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>880</td>\n", " <td>Get Shorty (1995)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>355</td>\n", " <td>Kolya (1996)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>334</td>\n", " <td>Sound of Music, The (1965)</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>826</td>\n", " <td>Striking Distance (1993)</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>942</td>\n", " <td>It Happened One Night (1934)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>289</td>\n", " <td>Time to Kill, A (1996)</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>405</td>\n", " <td>My Left Foot (1989)</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>452</td>\n", " <td>Fantasia (1940)</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>747</td>\n", " <td>Thirty-Two Short Films About Glenn Gould (1993)</td>\n", " <td>3</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "NTOLvhrqKgs0", "outputId": "7cedb26d-ecb2-44f6-bb0a-f436a752ef96", "colab": { "base_uri": "https://localhost:8080/", "height": 72 } }, "source": [ "dls.classes" ], "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'title': (#1665) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...],\n", " 'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...]}" ] }, "metadata": { "tags": [] }, "execution_count": 32 } ] }, { "cell_type": "code", "metadata": { "id": "_lneJ6HaKqZR" }, "source": [ "n_users = len(dls.classes['user'])\n", "n_mov = len(dls.classes['title'])\n", "n_factors = 5\n", "\n", "user_factors = torch.randn(n_users, n_factors)\n", "mov_factors = torch.randn(n_mov, n_factors)" ], "execution_count": 33, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "BZRvoW1eLL02", "outputId": "999f3bdc-f7fc-4ba0-e4ed-fa8588018a99", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors.shape, mov_factors.shape" ], "execution_count": 34, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([944, 5]), torch.Size([1665, 5]))" ] }, "metadata": { "tags": [] }, "execution_count": 34 } ] }, { "cell_type": "code", "metadata": { "id": "SDCVJkEkLXBX", "outputId": "9a482c2a-9b4c-4889-c82d-f9638a7df559", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors.t()@t.one_hot(3, n_users).float() " ], "execution_count": 35, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])" ] }, "metadata": { "tags": [] }, "execution_count": 35 } ] }, { "cell_type": "code", "metadata": { "id": "qaAy_vwAMQGG", "outputId": "e566925f-4ba7-415f-e29f-1023f4b31c10", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors[3]" ], "execution_count": 36, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])" ] }, "metadata": { "tags": [] }, "execution_count": 36 } ] }, { "cell_type": "markdown", "metadata": { "id": "cz5KC59ASGch" }, "source": [ "#Base line model" ] }, { "cell_type": "code", "metadata": { "id": "7XsV_cbKM0Ez" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " return (users*movies).sum(dim=1)" ], "execution_count": 37, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WLLPAeKZPanH", "outputId": "d1319df4-c0f9-4c46-a226-558297ee7177", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "dls.one_batch()[0].shape, dls.one_batch()[1].shape" ], "execution_count": 38, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([32, 2]), torch.Size([32, 1]))" ] }, "metadata": { "tags": [] }, "execution_count": 38 } ] }, { "cell_type": "code", "metadata": { "id": "ZnGFkHQ6Pofs" }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())" ], "execution_count": 39, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "-UM8TmTOQsxV", "outputId": "bff9753b-57a4-4160-ef09-34bca03e11cb", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 40, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.280856</td>\n", " <td>1.324047</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.139197</td>\n", " <td>1.145741</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.951527</td>\n", " <td>1.011723</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.780049</td>\n", " <td>0.897130</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.747144</td>\n", " <td>0.874317</td>\n", " <td>00:16</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "APF-yfBMSeWG" }, "source": [ "# Adding sigmoid range" ] }, { "cell_type": "code", "metadata": { "id": "Z3LOXdCcRHgu" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " return c.sigmoid_range((users*movies).sum(dim=1), *self.y_range)" ], "execution_count": 41, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SEpfIhH2Qzef", "outputId": "23177f6d-af02-4827-c628-d1d90afe3300", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 42, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.038298</td>\n", " <td>1.009943</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.880899</td>\n", " <td>0.922624</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.685187</td>\n", " <td>0.892175</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.453028</td>\n", " <td>0.901179</td>\n", " <td>00:16</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.308590</td>\n", " <td>0.908446</td>\n", " <td>00:16</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "DaC8i0nVSNYU" }, "source": [ "#Adding bias parameter" ] }, { "cell_type": "code", "metadata": { "id": "PH2yzzsPSk87" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.user_bias = c.Embedding(n_users, 1)\n", "\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", " self.movie_bias = c.Embedding(n_movies, 1)\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " out = (users*movies).sum(dim=1, keepdim=True)\n", " out += self.user_bias(x[:, 0]) + self.movie_bias(x[:, 1])\n", "\n", " return c.sigmoid_range(out, *self.y_range)" ], "execution_count": 43, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FeTFDuhBSk9O", "outputId": "99da5355-1941-4e0c-81df-7b5d90888b26", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 44, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>0.981148</td>\n", " <td>0.945925</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.842618</td>\n", " <td>0.882195</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.604646</td>\n", " <td>0.913617</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.407241</td>\n", " <td>0.941966</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.237349</td>\n", " <td>0.950696</td>\n", " <td>00:18</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "rpLDE7RwSMvZ", "outputId": "8783ea29-ce75-401d-b01d-8ea78964025d", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 45, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>0.979508</td>\n", " <td>0.941654</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.858030</td>\n", " <td>0.898452</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.798919</td>\n", " <td>0.857021</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.672152</td>\n", " <td>0.827143</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.545060</td>\n", " <td>0.824844</td>\n", " <td>00:18</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "bVaSDYAaWP3z" }, "source": [ "## Embedding layer from scratch" ] }, { "cell_type": "code", "metadata": { "id": "RTnq89LkVh4w" }, "source": [ "def create_params(size): return torch.nn.Parameter(torch.zeros(*size).normal_(0, 0.01))" ], "execution_count": 46, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Wg_XfhQ3WkXv", "outputId": "81e90c5f-9045-4403-ff01-8a22323087a8", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "source": [ "create_params((3, 4))" ], "execution_count": 47, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Parameter containing:\n", "tensor([[ 0.0043, -0.0059, -0.0025, -0.0032],\n", " [ 0.0016, -0.0125, 0.0049, 0.0115],\n", " [ 0.0234, 0.0013, -0.0026, -0.0069]], requires_grad=True)" ] }, "metadata": { "tags": [] }, "execution_count": 47 } ] }, { "cell_type": "code", "metadata": { "id": "1CYDjgIGWmu3" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = create_params((n_users, n_factors))\n", " self.user_bias = create_params((n_users, 1))\n", "\n", " self.movie_factors = create_params((n_movies, n_factors))\n", " self.movie_bias = create_params((n_movies, 1))\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors[x[:, 0]]\n", " movies = self.movie_factors[x[:, 1]]\n", "\n", " out = (users*movies).sum(dim=1, keepdim=True)\n", " out += self.user_bias[x[:, 0]] + self.movie_bias[x[:, 1]]\n", "\n", " return c.sigmoid_range(out, *self.y_range)" ], "execution_count": 48, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ivmygidxW8LV", "outputId": "c1c399a1-04f7-46e9-83b4-ddb75d8c09d3", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "source": [ "for p in DotProduct(3, 4, 5).parameters():\n", " print(p.shape)" ], "execution_count": 49, "outputs": [ { "output_type": "stream", "text": [ "torch.Size([3, 5])\n", "torch.Size([3, 1])\n", "torch.Size([4, 5])\n", "torch.Size([4, 1])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "_X1lQzkmXCSX", "outputId": "387a8399-a372-4cf3-afc3-2250c8064d12", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 50, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>0.934559</td>\n", " <td>0.959005</td>\n", " <td>00:19</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.847594</td>\n", " <td>0.889536</td>\n", " <td>00:19</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.797565</td>\n", " <td>0.851402</td>\n", " <td>00:19</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.709136</td>\n", " <td>0.822152</td>\n", " <td>00:19</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.585161</td>\n", " <td>0.821672</td>\n", " <td>00:19</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "ZzbSdATHYBi8", "outputId": "8dcb5ca1-06e6-4cba-b20b-04b6aadc0894", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "source": [ "b = learn.model.movie_bias.squeeze()\n", "idx = b.argsort()[:5]\n", "\n", "[dls.classes['title'][i] for i in idx]" ], "execution_count": 51, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['Children of the Corn: The Gathering (1996)',\n", " 'Mortal Kombat: Annihilation (1997)',\n", " 'Crow: City of Angels, The (1996)',\n", " 'Island of Dr. Moreau, The (1996)',\n", " 'Lawnmower Man 2: Beyond Cyberspace (1996)']" ] }, "metadata": { "tags": [] }, "execution_count": 51 } ] }, { "cell_type": "code", "metadata": { "id": "WLnDvYYgaFKn", "outputId": "697abb6f-30d5-412e-add3-59c336225c9c", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "source": [ "b = learn.model.movie_bias.squeeze()\n", "idx = b.argsort(descending=True)[:5]\n", "\n", "[dls.classes['title'][i] for i in idx]" ], "execution_count": 52, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['Titanic (1997)',\n", " \"Schindler's List (1993)\",\n", " 'Star Wars (1977)',\n", " 'Silence of the Lambs, The (1991)',\n", " 'L.A. Confidential (1997)']" ] }, "metadata": { "tags": [] }, "execution_count": 52 } ] }, { "cell_type": "markdown", "metadata": { "id": "6yyEP2jFeusk" }, "source": [ "# Training using fastai collab_learner" ] }, { "cell_type": "code", "metadata": { "id": "sDYjMvOccvcH" }, "source": [ "learn = c.collab_learner(dls, n_factors=50, y_range=(0, 5.5))" ], "execution_count": 53, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "0e489As1e9Th", "outputId": "873671c4-34d7-460b-d460-6360169fb5e8", "colab": { "base_uri": "https://localhost:8080/", "height": 121 } }, "source": [ "learn.model, learn.loss_func" ], "execution_count": 54, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(EmbeddingDotBias(\n", " (u_weight): Embedding(944, 50)\n", " (i_weight): Embedding(1665, 50)\n", " (u_bias): Embedding(944, 1)\n", " (i_bias): Embedding(1665, 1)\n", " ), FlattenedLoss of MSELoss())" ] }, "metadata": { "tags": [] }, "execution_count": 54 } ] }, { "cell_type": "code", "metadata": { "id": "ZIJroJerfATp", "outputId": "c543e862-4d14-4ade-c0bf-f4560212eb69", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-5, wd=0.1)" ], "execution_count": 55, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.894737</td>\n", " <td>1.829872</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.807373</td>\n", " <td>1.755659</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.684563</td>\n", " <td>1.698134</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.673523</td>\n", " <td>1.668491</td>\n", " <td>00:18</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>1.664241</td>\n", " <td>1.663463</td>\n", " <td>00:18</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "_Omm2PYTfIvz", "outputId": "fdcd73d5-6a79-4df8-9f59-c762909095cd", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "dls.classes['title'].o2i['2 Days in the Valley (1996)']" ], "execution_count": 56, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "6" ] }, "metadata": { "tags": [] }, "execution_count": 56 } ] }, { "cell_type": "code", "metadata": { "id": "buDN2R2Qf_mE", "outputId": "ac686ec7-63db-43d0-c0d9-49d99edb8921", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "(torch.nn.CosineSimilarity(dim=1)(learn.model.i_weight.weight[6][None], learn.model.i_weight.weight)).argsort(descending=True)[1]" ], "execution_count": 57, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(37, device='cuda:0')" ] }, "metadata": { "tags": [] }, "execution_count": 57 } ] }, { "cell_type": "code", "metadata": { "id": "DZl4FVkYgtK8", "outputId": "0291cda3-c609-4ba1-c577-3e2eb955ac45", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "dls.classes['title'][686]" ], "execution_count": 58, "outputs": [ { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'Highlander (1986)'" ] }, "metadata": { "tags": [] }, "execution_count": 58 } ] }, { "cell_type": "markdown", "metadata": { "id": "1NyNymNplOn2" }, "source": [ "## Using deep learning" ] }, { "cell_type": "code", "metadata": { "id": "Qtj1sDoWktPs" }, "source": [ "class CollabNN(c.Module):\n", " def __init__(self, user_sz, item_sz, n_act=100, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(*user_sz)\n", " self.movie_factors = c.Embedding(*item_sz)\n", " self.layers = torch.nn.Sequential(\n", " torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(n_act, 1)\n", " )\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n", " x = self.layers(torch.cat(embs, dim=1))\n", "\n", " return c.sigmoid_range(x, *self.y_range)" ], "execution_count": 59, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Xrj70IfSm04W", "outputId": "1dc3be0a-c23f-49be-a635-ef3cd64a5bba", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "sz = c.get_emb_sz(dls)\n", "sz" ], "execution_count": 60, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[(944, 74), (1665, 102)]" ] }, "metadata": { "tags": [] }, "execution_count": 60 } ] }, { "cell_type": "code", "metadata": { "id": "JUe26GT4m78Z" }, "source": [ "model = CollabNN(*sz)" ], "execution_count": 61, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "oSFfmWHJnLKX", "outputId": "93dc3396-1890-41e8-a703-81b13f363480", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.01)" ], "execution_count": 62, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>0.942369</td>\n", " <td>0.967605</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.902668</td>\n", " <td>0.912845</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.868656</td>\n", " <td>0.884016</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.841839</td>\n", " <td>0.872575</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.764020</td>\n", " <td>0.873379</td>\n", " <td>00:20</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "VAYV4lo6qGjV", "outputId": "5bea2e98-95be-4343-8918-6137a3dc4481", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn = c.collab_learner(dls, use_nn=True, layers=[10, 50], y_range=(0, 5.5))\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 63, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>0.978604</td>\n", " <td>0.979922</td>\n", " <td>00:23</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.958372</td>\n", " <td>0.930927</td>\n", " <td>00:24</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.900048</td>\n", " <td>0.904019</td>\n", " <td>00:23</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.877367</td>\n", " <td>0.876002</td>\n", " <td>00:23</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.816585</td>\n", " <td>0.871966</td>\n", " <td>00:24</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "DttVouzZqwY1", "outputId": "855f0e14-7deb-4701-99cd-6de2d73d4efa", "colab": { "base_uri": "https://localhost:8080/", "height": 434 } }, "source": [ "learn.model" ], "execution_count": 64, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "EmbeddingNN(\n", " (embeds): ModuleList(\n", " (0): Embedding(944, 74)\n", " (1): Embedding(1665, 102)\n", " )\n", " (emb_drop): Dropout(p=0.0, inplace=False)\n", " (bn_cont): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): LinBnDrop(\n", " (0): BatchNorm1d(176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Linear(in_features=176, out_features=10, bias=False)\n", " (2): ReLU(inplace=True)\n", " )\n", " (1): LinBnDrop(\n", " (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Linear(in_features=10, out_features=50, bias=False)\n", " (2): ReLU(inplace=True)\n", " )\n", " (2): LinBnDrop(\n", " (0): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", " (3): SigmoidRange(low=0, high=5.5)\n", " )\n", ")" ] }, "metadata": { "tags": [] }, "execution_count": 64 } ] }, { "cell_type": "code", "metadata": { "id": "5P_LC2p0rW9J", "outputId": "c98e77a9-f757-4ad5-f321-c883e4dd4e45", "colab": { "base_uri": "https://localhost:8080/", "height": 173 } }, "source": [ "CollabNN((944, 74), (1665, 102))" ], "execution_count": 65, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "CollabNN(\n", " (user_factors): Embedding(944, 74)\n", " (movie_factors): Embedding(1665, 102)\n", " (layers): Sequential(\n", " (0): Linear(in_features=176, out_features=100, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=100, out_features=1, bias=True)\n", " )\n", ")" ] }, "metadata": { "tags": [] }, "execution_count": 65 } ] }, { "cell_type": "code", "metadata": { "id": "arUEK0RbvuI8", "outputId": "2c3f11db-f504-44fd-bd46-8b70278d99b2", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "t.delegates" ], "execution_count": 66, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "<function fastcore.meta.delegates>" ] }, "metadata": { "tags": [] }, "execution_count": 66 } ] }, { "cell_type": "code", "metadata": { "id": "WKWbh8YKsArO" }, "source": [ "@t.delegates(t.TabularModel)\n", "class EmbeddingNN(t.TabularModel):\n", " def __init__(self, emb_sz, layers, **kwargs):\n", " super().__init__(emb_sz, layers=layers, n_cont=0, out_sz=1, **kwargs)" ], "execution_count": 67, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Y22fCzLzuIAt" }, "source": [ "### kwargs && args" ] }, { "cell_type": "markdown", "metadata": { "id": "Cp4Rd-D7wy1I" }, "source": [ "kwargs" ] }, { "cell_type": "code", "metadata": { "id": "xQnyy_I0s3MR" }, "source": [ "def s(a, **kwargs): \n", " return kwargs" ], "execution_count": 68, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "wRH7RrBjtIsJ", "outputId": "d272d756-cd43-4d55-f6cf-445ca1b6ae9c", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "s(2, s='dhdj', j=3)" ], "execution_count": 69, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'j': 3, 's': 'dhdj'}" ] }, "metadata": { "tags": [] }, "execution_count": 69 } ] }, { "cell_type": "markdown", "metadata": { "id": "T4C4MTfhwxYJ" }, "source": [ "args" ] }, { "cell_type": "code", "metadata": { "id": "hFbqzi1Et_Yb" }, "source": [ "def a(b, *size): return b, size" ], "execution_count": 70, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "F99TCHZVwjNz", "outputId": "43bef199-62e5-47bf-9176-279bed4d6b0b", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "a(2, 3, 'klk', 5)" ], "execution_count": 71, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(2, (3, 'klk', 5))" ] }, "metadata": { "tags": [] }, "execution_count": 71 } ] }, { "cell_type": "code", "metadata": { "id": "voovs7v5yLfn", "outputId": "d5fd900c-1807-4de2-daf4-879f91e572ca", "colab": { "base_uri": "https://localhost:8080/", "height": 332 } }, "source": [ "learn.show_results()" ], "execution_count": 72, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>user</th>\n", " <th>title</th>\n", " <th>rating</th>\n", " <th>rating_pred</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>621</td>\n", " <td>1366</td>\n", " <td>2</td>\n", " <td>4.168081</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>577</td>\n", " <td>1498</td>\n", " <td>4</td>\n", " <td>3.707471</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>846</td>\n", " <td>920</td>\n", " <td>4</td>\n", " <td>4.199018</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>793</td>\n", " <td>785</td>\n", " <td>2</td>\n", " <td>2.995941</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>336</td>\n", " <td>407</td>\n", " <td>3</td>\n", " <td>3.256781</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>394</td>\n", " <td>1544</td>\n", " <td>5</td>\n", " <td>4.228620</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>345</td>\n", " <td>1336</td>\n", " <td>4</td>\n", " <td>4.130409</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>815</td>\n", " <td>1581</td>\n", " <td>4</td>\n", " <td>4.287453</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>99</td>\n", " <td>552</td>\n", " <td>4</td>\n", " <td>4.037339</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "Fpl_W8kG_It8" }, "source": [ "## Using cross entropy loss" ] }, { "cell_type": "code", "metadata": { "id": "P6_WKrdT-WnY" }, "source": [ "df1 = df.copy()\n", "df1['rating'] = df1['rating']-1" ], "execution_count": 73, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FA9yUSor-VgU", "outputId": "c5c34ea5-8333-450f-dec7-5d7e80d6290e", "colab": { "base_uri": "https://localhost:8080/", "height": 363 } }, "source": [ "dls = c.CollabDataLoaders.from_df(\n", " df1,\n", " user_name='user',\n", " item_name='title',\n", " rating_name='rating',\n", " bs=32\n", ")\n", "\n", "dls.show_batch()" ], "execution_count": 74, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>user</th>\n", " <th>title</th>\n", " <th>rating</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>891</td>\n", " <td>Courage Under Fire (1996)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>806</td>\n", " <td>Raiders of the Lost Ark (1981)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>790</td>\n", " <td>Romeo Is Bleeding (1993)</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>311</td>\n", " <td>Last of the Mohicans, The (1992)</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>790</td>\n", " <td>Emma (1996)</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>151</td>\n", " <td>Crimson Tide (1995)</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>489</td>\n", " <td>Devil's Advocate, The (1997)</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>314</td>\n", " <td>Pallbearer, The (1996)</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>411</td>\n", " <td>Rear Window (1954)</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>483</td>\n", " <td>Powder (1995)</td>\n", " <td>1</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "8Aa0_9dL0yhn" }, "source": [ "class CollabNN(c.Module):\n", " def __init__(self, user_sz, item_sz, n_act=100):\n", " super().__init__()\n", " self.user_factors = c.Embedding(*user_sz)\n", " self.movie_factors = c.Embedding(*item_sz)\n", " self.layers = torch.nn.Sequential(\n", " torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(n_act, 5)\n", " )\n", "\n", " def forward(self, x):\n", " embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n", " x = self.layers(torch.cat(embs, dim=1))\n", "\n", " return x" ], "execution_count": 75, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "1zl2oS8607PD" }, "source": [ "model = CollabNN(*sz)\n", "learn = c.Learner(dls, model, loss_func=c.CrossEntropyLossFlat(), metrics=c.accuracy)" ], "execution_count": 76, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_4TX0Hvt0_EU", "outputId": "39180f1e-4a9c-4471-b6c0-3cfacf32e1d6", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-3, wd=0.01)" ], "execution_count": 77, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.284872</td>\n", " <td>1.289579</td>\n", " <td>0.418500</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.248881</td>\n", " <td>1.264122</td>\n", " <td>0.440350</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.215094</td>\n", " <td>1.232578</td>\n", " <td>0.447150</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.162452</td>\n", " <td>1.226760</td>\n", " <td>0.456100</td>\n", " <td>00:20</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>1.091269</td>\n", " <td>1.234871</td>\n", " <td>0.454250</td>\n", " <td>00:20</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] } ] }