{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# k-means clustering example demonstration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "k-means clustering aims to partition n observations into k clusters in which each \n", "observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster. \n", "The problem is computationally difficult (NP-hard).\n", " \n", "I have made a very simple example invloving 20 points and 4 means here using basic techniques.\n", "Its interesting to see that only after a few iterations like 4-5 the \n", "4 means in this example take their fixed places. Note that since this is a randomized \n", "algorithm the output may vary depending on where the centroids were initialized.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "if(window['d3'] === undefined ||\n", " window['Nyaplot'] === undefined){\n", " var path = {\"d3\":\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\",\"downloadable\":\"http://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\"};\n", "\n", "\n", "\n", " var shim = {\"d3\":{\"exports\":\"d3\"},\"downloadable\":{\"exports\":\"downloadable\"}};\n", "\n", " require.config({paths: path, shim:shim});\n", "\n", "\n", "require(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\n", "\n", "\tvar script = d3.select(\"head\")\n", "\t .append(\"script\")\n", "\t .attr(\"src\", \"http://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\")\n", "\t .attr(\"async\", true);\n", "\n", "\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\n", "\n", "\n", "\t var event = document.createEvent(\"HTMLEvents\");\n", "\t event.initEvent(\"load_nyaplot\",false,false);\n", "\t window.dispatchEvent(event);\n", "\t console.log('Finished loading Nyaplotjs');\n", "\n", "\t};\n", "\n", "\n", "});});\n", "}\n" ], "text/plain": [ "\"if(window['d3'] === undefined ||\\n window['Nyaplot'] === undefined){\\n var path = {\\\"d3\\\":\\\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\\\",\\\"downloadable\\\":\\\"http://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\\\"};\\n\\n\\n\\n var shim = {\\\"d3\\\":{\\\"exports\\\":\\\"d3\\\"},\\\"downloadable\\\":{\\\"exports\\\":\\\"downloadable\\\"}};\\n\\n require.config({paths: path, shim:shim});\\n\\n\\nrequire(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\\n\\n\\tvar script = d3.select(\\\"head\\\")\\n\\t .append(\\\"script\\\")\\n\\t .attr(\\\"src\\\", \\\"http://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\\\")\\n\\t .attr(\\\"async\\\", true);\\n\\n\\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\\n\\n\\n\\t var event = document.createEvent(\\\"HTMLEvents\\\");\\n\\t event.initEvent(\\\"load_nyaplot\\\",false,false);\\n\\t window.dispatchEvent(event);\\n\\t console.log('Finished loading Nyaplotjs');\\n\\n\\t};\\n\\n\\n});});\\n}\\n\"" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "#[#[#:scatter, :options=>{:x=>:v1, :y=>:v2, :title=>\"Points\", :color=>\"#00FF00\"}, :data=>\"a7ac7c1b-8046-4051-af20-14864f07f1df\"}, @xrange=[[1.0], [120.0]], @yrange=[[1.0], [144.0]]>, #:scatter, :options=>{:x=>:v3, :y=>:v4, :title=>\"K means\", :color=>\"#FFFF00\"}, :data=>\"a7ac7c1b-8046-4051-af20-14864f07f1df\"}, @xrange=[-1000, 89.8], @yrange=[-1000, 113.2]>], :options=>{:legend=>true, :xrange=>[-40, 190], :yrange=>[-40, 180], :x_label=>\"x axis\", :y_label=>\"y axis\", :zoom=>true, :width=>800}}>], :data=>{\"a7ac7c1b-8046-4051-af20-14864f07f1df\"=>#[3.0], :v2=>[4.0], :v3=>88.2, :v4=>107.8}, {:v1=>[89.0], :v2=>[31.0], :v3=>8.2, :v4=>8.4}, {:v1=>[23.0], :v2=>[144.0], :v3=>21.2, :v4=>113.2}, {:v1=>[80.0], :v2=>[1.0], :v3=>89.8, :v4=>9.8}, {:v1=>[6.0], :v2=>[15.0], :v3=>-1000, :v4=>-1000}, {:v1=>[21.0], :v2=>[10.0], :v3=>-1000, :v4=>-1000}, {:v1=>[100.0], :v2=>[89.0], :v3=>-1000, :v4=>-1000}, {:v1=>[90.0], :v2=>[124.0], :v3=>-1000, :v4=>-1000}, {:v1=>[80.0], :v2=>[93.0], :v3=>-1000, :v4=>-1000}, {:v1=>[80.0], :v2=>[123.0], :v3=>-1000, :v4=>-1000}, {:v1=>[91.0], :v2=>[110.0], :v3=>-1000, :v4=>-1000}, {:v1=>[120.0], :v2=>[14.0], :v3=>-1000, :v4=>-1000}, {:v1=>[70.0], :v2=>[2.0], :v3=>-1000, :v4=>-1000}, {:v1=>[90.0], :v2=>[1.0], :v3=>-1000, :v4=>-1000}, {:v1=>[1.0], :v2=>[2.0], :v3=>-1000, :v4=>-1000}, {:v1=>[10.0], :v2=>[11.0], :v3=>-1000, :v4=>-1000}, {:v1=>[21.0], :v2=>[121.0], :v3=>-1000, :v4=>-1000}, {:v1=>[1.0], :v2=>[100.0], :v3=>-1000, :v4=>-1000}, {:v1=>[30.0], :v2=>[90.0], :v3=>-1000, :v4=>-1000}, {:v1=>[31.0], :v2=>[111.0], :v3=>-1000, :v4=>-1000}]>}, :extension=>[]}>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "require 'slearn'\n", "points = N[[3.0,4.0],[89.0,31.0],[23,144],[80,1],[6.0,15.0],[21.0,10.0], \\\n", " [100.0,89.0],[90,124],[80,93],[80,123],[91,110], \\\n", " [120,14],[70,2],[90,1],[1.0,2.0],[10.0,11.0], \\\n", " [21.0,121.0],[1,100],[30,90],[31,111]]\n", "\n", "v1 = Daru::Vector.new(points[0..points.shape[0]-1,0])\n", "v2 = Daru::Vector.new(points[0..points.shape[0]-1,1])\n", "v3 = Array.new(points.shape[0], -1000) \n", "v4 = Array.new(points.shape[0], -1000) \n", "\n", "means = points.kmeans(4,20)\n", "0.upto(means.shape[0] - 1) do |i|\n", " v3[i]=means[i,0]\n", " v4[i]=means[i,1]\n", "end\n", "\n", "ploter=Daru::DataFrame.new({v1: v1, v2: v2,v3: v3,v4: v4})\n", "\n", "\n", "ploter.plot type: :scatter, x1: :v1, y1: :v2, x2: :v3, y2: :v4 do |plot, diagrams|\n", " points = diagrams[0]\n", " means = diagrams[1]\n", " \n", " points.title \"Points\"\n", " points.color \"#00FF00\"\n", " \n", " means.title \"K means\"\n", " means.color \"#FFFF00\"\n", " \n", " \n", " plot.legend true\n", " plot.xrange [-40,190]\n", " plot.yrange [-40,180]\n", " plot.x_label \"x axis\"\n", " plot.y_label \"y axis\"\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The co-ordinates of the resulting centroids are " ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/latex": [ "$$\\left(\\begin{array}{cc}\n", " 88.2&107.8\\\\\n", " 8.2&8.4\\\\\n", " 21.2&113.2\\\\\n", " 89.8&9.8\\\\\n", "\\end{array}\\right)$$" ], "text/plain": [ "#" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "means" ] } ], "metadata": { "kernelspec": { "display_name": "Ruby 2.2.1", "language": "ruby", "name": "ruby" }, "language_info": { "file_extension": ".rb", "mimetype": "application/x-ruby", "name": "ruby", "version": "2.2.1" } }, "nbformat": 4, "nbformat_minor": 0 }