{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Analysis of the KNNBasic algorithm\n", "\n", "In this notebook, we will run a basic neighborhood algorithm on the movielens dataset, dump the results, and use pandas to make some data analysis." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import (absolute_import, division, print_function, \n", " unicode_literals) \n", "import pickle\n", "import os\n", "\n", "import pandas as pd\n", "\n", "from surprise import KNNBasic\n", "from surprise import Dataset \n", "from surprise import Reader \n", "from surprise import dump\n", "from surprise.accuracy import rmse" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Computing the msd similarity matrix...\n", "Done computing similarity matrix.\n", "RMSE: 0.9889\n" ] } ], "source": [ "# We will train and test on the u1.base and u1.test files of the movielens-100k dataset.\n", "# if you haven't already, you need to download the movielens-100k dataset\n", "# You can do it manually, or by running:\n", "\n", "#Dataset.load_builtin('ml-100k')\n", "\n", "# Now, let's load the dataset\n", "train_file = os.path.expanduser('~') + '/.surprise_data/ml-100k/ml-100k/u1.base'\n", "test_file = os.path.expanduser('~') + '/.surprise_data/ml-100k/ml-100k/u1.test'\n", "data = Dataset.load_from_folds([(train_file, test_file)], Reader('ml-100k'))\n", "\n", " \n", "# We'll use a basic nearest neighbor approach, where similarities are computed\n", "# between users.\n", "algo = KNNBasic() \n", "\n", "for trainset, testset in data.folds(): \n", " algo.train(trainset) \n", " predictions = algo.test(testset)\n", " rmse(predictions)\n", " \n", " dump.dump('./dump_file', predictions, algo)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# The dump has been saved and we can now use it whenever we want.\n", "# Let's load it and see what we can do\n", "predictions, algo = dump.load('./dump_file')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "algo: KNNBasic, k = 40, min_k = 1\n" ] } ], "source": [ "trainset = algo.trainset\n", "print('algo: {0}, k = {1}, min_k = {2}'.format(algo.__class__.__name__, algo.k, algo.min_k))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Let's build a pandas dataframe with all the predictions\n", "\n", "def get_Iu(uid):\n", " \"\"\"Return the number of items rated by given user\n", " \n", " Args:\n", " uid: The raw id of the user.\n", " Returns:\n", " The number of items rated by the user.\n", " \"\"\"\n", " \n", " try:\n", " return len(trainset.ur[trainset.to_inner_uid(uid)])\n", " except ValueError: # user was not part of the trainset\n", " return 0\n", " \n", "def get_Ui(iid):\n", " \"\"\"Return the number of users that have rated given item\n", " \n", " Args:\n", " iid: The raw id of the item.\n", " Returns:\n", " The number of users that have rated the item.\n", " \"\"\"\n", " \n", " try:\n", " return len(trainset.ir[trainset.to_inner_iid(iid)])\n", " except ValueError: # item was not part of the trainset\n", " return 0\n", "\n", "df = pd.DataFrame(predictions, columns=['uid', 'iid', 'rui', 'est', 'details']) \n", "df['Iu'] = df.uid.apply(get_Iu)\n", "df['Ui'] = df.iid.apply(get_Ui)\n", "df['err'] = abs(df.est - df.rui)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailsIuUierr
0165.03.468613{'actual_k': 20, 'was_impossible': False}135201.531387
11103.03.866290{'actual_k': 40, 'was_impossible': False}135730.866290
21125.04.538194{'actual_k': 40, 'was_impossible': False}1352110.461806
31145.04.235741{'actual_k': 40, 'was_impossible': False}1351400.764259
41173.03.228002{'actual_k': 40, 'was_impossible': False}135720.228002
\n", "
" ], "text/plain": [ " uid iid rui est details Iu Ui \\\n", "0 1 6 5.0 3.468613 {'actual_k': 20, 'was_impossible': False} 135 20 \n", "1 1 10 3.0 3.866290 {'actual_k': 40, 'was_impossible': False} 135 73 \n", "2 1 12 5.0 4.538194 {'actual_k': 40, 'was_impossible': False} 135 211 \n", "3 1 14 5.0 4.235741 {'actual_k': 40, 'was_impossible': False} 135 140 \n", "4 1 17 3.0 3.228002 {'actual_k': 40, 'was_impossible': False} 135 72 \n", "\n", " err \n", "0 1.531387 \n", "1 0.866290 \n", "2 0.461806 \n", "3 0.764259 \n", "4 0.228002 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "best_predictions = df.sort_values(by='err')[:10]\n", "worst_predictions = df.sort_values(by='err')[-10:]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailsIuUierr
27254391.01.0{'actual_k': 3, 'was_impossible': False}9130.0
886133141.01.0{'actual_k': 2, 'was_impossible': False}37320.0
15623141.01.0{'actual_k': 2, 'was_impossible': False}4020.0
926134371.01.0{'actual_k': 3, 'was_impossible': False}37330.0
92762063141.01.0{'actual_k': 1, 'was_impossible': False}3320.0
191184054371.01.0{'actual_k': 3, 'was_impossible': False}58230.0
803218113341.01.0{'actual_k': 1, 'was_impossible': False}21810.0
804118113541.01.0{'actual_k': 1, 'was_impossible': False}21810.0
920220114243.03.0{'actual_k': 1, 'was_impossible': False}21510.0
30186011234.04.0{'actual_k': 1, 'was_impossible': False}11910.0
\n", "
" ], "text/plain": [ " uid iid rui est details Iu Ui \\\n", "272 5 439 1.0 1.0 {'actual_k': 3, 'was_impossible': False} 91 3 \n", "886 13 314 1.0 1.0 {'actual_k': 2, 'was_impossible': False} 373 2 \n", "156 2 314 1.0 1.0 {'actual_k': 2, 'was_impossible': False} 40 2 \n", "926 13 437 1.0 1.0 {'actual_k': 3, 'was_impossible': False} 373 3 \n", "9276 206 314 1.0 1.0 {'actual_k': 1, 'was_impossible': False} 33 2 \n", "19118 405 437 1.0 1.0 {'actual_k': 3, 'was_impossible': False} 582 3 \n", "8032 181 1334 1.0 1.0 {'actual_k': 1, 'was_impossible': False} 218 1 \n", "8041 181 1354 1.0 1.0 {'actual_k': 1, 'was_impossible': False} 218 1 \n", "9202 201 1424 3.0 3.0 {'actual_k': 1, 'was_impossible': False} 215 1 \n", "3018 60 1123 4.0 4.0 {'actual_k': 1, 'was_impossible': False} 119 1 \n", "\n", " err \n", "272 0.0 \n", "886 0.0 \n", "156 0.0 \n", "926 0.0 \n", "9276 0.0 \n", "19118 0.0 \n", "8032 0.0 \n", "8041 0.0 \n", "9202 0.0 \n", "3018 0.0 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's take a look at the best predictions of the algorithm\n", "best_predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's interesting to note that these perfect predictions are actually lucky shots: $|U_i|$ is always very small, meaning that very few users have rated the target item. This implies that the set of neighbors is very small (see the ``actual_k`` field)... And, it just happens that all the ratings from the neighbors are the same (and mostly, are equal to that of the target user).\n", "\n", "This may be a bit surprising but these lucky shots are actually very important to the accuracy of the algorithm... Try running the same algorithm with a value of ``min_k`` equal to $10$. This means that if there are less than $10$ neighbors, the prediction is set to the mean of all ratings. You'll see your accuracy decrease!" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailsIuUierr
94062083021.04.308447{'actual_k': 40, 'was_impossible': False}112453.308447
190894051691.04.364728{'actual_k': 40, 'was_impossible': False}582973.364728
197854361321.04.365369{'actual_k': 40, 'was_impossible': False}1262003.365369
15723151.04.381308{'actual_k': 40, 'was_impossible': False}401363.381308
8503193561.04.386478{'actual_k': 40, 'was_impossible': False}613123.386478
55311139765.01.610771{'actual_k': 7, 'was_impossible': False}3173.389229
79171814081.04.421499{'actual_k': 40, 'was_impossible': False}218933.421499
73901671691.04.664991{'actual_k': 40, 'was_impossible': False}38973.664991
741216713065.01.000000{'actual_k': 1, 'was_impossible': False}3814.000000
555311411045.01.000000{'actual_k': 1, 'was_impossible': False}2714.000000
\n", "
" ], "text/plain": [ " uid iid rui est details \\\n", "9406 208 302 1.0 4.308447 {'actual_k': 40, 'was_impossible': False} \n", "19089 405 169 1.0 4.364728 {'actual_k': 40, 'was_impossible': False} \n", "19785 436 132 1.0 4.365369 {'actual_k': 40, 'was_impossible': False} \n", "157 2 315 1.0 4.381308 {'actual_k': 40, 'was_impossible': False} \n", "8503 193 56 1.0 4.386478 {'actual_k': 40, 'was_impossible': False} \n", "5531 113 976 5.0 1.610771 {'actual_k': 7, 'was_impossible': False} \n", "7917 181 408 1.0 4.421499 {'actual_k': 40, 'was_impossible': False} \n", "7390 167 169 1.0 4.664991 {'actual_k': 40, 'was_impossible': False} \n", "7412 167 1306 5.0 1.000000 {'actual_k': 1, 'was_impossible': False} \n", "5553 114 1104 5.0 1.000000 {'actual_k': 1, 'was_impossible': False} \n", "\n", " Iu Ui err \n", "9406 11 245 3.308447 \n", "19089 582 97 3.364728 \n", "19785 126 200 3.365369 \n", "157 40 136 3.381308 \n", "8503 61 312 3.386478 \n", "5531 31 7 3.389229 \n", "7917 218 93 3.421499 \n", "7390 38 97 3.664991 \n", "7412 38 1 4.000000 \n", "5553 27 1 4.000000 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now, let's look at the prediction with the biggest error\n", "worst_predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's focus first on the last two predictions. Well, we can't do much about them. We should have predicted $5$, but the only available neighbor had a rating of $1$, so we were screwed. The only way to avoid this kind of errors would be to increase the ``min_k`` parameter, but it would actually worsen the accuracy (see note above).\n", "\n", "How about the other ones? It seems that for each prediction, the users are some kind of outsiders: they rated their item with a rating of $1$ when the most of the ratings for the item where high (or inversely, rated a *bad* item with a rating of $5$). See the plot below as an illustration for the first rating.\n", "\n", "These are situations where baseline estimates would be quite helpful, in order to deal with highly biased users (and items)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('