{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Declarative Node Worked Examples\n", "\n", "In this notebook we explore four simple examples of declarative nodes.\n", "The first example explores an unconstrained scalar-input/scalar-output problem,\n", "the second example solves a problem with linear objective and unit sphere constraint,\n", "the third example solves a problem with quadratic objective and unit sphere constraint,\n", "and the last example solve a problem with quadratic objective and unit ball constraint.\n", "All examples have one- or two-dimensional output to allow for easy visualization.\n", "\n", "It is assumed that you have read the [\"Deep Declarative Networks\"](https://arxiv.org/abs/1909.04866) paper.\n", "If you already have a good understanding of deep declarative networks and just want to know how to prototype a new declarative layer in PyTorch then skip ahead to the tutorial on [implementing a declarative node using the `ddn.pytorch.node` module](08_ddn_pytorch_node.ipynb). " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import autograd.numpy as np\n", "from autograd import grad\n", "from autograd import jacobian\n", "from scipy.linalg import cho_factor, cho_solve\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from matplotlib import animation, rc\n", "from IPython.display import HTML\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 1: One-dimensional unconstrained polynomial\n", "\n", "In this example we explore minimization of a one-dimensional parametrized polynomial from [Gould et al., 2016](https://arxiv.org/abs/1607.05447). Here the problem is\n", "$$\n", "y(x) = \\text{argmin}_u \\; f(x, u)\n", "$$\n", "where $f(x, u) = xu^4 + 2x^2u^3 - 12u^2$ with $x, u \\in \\mathbb{R}$.\n", "\n", "For fixed $x$, we can easily solve for the stationary points of the polynomial as\n", "$$\n", "\\begin{align*}\n", "0 &= \\frac{d}{du} f(x, u) \\\\\n", "&= 4xu^3 + 6x^2u^2 - 24u \\\\\n", "&= 2u(2xu^2 + 3x^2u - 12) \\\\\n", "\\end{align*}\n", "$$\n", "Therefore the stationary points are at\n", "$$\n", "u \\in \\left\\{0, \\frac{-3x^2 - \\sqrt{9x^4 + 96x}}{4x}, \\frac{-3x^2 + \\sqrt{9x^4 + 96x}}{4x} \\right\\}\n", "$$\n", "For $x > 0$ one of these will be the global minimum, $y$. The others are either local minimum, local maximum or inflection points depending on the value of $x$. We can therefore evaluate all three stationary points to determine the global minimum as done in the `solve` function below." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def f(x, u):\n", " return x * u ** 4.0 + 2 * x ** 2.0 * u ** 3.0 - 12.0 * u ** 2.0\n", "\n", "def solve(x):\n", " delta = np.sqrt(9.0 * x ** 4.0 + 96.0 * x)\n", " y_stationary = [0.0, (-3.0 * x ** 2.0 - delta) / (4.0 * x), (-3.0 * x ** 2.0 + delta) / (4.0 * x)]\n", " y_min_indx = np.argmin([f(x, y) for y in y_stationary])\n", " return y_stationary[y_min_indx]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n", "\n", "The following shows a contour plot of the function $f(x, u)$ for $0.25 \\leq x \\leq 2.25$ and $-6 \\leq u \\leq 4$. Only negative contour lines are shown to highlight the shape of the function around the minima. The black dashed line depicts the valley in which we find the global minima for each $x$. Over this range it turns out that the global minimum is\n", "$$\n", "y(x) = \\frac{-3x^2 - \\sqrt{9x^4 + 96x}}{4x}\n", "$$" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def gradient_by_closed_form(x):\n", " \"\"\"Compute the gradient using the closed-form expression.\"\"\"\n", " delta = np.sqrt(9.0 * x ** 4 + 96.0 * x)\n", " return -0.75 - (9.0 * x ** 3 - 48.0) / (4 * x * delta)\n", "\n", "def gradient_by_ift(x, y):\n", " \"\"\"Compute the gradient using the implicit function theorem result.\"\"\"\n", " return -1.0 * (y ** 3 + 3.0 * x * y ** 2) / (3.0 * x * y ** 2 + 3.0 * x ** 2 * y - 6.0)\n", "\n", "y = [solve(xi) for xi in x]\n", "\n", "plt.figure()\n", "plt.subplot(2, 1, 1)\n", "plt.plot(x, y)\n", "plt.grid()\n", "plt.title(r'$y = argmin_u f(x, u)$'); plt.ylabel(r'$y$')\n", "\n", "plt.subplot(2, 1, 2)\n", "plt.plot(x, [gradient_by_closed_form(xi) for xi in x])\n", "plt.plot(x, [gradient_by_ift(xi, yi) for xi, yi in zip(x, y)])\n", "plt.xlabel(r'$x$'); plt.ylabel(r'$dy/dx$')\n", "plt.grid()\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Automatic differentiation\n", "\n", "Instead of deriving the Hessian $\\frac{\\partial^2 f}{\\partial u^2}$ and mixed partial deriavtive $\\frac{\\partial^2 f}{\\partial x \\partial u}$ by hand we can also use Python's `autograd` module. However, this requires re-writing the objective function `f(x, u)` using `numpy` operators as the following code demonstrates." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def viz_problem(x, y, ax):\n", " \n", " ax.annotate(\"\", xy=(1.0, x), xytext=(0, 0), arrowprops = dict(arrowstyle=\"->\", linewidth=3, color='b')) \n", " for delta in np.linspace(-2.0, 2.0, 5):\n", " dx = delta * np.sqrt(1.0 + x**2.0)\n", " ax.plot([-2.0 * x + dx, 2.0 * x + dx], [2.0, -2.0], 'b--', linewidth=1)\n", "\n", " ax.plot(np.cos(np.linspace(0.0, 2.0 * np.pi)), np.sin(np.linspace(0.0, 2.0 * np.pi)), 'r--', linewidth=1)\n", " ax.plot(y[0], y[1], 'ro', markersize=12, linewidth=2)\n", " ax.set_xlabel(r\"$u_1$\"); ax.set_ylabel(r\"$u_2$\")\n", "\n", "plt.figure()\n", "a = plt.gca()\n", "x = 1.0\n", "y, _ = solve(x)\n", "viz_problem(x, y, a)\n", "a.axis('square'); a.set_xlim(-2, 2); a.set_ylim(-2, 2)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Above we compute the gradient in closed-form. It can also be computed using implicit differentiation on the first-order optimality condition of the Lagrangian,\n", "\n", "$$\n", "\\begin{align*}\n", "0 &= \n", "\\text{D} \\begin{bmatrix}\n", " \\text{D}_{Y} {\\cal L}(x, y, \\lambda) \\\\\n", " \\text{D}_{\\Lambda} {\\cal L}(x, y, \\lambda)\n", "\\end{bmatrix}\n", "\\\\\n", "&=\n", "\\text{D} \\begin{bmatrix}\n", " \\text{D}_Y f(x, y) - \\lambda \\text{D}_Y h(y) \\\\\n", " h(y)\n", "\\end{bmatrix}\n", "\\\\\n", "&=\n", "\\begin{bmatrix}\n", " \\text{D}^2_{XY} f(x, y) + \\text{D}^2_{YY} f(x, y) \\text{D}y - \\lambda \\text{D}^2_{YY} h(y) \\text{D}y - (\\text{D}_Y h(y))^T \\text{D}\\lambda \\\\\n", " \\text{D}_{Y} h(y) \\text{D}y\n", "\\end{bmatrix}\n", "\\\\\n", "&=\n", "\\begin{bmatrix}\n", " \\text{D}^2_{XY} f(x, y) \\\\\n", " 0\n", "\\end{bmatrix}\n", "+\n", "\\begin{bmatrix}\n", " \\text{D}^2_{YY} f(x, y) - \\lambda \\text{D}^2_{YY} h(y) & -(\\text{D}_Y h(y))^T \\\\\n", " \\text{D}_{Y} h(y) & 0\n", "\\end{bmatrix}\n", "\\begin{bmatrix}\n", " \\text{D}y \\\\\n", " \\text{D}\\lambda\n", "\\end{bmatrix}\n", "\\\\\n", "&=\n", "\\begin{bmatrix}\n", " b \\\\\n", " 0\n", "\\end{bmatrix}\n", "+\n", "\\begin{bmatrix}\n", " H & -a \\\\\n", " a^T & 0\n", "\\end{bmatrix}\n", "\\begin{bmatrix}\n", " \\text{D}y \\\\\n", " \\text{D}\\lambda\n", "\\end{bmatrix}\n", "\\end{align*}\n", "$$\n", "\n", "Block elimination then gives\n", "\n", "$$\n", "\\begin{align*}\n", "\\text{D} y\n", "&= \\left( \\frac{H^{-1}aa^TH^{-1}}{a^T H^{-1} a} - H^{-1}\\right) b \\\\\n", "&= \\frac{v^T b}{v^T a} v - w\n", "\\end{align*}\n", "$$\n", "\n", "where $v = H^{-1}a$ and $w = H^{-1}b$, and which we have implemented below using automatic differentiation." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# gradient by implicit differentiation\n", "\n", "fY = grad(f, 1)\n", "hY = grad(h)\n", "fXY = jacobian(fY, 0)\n", "fYY = jacobian(fY, 1)\n", "hYY = jacobian(hY, 0)\n", "\n", "def dy(x):\n", " \"\"\"Compute gradient of y with respect to x using implicit differentiation.\"\"\"\n", "\n", " y, nu = solve(x)\n", "\n", " # Here we solve a system of linear equations rather than inverting H. The linear\n", " # algebra solver gives $w = H^{-1} D_Y h$ and $v = H^{-1} D^2_{XY} f$ from which\n", " # we compute Dy(x) as (w^T D^2_{XY} f / w^T D_{Y} h) w - v. \n", "\n", " a = hY(y)\n", " b = fXY(x, y)\n", " H = fYY(x, y) - nu * hYY(y)\n", " \n", " (v, w) = np.linalg.solve(H, np.stack((a, b)))\n", " return v.dot(b) / v.dot(a) * v - w" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following compares the analytic gradient against the gradient derived implicitly over the range $x \\in [-2, 2]$." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Max. difference between implicit and analytic gradients is 1.1102230246251565e-16\n" ] } ], "source": [ "# generate data for different input\n", "x = np.linspace(-2.0, 2.0, num=51)\n", "y = [solve(xi)[0] for xi in x]\n", "dy_analytic = [dy_closed_form(xi) for xi in x]\n", "dy_implicit = [dy(xi) for xi in x]\n", "\n", "# print difference between analytic and implicit gradients\n", "err = np.abs(np.array(dy_analytic) - np.array(dy_implicit))\n", "print(\"Max. difference between implicit and analytic gradients is {}\".format(np.max(err)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Animation\n", "\n", "To get a better understanding of how the output of this constrained declarative node changes with input---a requirement for end-to-end learning---we animate the geometry of the solution to the problem $y$ as we vary the input $x$. Along with the animation we plot both $y$ and $\\text{D}y$ as functions of $x$." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "\n", "# simple animation of the problem\n", "\n", "def init():\n", " for a in ax:\n", " a.axis('square')\n", " a.set_xlim(x[0], x[-1]) \n", " ax[0].set_ylim(x[0], x[-1])\n", " ax[1].set_ylim(-1.0, 1.0)\n", " ax[2].set_ylim(-1.0, 1.0)\n", "\n", " return ax\n", "\n", "\n", "def animate(fnum, x, y, dy):\n", " \n", " for a in ax:\n", " a.clear()\n", " \n", " viz_problem(x[fnum], y[fnum], ax[0])\n", "\n", " ax[1].plot(x[0:fnum], [yi[0] for yi in y[0:fnum]], x[0:fnum], [yi[1] for yi in y[0:fnum]])\n", " ax[1].legend([r\"$y_1$\", r\"$y_2$\"])\n", " \n", " ax[2].plot(x[0:fnum], [di[0] for di in dy[0:fnum]], x[0:fnum], [di[1] for di in dy[0:fnum]])\n", " ax[2].legend([r\"$Dy_1$\", r\"$Dy_2$\"], loc='upper left')\n", " \n", " for a in ax:\n", " a.axis('square')\n", " a.set_xlim(x[0], x[-1])\n", " ax[0].set_ylim(x[0], x[-1])\n", " ax[1].set_ylim(-1.0, 1.0)\n", " ax[2].set_ylim(-1.0, 1.0)\n", " \n", " return ax\n", "\n", "\n", "fig = plt.figure()\n", "ax = [plt.subplot(1, 2, 1), plt.subplot(2, 2, 2), plt.subplot(2, 2, 4)]\n", "plt.suptitle(r\"$y$ = argmin $(1, x)^T u$ subject to $\\|u\\|^2 = 1$\")\n", "\n", "ani = animation.FuncAnimation(fig, animate, init_func=init, fargs=(x, y, dy_implicit),\n", " interval=100, frames=len(x), blit=False, repeat=False)\n", "\n", "plt.close(fig)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# display using video or javascript\n", "\n", "HTML(ani.to_html5_video())\n", "#HTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 3: Minimize quadratic objective over unit circle\n", "\n", "Consider the problem\n", "$$\n", "\\begin{array}{rll}\n", "y(x) =& \\text{argmin}_u & \\frac{1}{2} u^T u - x^T u\n", "\\\\\n", "& \\text{subject to} & \\|u\\|_2^2 = 1\n", "\\end{array}\n", "$$\n", "with $x \\in \\mathbb{R}^2$ and $y \\in \\mathbb{R}^2$.\n", "\n", "This is just the Euclidean projection of the two-dimensional point $x$ onto the unit circle, which can be seen by replacing the objective function with\n", "$$\n", "\\frac{1}{2} \\|u - x\\|^2 = \\frac{1}{2} u^T u - x^T u + \\frac{1}{2} x^T x\n", "$$\n", "and recognizing that $\\frac{1}{2} x^T x$ plays no part in the optimization. For a discussion of projecting $n$-dimensional points onto $L_p$ spheres see Section 5.2 of [\"Deep Declarative Networks: A New Hope\"](https://arxiv.org/abs/1909.04866). The problem has analytic solution\n", "\n", "$$\n", "\\begin{align*}\n", "y &= \\frac{1}{\\|x\\|} x \\\\\n", "\\text{D} y &= \\begin{bmatrix}\n", "\\frac{\\partial y_1}{\\partial x_1} & \\frac{\\partial y_1}{\\partial x_2} \\\\\n", "\\frac{\\partial y_2}{\\partial x_1} & \\frac{\\partial y_2}{\\partial x_2}\n", "\\end{bmatrix}\n", "= \\frac{1}{\\|x\\|^3} \\begin{bmatrix}\n", " x_2^2 & - x_1 x_2 \\\\\n", " - x_1 x_2 & x_1^2\n", "\\end{bmatrix}\n", "\\end{align*}\n", "$$\n", "which is defined for all $x \\neq 0$." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# objective and constraint functions\n", "\n", "def f(x, u):\n", " \"\"\"Objective function taking x \\in \\reals^2, u \\in \\reals^2.\"\"\"\n", " return 0.5 * np.dot(u, u) - np.dot(u, x)\n", "\n", "def h(u):\n", " \"\"\"Constraint function taking u \\in \\reals^2.\"\"\"\n", " return np.dot(u, u) - 1.0\n", "\n", "# analytical solutions\n", "\n", "def solve(x):\n", " \"\"\"Analytical solution to min. f s.t. h = 0. Returns both optimal primal and dual variables.\"\"\"\n", " return 1.0 / np.sqrt(np.dot(x, x)) * x, None\n", "\n", "def dy_closed_form(x):\n", " \"\"\"Analytical derivative of y with respect to x.\"\"\"\n", " return 1.0 / np.power(np.dot(x, x), 1.5) * np.array([[x[1] ** 2, -x[0]*x[1]], [-x[0]*x[1], x[0] ** 2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A visualization of the problem showing contour lines, constraint set and solution for $x = (0.5, 1.25)$ is shown below." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('