{ "cells": [ { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Movielens" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai.learner import *\n", "from fastai.column_data import *" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Data available from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "path='data/ml-latest-small/'" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We're working with the movielens data, which contains one rating per row, like this:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/html": [ "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "1 | \n", "31 | \n", "2.5 | \n", "1260759144 | \n", "
1 | \n", "1 | \n", "1029 | \n", "3.0 | \n", "1260759179 | \n", "
2 | \n", "1 | \n", "1061 | \n", "3.0 | \n", "1260759182 | \n", "
3 | \n", "1 | \n", "1129 | \n", "2.0 | \n", "1260759185 | \n", "
4 | \n", "1 | \n", "1172 | \n", "4.0 | \n", "1260759205 | \n", "
\n", " | movieId | \n", "title | \n", "genres | \n", "
---|---|---|---|
0 | \n", "1 | \n", "Toy Story (1995) | \n", "Adventure|Animation|Children|Comedy|Fantasy | \n", "
1 | \n", "2 | \n", "Jumanji (1995) | \n", "Adventure|Children|Fantasy | \n", "
2 | \n", "3 | \n", "Grumpier Old Men (1995) | \n", "Comedy|Romance | \n", "
3 | \n", "4 | \n", "Waiting to Exhale (1995) | \n", "Comedy|Drama|Romance | \n", "
4 | \n", "5 | \n", "Father of the Bride Part II (1995) | \n", "Comedy | \n", "
movieId | \n", "1 | \n", "110 | \n", "260 | \n", "296 | \n", "318 | \n", "356 | \n", "480 | \n", "527 | \n", "589 | \n", "593 | \n", "608 | \n", "1196 | \n", "1198 | \n", "1270 | \n", "2571 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
userId | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
15 | \n", "2.0 | \n", "3.0 | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "1.0 | \n", "3.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "
30 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "
73 | \n", "5.0 | \n", "4.0 | \n", "4.5 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "3.0 | \n", "4.5 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.5 | \n", "
212 | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "3.0 | \n", "4.0 | \n", "NaN | \n", "NaN | \n", "3.0 | \n", "3.0 | \n", "5.0 | \n", "
213 | \n", "3.0 | \n", "2.5 | \n", "5.0 | \n", "NaN | \n", "NaN | \n", "2.0 | \n", "5.0 | \n", "NaN | \n", "4.0 | \n", "2.5 | \n", "2.0 | \n", "5.0 | \n", "3.0 | \n", "3.0 | \n", "4.0 | \n", "
294 | \n", "4.0 | \n", "3.0 | \n", "4.0 | \n", "NaN | \n", "3.0 | \n", "4.0 | \n", "4.0 | \n", "4.0 | \n", "3.0 | \n", "NaN | \n", "NaN | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "
311 | \n", "3.0 | \n", "3.0 | \n", "4.0 | \n", "3.0 | \n", "4.5 | \n", "5.0 | \n", "4.5 | \n", "5.0 | \n", "4.5 | \n", "2.0 | \n", "4.0 | \n", "3.0 | \n", "4.5 | \n", "4.5 | \n", "4.0 | \n", "
380 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "NaN | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "
452 | \n", "3.5 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "4.0 | \n", "2.0 | \n", "
468 | \n", "4.0 | \n", "3.0 | \n", "3.5 | \n", "3.5 | \n", "3.5 | \n", "3.0 | \n", "2.5 | \n", "NaN | \n", "NaN | \n", "3.0 | \n", "4.0 | \n", "3.0 | \n", "3.5 | \n", "3.0 | \n", "3.0 | \n", "
509 | \n", "3.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "4.0 | \n", "4.0 | \n", "3.0 | \n", "5.0 | \n", "2.0 | \n", "4.0 | \n", "4.5 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "4.5 | \n", "
547 | \n", "3.5 | \n", "NaN | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "3.0 | \n", "5.0 | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "2.5 | \n", "2.0 | \n", "3.5 | \n", "3.5 | \n", "
564 | \n", "4.0 | \n", "1.0 | \n", "2.0 | \n", "5.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "3.0 | \n", "3.0 | \n", "
580 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.5 | \n", "3.0 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "4.5 | \n", "4.0 | \n", "3.5 | \n", "3.0 | \n", "4.5 | \n", "
624 | \n", "5.0 | \n", "NaN | \n", "5.0 | \n", "5.0 | \n", "NaN | \n", "3.0 | \n", "3.0 | \n", "NaN | \n", "3.0 | \n", "5.0 | \n", "4.0 | \n", "5.0 | \n", "5.0 | \n", "5.0 | \n", "2.0 | \n", "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.69763 1.14979] \n", "[ 1. 0.70115 1.13657] \n", "[ 2. 0.66739 1.1303 ] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Bias" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(0.5, 5.0)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "min_rating,max_rating = ratings.rating.min(),ratings.rating.max()\n", "min_rating,max_rating" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def get_emb(ni,nf):\n", " e = nn.Embedding(ni, nf)\n", " e.weight.data.uniform_(-0.01,0.01)\n", " return e\n", "\n", "class EmbeddingDotBias(nn.Module):\n", " def __init__(self, n_users, n_movies):\n", " super().__init__()\n", " (self.u, self.m, self.ub, self.mb) = [get_emb(*o) for o in [\n", " (n_users, n_factors), (n_movies, n_factors), (n_users,1), (n_movies,1)\n", " ]]\n", " \n", " def forward(self, cats, conts):\n", " users,movies = cats[:,0],cats[:,1]\n", " um = (self.u(users)* self.m(movies)).sum(1)\n", " res = um + self.ub(users).squeeze() + self.mb(movies).squeeze()\n", " res = F.sigmoid(res) * (max_rating-min_rating) + min_rating\n", " return res.view(-1, 1)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true, "hidden": true, "scrolled": true }, "outputs": [], "source": [ "wd=2e-4\n", "model = EmbeddingDotBias(cf.n_users, cf.n_items).cuda()\n", "opt = optim.SGD(model.parameters(), 1e-1, weight_decay=wd, momentum=0.9)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "95093026b28a415783ac620cc5ade85e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.88212 0.83626] \n", "[ 1. 0.8108 0.81831] \n", "[ 2. 0.78864 0.80989] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-2)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ab6c0fd5887430b9b5f0cda8f8a1772", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.72795 0.80337] \n", "[ 1. 0.75064 0.80203] \n", "[ 2. 0.75122 0.80124] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mini net" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "code_folding": [], "collapsed": true }, "outputs": [], "source": [ "class EmbeddingNet(nn.Module):\n", " def __init__(self, n_users, n_movies, nh=10, p1=0.05, p2=0.5):\n", " super().__init__()\n", " (self.u, self.m) = [get_emb(*o) for o in [\n", " (n_users, n_factors), (n_movies, n_factors)]]\n", " self.lin1 = nn.Linear(n_factors*2, nh)\n", " self.lin2 = nn.Linear(nh, 1)\n", " self.drop1 = nn.Dropout(p1)\n", " self.drop2 = nn.Dropout(p2)\n", " \n", " def forward(self, cats, conts):\n", " users,movies = cats[:,0],cats[:,1]\n", " x = self.drop1(torch.cat([self.u(users),self.m(movies)], dim=1))\n", " x = self.drop2(F.relu(self.lin1(x)))\n", " return F.sigmoid(self.lin2(x)) * (max_rating-min_rating+1) + min_rating-0.5" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "collapsed": true }, "outputs": [], "source": [ "wd=1e-5\n", "model = EmbeddingNet(n_users, n_movies).cuda()\n", "opt = optim.Adam(model.parameters(), 1e-3, weight_decay=wd)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ca8d1c156c9403fab6adec1f863786d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.88043 0.82363] \n", "[ 1. 0.8941 0.81264] \n", "[ 2. 0.86179 0.80706] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-3)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9d4c659d07b543b796fe37c789daebad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.7669 0.78622] \n", "[ 1. 0.74277 0.78152] \n", "[ 2. 0.69891 0.78075] \n", "\n" ] } ], "source": [ "fit(model, data, 3, opt, F.mse_loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "123px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }