{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimal complexity finite element assembly" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sum factorisation\n", "Suppose we are assemblying the 2-form of the Laplace operator on a hexahedral element, e.g. the Q3 element:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\"Q3\"
\n", "
$Q_3$ element on hexahedron. Image from the periodic table.
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since each quadrature point $X$ can be represented as a 3-tuple $q=\\{q_1,q_2,q_3\\}$, and each basis function by a 3-tuple $i = \\{i_1,i_2,i_3\\}$, the naive local assembly kernel for the local tensor $A_{ij}$ contains the loop structure:\n", "```\n", "for q1, q2, q3\n", " for i1, i2, i3\n", " for j1, j2, j3\n", " A[i1,i2,i3,j1,j2,j3] += ...\n", "```\n", "This requires $O(N_{q}^3N_{i}^6)$ FLOPs. For polynomial degree $p$, both $N_q$ and $N_i$ are $O(p)$, so this local assembly requires $O(p^9)$ FLOPs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For *tensor product elements* like this, we can rearrange the contraction over quadrature points and hoist invariant sub-expressions out of the innermost loop into temporary variables. This is known as *sum factorisation*:\n", "```\n", "for q1, i1, j1\n", " t1[i1,j1] += ...\n", "for q2, i2, j2\n", " t2[i2,j2] += ...\n", "for q3\n", " for i1, i2, i3\n", " for j1, j2, j3\n", " A[i1,i2,i3,j1,j2,j3] += t1*t2*...\n", "```\n", "This reduces the complexity to $O(p^7)$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TSFC \\[1\\], the form compiler of Firedrake, is capable of exploiting this intrinsic structure of the finite element, provided by FInAT \\[2\\], and apply sum factorisation automatically to generate assembly kernels with optimal algorithmic complexity." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib notebook\n", "from firedrake import *\n", "set_log_level(ERROR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can create a hexahedral mesh by extruding a quadrilateral mesh." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mesh = ExtrudedMesh(UnitSquareMesh(10, 10, quadrilateral=True), 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's choose the continuous Lagrange element of degree 5 as our function space." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "p = 5\n", "V = FunctionSpace(mesh, \"CG\", p)\n", "u = TrialFunction(V)\n", "v = TestFunction(V)\n", "a = dot(grad(u), grad(v)) *dx # Laplace operator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Firedrake internalises the process of local assembly. In order to look at the kernel, we need to import the compilation interface from TSFC." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from tsfc import compile_form" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TSFC organises the optimisation passes into *modes*. Let's first try the *vanilla* mode, which does as little optimisation as possible:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "kernel_vanilla, = compile_form(a, parameters={\"mode\": \"vanilla\"})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The COFFEE package provides some useful tools to inspect the local assembly kernel, such as the FLOPs estimator." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local assembly FLOPs with vanilla mode is 1.15e+08\n" ] } ], "source": [ "from coffee.visitors import EstimateFlops\n", "print(\"Local assembly FLOPs with vanilla mode is {0:.3g}\".format(EstimateFlops().visit(kernel_vanilla.ast)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default optimisation mode in TSFC is *spectral*, which applies sum factorisation to determine the tensor contraction order, and at each level, apply *argument factorisation* \\[3\\] to rearrange the expression using associative and distributive laws. Since *spectral* is the default mode, we do not need to specify it in the parameters." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local assembly FLOPs with spectral mode is 2.66e+06\n" ] } ], "source": [ "kernel_spectral, = compile_form(a)\n", "print(\"Local assembly FLOPs with spectral mode is {0:.3g}\".format(EstimateFlops().visit(kernel_spectral.ast)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a 34x reduction in FLOPs. Not bad, but there's opportunity to do better. For spectral elements, if we use the Gauss–Lobatto–Legendre (GLL) quadrature scheme, which has quadrature points collated with the Lagrange basis function nodes, then we know that the basis function tabulation is an indentity matrix. TSFC and FInAT can further simplify the loop structure of the local assembly kernels. This reduces the complexity to $O(p^5)$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to specify the GLL quadrature scheme for hexahedra. We can do this with FIAT, which defines GLL on intervals, and FInAT, which makes the tensor product scheme from the interval scheme." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import FIAT, finat\n", "\n", "def gauss_lobatto_legendre_line_rule(degree):\n", " fiat_make_rule = FIAT.quadrature.GaussLobattoLegendreQuadratureLineRule\n", " fiat_rule = fiat_make_rule(FIAT.ufc_simplex(1), degree + 1)\n", " finat_ps = finat.point_set.GaussLobattoLegendrePointSet\n", " finat_qr = finat.quadrature.QuadratureRule\n", " return finat_qr(finat_ps(fiat_rule.get_points()), fiat_rule.get_weights())\n", "\n", "def gauss_lobatto_legendre_cube_rule(dimension, degree):\n", " make_tensor_rule = finat.quadrature.TensorProductQuadratureRule\n", " result = gauss_lobatto_legendre_line_rule(degree)\n", " for _ in range(1, dimension):\n", " line_rule = gauss_lobatto_legendre_line_rule(degree)\n", " result = make_tensor_rule([result, line_rule])\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by creating the spectral finite element function space of the same polynomial degree." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "element = FiniteElement('CG', mesh.ufl_cell(), degree=p, variant='spectral')\n", "V = FunctionSpace(mesh, element)\n", "u = TrialFunction(V)\n", "v = TestFunction(V)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to tell Firedrake to use the GLL quadratures for numerical integration." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "gll_quadrature_rule = gauss_lobatto_legendre_cube_rule(dimension=3, degree=p)\n", "a_gll = dot(grad(u), grad(v)) *dx(rule=gll_quadrature_rule)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local assembly FLOPs with GLL quadrature is 2.21e+05\n" ] } ], "source": [ "kernel_gll, = compile_form(a_gll)\n", "print(\"Local assembly FLOPs with GLL quadrature is {0:.3g}\".format(EstimateFlops().visit(kernel_gll.ast)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a further 10x reduction in FLOPs.\n", "\n", "Now, let's verify that we achieve the expected asymptotic algorithmic complexity with respect to polynomial degrees." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "import matplotlib.pyplot as plt\n", "import numpy" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "flops = defaultdict(list)\n", "ps = range(1, 33) # polynomial degrees\n", "modes = {\n", " 'gll': {'mode': 'spectral', 'variant': 'spectral', 'rule': gauss_lobatto_legendre_cube_rule},\n", " 'spectral': {'mode': 'spectral', 'variant': None, 'rule': lambda *args: None},\n", " 'vanilla': {'mode': 'vanilla', 'variant': None, 'rule': lambda *args: None}\n", "}\n", "\n", "for p in ps:\n", " for mode in modes:\n", " element = FiniteElement('CG', mesh.ufl_cell(), degree=p, variant=modes[mode]['variant'])\n", " V = FunctionSpace(mesh, element)\n", " u = TrialFunction(V)\n", " v = TestFunction(V)\n", " a = dot(grad(u), grad(v))*dx(rule=modes[mode]['rule'](3, p))\n", " kernel, = compile_form(a, parameters={\"mode\": modes[mode]['mode']})\n", " flops[mode].append(EstimateFlops().visit(kernel.ast))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(1, 1)\n", "ax.set_xscale('log')\n", "ax.set_yscale('log')\n", "for mode in modes:\n", " ax.plot(ps_curl, flops_curl[mode], label=mode)\n", "x = numpy.linspace(1, 16, 100)\n", "for p, style, offset in zip([5,7,9], ['-.','--',':'], [800,40,60]):\n", " ax.plot(x, numpy.power(x, p)*offset, label=r\"$p^{0}$\".format(p), color='grey', linestyle=style)\n", "ax.legend(loc='upper left');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\\[1\\] Homolya, M., Mitchell, L., Luporini, F. and Ham, D.A., 2018. TSFC: a structure-preserving form compiler. SIAM Journal on Scientific Computing, 40(3), pp.C401-C428.\n", "\n", "\\[2\\] Homolya, M., Kirby, R.C. and Ham, D.A., 2017. Exposing and exploiting structure: optimal code generation for high-order finite element methods. arXiv preprint arXiv:1711.02473.\n", "\n", "\\[3\\] Luporini, F., Ham, D.A. and Kelly, P.H., 2017. An algorithm for the optimization of finite element integration loops. ACM Transactions on Mathematical Software (TOMS), 44(1), p.3." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }