{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "fastai_collab.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyN1v8T4yqV4jCEYOlmXCsEI",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "code",
"metadata": {
"id": "12UvxQcEEo4o"
},
"source": [
"!pip install fastai --upgrade"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "neBLmF00EsVi"
},
"source": [
"import fastai.collab as c\n",
"import fastai.tabular.all as t\n",
"import torch\n",
"\n",
"import pandas as pd\n",
"import os\n",
"os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pZHSdM3ZE3XT",
"outputId": "e5ea3cdf-5399-4219-9a42-d0b5eb451d02",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
}
},
"source": [
"path = c.untar_data(c.URLs.ML_100k)\n",
"path.ls()"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#23) [Path('/root/.fastai/data/ml-100k/ub.base'),Path('/root/.fastai/data/ml-100k/u5.base'),Path('/root/.fastai/data/ml-100k/u5.test'),Path('/root/.fastai/data/ml-100k/u.info'),Path('/root/.fastai/data/ml-100k/u2.base'),Path('/root/.fastai/data/ml-100k/u3.base'),Path('/root/.fastai/data/ml-100k/ua.test'),Path('/root/.fastai/data/ml-100k/u4.test'),Path('/root/.fastai/data/ml-100k/mku.sh'),Path('/root/.fastai/data/ml-100k/u.genre')...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ldPlun7NFBqw",
"outputId": "99e6eadd-d2dd-4e69-9d49-da0d5ab40f1b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"df = pd.read_csv(path/'u.data', delimiter='\\t', names=['user', 'movie', 'rating', 'timestamp'])\n",
"df.head()"
],
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user | \n",
" movie | \n",
" rating | \n",
" timestamp | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 196 | \n",
" 242 | \n",
" 3 | \n",
" 881250949 | \n",
"
\n",
" \n",
" | 1 | \n",
" 186 | \n",
" 302 | \n",
" 3 | \n",
" 891717742 | \n",
"
\n",
" \n",
" | 2 | \n",
" 22 | \n",
" 377 | \n",
" 1 | \n",
" 878887116 | \n",
"
\n",
" \n",
" | 3 | \n",
" 244 | \n",
" 51 | \n",
" 2 | \n",
" 880606923 | \n",
"
\n",
" \n",
" | 4 | \n",
" 166 | \n",
" 346 | \n",
" 1 | \n",
" 886397596 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user movie rating timestamp\n",
"0 196 242 3 881250949\n",
"1 186 302 3 891717742\n",
"2 22 377 1 878887116\n",
"3 244 51 2 880606923\n",
"4 166 346 1 886397596"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SnFX7W8NF3K2",
"outputId": "f79c3180-7c71-42db-c167-66f8183e9910",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 225
}
},
"source": [
"df.info(), df.shape"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 100000 entries, 0 to 99999\n",
"Data columns (total 4 columns):\n",
" # Column Non-Null Count Dtype\n",
"--- ------ -------------- -----\n",
" 0 user 100000 non-null int64\n",
" 1 movie 100000 non-null int64\n",
" 2 rating 100000 non-null int64\n",
" 3 timestamp 100000 non-null int64\n",
"dtypes: int64(4)\n",
"memory usage: 3.1 MB\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(None, (100000, 4))"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LHL4rnnPGf2E",
"outputId": "a323bdaa-2c89-49ac-8816-f749451c6294",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"movies = pd.read_csv(path/'u.item', delimiter='|', names=['movie', 'title'], encoding='latin-1', usecols=(0, 1))\n",
"movies.head()"
],
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" movie | \n",
" title | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" Toy Story (1995) | \n",
"
\n",
" \n",
" | 1 | \n",
" 2 | \n",
" GoldenEye (1995) | \n",
"
\n",
" \n",
" | 2 | \n",
" 3 | \n",
" Four Rooms (1995) | \n",
"
\n",
" \n",
" | 3 | \n",
" 4 | \n",
" Get Shorty (1995) | \n",
"
\n",
" \n",
" | 4 | \n",
" 5 | \n",
" Copycat (1995) | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" movie title\n",
"0 1 Toy Story (1995)\n",
"1 2 GoldenEye (1995)\n",
"2 3 Four Rooms (1995)\n",
"3 4 Get Shorty (1995)\n",
"4 5 Copycat (1995)"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pL6bEndiIzyT",
"outputId": "026ecd26-0e1c-4854-b968-0b1ff5d21773",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"df = df.merge(movies)\n",
"df.head()"
],
"execution_count": 30,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" user | \n",
" movie | \n",
" rating | \n",
" timestamp | \n",
" title | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 196 | \n",
" 242 | \n",
" 3 | \n",
" 881250949 | \n",
" Kolya (1996) | \n",
"
\n",
" \n",
" | 1 | \n",
" 63 | \n",
" 242 | \n",
" 3 | \n",
" 875747190 | \n",
" Kolya (1996) | \n",
"
\n",
" \n",
" | 2 | \n",
" 226 | \n",
" 242 | \n",
" 5 | \n",
" 883888671 | \n",
" Kolya (1996) | \n",
"
\n",
" \n",
" | 3 | \n",
" 154 | \n",
" 242 | \n",
" 3 | \n",
" 879138235 | \n",
" Kolya (1996) | \n",
"
\n",
" \n",
" | 4 | \n",
" 306 | \n",
" 242 | \n",
" 5 | \n",
" 876503793 | \n",
" Kolya (1996) | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" user movie rating timestamp title\n",
"0 196 242 3 881250949 Kolya (1996)\n",
"1 63 242 3 875747190 Kolya (1996)\n",
"2 226 242 5 883888671 Kolya (1996)\n",
"3 154 242 3 879138235 Kolya (1996)\n",
"4 306 242 5 876503793 Kolya (1996)"
]
},
"metadata": {
"tags": []
},
"execution_count": 30
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YXShxk83Jmnm",
"outputId": "19e3c40e-aa0a-4cb9-8f28-446988ea5ee7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
}
},
"source": [
"dls = c.CollabDataLoaders.from_df(\n",
" df,\n",
" user_name='user',\n",
" item_name='title',\n",
" rating_name='rating',\n",
" bs=32\n",
")\n",
"\n",
"dls.show_batch()"
],
"execution_count": 31,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" user | \n",
" title | \n",
" rating | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 303 | \n",
" Shall We Dance? (1937) | \n",
" 4 | \n",
"
\n",
" \n",
" | 1 | \n",
" 880 | \n",
" Get Shorty (1995) | \n",
" 4 | \n",
"
\n",
" \n",
" | 2 | \n",
" 355 | \n",
" Kolya (1996) | \n",
" 4 | \n",
"
\n",
" \n",
" | 3 | \n",
" 334 | \n",
" Sound of Music, The (1965) | \n",
" 2 | \n",
"
\n",
" \n",
" | 4 | \n",
" 826 | \n",
" Striking Distance (1993) | \n",
" 3 | \n",
"
\n",
" \n",
" | 5 | \n",
" 942 | \n",
" It Happened One Night (1934) | \n",
" 4 | \n",
"
\n",
" \n",
" | 6 | \n",
" 289 | \n",
" Time to Kill, A (1996) | \n",
" 3 | \n",
"
\n",
" \n",
" | 7 | \n",
" 405 | \n",
" My Left Foot (1989) | \n",
" 1 | \n",
"
\n",
" \n",
" | 8 | \n",
" 452 | \n",
" Fantasia (1940) | \n",
" 2 | \n",
"
\n",
" \n",
" | 9 | \n",
" 747 | \n",
" Thirty-Two Short Films About Glenn Gould (1993) | \n",
" 3 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NTOLvhrqKgs0",
"outputId": "7cedb26d-ecb2-44f6-bb0a-f436a752ef96",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
}
},
"source": [
"dls.classes"
],
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'title': (#1665) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...],\n",
" 'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...]}"
]
},
"metadata": {
"tags": []
},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_lneJ6HaKqZR"
},
"source": [
"n_users = len(dls.classes['user'])\n",
"n_mov = len(dls.classes['title'])\n",
"n_factors = 5\n",
"\n",
"user_factors = torch.randn(n_users, n_factors)\n",
"mov_factors = torch.randn(n_mov, n_factors)"
],
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BZRvoW1eLL02",
"outputId": "999f3bdc-f7fc-4ba0-e4ed-fa8588018a99",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"user_factors.shape, mov_factors.shape"
],
"execution_count": 34,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(torch.Size([944, 5]), torch.Size([1665, 5]))"
]
},
"metadata": {
"tags": []
},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SDCVJkEkLXBX",
"outputId": "9a482c2a-9b4c-4889-c82d-f9638a7df559",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"user_factors.t()@t.one_hot(3, n_users).float() "
],
"execution_count": 35,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])"
]
},
"metadata": {
"tags": []
},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qaAy_vwAMQGG",
"outputId": "e566925f-4ba7-415f-e29f-1023f4b31c10",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"user_factors[3]"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([-0.6902, 0.9260, -1.2527, 0.5396, 0.1056])"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cz5KC59ASGch"
},
"source": [
"#Base line model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7XsV_cbKM0Ez"
},
"source": [
"class DotProduct(torch.nn.Module):\n",
" def __init__(self, n_users, n_movies, n_factors):\n",
" super().__init__()\n",
" self.user_factors = c.Embedding(n_users, n_factors)\n",
" self.movie_factors = c.Embedding(n_movies, n_factors)\n",
"\n",
" def forward(self, x):\n",
" users = self.user_factors(x[:, 0])\n",
" movies = self.movie_factors(x[:, 1])\n",
"\n",
" return (users*movies).sum(dim=1)"
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WLLPAeKZPanH",
"outputId": "d1319df4-c0f9-4c46-a226-558297ee7177",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"dls.one_batch()[0].shape, dls.one_batch()[1].shape"
],
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(torch.Size([32, 2]), torch.Size([32, 1]))"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZnGFkHQ6Pofs"
},
"source": [
"model = DotProduct(n_users, n_mov, 50)\n",
"learn = c.Learner(dls, model, c.MSELossFlat())"
],
"execution_count": 39,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-UM8TmTOQsxV",
"outputId": "bff9753b-57a4-4160-ef09-34bca03e11cb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"learn.fit_one_cycle(5, 5e-3)"
],
"execution_count": 40,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.280856 | \n",
" 1.324047 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.139197 | \n",
" 1.145741 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.951527 | \n",
" 1.011723 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.780049 | \n",
" 0.897130 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.747144 | \n",
" 0.874317 | \n",
" 00:16 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "APF-yfBMSeWG"
},
"source": [
"# Adding sigmoid range"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z3LOXdCcRHgu"
},
"source": [
"class DotProduct(torch.nn.Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
" super().__init__()\n",
" self.user_factors = c.Embedding(n_users, n_factors)\n",
" self.movie_factors = c.Embedding(n_movies, n_factors)\n",
" self.y_range = y_range\n",
"\n",
" def forward(self, x):\n",
" users = self.user_factors(x[:, 0])\n",
" movies = self.movie_factors(x[:, 1])\n",
"\n",
" return c.sigmoid_range((users*movies).sum(dim=1), *self.y_range)"
],
"execution_count": 41,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SEpfIhH2Qzef",
"outputId": "23177f6d-af02-4827-c628-d1d90afe3300",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"model = DotProduct(n_users, n_mov, 50)\n",
"learn = c.Learner(dls, model, c.MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
],
"execution_count": 42,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.038298 | \n",
" 1.009943 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.880899 | \n",
" 0.922624 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.685187 | \n",
" 0.892175 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.453028 | \n",
" 0.901179 | \n",
" 00:16 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.308590 | \n",
" 0.908446 | \n",
" 00:16 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DaC8i0nVSNYU"
},
"source": [
"#Adding bias parameter"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PH2yzzsPSk87"
},
"source": [
"class DotProduct(torch.nn.Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
" super().__init__()\n",
" self.user_factors = c.Embedding(n_users, n_factors)\n",
" self.user_bias = c.Embedding(n_users, 1)\n",
"\n",
" self.movie_factors = c.Embedding(n_movies, n_factors)\n",
" self.movie_bias = c.Embedding(n_movies, 1)\n",
"\n",
" self.y_range = y_range\n",
"\n",
" def forward(self, x):\n",
" users = self.user_factors(x[:, 0])\n",
" movies = self.movie_factors(x[:, 1])\n",
"\n",
" out = (users*movies).sum(dim=1, keepdim=True)\n",
" out += self.user_bias(x[:, 0]) + self.movie_bias(x[:, 1])\n",
"\n",
" return c.sigmoid_range(out, *self.y_range)"
],
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FeTFDuhBSk9O",
"outputId": "99da5355-1941-4e0c-81df-7b5d90888b26",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"model = DotProduct(n_users, n_mov, 50)\n",
"learn = c.Learner(dls, model, c.MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
],
"execution_count": 44,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.981148 | \n",
" 0.945925 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.842618 | \n",
" 0.882195 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.604646 | \n",
" 0.913617 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.407241 | \n",
" 0.941966 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.237349 | \n",
" 0.950696 | \n",
" 00:18 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rpLDE7RwSMvZ",
"outputId": "8783ea29-ce75-401d-b01d-8ea78964025d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"model = DotProduct(n_users, n_mov, 50)\n",
"learn = c.Learner(dls, model, c.MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
],
"execution_count": 45,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.979508 | \n",
" 0.941654 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.858030 | \n",
" 0.898452 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.798919 | \n",
" 0.857021 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.672152 | \n",
" 0.827143 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.545060 | \n",
" 0.824844 | \n",
" 00:18 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bVaSDYAaWP3z"
},
"source": [
"## Embedding layer from scratch"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RTnq89LkVh4w"
},
"source": [
"def create_params(size): return torch.nn.Parameter(torch.zeros(*size).normal_(0, 0.01))"
],
"execution_count": 46,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Wg_XfhQ3WkXv",
"outputId": "81e90c5f-9045-4403-ff01-8a22323087a8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
}
},
"source": [
"create_params((3, 4))"
],
"execution_count": 47,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[ 0.0043, -0.0059, -0.0025, -0.0032],\n",
" [ 0.0016, -0.0125, 0.0049, 0.0115],\n",
" [ 0.0234, 0.0013, -0.0026, -0.0069]], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 47
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1CYDjgIGWmu3"
},
"source": [
"class DotProduct(torch.nn.Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
" super().__init__()\n",
" self.user_factors = create_params((n_users, n_factors))\n",
" self.user_bias = create_params((n_users, 1))\n",
"\n",
" self.movie_factors = create_params((n_movies, n_factors))\n",
" self.movie_bias = create_params((n_movies, 1))\n",
"\n",
" self.y_range = y_range\n",
"\n",
" def forward(self, x):\n",
" users = self.user_factors[x[:, 0]]\n",
" movies = self.movie_factors[x[:, 1]]\n",
"\n",
" out = (users*movies).sum(dim=1, keepdim=True)\n",
" out += self.user_bias[x[:, 0]] + self.movie_bias[x[:, 1]]\n",
"\n",
" return c.sigmoid_range(out, *self.y_range)"
],
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ivmygidxW8LV",
"outputId": "c1c399a1-04f7-46e9-83b4-ddb75d8c09d3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
}
},
"source": [
"for p in DotProduct(3, 4, 5).parameters():\n",
" print(p.shape)"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([3, 5])\n",
"torch.Size([3, 1])\n",
"torch.Size([4, 5])\n",
"torch.Size([4, 1])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_X1lQzkmXCSX",
"outputId": "387a8399-a372-4cf3-afc3-2250c8064d12",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"model = DotProduct(n_users, n_mov, 50)\n",
"learn = c.Learner(dls, model, c.MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
],
"execution_count": 50,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.934559 | \n",
" 0.959005 | \n",
" 00:19 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.847594 | \n",
" 0.889536 | \n",
" 00:19 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.797565 | \n",
" 0.851402 | \n",
" 00:19 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.709136 | \n",
" 0.822152 | \n",
" 00:19 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.585161 | \n",
" 0.821672 | \n",
" 00:19 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZzbSdATHYBi8",
"outputId": "8dcb5ca1-06e6-4cba-b20b-04b6aadc0894",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"b = learn.model.movie_bias.squeeze()\n",
"idx = b.argsort()[:5]\n",
"\n",
"[dls.classes['title'][i] for i in idx]"
],
"execution_count": 51,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Children of the Corn: The Gathering (1996)',\n",
" 'Mortal Kombat: Annihilation (1997)',\n",
" 'Crow: City of Angels, The (1996)',\n",
" 'Island of Dr. Moreau, The (1996)',\n",
" 'Lawnmower Man 2: Beyond Cyberspace (1996)']"
]
},
"metadata": {
"tags": []
},
"execution_count": 51
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WLnDvYYgaFKn",
"outputId": "697abb6f-30d5-412e-add3-59c336225c9c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"b = learn.model.movie_bias.squeeze()\n",
"idx = b.argsort(descending=True)[:5]\n",
"\n",
"[dls.classes['title'][i] for i in idx]"
],
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Titanic (1997)',\n",
" \"Schindler's List (1993)\",\n",
" 'Star Wars (1977)',\n",
" 'Silence of the Lambs, The (1991)',\n",
" 'L.A. Confidential (1997)']"
]
},
"metadata": {
"tags": []
},
"execution_count": 52
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6yyEP2jFeusk"
},
"source": [
"# Training using fastai collab_learner"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sDYjMvOccvcH"
},
"source": [
"learn = c.collab_learner(dls, n_factors=50, y_range=(0, 5.5))"
],
"execution_count": 53,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0e489As1e9Th",
"outputId": "873671c4-34d7-460b-d460-6360169fb5e8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 121
}
},
"source": [
"learn.model, learn.loss_func"
],
"execution_count": 54,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(EmbeddingDotBias(\n",
" (u_weight): Embedding(944, 50)\n",
" (i_weight): Embedding(1665, 50)\n",
" (u_bias): Embedding(944, 1)\n",
" (i_bias): Embedding(1665, 1)\n",
" ), FlattenedLoss of MSELoss())"
]
},
"metadata": {
"tags": []
},
"execution_count": 54
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZIJroJerfATp",
"outputId": "c543e862-4d14-4ade-c0bf-f4560212eb69",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"learn.fit_one_cycle(5, 5e-5, wd=0.1)"
],
"execution_count": 55,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.894737 | \n",
" 1.829872 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.807373 | \n",
" 1.755659 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.684563 | \n",
" 1.698134 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.673523 | \n",
" 1.668491 | \n",
" 00:18 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.664241 | \n",
" 1.663463 | \n",
" 00:18 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_Omm2PYTfIvz",
"outputId": "fdcd73d5-6a79-4df8-9f59-c762909095cd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"dls.classes['title'].o2i['2 Days in the Valley (1996)']"
],
"execution_count": 56,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"6"
]
},
"metadata": {
"tags": []
},
"execution_count": 56
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "buDN2R2Qf_mE",
"outputId": "ac686ec7-63db-43d0-c0d9-49d99edb8921",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"(torch.nn.CosineSimilarity(dim=1)(learn.model.i_weight.weight[6][None], learn.model.i_weight.weight)).argsort(descending=True)[1]"
],
"execution_count": 57,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(37, device='cuda:0')"
]
},
"metadata": {
"tags": []
},
"execution_count": 57
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DZl4FVkYgtK8",
"outputId": "0291cda3-c609-4ba1-c577-3e2eb955ac45",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"dls.classes['title'][686]"
],
"execution_count": 58,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'Highlander (1986)'"
]
},
"metadata": {
"tags": []
},
"execution_count": 58
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1NyNymNplOn2"
},
"source": [
"## Using deep learning"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qtj1sDoWktPs"
},
"source": [
"class CollabNN(c.Module):\n",
" def __init__(self, user_sz, item_sz, n_act=100, y_range=(0, 5.5)):\n",
" super().__init__()\n",
" self.user_factors = c.Embedding(*user_sz)\n",
" self.movie_factors = c.Embedding(*item_sz)\n",
" self.layers = torch.nn.Sequential(\n",
" torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(n_act, 1)\n",
" )\n",
"\n",
" self.y_range = y_range\n",
"\n",
" def forward(self, x):\n",
" embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n",
" x = self.layers(torch.cat(embs, dim=1))\n",
"\n",
" return c.sigmoid_range(x, *self.y_range)"
],
"execution_count": 59,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xrj70IfSm04W",
"outputId": "1dc3be0a-c23f-49be-a635-ef3cd64a5bba",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"sz = c.get_emb_sz(dls)\n",
"sz"
],
"execution_count": 60,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[(944, 74), (1665, 102)]"
]
},
"metadata": {
"tags": []
},
"execution_count": 60
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JUe26GT4m78Z"
},
"source": [
"model = CollabNN(*sz)"
],
"execution_count": 61,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oSFfmWHJnLKX",
"outputId": "93dc3396-1890-41e8-a703-81b13f363480",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"learn = c.Learner(dls, model, c.MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.01)"
],
"execution_count": 62,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.942369 | \n",
" 0.967605 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.902668 | \n",
" 0.912845 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.868656 | \n",
" 0.884016 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.841839 | \n",
" 0.872575 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.764020 | \n",
" 0.873379 | \n",
" 00:20 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VAYV4lo6qGjV",
"outputId": "5bea2e98-95be-4343-8918-6137a3dc4481",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"learn = c.collab_learner(dls, use_nn=True, layers=[10, 50], y_range=(0, 5.5))\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
],
"execution_count": 63,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.978604 | \n",
" 0.979922 | \n",
" 00:23 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.958372 | \n",
" 0.930927 | \n",
" 00:24 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.900048 | \n",
" 0.904019 | \n",
" 00:23 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.877367 | \n",
" 0.876002 | \n",
" 00:23 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.816585 | \n",
" 0.871966 | \n",
" 00:24 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DttVouzZqwY1",
"outputId": "855f0e14-7deb-4701-99cd-6de2d73d4efa",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 434
}
},
"source": [
"learn.model"
],
"execution_count": 64,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"EmbeddingNN(\n",
" (embeds): ModuleList(\n",
" (0): Embedding(944, 74)\n",
" (1): Embedding(1665, 102)\n",
" )\n",
" (emb_drop): Dropout(p=0.0, inplace=False)\n",
" (bn_cont): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers): Sequential(\n",
" (0): LinBnDrop(\n",
" (0): BatchNorm1d(176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Linear(in_features=176, out_features=10, bias=False)\n",
" (2): ReLU(inplace=True)\n",
" )\n",
" (1): LinBnDrop(\n",
" (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Linear(in_features=10, out_features=50, bias=False)\n",
" (2): ReLU(inplace=True)\n",
" )\n",
" (2): LinBnDrop(\n",
" (0): Linear(in_features=50, out_features=1, bias=True)\n",
" )\n",
" (3): SigmoidRange(low=0, high=5.5)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 64
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "5P_LC2p0rW9J",
"outputId": "c98e77a9-f757-4ad5-f321-c883e4dd4e45",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
}
},
"source": [
"CollabNN((944, 74), (1665, 102))"
],
"execution_count": 65,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"CollabNN(\n",
" (user_factors): Embedding(944, 74)\n",
" (movie_factors): Embedding(1665, 102)\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=176, out_features=100, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=100, out_features=1, bias=True)\n",
" )\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 65
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "arUEK0RbvuI8",
"outputId": "2c3f11db-f504-44fd-bd46-8b70278d99b2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"t.delegates"
],
"execution_count": 66,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {
"tags": []
},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WKWbh8YKsArO"
},
"source": [
"@t.delegates(t.TabularModel)\n",
"class EmbeddingNN(t.TabularModel):\n",
" def __init__(self, emb_sz, layers, **kwargs):\n",
" super().__init__(emb_sz, layers=layers, n_cont=0, out_sz=1, **kwargs)"
],
"execution_count": 67,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y22fCzLzuIAt"
},
"source": [
"### kwargs && args"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cp4Rd-D7wy1I"
},
"source": [
"kwargs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xQnyy_I0s3MR"
},
"source": [
"def s(a, **kwargs): \n",
" return kwargs"
],
"execution_count": 68,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wRH7RrBjtIsJ",
"outputId": "d272d756-cd43-4d55-f6cf-445ca1b6ae9c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"s(2, s='dhdj', j=3)"
],
"execution_count": 69,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'j': 3, 's': 'dhdj'}"
]
},
"metadata": {
"tags": []
},
"execution_count": 69
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4C4MTfhwxYJ"
},
"source": [
"args"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hFbqzi1Et_Yb"
},
"source": [
"def a(b, *size): return b, size"
],
"execution_count": 70,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "F99TCHZVwjNz",
"outputId": "43bef199-62e5-47bf-9176-279bed4d6b0b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"a(2, 3, 'klk', 5)"
],
"execution_count": 71,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2, (3, 'klk', 5))"
]
},
"metadata": {
"tags": []
},
"execution_count": 71
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "voovs7v5yLfn",
"outputId": "d5fd900c-1807-4de2-daf4-879f91e572ca",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 332
}
},
"source": [
"learn.show_results()"
],
"execution_count": 72,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" user | \n",
" title | \n",
" rating | \n",
" rating_pred | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 621 | \n",
" 1366 | \n",
" 2 | \n",
" 4.168081 | \n",
"
\n",
" \n",
" | 1 | \n",
" 577 | \n",
" 1498 | \n",
" 4 | \n",
" 3.707471 | \n",
"
\n",
" \n",
" | 2 | \n",
" 846 | \n",
" 920 | \n",
" 4 | \n",
" 4.199018 | \n",
"
\n",
" \n",
" | 3 | \n",
" 793 | \n",
" 785 | \n",
" 2 | \n",
" 2.995941 | \n",
"
\n",
" \n",
" | 4 | \n",
" 336 | \n",
" 407 | \n",
" 3 | \n",
" 3.256781 | \n",
"
\n",
" \n",
" | 5 | \n",
" 394 | \n",
" 1544 | \n",
" 5 | \n",
" 4.228620 | \n",
"
\n",
" \n",
" | 6 | \n",
" 345 | \n",
" 1336 | \n",
" 4 | \n",
" 4.130409 | \n",
"
\n",
" \n",
" | 7 | \n",
" 815 | \n",
" 1581 | \n",
" 4 | \n",
" 4.287453 | \n",
"
\n",
" \n",
" | 8 | \n",
" 99 | \n",
" 552 | \n",
" 4 | \n",
" 4.037339 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fpl_W8kG_It8"
},
"source": [
"## Using cross entropy loss"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P6_WKrdT-WnY"
},
"source": [
"df1 = df.copy()\n",
"df1['rating'] = df1['rating']-1"
],
"execution_count": 73,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FA9yUSor-VgU",
"outputId": "c5c34ea5-8333-450f-dec7-5d7e80d6290e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
}
},
"source": [
"dls = c.CollabDataLoaders.from_df(\n",
" df1,\n",
" user_name='user',\n",
" item_name='title',\n",
" rating_name='rating',\n",
" bs=32\n",
")\n",
"\n",
"dls.show_batch()"
],
"execution_count": 74,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | \n",
" user | \n",
" title | \n",
" rating | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 891 | \n",
" Courage Under Fire (1996) | \n",
" 4 | \n",
"
\n",
" \n",
" | 1 | \n",
" 806 | \n",
" Raiders of the Lost Ark (1981) | \n",
" 4 | \n",
"
\n",
" \n",
" | 2 | \n",
" 790 | \n",
" Romeo Is Bleeding (1993) | \n",
" 1 | \n",
"
\n",
" \n",
" | 3 | \n",
" 311 | \n",
" Last of the Mohicans, The (1992) | \n",
" 2 | \n",
"
\n",
" \n",
" | 4 | \n",
" 790 | \n",
" Emma (1996) | \n",
" 1 | \n",
"
\n",
" \n",
" | 5 | \n",
" 151 | \n",
" Crimson Tide (1995) | \n",
" 2 | \n",
"
\n",
" \n",
" | 6 | \n",
" 489 | \n",
" Devil's Advocate, The (1997) | \n",
" 3 | \n",
"
\n",
" \n",
" | 7 | \n",
" 314 | \n",
" Pallbearer, The (1996) | \n",
" 1 | \n",
"
\n",
" \n",
" | 8 | \n",
" 411 | \n",
" Rear Window (1954) | \n",
" 4 | \n",
"
\n",
" \n",
" | 9 | \n",
" 483 | \n",
" Powder (1995) | \n",
" 1 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8Aa0_9dL0yhn"
},
"source": [
"class CollabNN(c.Module):\n",
" def __init__(self, user_sz, item_sz, n_act=100):\n",
" super().__init__()\n",
" self.user_factors = c.Embedding(*user_sz)\n",
" self.movie_factors = c.Embedding(*item_sz)\n",
" self.layers = torch.nn.Sequential(\n",
" torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(n_act, 5)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n",
" x = self.layers(torch.cat(embs, dim=1))\n",
"\n",
" return x"
],
"execution_count": 75,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1zl2oS8607PD"
},
"source": [
"model = CollabNN(*sz)\n",
"learn = c.Learner(dls, model, loss_func=c.CrossEntropyLossFlat(), metrics=c.accuracy)"
],
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_4TX0Hvt0_EU",
"outputId": "39180f1e-4a9c-4471-b6c0-3cfacf32e1d6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
}
},
"source": [
"learn.fit_one_cycle(5, 5e-3, wd=0.01)"
],
"execution_count": 77,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.284872 | \n",
" 1.289579 | \n",
" 0.418500 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.248881 | \n",
" 1.264122 | \n",
" 0.440350 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.215094 | \n",
" 1.232578 | \n",
" 0.447150 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.162452 | \n",
" 1.226760 | \n",
" 0.456100 | \n",
" 00:20 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.091269 | \n",
" 1.234871 | \n",
" 0.454250 | \n",
" 00:20 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
}
]
}