{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NYJWmpEfR9lJ" }, "source": [ "# Linear Support Vector Machine (SVM)\n", "\n", "We've seen how to frame a problem as a differentiable program in the optimising functions example. \n", "Now we can take a look a more usable example; a linear Support Vector Machine (SVM). Note that the model and loss used\n", "in this guide are based on the code found [here](https://github.com/kazuto1011/svm-pytorch).\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.3.2\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "a31Upx80S0Wf" }, "source": [ "## SVM Recap\n", "\n", "Recall that an SVM tries to find the maximum margin hyperplane which separates the data classes. For a soft margin SVM\n", "where $\\textbf{x}$ is our data, we minimize:\n", "\n", "\\begin{equation}\n", "\\left[\\frac 1 n \\sum_{i=1}^n \\max\\left(0, 1 - y_i(\\textbf{w}\\cdot \\textbf{x}_i - b)\\right) \\right] + \\lambda\\lVert \\textbf{w} \\rVert^2\n", "\\end{equation}\n", "\n", "We can formulate this as an optimization over our weights $\\textbf{w}$ and bias $b$, where we minimize the\n", "hinge loss subject to a level 2 weight decay term. The hinge loss for some model outputs\n", "$z = \\textbf{w}\\textbf{x} + b$ with targets $y$ is given by:\n", "\n", "\\begin{equation}\n", "\\ell(y,z) = \\max\\left(0, 1 - yz \\right)\n", "\\end{equation}\n", "\n", "## Defining the Model\n", "\n", "Let's put this into code. First we can define our module which will project the data through our weights and offset by\n", "a bias. Note that this is identical to the function of a linear layer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "27hRy0i8Sze4" }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class LinearSVM(nn.Module):\n", " \"\"\"Support Vector Machine\"\"\"\n", "\n", " def __init__(self):\n", " super(LinearSVM, self).__init__()\n", " self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True)\n", " self.b = nn.Parameter(torch.randn(1), requires_grad=True)\n", "\n", " def forward(self, x):\n", " h = x.matmul(self.w.t()) + self.b\n", " return h" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PyfrGG7uS59t" }, "source": [ "Next, we define the hinge loss function:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "a-0v2QecS6YP" }, "outputs": [], "source": [ "import torch\n", "\n", "def hinge_loss(y_pred, y_true):\n", " return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "diJonMlwS7z4" }, "source": [ "Creating Synthetic Data\n", "-----------------------------------------------\n", "\n", "Now for some data, 1024 samples should do the trick. We normalise here so that our random init is in the same space as\n", "the data:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "U4U7FpoiS946" }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets.samples_generator import make_blobs\n", "\n", "X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)\n", "X = (X - X.mean()) / X.std()\n", "Y[np.where(Y == 0)] = -1\n", "X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OR_iJX_6TJRF" }, "source": [ "Visualizing the Training\n", "----------------------------------------\n", "\n", "We now aim to create a nice visualisation, such as the one below. \n", "\n", "![svmgif](https://raw.githubusercontent.com/ecs-vlc/torchbearer/master/docs/_static/img/svm_fit.gif)\n", "\n", "The code for the visualisation (using [pyplot](https://matplotlib.org/api/pyplot_api.html)) is a bit ugly but we'll\n", "try to explain it to some degree. First, we need a mesh grid `xy` over the range of our data:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "9WWuOIt5TeAA" }, "outputs": [], "source": [ "delta = 0.01\n", "x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)\n", "y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)\n", "x, y = np.meshgrid(x, y)\n", "xy = list(map(np.ravel, [x, y]))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Wm9gBsuzTy7t" }, "source": [ "Now things get a little strange. We start by evaluating our model over the mesh grid from earlier.\n", "\n", "For our outputs $z \\in \\textbf{Z}$, we can make some observations about the decision boundary. First, that we are\n", "outside the margin if $z \\lt -1$ or $z \\gt 1$. Conversely, we are inside the margine where $z \\gt -1$\n", "or $z \\lt 1$. \n", "\n", "The next bit is a bit of a hack to get the update of the contour plot working. If a reference to the plot is already in state we just remove the old one and add a new one, otherwise we add it and show the plot. Finally, we call mypause to trigger an update. You could just use plt.pause, however, it grabs the mouse focus each time it is called which can be annoying. Instead, mypause is taken from stackoverflow.\n", "\n", "This whole process is shown in the callback below:\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "QEcC8BsoTzQ9" }, "outputs": [], "source": [ "from torchbearer import callbacks\n", "\n", "%matplotlib notebook\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "@callbacks.on_step_training\n", "@callbacks.only_if(lambda state: state[torchbearer.BATCH] % 10 == 0)\n", "def draw_margin(state):\n", " w = state[torchbearer.MODEL].w[0].detach().to('cpu').numpy()\n", " b = state[torchbearer.MODEL].b[0].detach().to('cpu').numpy()\n", "\n", " z = (w.dot(xy) + b).reshape(x.shape)\n", " z[np.where(z > 1.)] = 4\n", " z[np.where((z > 0.) & (z <= 1.))] = 3\n", " z[np.where((z > -1.) & (z <= 0.))] = 2\n", " z[np.where(z <= -1.)] = 1\n", "\n", " plt.clf()\n", " plt.scatter(x=X[:, 0], y=X[:, 1], c=\"black\", s=10)\n", " plt.contourf(x, y, z, cmap=plt.cm.jet, alpha=0.5)\n", " fig.canvas.draw()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GAdaOug0S_Nf" }, "source": [ "Subgradient Descent\n", "----------------------------------------------\n", "\n", "Since we don't know that our data is linearly separable, we would like to use a soft-margin SVM. That is, an SVM for\n", "which the data does not all have to be outside of the margin. This takes the form of a weight decay term,\n", "$\\lambda\\lVert \\textbf{w} \\rVert^2$ in the above equation. This term is called weight decay because the gradient\n", "corresponds to subtracting some amount ($2\\lambda\\textbf{w}$) from our weights at each step. With torchbearer we\n", "can use the `L2WeightDecay` callback to do this. This whole process is known as subgradient descent because we\n", "only use a mini-batch (of size 32 in our example) at each step to approximate the gradient over all of the data. This is\n", "proven to converge to the minimum for convex functions such as our SVM. At this point we are ready to create and train\n", "our model:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "gpKBohTtTHdr" }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('