{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "hideCode": true, "hideOutput": true, "hidePrompt": true, "jupyter": { "source_hidden": true }, "slideshow": { "slide_type": "skip" }, "tags": [ "remove-cell", "skip-execution" ] }, "outputs": [], "source": [ "# WARNING: advised to install a specific version, e.g. tensorwaves==0.1.2\n", "%pip install -q tensorwaves[doc,jax,pwa,viz] IPython" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hideCode": true, "hideOutput": true, "hidePrompt": true, "jupyter": { "source_hidden": true }, "slideshow": { "slide_type": "skip" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import os\n", "\n", "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{autolink-concat}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Speed up lambdifying" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "import logging\n", "\n", "import ampform\n", "import graphviz\n", "import qrules\n", "import sympy as sp\n", "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n", "from IPython.display import HTML, SVG\n", "\n", "from tensorwaves.function.sympy import (\n", " create_parametrized_function,\n", " fast_lambdify,\n", " split_expression,\n", ")\n", "\n", "logging.getLogger(\"tensorwaves.data\").setLevel(logging.ERROR) # hide progress bars" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{note}\n", "\n", "Since [#374](https://github.com/ComPWA/tensorwaves/pull/374), expressions are lambdified with common sub-expressions. This should already reduce lambdification time significantly and also results in faster computational functions.\n", "\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split expression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let's illustrate that idea with the following simplified example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y, z = sp.symbols(\"x:z\")\n", "expr = x**z + 2 * y + sp.log(y * z)\n", "expr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This expression can be represented in a tree of mathematical operations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "dot = sp.dotprint(expr, bgcolor=\"none\")\n", "graphviz.Source(dot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function {func}`.split_expression` can now be used to split up this expression tree into a 'top expression' plus definitions for each of the sub-expressions into which it was split:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_expr, sub_expressions = split_expression(expr, max_complexity=3)\n", "top_expr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sub_expressions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The original expression can easily be reconstructed with {meth}`~sympy.core.basic.Basic.subs` or {meth}`~sympy.core.basic.Basic.xreplace`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_expr.xreplace(sub_expressions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of the expression trees are now smaller than the original:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "dot = sp.dotprint(top_expr, bgcolor=\"none\")\n", "graphviz.Source(dot)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "for symbol, definition in sub_expressions.items():\n", " dot = sp.dotprint(definition, bgcolor=\"none\")\n", " graph = graphviz.Source(dot)\n", " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", "\n", "html = \"\\n\"\n", "html += \" \\n\"\n", "html += \"\".join(\n", " f' \\n'\n", " for symbol in sub_expressions\n", ")\n", "html += \" \\n\"\n", "html += \" \\n\"\n", "for symbol in sub_expressions:\n", " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n", " html += f' \\n'\n", "html += \" \\n\"\n", "html += \"
{symbol.name}
{svg}
\"\n", "HTML(html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fast lambdify" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generally, the lambdify time scales exponentially with the size of an expression tree. With larger expression trees, it's therefore much faster to lambdify these sub-expressions separately and to recombine them. TensorWaves offers a function that does this for you: {func}`.fast_lambdify`. We'll use an {class}`~ampform.helicity.HelicityModel` to illustrate this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "reaction = qrules.generate_transitions(\n", " initial_state=(\"J/psi(1S)\", [+1]),\n", " final_state=[\"gamma\", \"pi0\", \"pi0\"],\n", " allowed_intermediate_particles=[\"f(0)\"],\n", ")\n", "model_builder = ampform.get_builder(reaction)\n", "for name in reaction.get_intermediate_particles().names:\n", " model_builder.dynamics.assign(name, create_relativistic_breit_wigner_with_ff)\n", "model = model_builder.formulate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "expression = model.expression.doit()\n", "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{autolink-skip}\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "split_function = fast_lambdify(\n", " expression,\n", " sorted_symbols,\n", " max_complexity=100,\n", " backend=\"numpy\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "split_function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{autolink-skip}\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%%time\n", "\n", "normal_function = sp.lambdify(sorted_symbols, expression)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "normal_function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Specifying complexity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When creating a parametrized function, we use the {func}`.create_parametrized_function` function. By default, this internally calls SymPy's own {func}`~sympy.utilities.lambdify.lambdify` function. But if you specify its `max_complexity` argument, {func}`.create_parametrized_function` uses TensorWaves's {func}`.fast_lambdify`.\n", "\n", "```{autolink-skip}\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "function = create_parametrized_function(\n", " expression=model.expression.doit(),\n", " parameters=model.parameter_defaults,\n", " max_complexity=100,\n", " backend=\"numpy\",\n", ")" ] } ], "metadata": { "colab": { "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 4 }