{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Mass-Spring-Damper Model Estimation\n", "-----------------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### System Description" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a really simple model: A 1-DOF mass-spring-damper system." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import Image\n", "\n", "Image(url='https://github.com/stuckeyr/msd/raw/master/mass_spring_damper.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [here](http://ctms.engin.umich.edu/CTMS/index.php?example=Introduction§ion=SystemModeling#5) for a more detailed explanation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The equations of motion can be written:\n", "\n", "$$\n", "\\sum F_x = F(t) - b \\dot{x} - kx = m \\ddot{x}\n", "$$\n", "\n", "where $x$ is the displacement, $m$ is the mass, $k$ is the spring constant, $b$ is the damping constant and $F$ is the input force, expressed here as a function of time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reframing the system in state-space form, using the state vector:\n", "\n", "$$\n", "\\mathbf{x} = \\begin{bmatrix} x \\\\\\ \\dot{x} \\end{bmatrix}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "the equations are written:\n", "\n", "$$\n", "m \\mathbf{\\dot{x}} = \\begin{bmatrix} \\dot{x} \\\\\\ \\ddot{x} \\end{bmatrix} = \\begin{bmatrix} 0 & m \\\\\\ -k & -b \\end{bmatrix} \\begin{bmatrix} x \\\\\\ \\dot{x} \\end{bmatrix} + \\begin{bmatrix} 0 \\\\\\ 1 \\end{bmatrix} F(t)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simulation in Python " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, reset the workspace. Import all necessary libraries (the following assumes you have built both Boost and Cython models)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib notebook\n", "# %matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%reset -f" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import numpy.matlib as ml\n", "from scipy import interpolate, integrate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the model. Below is the pure Python version, which I have included for information." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# ------------------------------------------------------------------------------\n", "# MSD class\n", "# ------------------------------------------------------------------------------\n", "class MSD(object):\n", " \"\"\"\n", " The MSD class represents a Mass-Spring-Damper system.\n", " \"\"\"\n", " # System parameters\n", " m = 30.48\n", "\n", " def __init__(self, name, **kwargs):\n", " \"\"\"\n", " Initialise the msd object.\n", "\n", " :param: name = system name\n", " \"\"\"\n", " self.name = name\n", "\n", " # Pass through any other keyword arguments\n", " for key in kwargs:\n", " self.__dict__[key] = kwargs[key]\n", "\n", " self.c_idx = [ 'k', 'b', 'd' ]\n", "\n", " # Model coefficients\n", " self.C = { 'k': -50.0, 'b': -10.0, 'd': 1.0, 'z': 0.0 }\n", "\n", " # External force function\n", " self.d_func = None\n", "\n", " # State noise function\n", " self.w_func = None\n", "\n", " self.init()\n", "\n", " def __str__(self):\n", " return self.name\n", "\n", " def init(self):\n", " \"\"\"\n", " Construct the force and moment matrices.\n", " \"\"\"\n", " # Rigid body mass matrix\n", " self.C_M_I = 1.0/self.m\n", "\n", " self.C_SD = ml.zeros((2, 2))\n", " self.M_EF = ml.zeros((2, 1))\n", "\n", " def get_coeffs(self):\n", " \"\"\"\n", " Get the model coefficients.\n", " \"\"\"\n", " return self.C\n", "\n", " def set_coeffs(self, C):\n", " \"\"\"\n", " Set the model coefficients.\n", " \"\"\"\n", " for ck in C.keys():\n", " self.C[ck] = C[ck]\n", "\n", " def set_external_forces(self, T_S, D_S, interp_kind):\n", " \"\"\"\n", " Set the external force interpolant points.\n", " \"\"\"\n", " self.d_func = interpolate.interp1d(T_S, D_S, kind=interp_kind, axis=0, bounds_error=False)\n", "\n", " def add_state_noise(self, T_S, W_S):\n", " \"\"\"\n", " Set the state noise interpolant points.\n", " \"\"\"\n", " self.w_func = interpolate.interp1d(T_S, W_S, kind='linear', axis=0, bounds_error=False)\n", "\n", " def rates(self, x, t):\n", " \"\"\"\n", " Calculate the system state-rate for the current state x.\n", "\n", " :param: x = current system state [ xp, xpd ]\n", " :param: t = current time\n", "\n", " :returns: xdot = system state-rate\n", " \"\"\"\n", " # Spring-damper forces\n", " # C_SD = np.mat([[ 0.0, self.m ],\n", " # [ self.C['k'], self.C['b']]])\n", " self.C_SD[0, 0] = 0.0\n", " self.C_SD[0, 1] = self.m\n", " self.C_SD[1, 0] = self.C['k']\n", " self.C_SD[1, 1] = self.C['b']\n", "\n", " M_SD = self.C_SD*x.reshape((-1, 1))\n", "\n", " d = np.nan_to_num(self.d_func(t))\n", "\n", " # External force\n", " # M_EF = np.mat([[ 0.0 ],\n", " # [ self.C['d']*f ]])\n", " self.M_EF[0, 0] = 0.0\n", " self.M_EF[1, 0] = self.C['d']*d\n", "\n", " xdot = np.ravel(self.C_M_I*(M_SD + self.M_EF))\n", "\n", " if (self.w_func is not None):\n", " xdot[1] += np.nan_to_num(self.w_func(t))\n", "\n", " return xdot\n", "\n", " def rrates(self, t, x):\n", " \"\"\"\n", " Rates method with arguments reversed.\n", " \"\"\"\n", " return self.rates(x, t, self.d_func)\n", "\n", " def forces(self, xdot, x):\n", " \"\"\"\n", " Calculate the forces from recorded state data.\n", "\n", " :param: xdot = system state rate\n", " :param: x = system state [ xp, xpd ]\n", "\n", " :returns: f = state forces\n", " \"\"\"\n", " xpddot = xdot[1]\n", "\n", " f = self.m*xpddot\n", "\n", " return f\n", "\n", " def integrate(self, x0, T):\n", " \"\"\"\n", " Integrate the differential equations and calculate the resulting rates and forces.\n", "\n", " :param: x0 = initial system state\n", " :param: T = sequence of time points for which to solve for x\n", "\n", " :returns: X = system state array\n", " :returns: Xdot = state rates array\n", " :returns: F = state force array\n", " \"\"\"\n", " N = T.shape[0]\n", "\n", " # Initialise the model\n", " self.init()\n", "\n", " # Perform the integration\n", " X = integrate.odeint(self.rates, x0, T, rtol=1.0e-6, atol=1.0e-6)\n", "\n", " Xdot = np.zeros((N, len(x0)))\n", " for n in range(N):\n", " Xdot[n] = self.rates(X[n], T[n])\n", "\n", " # Force and moment matrix\n", " F = np.zeros((N, 1))\n", " for n in range(N):\n", " F[n] = self.forces(Xdot[n], X[n])\n", "\n", " return X, Xdot, F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But we are going to use the Cython implementation here, as the Python model is very slow!. Simulate and add (Gaussian) noise to the output." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Simulation model\n", "MODEL = 'pyublas' # ['python', 'cython', 'pyublas', 'numba', 'boost']\n", "\n", "if (MODEL == 'python'):\n", " # Pure Python\n", " from msd import MSD\n", "elif (MODEL == 'cython'):\n", " # Cython\n", " from msd.msdc import MSD_CYTHON\n", "elif (MODEL == 'pyublas'):\n", " # PyUblas extension\n", " from msd.msdu import MSD_PYUBLAS\n", "elif (MODEL == 'numba'):\n", " # Numba JIT\n", " from msd.msdn import MSD_NUMBA\n", "elif (MODEL == 'numba_jc'):\n", " # Numba JIT\n", " from msd.msdn import MSD_NUMBA_JC, msd_integrate\n", "elif (MODEL == 'boost'):\n", " # Boost extension\n", " from msd.msdb import MSD_BOOST" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Measurement and state noise standard deviations\n", "NOISE_SD = [ 0.001 for _ in range(3) ]\n", "STATE_NOISE_SD = 0.0\n", "\n", "VERBOSE = False\n", "\n", "# Zero the RNG seed\n", "np.random.seed(1)\n", "\n", "# Initial system state and external force input\n", "x0 = np.zeros((2, ))\n", "d0 = 0.0\n", "\n", "# Sample period\n", "dt = 0.01\n", "\n", "# Start and end time\n", "t0 = 0.0\n", "tN = 15.0\n", "\n", "# Create the time vector\n", "T = np.arange(t0, tN, dt)\n", "N = T.shape[0]\n", "\n", "# Create the predefined external force vector\n", "T_S0 = np.hstack((t0, np.arange(t0 + 1.0, tN + 1.0, 1.0), tN))\n", "D_S0 = np.hstack((d0, np.array([ d0 + ((_ % 2)*2 - 1) * 1.0 for _ in range(T_S0.shape[0] - 2) ]), d0))\n", "interpfun = interpolate.interp1d(T_S0, D_S0, kind='zero', axis=0, bounds_error=False)\n", "D0 = np.array([ [ interpfun(t) ] for t in T ])\n", "T_S = T_S0.copy()\n", "D_S = D_S0.copy()\n", "D = D0.copy()\n", "\n", "# Create the simulation model\n", "if (MODEL == 'python'):\n", " # Pure Python\n", " msd = MSD(\"Mass-Spring-Damper (Python)\")\n", " msd.set_external_forces(T_S, D_S, 'zero')\n", "elif (MODEL == 'cython'):\n", " # Cython\n", " msd = MSD_CYTHON(\"Mass-Spring-Damper (Cython)\")\n", " msd.set_external_forces(T_S, D_S, 'zero')\n", "elif (MODEL == 'pyublas'):\n", " # PyUblas extension\n", " msd = MSD_PYUBLAS(\"Mass-Spring-Damper (PyUblas)\", N)\n", " msd.set_external_forces(T_S, D_S, 'zero')\n", "elif (MODEL == 'numba'):\n", " # Numba JIT\n", " msd = MSD_NUMBA(\"Mass-Spring-Damper (Numba)\", N)\n", " msd.set_external_forces(T_S, D_S, 'zero')\n", "elif (MODEL == 'numba_jc'):\n", " # Numba JIT\n", " msd = MSD_NUMBA_JC(N)\n", " msd.set_external_forces(T_S, D_S, 0)\n", "elif (MODEL == 'boost'):\n", " # Boost extension\n", " msd = MSD_BOOST(\"Mass-Spring-Damper (Boost)\", N)\n", " msd.set_external_forces(T_S, D_S, 'zero')\n", "\n", "# Identification keys\n", "c_idx = ['k', 'b', 'd']\n", "\n", "# True parameter set\n", "CT = [ msd.get_coeffs()[ck] for ck in c_idx ]\n", "\n", "# Initial parameter set\n", "C0 = [ 0.5*msd.get_coeffs()[ck] for ck in c_idx ]\n", "\n", "# Add any state noise\n", "if (STATE_NOISE_SD > 0.0):\n", " sdw = STATE_NOISE_SD\n", " W = np.random.randn(N, 1)*sdw\n", " msd.add_state_noise(T, W)\n", "\n", "# Compute the response\n", "if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " X, Xdot, F = msd.integrate(x0, T)\n", "elif (MODEL == 'numba_jc'):\n", " X, Xdot, F = msd_integrate(msd, x0, T)\n", "\n", "# State noise standard deviation vector\n", "sdz = np.zeros((len(x0),))\n", "if any(w > 0.0 for w in NOISE_SD[:2]):\n", " sdz = np.array(NOISE_SD[:2])\n", "\n", "# Measured state matrix\n", "Z = X + np.random.randn(N, len(x0))*sdz\n", "\n", "# Set the initial measured state equal to the initial true state\n", "z0 = x0\n", "\n", "Nu = Z[:,1]\n", "sdnu = sdz[1]\n", "\n", "# State rate noise standard deviation vector\n", "sdzdot = np.zeros((len(x0),))\n", "if any(w > 0.0 for w in NOISE_SD[:2]):\n", " sdzdot = np.array(NOISE_SD[:2])\n", "\n", "# Measured state rate matrix\n", "Zdot = Xdot + np.random.randn(N, len(x0))*sdzdot\n", "\n", "# External force noise standard deviation vector\n", "sde = 0.0\n", "if (NOISE_SD[2] > 0.0):\n", " sde = NOISE_SD[2]\n", "\n", "# Measured external force matrix\n", "E = D + np.random.randn(N, 1)*sde\n", "\n", "# Set the initial measured external force equal to the initial true external force\n", "e0 = d0\n", "\n", "# Compute the inertial force and noise standard deviation\n", "G = F.copy()\n", "sdg = 0.0\n", "if any(w > 0.0 for w in NOISE_SD):\n", " # Forces are calculated from (measured) accelerations, not measured directly\n", " for n in range(N):\n", " G[n] = msd.forces(Zdot[n], Z[n])\n", " sdg = np.std(F - G)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the measured response and control input." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as pp" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, Axes, Lines, Text = plot(msd_est.name, T, E, Z, G, Xe=Xe, Fe=Fe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nice. Ok, now let's increase the amount of noise in the state & state rate by a factor of 5." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Observation and state noise standard deviations\n", "NOISE_SD = [ 0.005 for _ in range(3) ]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "PLOT_SIM = False\n", "%run -i sim.py" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, Axes, Lines, Text = plot(msd_est.name, T, E, Z, G, Xe=Xe, Fe=Fe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That is not good! This is primarily due to the fact that the error in the independant variables is introducing bias into the LS estimate.\n", "\n", "Let's see how the same model compares with a different response dataset." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "C_LS0 = C_LS" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# Create the predefined external force vector\n", "T_S1 = np.hstack((t0, np.arange(t0 + 1.5, tN + 1.5, 1.5), tN))\n", "D_S1 = np.hstack((d0, np.array([ d0 + np.random.randint(-2, 3)/2.0*(1.0 - np.abs(d0)) for _ in range(T_S1.shape[0] - 2) ]), d0))\n", "interpfun = interpolate.interp1d(T_S1, D_S1, kind='zero', axis=0, bounds_error=False)\n", "D1 = np.array([ [ interpfun(t) ] for t in T ])\n", "T_S = T_S1.copy()\n", "D_S = D_S1.copy()\n", "D = D1.copy()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%run -i sim.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll use the new external force input to stimulate our model." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "msd_est.set_external_forces(T, E, 'linear_uniform')\n", "\n", "for i in range(len(c_idx)):\n", " msd_est.set_coeffs({ 'k': C_LS0[0], 'b': C_LS0[1], 'd': C_LS0[2] })\n", "\n", "# Compute the response\n", "Xe, Xedot, Fe = msd_est.integrate(z0, T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And plot the system response." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "POWELL'S MINIMIZATION:\n", "Optimization terminated successfully.\n", " Current function value: 87.164648\n", " Iterations: 6\n", " Function evaluations: 434\n", "Time elapsed: 73.060531 seconds\n", "\n", " TRUE F_EST\n", "k : -50.0000 -50.2010\n", "b : -10.0000 -8.9262\n", "d : 1.0000 0.9761\n" ] } ], "source": [ "FF = ml.repmat(None, 50, 1)\n", "\n", "# Create the simulation model\n", "if (MODEL == 'python'):\n", " # Pure Python\n", " msd_fest = MSD(\"Mass-Spring-Damper_FMIN_EST\")\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'cython'):\n", " # Cython\n", " msd_fest = MSD_CYTHON(\"Mass-Spring-Damper_FMIN_EST (Cython)\")\n", " msd_fest.set_external_forces(T, E, 'linear_uniform')\n", "elif (MODEL == 'pyublas'):\n", " # PyUblas extension\n", " msd_fest = MSD_PYUBLAS(\"Mass-Spring-Damper_FMIN_EST (PyUblas)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'numba'):\n", " # Numba JIT\n", " msd_fest = MSD_NUMBA(\"Mass-Spring-Damper_FMIN_EST (Numba)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'numba_jc'):\n", " # Numba JIT\n", " msd_fest = MSD_NUMBA_JC(N)\n", " msd_fest.set_external_forces(T, E, 1)\n", "elif (MODEL == 'boost'):\n", " # Boost extension\n", " msd_fest = MSD_BOOST(\"Mass-Spring-Damper_FMIN_EST (Boost)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_uniform')\n", "\n", "fig, Axes, Lines, Text = plot(msd_fest.name, T, E, Z, G, Xe=np.zeros(X.shape), Fe=np.zeros(F.shape), FF=FF)\n", "fig.canvas.draw()\n", "\n", "kws = { 'fig': fig, 'Axes': Axes, 'Lines': Lines, 'Text': Text }\n", "\n", "c_idx = ['k', 'b', 'd']\n", "\n", "print(\"POWELL'S MINIMIZATION:\")\n", "\n", "class Objfun(object):\n", "\n", " def __init__(self, z0, T, G, FF):\n", " self.z0 = z0\n", " self.T = T\n", " self.G = G\n", " self.FF = FF\n", " self.fopt_max = None\n", " self.it = 0;\n", "\n", " def __call__(self, C, fig, Axes, Lines, Text):\n", " msd_fest.set_coeffs({ 'k': C[0], 'b': C[1], 'd': C[2] })\n", "\n", " # Compute the response\n", " if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " Xe, Xedot, Fe = msd_fest.integrate(z0, T)\n", " elif (MODEL == 'numba_jc'):\n", " Xe, Xedot, Fe = msd_integrate(msd_fest, z0, T)\n", "\n", " # For fmin, fmin_powell, fmin_bfgs, fmin_l_bfgs_b\n", " dF = G - Fe\n", " fopt_sum = np.sum(dF*dF)\n", "\n", " if (self.it < np.size(FF, 0)):\n", " self.FF[self.it, 0] = math.log(fopt_sum)\n", " else:\n", " self.FF = np.roll(self.FF, -1)\n", " self.FF[-1, 0] = math.log(fopt_sum)\n", "\n", " f_max = None\n", " if ((self.fopt_max is None) or (self.fopt_max < math.log(fopt_sum))):\n", " f_max = math.log(fopt_sum) * 1.1\n", " self.fopt_max = math.log(fopt_sum)\n", " # rescale = True\n", " f_txt = '{:.4f}'.format(fopt_sum)\n", "\n", " updateplot(fig, Axes, Lines, Text, Xe, Fe, self.FF, f_max=f_max, f_txt=f_txt, c_txt=C)\n", "\n", " self.it += 1\n", "\n", " return fopt_sum\n", "\n", "tic = time.clock()\n", "\n", "objfun = Objfun(z0, T, G, FF)\n", "\n", "# Need to start with a nontrivial parameter set to avoid getting stuck in a local minima straight away...\n", "C = optimize.fmin_powell(objfun, C0, args=( fig, Axes, Lines, Text ), maxiter=100)\n", "\n", "toc = time.clock() - tic\n", "print(\"Time elapsed: {:f} seconds\".format(toc))\n", "\n", "C_PM = C.tolist()\n", "\n", "print()\n", "print(\" TRUE F_EST\")\n", "for i in range(len(c_idx)):\n", " ck = c_idx[i]\n", " print(\"{:5s}: {:10.4f} {:10.4f}\".format(ck, msd.get_coeffs()[ck], C_PM[i]))\n", "\n", "msd_fest.set_coeffs({ 'k': C_PM[0], 'b': C_PM[1], 'd': C_PM[2] })\n", "\n", "# Compute the response\n", "if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " Xe, Xedot, Fe = msd_fest.integrate(z0, T)\n", "elif (MODEL == 'numba_jc'):\n", " Xe, Xedot, Fe = msd_integrate(msd_fest, z0, T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Much better! Although it did take some time. And the system response is quite good too. Let's check again using a different input and response dataset." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "C_PM0 = C_PM" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "T_S = T_S1.copy()\n", "D_S = D_S1.copy()\n", "D = D1.copy()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "%run -i sim.py" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "msd_est.set_external_forces(T, E, 'linear_uniform')\n", "\n", "for i in range(len(c_idx)):\n", " msd_est.set_coeffs({ 'k': C_PM0[0], 'b': C_PM0[1], 'd': C_PM0[2] })\n", "\n", "# Compute the response\n", "if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " Xe, Xedot, Fe = msd_est.integrate(z0, T)\n", "elif (MODEL == 'numba_jc'):\n", " Xe, Xedot, Fe = msd_integrate(msd_est, z0, T)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LEVENBERG-MARQUARDT OPTIMIZATION:\n", "Time elapsed: 5.299416 seconds\n", "\n", " TRUE F_EST\n", "k : -50.0000 -49.5020\n", "b : -10.0000 -9.0243\n", "d : 1.0000 0.9791\n" ] } ], "source": [ "FF = ml.repmat(None, 50, 1)\n", "\n", "# Create the simulation model\n", "if (MODEL == 'python'):\n", " # Pure Python\n", " msd_fest = MSD(\"Mass-Spring-Damper_FMIN_EST\")\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'cython'):\n", " # Cython\n", " msd_fest = MSD_CYTHON(\"Mass-Spring-Damper_FMIN_EST (Cython)\")\n", " msd_fest.set_external_forces(T, E, 'linear_uniform')\n", "elif (MODEL == 'pyublas'):\n", " # PyUblas extension\n", " msd_fest = MSD_PYUBLAS(\"Mass-Spring-Damper_FMIN_EST (PyUblas)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'numba'):\n", " # Numba JIT\n", " msd_fest = MSD_NUMBA(\"Mass-Spring-Damper_FMIN_EST (Numba)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_unifom')\n", "elif (MODEL == 'numba_jc'):\n", " # Numba JIT\n", " msd_fest = MSD_NUMBA_JC(N)\n", " msd_fest.set_external_forces(T, E, 1)\n", "elif (MODEL == 'boost'):\n", " # Boost extension\n", " msd_fest = MSD_BOOST(\"Mass-Spring-Damper_FMIN_EST (Boost)\", N)\n", " msd_fest.set_external_forces(T, E, 'linear_uniform')\n", "\n", "fig, Axes, Lines, Text = plot(msd_fest.name, T, E, Z, G, Xe=np.zeros(X.shape), Fe=np.zeros(F.shape), FF=FF)\n", "fig.canvas.draw()\n", "\n", "kws = { 'fig': fig, 'Axes': Axes, 'Lines': Lines, 'Text': Text }\n", "\n", "c_idx = ['k', 'b', 'd']\n", "\n", "print(\"LEVENBERG-MARQUARDT OPTIMIZATION:\")\n", "\n", "class Fcn2min(object):\n", "\n", " def __init__(self, z0, T, G, FF):\n", " self.z0 = z0\n", " self.T = T\n", " self.G = G\n", " self.FF = FF\n", " self.fopt_max = None\n", " self.it = 0;\n", "\n", " def __call__(self, P, **kws):\n", " C = [ P[c_idx[i]].value for i in range(len(c_idx)) ]\n", " msd_fest.set_coeffs({ 'k': C[0], 'b': C[1], 'd': C[2] })\n", "\n", " # Compute the response\n", " if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " Xe, Xedot, Fe = msd_fest.integrate(self.z0, self.T)\n", " elif (MODEL == 'numba_jc'):\n", " Xe, Xedot, Fe = msd_integrate(msd_fest, self.z0, self.T)\n", "\n", " fopt = np.ravel(self.G - Fe)\n", " fopt_sum = np.sum(fopt*fopt)\n", "\n", " if (self.it < np.size(FF, 0)):\n", " self.FF[self.it, 0] = fopt_sum\n", " else:\n", " self.FF = np.roll(self.FF, -1)\n", " self.FF[-1, 0] = fopt_sum\n", "\n", " f_max = None\n", " if ((self.fopt_max is None) or (self.fopt_max < fopt_sum)):\n", " f_max = fopt_sum * 1.1\n", " self.fopt_max = fopt_sum\n", " f_txt = '{:.4f}'.format(fopt_sum)\n", "\n", " updateplot(kws['fig'], kws['Axes'], kws['Lines'], kws['Text'], Xe, Fe, self.FF, f_max=f_max, f_txt=f_txt, c_txt=C)\n", "\n", " self.it += 1\n", "\n", " return fopt\n", "\n", "# Create a set of Parameters\n", "P = lm.Parameters()\n", "for i in range(len(c_idx)):\n", " ck = c_idx[i]\n", " P.add(ck, value=C0[i])\n", "\n", "tic = time.clock()\n", "\n", "fcn2min = Fcn2min(z0, T, G, FF)\n", "\n", "# Do fit, here with leastsq model\n", "res = lm.minimize(fcn2min, P, kws=kws, method='leastsq', epsfcn=0.1)\n", "\n", "toc = time.clock() - tic\n", "print(\"Time elapsed: {:f} seconds\".format(toc))\n", "\n", "C_LM = [ res.params[c_idx[i]].value for i in range(len(c_idx)) ]\n", "\n", "print()\n", "print(\" TRUE F_EST\")\n", "for i in range(len(c_idx)):\n", " ck = c_idx[i]\n", " print(\"{:5s}: {:10.4f} {:10.4f}\".format(ck, msd.get_coeffs()[ck], C_LM[i]))\n", "\n", "msd_fest.set_coeffs({ 'k': C_LM[0], 'b': C_LM[1], 'd': C_LM[2] })\n", "\n", "# Compute the response\n", "if (MODEL in ['python', 'cython', 'pyublas', 'numba', 'boost']):\n", " Xe, Xedot, Fe = msd_fest.integrate(z0, T)\n", "elif (MODEL == 'numba_jc'):\n", " Xe, Xedot, Fe = msd_integrate(msd_fest, z0, T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting close... and a lot faster!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Last, we're going to try to fit our model using Markov-Chain Monte Carlo (MCMC)." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "import pymc as mc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Build the model and run the sampler." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BUILDING PROBABILITY DISTRIBUTION MODEL\n" ] }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "C_str = { 'k' : r'$k$', 'b' : r'$b$', 'd' : r'$d$' }\n", "\n", "nc = 3 # number of columns\n", "nr = len(c_idx) # number of rows\n", "\n", "figname = msd_best.name + \" Coefficient Traces, Autocorrelations & Histograms\"\n", "\n", "fig, AxesArr = pp.subplots(nr, nc, figsize=(10.0, 6.0))\n", "\n", "acorr_maxlags = 100\n", "hist_num_bins = min(50, max(10, Nc[-1]/250))\n", "hist_hpd_alpha = 0.05\n", "\n", "Nccs = [ 0 ] + Nc\n", "for k in range(1, len(Nccs)):\n", " Nccs[k] += Nccs[k - 1]\n", "\n", "for i in range(nr):\n", " ck = c_idx[i]\n", "\n", " assert (len(mcmc_trace[-1][ck]) == Nc[-1]), \"Trace length for {:s} ({:d}) is not equal to {:d}!\".format(ck, len(mcmc_trace[-1][ck]), Nc[-1])\n", "\n", " c = np.mean(mcmc_trace[-1][ck])\n", "\n", " ax = AxesArr[i, 0]\n", " ax.grid(color='lightgrey', linestyle=':')\n", " ax.tick_params(axis='both', which='major', labelsize=10)\n", " for k in range(len(Nc)):\n", " alpha = 0.5\n", " if (k == len(Nc) - 1):\n", " alpha = 1.0\n", " ax.plot(range(Nccs[k], Nccs[k + 1]), mcmc_trace[k][ck], alpha=alpha, color='seagreen', linestyle='-', linewidth=1.0, zorder=2)\n", " if (k > 0):\n", " ax.axvline(Nccs[k], alpha=0.75, linestyle='--', linewidth=1.5, color='darkgreen')\n", " ax.set_xlim(0, sum(Nc))\n", " ax.set_ylabel(C_str[ck], rotation='horizontal')\n", " if (i == 0):\n", " ax.set_title(\"Trace\")\n", "\n", " ax = AxesArr[i, 1]\n", " ax.grid(color='lightgrey', linestyle=':')\n", " # Calculate the autocorrelation (raw and detrended)\n", " (acorr_lags, acorr_c, acorr_line, acorr_b) = ax.acorr(mcmc_trace[-1][ck], detrend=mm.detrend_none, linewidth=0.0, markersize=0.0, maxlags=acorr_maxlags, usevlines=False)\n", " ax.fill_between(acorr_lags, acorr_line.get_ydata(), alpha=0.25, color='crimson', linewidth=0.0)\n", " ax.acorr(mcmc_trace[-1][ck], color='crimson', detrend=mm.detrend_mean, linestyle='-', linewidth=1.5, maxlags=acorr_maxlags)\n", " ax.set_xlim(-acorr_maxlags, acorr_maxlags)\n", " ax.set_ylim(-0.1, 1.1)\n", " ax.set_ylabel(C_str[ck], rotation='horizontal')\n", " if (i == 0):\n", " ax.set_title(\"Autocorrelation (detrended)\")\n", "\n", " ax = AxesArr[i, 2]\n", " ax.grid(color='lightgrey', linestyle=':')\n", " # Calculate the median and 95% Highest Probability Density (HPD) or minimum width Bayesian Confidence (BCI) interval\n", " hist_quant = calc_quantiles(mcmc_trace[-1][ck])\n", " hist_hpd = calc_hpd(mcmc_trace[-1][ck], hist_hpd_alpha)\n", " (hist_n, hist_bins, hist_patches) = ax.hist(mcmc_trace[-1][ck], bins=hist_num_bins, color='steelblue', histtype='stepfilled', linewidth=0.0, normed=True, zorder=2)\n", " ax.set_ylim(0.0, max(hist_n)*1.1)\n", " ax.axvspan(hist_hpd[0], hist_hpd[1], alpha=0.25, facecolor='darkslategray', linewidth=1.5)\n", " ax.axvline(hist_quant[50], linestyle='-', linewidth=1.5, color='darkslategray')\n", " ax.set_ylabel(C_str[ck], rotation='horizontal')\n", " if (i == 0):\n", " ax.set_title(\"Posterior ({:2.0f}% HPD)\".format((1.0 - hist_hpd_alpha)*100.0))\n", "\n", "pp.subplots_adjust(left=0.1, wspace=0.3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nice, but we can do better by starting with a more accurate initial estimate and running more samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Number of samples for chain-0 (initial estimation) and chain-1 (training)\n", "NUM_SAMPLES = [ 5000 ]\n", "\n", "C0 = C_LM" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%run -i bms.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the traces, autocorrelation and posterior distributions for each of the coefficients." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%run -i bmsplot.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Excellent! With a fast, accurate optimisation initially and more samples for the adaptive phase, we can get a good estimate of the posteriors." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 1 }