{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "fastai_collab.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyN1v8T4yqV4jCEYOlmXCsEI", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "id": "12UvxQcEEo4o" }, "source": [ "!pip install fastai --upgrade" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "neBLmF00EsVi" }, "source": [ "import fastai.collab as c\n", "import fastai.tabular.all as t\n", "import torch\n", "\n", "import pandas as pd\n", "import os\n", "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"" ], "execution_count": 25, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "pZHSdM3ZE3XT", "outputId": "e5ea3cdf-5399-4219-9a42-d0b5eb451d02", "colab": { "base_uri": "https://localhost:8080/", "height": 54 } }, "source": [ "path = c.untar_data(c.URLs.ML_100k)\n", "path.ls()" ], "execution_count": 26, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#23) [Path('/root/.fastai/data/ml-100k/ub.base'),Path('/root/.fastai/data/ml-100k/u5.base'),Path('/root/.fastai/data/ml-100k/u5.test'),Path('/root/.fastai/data/ml-100k/u.info'),Path('/root/.fastai/data/ml-100k/u2.base'),Path('/root/.fastai/data/ml-100k/u3.base'),Path('/root/.fastai/data/ml-100k/ua.test'),Path('/root/.fastai/data/ml-100k/u4.test'),Path('/root/.fastai/data/ml-100k/mku.sh'),Path('/root/.fastai/data/ml-100k/u.genre')...]" ] }, "metadata": { "tags": [] }, "execution_count": 26 } ] }, { "cell_type": "code", "metadata": { "id": "ldPlun7NFBqw", "outputId": "99e6eadd-d2dd-4e69-9d49-da0d5ab40f1b", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "df = pd.read_csv(path/'u.data', delimiter='\\t', names=['user', 'movie', 'rating', 'timestamp'])\n", "df.head()" ], "execution_count": 27, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usermovieratingtimestamp
01962423881250949
11863023891717742
2223771878887116
3244512880606923
41663461886397596
\n", "
" ], "text/plain": [ " user movie rating timestamp\n", "0 196 242 3 881250949\n", "1 186 302 3 891717742\n", "2 22 377 1 878887116\n", "3 244 51 2 880606923\n", "4 166 346 1 886397596" ] }, "metadata": { "tags": [] }, "execution_count": 27 } ] }, { "cell_type": "code", "metadata": { "id": "SnFX7W8NF3K2", "outputId": "f79c3180-7c71-42db-c167-66f8183e9910", "colab": { "base_uri": "https://localhost:8080/", "height": 225 } }, "source": [ "df.info(), df.shape" ], "execution_count": 28, "outputs": [ { "output_type": "stream", "text": [ "\n", "RangeIndex: 100000 entries, 0 to 99999\n", "Data columns (total 4 columns):\n", " # Column Non-Null Count Dtype\n", "--- ------ -------------- -----\n", " 0 user 100000 non-null int64\n", " 1 movie 100000 non-null int64\n", " 2 rating 100000 non-null int64\n", " 3 timestamp 100000 non-null int64\n", "dtypes: int64(4)\n", "memory usage: 3.1 MB\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "(None, (100000, 4))" ] }, "metadata": { "tags": [] }, "execution_count": 28 } ] }, { "cell_type": "code", "metadata": { "id": "LHL4rnnPGf2E", "outputId": "a323bdaa-2c89-49ac-8816-f749451c6294", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "movies = pd.read_csv(path/'u.item', delimiter='|', names=['movie', 'title'], encoding='latin-1', usecols=(0, 1))\n", "movies.head()" ], "execution_count": 29, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
movietitle
01Toy Story (1995)
12GoldenEye (1995)
23Four Rooms (1995)
34Get Shorty (1995)
45Copycat (1995)
\n", "
" ], "text/plain": [ " movie title\n", "0 1 Toy Story (1995)\n", "1 2 GoldenEye (1995)\n", "2 3 Four Rooms (1995)\n", "3 4 Get Shorty (1995)\n", "4 5 Copycat (1995)" ] }, "metadata": { "tags": [] }, "execution_count": 29 } ] }, { "cell_type": "code", "metadata": { "id": "pL6bEndiIzyT", "outputId": "026ecd26-0e1c-4854-b968-0b1ff5d21773", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "df = df.merge(movies)\n", "df.head()" ], "execution_count": 30, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usermovieratingtimestamptitle
01962423881250949Kolya (1996)
1632423875747190Kolya (1996)
22262425883888671Kolya (1996)
31542423879138235Kolya (1996)
43062425876503793Kolya (1996)
\n", "
" ], "text/plain": [ " user movie rating timestamp title\n", "0 196 242 3 881250949 Kolya (1996)\n", "1 63 242 3 875747190 Kolya (1996)\n", "2 226 242 5 883888671 Kolya (1996)\n", "3 154 242 3 879138235 Kolya (1996)\n", "4 306 242 5 876503793 Kolya (1996)" ] }, "metadata": { "tags": [] }, "execution_count": 30 } ] }, { "cell_type": "code", "metadata": { "id": "YXShxk83Jmnm", "outputId": "19e3c40e-aa0a-4cb9-8f28-446988ea5ee7", "colab": { "base_uri": "https://localhost:8080/", "height": 363 } }, "source": [ "dls = c.CollabDataLoaders.from_df(\n", " df,\n", " user_name='user',\n", " item_name='title',\n", " rating_name='rating',\n", " bs=32\n", ")\n", "\n", "dls.show_batch()" ], "execution_count": 31, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usertitlerating
0303Shall We Dance? (1937)4
1880Get Shorty (1995)4
2355Kolya (1996)4
3334Sound of Music, The (1965)2
4826Striking Distance (1993)3
5942It Happened One Night (1934)4
6289Time to Kill, A (1996)3
7405My Left Foot (1989)1
8452Fantasia (1940)2
9747Thirty-Two Short Films About Glenn Gould (1993)3
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "NTOLvhrqKgs0", "outputId": "7cedb26d-ecb2-44f6-bb0a-f436a752ef96", "colab": { "base_uri": "https://localhost:8080/", "height": 72 } }, "source": [ "dls.classes" ], "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'title': (#1665) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...],\n", " 'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...]}" ] }, "metadata": { "tags": [] }, "execution_count": 32 } ] }, { "cell_type": "code", "metadata": { "id": "_lneJ6HaKqZR" }, "source": [ "n_users = len(dls.classes['user'])\n", "n_mov = len(dls.classes['title'])\n", "n_factors = 5\n", "\n", "user_factors = torch.randn(n_users, n_factors)\n", "mov_factors = torch.randn(n_mov, n_factors)" ], "execution_count": 33, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "BZRvoW1eLL02", "outputId": "999f3bdc-f7fc-4ba0-e4ed-fa8588018a99", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors.shape, mov_factors.shape" ], "execution_count": 34, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([944, 5]), torch.Size([1665, 5]))" ] }, "metadata": { "tags": [] }, "execution_count": 34 } ] }, { "cell_type": "code", "metadata": { "id": "SDCVJkEkLXBX", "outputId": "9a482c2a-9b4c-4889-c82d-f9638a7df559", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors.t()@t.one_hot(3, n_users).float() " ], "execution_count": 35, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])" ] }, "metadata": { "tags": [] }, "execution_count": 35 } ] }, { "cell_type": "code", "metadata": { "id": "qaAy_vwAMQGG", "outputId": "e566925f-4ba7-415f-e29f-1023f4b31c10", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "user_factors[3]" ], "execution_count": 36, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])" ] }, "metadata": { "tags": [] }, "execution_count": 36 } ] }, { "cell_type": "markdown", "metadata": { "id": "cz5KC59ASGch" }, "source": [ "#Base line model" ] }, { "cell_type": "code", "metadata": { "id": "7XsV_cbKM0Ez" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " return (users*movies).sum(dim=1)" ], "execution_count": 37, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WLLPAeKZPanH", "outputId": "d1319df4-c0f9-4c46-a226-558297ee7177", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "dls.one_batch()[0].shape, dls.one_batch()[1].shape" ], "execution_count": 38, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([32, 2]), torch.Size([32, 1]))" ] }, "metadata": { "tags": [] }, "execution_count": 38 } ] }, { "cell_type": "code", "metadata": { "id": "ZnGFkHQ6Pofs" }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())" ], "execution_count": 39, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "-UM8TmTOQsxV", "outputId": "bff9753b-57a4-4160-ef09-34bca03e11cb", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 40, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
01.2808561.32404700:16
11.1391971.14574100:16
20.9515271.01172300:16
30.7800490.89713000:16
40.7471440.87431700:16
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "APF-yfBMSeWG" }, "source": [ "# Adding sigmoid range" ] }, { "cell_type": "code", "metadata": { "id": "Z3LOXdCcRHgu" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " return c.sigmoid_range((users*movies).sum(dim=1), *self.y_range)" ], "execution_count": 41, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SEpfIhH2Qzef", "outputId": "23177f6d-af02-4827-c628-d1d90afe3300", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 42, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
01.0382981.00994300:16
10.8808990.92262400:16
20.6851870.89217500:16
30.4530280.90117900:16
40.3085900.90844600:16
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "DaC8i0nVSNYU" }, "source": [ "#Adding bias parameter" ] }, { "cell_type": "code", "metadata": { "id": "PH2yzzsPSk87" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(n_users, n_factors)\n", " self.user_bias = c.Embedding(n_users, 1)\n", "\n", " self.movie_factors = c.Embedding(n_movies, n_factors)\n", " self.movie_bias = c.Embedding(n_movies, 1)\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors(x[:, 0])\n", " movies = self.movie_factors(x[:, 1])\n", "\n", " out = (users*movies).sum(dim=1, keepdim=True)\n", " out += self.user_bias(x[:, 0]) + self.movie_bias(x[:, 1])\n", "\n", " return c.sigmoid_range(out, *self.y_range)" ], "execution_count": 43, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FeTFDuhBSk9O", "outputId": "99da5355-1941-4e0c-81df-7b5d90888b26", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3)" ], "execution_count": 44, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9811480.94592500:18
10.8426180.88219500:18
20.6046460.91361700:18
30.4072410.94196600:18
40.2373490.95069600:18
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "rpLDE7RwSMvZ", "outputId": "8783ea29-ce75-401d-b01d-8ea78964025d", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 45, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9795080.94165400:18
10.8580300.89845200:18
20.7989190.85702100:18
30.6721520.82714300:18
40.5450600.82484400:18
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "bVaSDYAaWP3z" }, "source": [ "## Embedding layer from scratch" ] }, { "cell_type": "code", "metadata": { "id": "RTnq89LkVh4w" }, "source": [ "def create_params(size): return torch.nn.Parameter(torch.zeros(*size).normal_(0, 0.01))" ], "execution_count": 46, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Wg_XfhQ3WkXv", "outputId": "81e90c5f-9045-4403-ff01-8a22323087a8", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "source": [ "create_params((3, 4))" ], "execution_count": 47, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Parameter containing:\n", "tensor([[ 0.0043, -0.0059, -0.0025, -0.0032],\n", " [ 0.0016, -0.0125, 0.0049, 0.0115],\n", " [ 0.0234, 0.0013, -0.0026, -0.0069]], requires_grad=True)" ] }, "metadata": { "tags": [] }, "execution_count": 47 } ] }, { "cell_type": "code", "metadata": { "id": "1CYDjgIGWmu3" }, "source": [ "class DotProduct(torch.nn.Module):\n", " def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = create_params((n_users, n_factors))\n", " self.user_bias = create_params((n_users, 1))\n", "\n", " self.movie_factors = create_params((n_movies, n_factors))\n", " self.movie_bias = create_params((n_movies, 1))\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " users = self.user_factors[x[:, 0]]\n", " movies = self.movie_factors[x[:, 1]]\n", "\n", " out = (users*movies).sum(dim=1, keepdim=True)\n", " out += self.user_bias[x[:, 0]] + self.movie_bias[x[:, 1]]\n", "\n", " return c.sigmoid_range(out, *self.y_range)" ], "execution_count": 48, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ivmygidxW8LV", "outputId": "c1c399a1-04f7-46e9-83b4-ddb75d8c09d3", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "source": [ "for p in DotProduct(3, 4, 5).parameters():\n", " print(p.shape)" ], "execution_count": 49, "outputs": [ { "output_type": "stream", "text": [ "torch.Size([3, 5])\n", "torch.Size([3, 1])\n", "torch.Size([4, 5])\n", "torch.Size([4, 1])\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "_X1lQzkmXCSX", "outputId": "387a8399-a372-4cf3-afc3-2250c8064d12", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "model = DotProduct(n_users, n_mov, 50)\n", "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 50, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9345590.95900500:19
10.8475940.88953600:19
20.7975650.85140200:19
30.7091360.82215200:19
40.5851610.82167200:19
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "ZzbSdATHYBi8", "outputId": "8dcb5ca1-06e6-4cba-b20b-04b6aadc0894", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "source": [ "b = learn.model.movie_bias.squeeze()\n", "idx = b.argsort()[:5]\n", "\n", "[dls.classes['title'][i] for i in idx]" ], "execution_count": 51, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['Children of the Corn: The Gathering (1996)',\n", " 'Mortal Kombat: Annihilation (1997)',\n", " 'Crow: City of Angels, The (1996)',\n", " 'Island of Dr. Moreau, The (1996)',\n", " 'Lawnmower Man 2: Beyond Cyberspace (1996)']" ] }, "metadata": { "tags": [] }, "execution_count": 51 } ] }, { "cell_type": "code", "metadata": { "id": "WLnDvYYgaFKn", "outputId": "697abb6f-30d5-412e-add3-59c336225c9c", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "source": [ "b = learn.model.movie_bias.squeeze()\n", "idx = b.argsort(descending=True)[:5]\n", "\n", "[dls.classes['title'][i] for i in idx]" ], "execution_count": 52, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['Titanic (1997)',\n", " \"Schindler's List (1993)\",\n", " 'Star Wars (1977)',\n", " 'Silence of the Lambs, The (1991)',\n", " 'L.A. Confidential (1997)']" ] }, "metadata": { "tags": [] }, "execution_count": 52 } ] }, { "cell_type": "markdown", "metadata": { "id": "6yyEP2jFeusk" }, "source": [ "# Training using fastai collab_learner" ] }, { "cell_type": "code", "metadata": { "id": "sDYjMvOccvcH" }, "source": [ "learn = c.collab_learner(dls, n_factors=50, y_range=(0, 5.5))" ], "execution_count": 53, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "0e489As1e9Th", "outputId": "873671c4-34d7-460b-d460-6360169fb5e8", "colab": { "base_uri": "https://localhost:8080/", "height": 121 } }, "source": [ "learn.model, learn.loss_func" ], "execution_count": 54, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(EmbeddingDotBias(\n", " (u_weight): Embedding(944, 50)\n", " (i_weight): Embedding(1665, 50)\n", " (u_bias): Embedding(944, 1)\n", " (i_bias): Embedding(1665, 1)\n", " ), FlattenedLoss of MSELoss())" ] }, "metadata": { "tags": [] }, "execution_count": 54 } ] }, { "cell_type": "code", "metadata": { "id": "ZIJroJerfATp", "outputId": "c543e862-4d14-4ade-c0bf-f4560212eb69", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-5, wd=0.1)" ], "execution_count": 55, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
01.8947371.82987200:18
11.8073731.75565900:18
21.6845631.69813400:18
31.6735231.66849100:18
41.6642411.66346300:18
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "_Omm2PYTfIvz", "outputId": "fdcd73d5-6a79-4df8-9f59-c762909095cd", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "dls.classes['title'].o2i['2 Days in the Valley (1996)']" ], "execution_count": 56, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "6" ] }, "metadata": { "tags": [] }, "execution_count": 56 } ] }, { "cell_type": "code", "metadata": { "id": "buDN2R2Qf_mE", "outputId": "ac686ec7-63db-43d0-c0d9-49d99edb8921", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "(torch.nn.CosineSimilarity(dim=1)(learn.model.i_weight.weight[6][None], learn.model.i_weight.weight)).argsort(descending=True)[1]" ], "execution_count": 57, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(37, device='cuda:0')" ] }, "metadata": { "tags": [] }, "execution_count": 57 } ] }, { "cell_type": "code", "metadata": { "id": "DZl4FVkYgtK8", "outputId": "0291cda3-c609-4ba1-c577-3e2eb955ac45", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "dls.classes['title'][686]" ], "execution_count": 58, "outputs": [ { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'Highlander (1986)'" ] }, "metadata": { "tags": [] }, "execution_count": 58 } ] }, { "cell_type": "markdown", "metadata": { "id": "1NyNymNplOn2" }, "source": [ "## Using deep learning" ] }, { "cell_type": "code", "metadata": { "id": "Qtj1sDoWktPs" }, "source": [ "class CollabNN(c.Module):\n", " def __init__(self, user_sz, item_sz, n_act=100, y_range=(0, 5.5)):\n", " super().__init__()\n", " self.user_factors = c.Embedding(*user_sz)\n", " self.movie_factors = c.Embedding(*item_sz)\n", " self.layers = torch.nn.Sequential(\n", " torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(n_act, 1)\n", " )\n", "\n", " self.y_range = y_range\n", "\n", " def forward(self, x):\n", " embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n", " x = self.layers(torch.cat(embs, dim=1))\n", "\n", " return c.sigmoid_range(x, *self.y_range)" ], "execution_count": 59, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Xrj70IfSm04W", "outputId": "1dc3be0a-c23f-49be-a635-ef3cd64a5bba", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "sz = c.get_emb_sz(dls)\n", "sz" ], "execution_count": 60, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[(944, 74), (1665, 102)]" ] }, "metadata": { "tags": [] }, "execution_count": 60 } ] }, { "cell_type": "code", "metadata": { "id": "JUe26GT4m78Z" }, "source": [ "model = CollabNN(*sz)" ], "execution_count": 61, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "oSFfmWHJnLKX", "outputId": "93dc3396-1890-41e8-a703-81b13f363480", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn = c.Learner(dls, model, c.MSELossFlat())\n", "learn.fit_one_cycle(5, 5e-3, wd=0.01)" ], "execution_count": 62, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9423690.96760500:20
10.9026680.91284500:20
20.8686560.88401600:20
30.8418390.87257500:20
40.7640200.87337900:20
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "VAYV4lo6qGjV", "outputId": "5bea2e98-95be-4343-8918-6137a3dc4481", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn = c.collab_learner(dls, use_nn=True, layers=[10, 50], y_range=(0, 5.5))\n", "learn.fit_one_cycle(5, 5e-3, wd=0.1)" ], "execution_count": 63, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.9786040.97992200:23
10.9583720.93092700:24
20.9000480.90401900:23
30.8773670.87600200:23
40.8165850.87196600:24
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "DttVouzZqwY1", "outputId": "855f0e14-7deb-4701-99cd-6de2d73d4efa", "colab": { "base_uri": "https://localhost:8080/", "height": 434 } }, "source": [ "learn.model" ], "execution_count": 64, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "EmbeddingNN(\n", " (embeds): ModuleList(\n", " (0): Embedding(944, 74)\n", " (1): Embedding(1665, 102)\n", " )\n", " (emb_drop): Dropout(p=0.0, inplace=False)\n", " (bn_cont): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers): Sequential(\n", " (0): LinBnDrop(\n", " (0): BatchNorm1d(176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Linear(in_features=176, out_features=10, bias=False)\n", " (2): ReLU(inplace=True)\n", " )\n", " (1): LinBnDrop(\n", " (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (1): Linear(in_features=10, out_features=50, bias=False)\n", " (2): ReLU(inplace=True)\n", " )\n", " (2): LinBnDrop(\n", " (0): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", " (3): SigmoidRange(low=0, high=5.5)\n", " )\n", ")" ] }, "metadata": { "tags": [] }, "execution_count": 64 } ] }, { "cell_type": "code", "metadata": { "id": "5P_LC2p0rW9J", "outputId": "c98e77a9-f757-4ad5-f321-c883e4dd4e45", "colab": { "base_uri": "https://localhost:8080/", "height": 173 } }, "source": [ "CollabNN((944, 74), (1665, 102))" ], "execution_count": 65, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "CollabNN(\n", " (user_factors): Embedding(944, 74)\n", " (movie_factors): Embedding(1665, 102)\n", " (layers): Sequential(\n", " (0): Linear(in_features=176, out_features=100, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=100, out_features=1, bias=True)\n", " )\n", ")" ] }, "metadata": { "tags": [] }, "execution_count": 65 } ] }, { "cell_type": "code", "metadata": { "id": "arUEK0RbvuI8", "outputId": "2c3f11db-f504-44fd-bd46-8b70278d99b2", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "t.delegates" ], "execution_count": 66, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 66 } ] }, { "cell_type": "code", "metadata": { "id": "WKWbh8YKsArO" }, "source": [ "@t.delegates(t.TabularModel)\n", "class EmbeddingNN(t.TabularModel):\n", " def __init__(self, emb_sz, layers, **kwargs):\n", " super().__init__(emb_sz, layers=layers, n_cont=0, out_sz=1, **kwargs)" ], "execution_count": 67, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Y22fCzLzuIAt" }, "source": [ "### kwargs && args" ] }, { "cell_type": "markdown", "metadata": { "id": "Cp4Rd-D7wy1I" }, "source": [ "kwargs" ] }, { "cell_type": "code", "metadata": { "id": "xQnyy_I0s3MR" }, "source": [ "def s(a, **kwargs): \n", " return kwargs" ], "execution_count": 68, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "wRH7RrBjtIsJ", "outputId": "d272d756-cd43-4d55-f6cf-445ca1b6ae9c", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "s(2, s='dhdj', j=3)" ], "execution_count": 69, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'j': 3, 's': 'dhdj'}" ] }, "metadata": { "tags": [] }, "execution_count": 69 } ] }, { "cell_type": "markdown", "metadata": { "id": "T4C4MTfhwxYJ" }, "source": [ "args" ] }, { "cell_type": "code", "metadata": { "id": "hFbqzi1Et_Yb" }, "source": [ "def a(b, *size): return b, size" ], "execution_count": 70, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "F99TCHZVwjNz", "outputId": "43bef199-62e5-47bf-9176-279bed4d6b0b", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "a(2, 3, 'klk', 5)" ], "execution_count": 71, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(2, (3, 'klk', 5))" ] }, "metadata": { "tags": [] }, "execution_count": 71 } ] }, { "cell_type": "code", "metadata": { "id": "voovs7v5yLfn", "outputId": "d5fd900c-1807-4de2-daf4-879f91e572ca", "colab": { "base_uri": "https://localhost:8080/", "height": 332 } }, "source": [ "learn.show_results()" ], "execution_count": 72, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usertitleratingrating_pred
0621136624.168081
1577149843.707471
284692044.199018
379378522.995941
433640733.256781
5394154454.228620
6345133644.130409
7815158144.287453
89955244.037339
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "Fpl_W8kG_It8" }, "source": [ "## Using cross entropy loss" ] }, { "cell_type": "code", "metadata": { "id": "P6_WKrdT-WnY" }, "source": [ "df1 = df.copy()\n", "df1['rating'] = df1['rating']-1" ], "execution_count": 73, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FA9yUSor-VgU", "outputId": "c5c34ea5-8333-450f-dec7-5d7e80d6290e", "colab": { "base_uri": "https://localhost:8080/", "height": 363 } }, "source": [ "dls = c.CollabDataLoaders.from_df(\n", " df1,\n", " user_name='user',\n", " item_name='title',\n", " rating_name='rating',\n", " bs=32\n", ")\n", "\n", "dls.show_batch()" ], "execution_count": 74, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
usertitlerating
0891Courage Under Fire (1996)4
1806Raiders of the Lost Ark (1981)4
2790Romeo Is Bleeding (1993)1
3311Last of the Mohicans, The (1992)2
4790Emma (1996)1
5151Crimson Tide (1995)2
6489Devil's Advocate, The (1997)3
7314Pallbearer, The (1996)1
8411Rear Window (1954)4
9483Powder (1995)1
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "8Aa0_9dL0yhn" }, "source": [ "class CollabNN(c.Module):\n", " def __init__(self, user_sz, item_sz, n_act=100):\n", " super().__init__()\n", " self.user_factors = c.Embedding(*user_sz)\n", " self.movie_factors = c.Embedding(*item_sz)\n", " self.layers = torch.nn.Sequential(\n", " torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(n_act, 5)\n", " )\n", "\n", " def forward(self, x):\n", " embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n", " x = self.layers(torch.cat(embs, dim=1))\n", "\n", " return x" ], "execution_count": 75, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "1zl2oS8607PD" }, "source": [ "model = CollabNN(*sz)\n", "learn = c.Learner(dls, model, loss_func=c.CrossEntropyLossFlat(), metrics=c.accuracy)" ], "execution_count": 76, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_4TX0Hvt0_EU", "outputId": "39180f1e-4a9c-4471-b6c0-3cfacf32e1d6", "colab": { "base_uri": "https://localhost:8080/", "height": 206 } }, "source": [ "learn.fit_one_cycle(5, 5e-3, wd=0.01)" ], "execution_count": 77, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.2848721.2895790.41850000:20
11.2488811.2641220.44035000:20
21.2150941.2325780.44715000:20
31.1624521.2267600.45610000:20
41.0912691.2348710.45425000:20
" ], "text/plain": [ "" ] }, "metadata": { "tags": [] } } ] } ] }