{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Comparison of two algorithms\n", "\n", "We will see in this notebook how we can compare the prediction accuracy of two algorithms." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from __future__ import (absolute_import, division, print_function, \n", " unicode_literals) \n", "import pickle\n", "import os\n", "\n", "import pandas as pd\n", "\n", "from surprise import SVD\n", "from surprise import KNNBasic\n", "from surprise import Dataset \n", "from surprise import Reader \n", "from surprise.model_selection import PredefinedKFold\n", "from surprise import dump\n", "from surprise.accuracy import rmse" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Computing the msd similarity matrix...\n", "Done computing similarity matrix.\n" ] } ], "source": [ "# We will train and test on the u1.base and u1.test files of the movielens-100k dataset.\n", "# if you haven't already, you need to download the movielens-100k dataset\n", "# You can do it manually, or by running:\n", "\n", "# Dataset.load_builtin('ml-100k')\n", "\n", "# Now, let's load the dataset\n", "train_file = os.path.expanduser('~') + '/.surprise_data/ml-100k/ml-100k/u1.base'\n", "test_file = os.path.expanduser('~') + '/.surprise_data/ml-100k/ml-100k/u1.test'\n", "data = Dataset.load_from_folds([(train_file, test_file)], Reader('ml-100k'))\n", "\n", "pkf = PredefinedKFold()\n", "\n", " \n", "# We'll use the well-known SVD algorithm and a basic nearest neighbors approach.\n", "algo_svd = SVD() \n", "algo_knn = KNNBasic()\n", "\n", "for trainset, testset in pkf.split(data): \n", " algo_svd.fit(trainset) \n", " predictions_svd = algo_svd.test(testset)\n", " \n", " algo_knn.fit(trainset)\n", " predictions_knn = algo_knn.test(testset)\n", " \n", " rmse(predictions_svd)\n", " rmse(predictions_knn) \n", " \n", " dump.dump('./dump_SVD', predictions_svd, algo_svd)\n", " dump.dump('./dump_KNN', predictions_knn, algo_knn)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# The dumps have been saved and we can now use them whenever we want.\n", "\n", "predictions_svd, algo_svd = dump.load('./dump_SVD')\n", "predictions_knn, algo_knn = dump.load('./dump_KNN')\n", "\n", "df_svd = pd.DataFrame(predictions_svd, columns=['uid', 'iid', 'rui', 'est', 'details']) \n", "df_knn = pd.DataFrame(predictions_knn, columns=['uid', 'iid', 'rui', 'est', 'details']) \n", "\n", "df_svd['err'] = abs(df_svd.est - df_svd.rui)\n", "df_knn['err'] = abs(df_knn.est - df_knn.rui)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now have two dataframes with the all the predictions for each algorithm. The cool thing is that, as both algorithm have been tested on the same testset, the indexes of the two dataframes are the same!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
0184673.03.070263{'was_impossible': False}0.070263
17664873.03.797903{'was_impossible': False}0.797903
22631173.03.594508{'was_impossible': False}0.594508
35451684.03.961151{'was_impossible': False}0.038849
45252551.03.306502{'was_impossible': False}2.306502
\n", "
" ], "text/plain": [ " uid iid rui est details err\n", "0 184 67 3.0 3.070263 {'was_impossible': False} 0.070263\n", "1 766 487 3.0 3.797903 {'was_impossible': False} 0.797903\n", "2 263 117 3.0 3.594508 {'was_impossible': False} 0.594508\n", "3 545 168 4.0 3.961151 {'was_impossible': False} 0.038849\n", "4 525 255 1.0 3.306502 {'was_impossible': False} 2.306502" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_svd.head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
0184673.03.043189{'actual_k': 40, 'was_impossible': False}0.043189
17664873.04.139804{'actual_k': 40, 'was_impossible': False}1.139804
22631173.03.525691{'actual_k': 40, 'was_impossible': False}0.525691
35451684.04.393259{'actual_k': 40, 'was_impossible': False}0.393259
45252551.03.638801{'actual_k': 40, 'was_impossible': False}2.638801
\n", "
" ], "text/plain": [ " uid iid rui est details \\\n", "0 184 67 3.0 3.043189 {'actual_k': 40, 'was_impossible': False} \n", "1 766 487 3.0 4.139804 {'actual_k': 40, 'was_impossible': False} \n", "2 263 117 3.0 3.525691 {'actual_k': 40, 'was_impossible': False} \n", "3 545 168 4.0 4.393259 {'actual_k': 40, 'was_impossible': False} \n", "4 525 255 1.0 3.638801 {'actual_k': 40, 'was_impossible': False} \n", "\n", " err \n", "0 0.043189 \n", "1 1.139804 \n", "2 0.525691 \n", "3 0.393259 \n", "4 2.638801 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_knn.head()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
5334054525.02.370203{'actual_k': 40, 'was_impossible': False}2.629797
15572951831.04.275709{'actual_k': 40, 'was_impossible': False}3.275709
44314813181.04.855612{'actual_k': 40, 'was_impossible': False}3.855612
657940512185.03.329299{'actual_k': 21, 'was_impossible': False}1.670701
100322395141.04.250013{'actual_k': 40, 'was_impossible': False}3.250013
143114253131.04.093898{'actual_k': 40, 'was_impossible': False}3.093898
1597940510535.03.497124{'actual_k': 17, 'was_impossible': False}1.502876
1929211311.03.779858{'actual_k': 40, 'was_impossible': False}2.779858
\n", "
" ], "text/plain": [ " uid iid rui est details \\\n", "533 405 452 5.0 2.370203 {'actual_k': 40, 'was_impossible': False} \n", "1557 295 183 1.0 4.275709 {'actual_k': 40, 'was_impossible': False} \n", "4431 481 318 1.0 4.855612 {'actual_k': 40, 'was_impossible': False} \n", "6579 405 1218 5.0 3.329299 {'actual_k': 21, 'was_impossible': False} \n", "10032 239 514 1.0 4.250013 {'actual_k': 40, 'was_impossible': False} \n", "14311 425 313 1.0 4.093898 {'actual_k': 40, 'was_impossible': False} \n", "15979 405 1053 5.0 3.497124 {'actual_k': 17, 'was_impossible': False} \n", "19292 1 131 1.0 3.779858 {'actual_k': 40, 'was_impossible': False} \n", "\n", " err \n", "533 2.629797 \n", "1557 3.275709 \n", "4431 3.855612 \n", "6579 1.670701 \n", "10032 3.250013 \n", "14311 3.093898 \n", "15979 1.502876 \n", "19292 2.779858 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's check how good are the KNN predictions when the SVD has a huge error:\n", "df_knn[df_svd.err >= 3.5]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
14619771981.04.153106{'was_impossible': False}3.153106
15572951831.04.632378{'was_impossible': False}3.632378
1175940514051.01.915805{'was_impossible': False}0.915805
449318112421.02.277950{'was_impossible': False}1.277950
1097027912421.03.515677{'was_impossible': False}2.515677
26572393181.04.144616{'was_impossible': False}3.144616
44314813181.04.580412{'was_impossible': False}3.580412
1283816713065.03.136852{'was_impossible': False}1.863148
1668128813585.03.253280{'was_impossible': False}1.746720
1286936315121.03.425335{'was_impossible': False}2.425335
\n", "
" ], "text/plain": [ " uid iid rui est details err\n", "14619 771 98 1.0 4.153106 {'was_impossible': False} 3.153106\n", "1557 295 183 1.0 4.632378 {'was_impossible': False} 3.632378\n", "11759 405 1405 1.0 1.915805 {'was_impossible': False} 0.915805\n", "4493 181 1242 1.0 2.277950 {'was_impossible': False} 1.277950\n", "10970 279 1242 1.0 3.515677 {'was_impossible': False} 2.515677\n", "2657 239 318 1.0 4.144616 {'was_impossible': False} 3.144616\n", "4431 481 318 1.0 4.580412 {'was_impossible': False} 3.580412\n", "12838 167 1306 5.0 3.136852 {'was_impossible': False} 1.863148\n", "16681 288 1358 5.0 3.253280 {'was_impossible': False} 1.746720\n", "12869 363 1512 1.0 3.425335 {'was_impossible': False} 2.425335" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Well... Not much better.\n", "# Now, let's look at the predictions of SVD on the 10 worst predictions for KNN\n", "df_svd.iloc[df_knn.sort_values(by='err')[-10:].index]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('