{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Comparison of two algorithms\n", "\n", "We will see in this notebook how we can compare the prediction accuracy of two algorithms." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from __future__ import (absolute_import, division, print_function, \n", " unicode_literals) \n", "import pickle\n", "import os\n", "\n", "import pandas as pd\n", "\n", "from recsys import SVD\n", "from recsys import KNNBasic\n", "from recsys import Dataset \n", "from recsys import Reader \n", "from recsys import dump\n", "from recsys.accuracy import rmse" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Computing the msd similarity matrix...\n", "RMSE: 0.9525\n", "RMSE: 0.9889\n", "The dump has been saved as file ./dump_SVD\n", "The dump has been saved as file ./dump_KNN\n" ] } ], "source": [ "# We will train and test on the u1.base and u1.test files of the movielens-100k dataset.\n", "# if you haven't already, you need to download the movielens-100k dataset\n", "# You can do it manually, or by running:\n", "\n", "#Dataset.load_builtin('ml-100k')\n", "\n", "# Now, let's load the dataset\n", "train_file = os.path.expanduser('~') + '/.recsys_data/ml-100k/ml-100k/u1.base'\n", "test_file = os.path.expanduser('~') + '/.recsys_data/ml-100k/ml-100k/u1.test'\n", "data = Dataset.load_from_folds([(train_file, test_file)], Reader('ml-100k'))\n", "\n", " \n", "# We'll use the well-known SVD algorithm and a basic nearest neighbors approach.\n", "algo_svd = SVD() \n", "algo_knn = KNNBasic()\n", "\n", "for trainset, testset in data.folds(): \n", " algo_svd.train(trainset) \n", " predictions_svd = algo_svd.test(testset)\n", " \n", " algo_knn.train(trainset)\n", " predictions_knn = algo_knn.test(testset)\n", " \n", " rmse(predictions_svd)\n", " rmse(predictions_knn) \n", " \n", " dump('./dump_SVD', predictions_svd, trainset, algo_svd)\n", " dump('./dump_KNN', predictions_knn, trainset, algo_knn)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# The dumps have been saved and we can now use them whenever we want.\n", "\n", "dump_obj_svd = pickle.load(open('./dump_SVD', 'rb'))\n", "dump_obj_knn = pickle.load(open('./dump_KNN', 'rb'))\n", "\n", "df_svd = pd.DataFrame(dump_obj_svd['predictions'], columns=['uid', 'iid', 'rui', 'est', 'details']) \n", "df_knn = pd.DataFrame(dump_obj_knn['predictions'], columns=['uid', 'iid', 'rui', 'est', 'details']) \n", "\n", "df_svd['err'] = abs(df_svd.est - df_svd.rui)\n", "df_knn['err'] = abs(df_knn.est - df_knn.rui)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now have two dataframes with the all the predictions for each algorithm. The cool thing is that, as both algorithm have been tested on the same testset, the indexes of the two dataframes are the same!" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
0165.03.796291{'was_impossible': False}1.203709
11103.03.955134{'was_impossible': False}0.955134
21125.04.477002{'was_impossible': False}0.522998
31145.03.990782{'was_impossible': False}1.009218
41173.03.376097{'was_impossible': False}0.376097
\n", "
" ], "text/plain": [ " uid iid rui est details err\n", "0 1 6 5.0 3.796291 {'was_impossible': False} 1.203709\n", "1 1 10 3.0 3.955134 {'was_impossible': False} 0.955134\n", "2 1 12 5.0 4.477002 {'was_impossible': False} 0.522998\n", "3 1 14 5.0 3.990782 {'was_impossible': False} 1.009218\n", "4 1 17 3.0 3.376097 {'was_impossible': False} 0.376097" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_svd.head()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
0165.03.468613{'was_impossible': False, 'actual_k': 20}1.531387
11103.03.866290{'was_impossible': False, 'actual_k': 40}0.866290
21125.04.538194{'was_impossible': False, 'actual_k': 40}0.461806
31145.04.235741{'was_impossible': False, 'actual_k': 40}0.764259
41173.03.228002{'was_impossible': False, 'actual_k': 40}0.228002
\n", "
" ], "text/plain": [ " uid iid rui est details err\n", "0 1 6 5.0 3.468613 {'was_impossible': False, 'actual_k': 20} 1.531387\n", "1 1 10 3.0 3.866290 {'was_impossible': False, 'actual_k': 40} 0.866290\n", "2 1 12 5.0 4.538194 {'was_impossible': False, 'actual_k': 40} 0.461806\n", "3 1 14 5.0 4.235741 {'was_impossible': False, 'actual_k': 40} 0.764259\n", "4 1 17 3.0 3.228002 {'was_impossible': False, 'actual_k': 40} 0.228002" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_knn.head()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
1905382111.04.136955{'was_impossible': False, 'actual_k': 40}3.136955
1925384321.04.064878{'was_impossible': False, 'actual_k': 40}3.064878
1930385261.04.115078{'was_impossible': False, 'actual_k': 40}3.115078
5024996941.04.078664{'was_impossible': False, 'actual_k': 36}3.078664
73901671691.04.664991{'was_impossible': False, 'actual_k': 40}3.664991
139722951831.04.202611{'was_impossible': False, 'actual_k': 40}3.202611
153063122651.04.131875{'was_impossible': False, 'actual_k': 40}3.131875
190964051925.03.763118{'was_impossible': False, 'actual_k': 40}1.236882
191554056735.03.433994{'was_impossible': False, 'actual_k': 40}1.566006
\n", "
" ], "text/plain": [ " uid iid rui est details \\\n", "1905 38 211 1.0 4.136955 {'was_impossible': False, 'actual_k': 40} \n", "1925 38 432 1.0 4.064878 {'was_impossible': False, 'actual_k': 40} \n", "1930 38 526 1.0 4.115078 {'was_impossible': False, 'actual_k': 40} \n", "5024 99 694 1.0 4.078664 {'was_impossible': False, 'actual_k': 36} \n", "7390 167 169 1.0 4.664991 {'was_impossible': False, 'actual_k': 40} \n", "13972 295 183 1.0 4.202611 {'was_impossible': False, 'actual_k': 40} \n", "15306 312 265 1.0 4.131875 {'was_impossible': False, 'actual_k': 40} \n", "19096 405 192 5.0 3.763118 {'was_impossible': False, 'actual_k': 40} \n", "19155 405 673 5.0 3.433994 {'was_impossible': False, 'actual_k': 40} \n", "\n", " err \n", "1905 3.136955 \n", "1925 3.064878 \n", "1930 3.115078 \n", "5024 3.078664 \n", "7390 3.664991 \n", "13972 3.202611 \n", "15306 3.131875 \n", "19096 1.236882 \n", "19155 1.566006 " ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's check how good are the KNN predictions when the SVD has a huge error:\n", "df_knn[df_svd.err >= 3.5]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
uidiidruiestdetailserr
94062083021.04.386044{'was_impossible': False}3.386044
190894051691.02.087386{'was_impossible': False}1.087386
197854361321.04.389942{'was_impossible': False}3.389942
15723151.04.176330{'was_impossible': False}3.176330
8503193561.03.893228{'was_impossible': False}2.893228
55311139765.02.924792{'was_impossible': False}2.075208
79171814081.01.976466{'was_impossible': False}0.976466
73901671691.04.738044{'was_impossible': False}3.738044
741216713065.03.942998{'was_impossible': False}1.057002
555311411045.03.338453{'was_impossible': False}1.661547
\n", "
" ], "text/plain": [ " uid iid rui est details err\n", "9406 208 302 1.0 4.386044 {'was_impossible': False} 3.386044\n", "19089 405 169 1.0 2.087386 {'was_impossible': False} 1.087386\n", "19785 436 132 1.0 4.389942 {'was_impossible': False} 3.389942\n", "157 2 315 1.0 4.176330 {'was_impossible': False} 3.176330\n", "8503 193 56 1.0 3.893228 {'was_impossible': False} 2.893228\n", "5531 113 976 5.0 2.924792 {'was_impossible': False} 2.075208\n", "7917 181 408 1.0 1.976466 {'was_impossible': False} 0.976466\n", "7390 167 169 1.0 4.738044 {'was_impossible': False} 3.738044\n", "7412 167 1306 5.0 3.942998 {'was_impossible': False} 1.057002\n", "5553 114 1104 5.0 3.338453 {'was_impossible': False} 1.661547" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Well... Not much better.\n", "# Now, let's look at the predictions of SVD on the 10 worst predictions for KNN\n", "df_svd.iloc[df_knn.sort_values(by='err')[-10:].index]" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width);\n", " canvas.attr('height', height);\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('