{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Recommendation engine using collaborating filtering on Movielens" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai.learner import *\n", "from fastai.column_data import *\n", "from fastai.imports import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = '.'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "collaborating filter.ipynb ml-latest-small.zip movielens.ipynb tmp\r\n", "ml-latest-small\t\t models\t\t ratings_small.csv\r\n" ] } ], "source": [ "! ls ." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "ratings = pd.read_csv('ratings_small.csv')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "1 | \n", "31 | \n", "2.5 | \n", "1260759144 | \n", "
1 | \n", "1 | \n", "1029 | \n", "3.0 | \n", "1260759179 | \n", "
2 | \n", "1 | \n", "1061 | \n", "3.0 | \n", "1260759182 | \n", "
3 | \n", "1 | \n", "1129 | \n", "2.0 | \n", "1260759185 | \n", "
4 | \n", "1 | \n", "1172 | \n", "4.0 | \n", "1260759205 | \n", "
movieId | \n", "1 | \n", "110 | \n", "260 | \n", "296 | \n", "318 | \n", "356 | \n", "480 | \n", "527 | \n", "589 | \n", "593 | \n", "608 | \n", "1196 | \n", "1198 | \n", "1270 | \n", "2571 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
userId | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
15 | \n", "2.0 | \n", "3.0 | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "1.0 | \n", "3.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "
30 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "
73 | \n", "5.0 | \n", "4.0 | \n", "4.5 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "3.0 | \n", "4.5 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.5 | \n", "
212 | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "3.0 | \n", "4.0 | \n", "NaN | \n", "NaN | \n", "3.0 | \n", "3.0 | \n", "5.0 | \n", "
213 | \n", "3.0 | \n", "2.5 | \n", "5.0 | \n", "NaN | \n", "NaN | \n", "2.0 | \n", "5.0 | \n", "NaN | \n", "4.0 | \n", "2.5 | \n", "2.0 | \n", "5.0 | \n", "3.0 | \n", "3.0 | \n", "4.0 | \n", "
294 | \n", "4.0 | \n", "3.0 | \n", "4.0 | \n", "NaN | \n", "3.0 | \n", "4.0 | \n", "4.0 | \n", "4.0 | \n", "3.0 | \n", "NaN | \n", "NaN | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "
311 | \n", "3.0 | \n", "3.0 | \n", "4.0 | \n", "3.0 | \n", "4.5 | \n", "5.0 | \n", "4.5 | \n", "5.0 | \n", "4.5 | \n", "2.0 | \n", "4.0 | \n", "3.0 | \n", "4.5 | \n", "4.5 | \n", "4.0 | \n", "
380 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "NaN | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "
452 | \n", "3.5 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "4.0 | \n", "2.0 | \n", "
468 | \n", "4.0 | \n", "3.0 | \n", "3.5 | \n", "3.5 | \n", "3.5 | \n", "3.0 | \n", "2.5 | \n", "NaN | \n", "NaN | \n", "3.0 | \n", "4.0 | \n", "3.0 | \n", "3.5 | \n", "3.0 | \n", "3.0 | \n", "
509 | \n", "3.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "2.0 | \n", "4.0 | \n", "4.5 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "4.5 | \n", "
547 | \n", "3.5 | \n", "NaN | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "3.0 | \n", "5.0 | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "2.5 | \n", "2.0 | \n", "3.5 | \n", "3.5 | \n", "
564 | \n", "4.0 | \n", "1.0 | \n", "2.0 | \n", "5.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "3.0 | \n", "
580 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.5 | \n", "3.0 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.5 | \n", "3.0 | \n", "4.5 | \n", "
624 | \n", "5.0 | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "NaN | \n", "3.0 | \n", "3.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.7727 0.80396] \n", "[ 1. 0.77782 0.77585] \n", "[ 2. 0.58389 0.76542] \n", "\n" ] } ], "source": [ "learn.fit(1e-2,2, wds = wd, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We got .76" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Collaborating filter from scratch" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "u_uniq = ratings.userId.unique()\n", "user2idx = {o:i for i,o in enumerate(u_uniq)}\n", "ratings.userId = ratings.userId.apply(lambda x: user2idx[x])\n", "\n", "m_uniq = ratings.movieId.unique()\n", "movie2idx = {o:i for i,o in enumerate(m_uniq)}\n", "ratings.movieId = ratings.movieId.apply(lambda x: movie2idx[x])\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(671, 9066)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_users, n_movies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nn.Embedding` creates a lookup table that stores embeddings of a fixed dictionary and size. So word embeddings once stored can be retrieved using indices. After making `embeddings`, we get free `u.weights` which are correspondings weights of ebeddings" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "val_indx = get_cv_idxs(len(ratings)) # index for validation set\n", "wd = 2e-4 # weight decay\n", "n_factors = 50 # n_factors i.e. 1 dimension of embeddings (random)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.5, 5.0)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "min_rating,max_rating = ratings.rating.min(),ratings.rating.max()\n", "min_rating,max_rating" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_emb(ni,nf):\n", " e = nn.Embedding(ni, nf)\n", " e.weight.data.uniform_(-0.01,0.01)\n", " #e.weight.data.normal_(0,0.003)\n", "\n", " return e" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "x = ratings.drop(['rating'],axis=1)\n", "y = ratings['rating'].astype(np.float32)\n", "\n", "data = ColumnarModelData.from_data_frame(path, val_indx, x, y, ['userId', 'movieId'], 64)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# nh = dimension of hidden linear layer\n", "# p1 = dropout1\n", "# p2 = dropout2\n", "\n", "class EmbeddingNet(nn.Module):\n", " def __init__(self, n_users, _n_movies, nh = 10, p1 = 0.05, p2= 0.5):\n", " super().__init__()\n", " (self.u, self.m, self.ub, self.mb) = [get_emb(*o) for o in [\n", " (n_users, n_factors), (n_movies, n_factors),\n", " (n_users,1), (n_movies,1)\n", " ]]\n", " \n", " self.lin1 = nn.Linear(n_factors*2, nh) # bias is True by default\n", " self.lin2 = nn.Linear(nh, 1)\n", " self.drop1 = nn.Dropout(p = p1)\n", " self.drop2 = nn.Dropout(p = p2)\n", " \n", " def forward(self, cats, conts): # forward pass i.e. dot product of vector from movie embedding matrixx\n", " # and vector from user embeddings matrix\n", " \n", " # torch.cat : concatenates both embedding matrix to make more columns, same rows i.e. n_factors*2, n : rows\n", " # u(users) is doing lookup for indexed mentioned in users\n", " # users has indexes to lookup in embedding matrix. \n", " \n", " users,movies = cats[:,0],cats[:,1]\n", " u2,m2 = self.u(users) , self.m(movies)\n", " \n", " x = self.drop1(torch.cat([u2,m2], 1)) # drop initialized weights\n", " x = self.drop2(F.relu(self.lin1(x))) # drop 1st linear + nonlinear wt\n", " r = F.sigmoid(self.lin2(x)) * (max_rating - min_rating) + min_rating \n", " return r" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "wd=1e-5\n", "model = EmbeddingNet(n_users, n_movies)\n", "model = model.cuda()\n", "opt = optim.Adam(model.parameters(), 1e-3, weight_decay=wd) # got parameter() for free , lr = 1e-3" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "EmbeddingNet (\n", " (u): Embedding(671, 50)\n", " (m): Embedding(9066, 50)\n", " (ub): Embedding(671, 1)\n", " (mb): Embedding(9066, 1)\n", " (lin1): Linear (100 -> 10)\n", " (lin2): Linear (10 -> 1)\n", " (drop1): Dropout (p = 0.05)\n", " (drop2): Dropout (p = 0.5)\n", ")" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f45cf8456bc4101a27367f9c26d9f6a", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.74293 0.79247] \n", "[ 1. 0.74748 0.79483] \n", "[ 2. 0.75364 0.79638] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "# from tqdm import tqdm as tqdm_cls\n", "\n", "# inst = tqdm_cls._instances\n", "# for i in range(len(inst)): inst.pop().close()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "set_lrs(opt, 1e-3)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ea1cd58e12574699969fd0040c18f1c2", "version_major": 2, "version_minor": 0 }, "text/html": [ "Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another notebook frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.79631 0.78994] \n", "[ 1. 0.78677 0.79127] \n", "[ 2. 0.7614 0.7906] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Surprise package" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "collaborating filter.ipynb ml-latest-small.zip movielens.ipynb tmp\r\n", "ml-latest-small\t\t models\t\t ratings_small.csv\r\n" ] } ], "source": [ "! ls ." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "196\t242\t3\t881250949\r\n", "186\t302\t3\t891717742\r\n", "22\t377\t1\t878887116\r\n", "244\t51\t2\t880606923\r\n", "166\t346\t1\t886397596\r\n" ] } ], "source": [ "! head -5 '/home/ubuntu/.surprise_data/ml-100k/ml-100k/u.data'" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "userId,movieId,rating,timestamp\r", "\r\n", "1,31,2.5,1260759144\r", "\r\n", "1,1029,3.0,1260759179\r", "\r\n", "1,1061,3.0,1260759182\r", "\r\n", "1,1129,2.0,1260759185\r", "\r\n" ] } ], "source": [ "! head -5 'ratings_small.csv'" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from surprise import Reader, Dataset\n", "# Define the format\n", "\n", "reader = Reader(line_format='user item rating timestamp', sep='\\t')\n", "# Load the data from the file using the reader format\n", "\n", "data = Dataset.load_from_file('/home/ubuntu/.surprise_data/ml-100k/ml-100k/u.data', reader=reader)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "1 | \n", "31 | \n", "2.5 | \n", "1260759144 | \n", "
1 | \n", "1 | \n", "1029 | \n", "3.0 | \n", "1260759179 | \n", "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "1 | \n", "31 | \n", "2.5 | \n", "1260759144 | \n", "
1 | \n", "1 | \n", "1029 | \n", "3.0 | \n", "1260759179 | \n", "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "2.5 | \n", "1260759144 | \n", "
1 | \n", "0 | \n", "1 | \n", "3.0 | \n", "1260759179 | \n", "