{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "hideCode": true, "hideOutput": true, "hidePrompt": true, "jupyter": { "source_hidden": true }, "slideshow": { "slide_type": "skip" }, "tags": [ "remove-cell", "skip-execution" ] }, "outputs": [], "source": [ "# WARNING: advised to install a specific version, e.g. tensorwaves==0.1.2\n", "%pip install -q tensorwaves[doc,jax,pwa,viz] IPython" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hideCode": true, "hideOutput": true, "hidePrompt": true, "jupyter": { "source_hidden": true }, "slideshow": { "slide_type": "skip" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import os\n", "\n", "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{autolink-concat}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Speed up lambdifying" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "import logging\n", "\n", "import ampform\n", "import graphviz\n", "import qrules\n", "import sympy as sp\n", "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n", "from IPython.display import HTML, SVG\n", "\n", "from tensorwaves.function.sympy import (\n", " create_parametrized_function,\n", " fast_lambdify,\n", " split_expression,\n", ")\n", "\n", "logging.getLogger(\"tensorwaves.data\").setLevel(logging.ERROR) # hide progress bars" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{note}\n", "\n", "Since [#374](https://github.com/ComPWA/tensorwaves/pull/374), expressions are lambdified with common sub-expressions. This should already reduce lambdification time significantly and also results in faster computational functions.\n", "\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split expression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let's illustrate that idea with the following simplified example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y, z = sp.symbols(\"x:z\")\n", "expr = x**z + 2 * y + sp.log(y * z)\n", "expr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This expression can be represented in a tree of mathematical operations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "dot = sp.dotprint(expr, bgcolor=\"none\")\n", "graphviz.Source(dot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function {func}`.split_expression` can now be used to split up this expression tree into a 'top expression' plus definitions for each of the sub-expressions into which it was split:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_expr, sub_expressions = split_expression(expr, max_complexity=3)\n", "top_expr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sub_expressions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The original expression can easily be reconstructed with {meth}`~sympy.core.basic.Basic.subs` or {meth}`~sympy.core.basic.Basic.xreplace`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_expr.xreplace(sub_expressions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of the expression trees are now smaller than the original:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "dot = sp.dotprint(top_expr, bgcolor=\"none\")\n", "graphviz.Source(dot)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "for symbol, definition in sub_expressions.items():\n", " dot = sp.dotprint(definition, bgcolor=\"none\")\n", " graph = graphviz.Source(dot)\n", " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", "\n", "html = \"
{symbol.name} | \\n'\n", " for symbol in sub_expressions\n", ")\n", "html += \"
---|
{svg} | \\n'\n", "html += \"