{ "cells": [ { "cell_type": "markdown", "id": "2ce28f2b-008b-4beb-ba51-5c478e4bdec7", "metadata": {}, "source": [ "# Symbolic Regression" ] }, { "cell_type": "markdown", "id": "d01a09d3-f78c-4b46-acd9-3a4250bb34c9", "metadata": {}, "source": [ "This example combines neural differential equations with regularised evolution to discover the equations\n", "\n", "$\\frac{\\mathrm{d} x}{\\mathrm{d} t}(t) = \\frac{y(t)}{1 + y(t)}$\n", "\n", "$\\frac{\\mathrm{d} y}{\\mathrm{d} t}(t) = \\frac{-x(t)}{1 + x(t)}$\n", "\n", "directly from data.\n", "\n", "**References:**\n", "\n", "This example appears as an example in Section 6.1 of:\n", "\n", "```bibtex\n", "@phdthesis{kidger2021on,\n", " title={{O}n {N}eural {D}ifferential {E}quations},\n", " author={Patrick Kidger},\n", " year={2021},\n", " school={University of Oxford},\n", "}\n", "```\n", "\n", "Whilst drawing heavy inspiration from:\n", "\n", "```bibtex\n", "@inproceedings{cranmer2020discovering,\n", " title={{D}iscovering {S}ymbolic {M}odels from {D}eep {L}earning with {I}nductive\n", " {B}iases},\n", " author={Cranmer, Miles and Sanchez Gonzalez, Alvaro and Battaglia, Peter and\n", " Xu, Rui and Cranmer, Kyle and Spergel, David and Ho, Shirley},\n", " booktitle={Advances in Neural Information Processing Systems},\n", " publisher={Curran Associates, Inc.},\n", " year={2020},\n", "}\n", "\n", "@software{cranmer2020pysr,\n", " title={PySR: Fast \\& Parallelized Symbolic Regression in Python/Julia},\n", " author={Miles Cranmer},\n", " publisher={Zenodo},\n", " url={http://doi.org/10.5281/zenodo.4041459},\n", " year={2020},\n", "}\n", "```\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/symbolic_regression.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "dea04fa4-a95b-47f8-b0eb-be297037bc7d", "metadata": {}, "outputs": [], "source": [ "import tempfile\n", "\n", "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", "import jax\n", "import jax.numpy as jnp\n", "import optax # https://github.com/deepmind/optax\n", "import pysr # https://github.com/MilesCranmer/PySR\n", "import sympy\n", "import sympy2jax # https://github.com/google/sympy2jax\n", "\n", "\n", "# Note that PySR, which we use for symbolic regression, uses Julia as a backend.\n", "# You'll need to install a recent version of Julia if you don't have one.\n", "# (And can get funny errors if you have a too-old version of Julia already.)\n", "# You may also need to restart Python after running `pysr.install()` the first time.\n", "pysr.install(quiet=True)" ] }, { "cell_type": "markdown", "id": "4d26c41f-7682-4ad0-aa33-77e22b2768f8", "metadata": {}, "source": [ "Now two helpers. We'll use these in a moment; skip over them for now." ] }, { "cell_type": "code", "execution_count": 2, "id": "0d294688-3ea9-43ce-855f-cc587de908a5", "metadata": {}, "outputs": [], "source": [ "class Stack(eqx.Module):\n", " modules: list[eqx.Module]\n", "\n", " def __call__(self, x):\n", " assert x.shape[-1] == 2\n", " x0 = x[..., 0]\n", " x1 = x[..., 1]\n", " return jnp.stack([module(x0=x0, x1=x1) for module in self.modules], axis=-1)\n", "\n", "\n", "def quantise(expr, quantise_to):\n", " if isinstance(expr, sympy.Float):\n", " return expr.func(round(float(expr) / quantise_to) * quantise_to)\n", " elif isinstance(expr, sympy.Symbol):\n", " return expr\n", " else:\n", " return expr.func(*[quantise(arg, quantise_to) for arg in expr.args])" ] }, { "cell_type": "markdown", "id": "35245e23-17e0-4c7e-bb60-3e462b9b3b3c", "metadata": {}, "source": [ "Okay, let's get started.\n", "\n", "We start by running the [Neural ODE example](./neural_ode.ipynb).\n", "Then we extract the learnt neural vector field, and symbolically regress across this.\n", "Finally we fine-tune the resulting symbolic expression.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "b0525ef3-979f-4cf5-b15f-530eae7b8ae0", "metadata": {}, "outputs": [], "source": [ "def main(\n", " symbolic_dataset_size=2000,\n", " symbolic_num_populations=100,\n", " symbolic_population_size=20,\n", " symbolic_migration_steps=4,\n", " symbolic_mutation_steps=30,\n", " symbolic_descent_steps=50,\n", " pareto_coefficient=2,\n", " fine_tuning_steps=500,\n", " fine_tuning_lr=3e-3,\n", " quantise_to=0.01,\n", "):\n", " #\n", " # First obtain a neural approximation to the dynamics.\n", " # We begin by running the previous example.\n", " #\n", "\n", " # Runs the Neural ODE example.\n", " # This defines the variables `ts`, `ys`, `model`.\n", " print(\"Training neural differential equation.\")\n", " %run neural_ode.ipynb\n", "\n", " #\n", " # Now symbolically regress across the learnt vector field, to obtain a Pareto\n", " # frontier of symbolic equations, that trades loss against complexity of the\n", " # equation. Select the \"best\" from this frontier.\n", " #\n", "\n", " print(\"Symbolically regressing across the vector field.\")\n", " vector_field = model.func.mlp # noqa: F821\n", " dataset_size, length_size, data_size = ys.shape # noqa: F821\n", " in_ = ys.reshape(dataset_size * length_size, data_size) # noqa: F821\n", " in_ = in_[:symbolic_dataset_size]\n", " out = jax.vmap(vector_field)(in_)\n", " with tempfile.TemporaryDirectory() as tempdir:\n", " symbolic_regressor = pysr.PySRRegressor(\n", " niterations=symbolic_migration_steps,\n", " ncyclesperiteration=symbolic_mutation_steps,\n", " populations=symbolic_num_populations,\n", " population_size=symbolic_population_size,\n", " optimizer_iterations=symbolic_descent_steps,\n", " optimizer_nrestarts=1,\n", " procs=1,\n", " model_selection=\"score\",\n", " progress=False,\n", " tempdir=tempdir,\n", " temp_equation_file=True,\n", " )\n", " symbolic_regressor.fit(in_, out)\n", " best_expressions = [b.sympy_format for b in symbolic_regressor.get_best()]\n", "\n", " #\n", " # Now the constants in this expression have been optimised for regressing across\n", " # the neural vector field. This was good enough to obtain the symbolic expression,\n", " # but won't quite be perfect -- some of the constants will be slightly off.\n", " #\n", " # To fix this we now plug our symbolic function back into the original dataset\n", " # and apply gradient descent.\n", " #\n", "\n", " print(\"\\nOptimising symbolic expression.\")\n", "\n", " symbolic_fn = Stack([sympy2jax.SymbolicModule(expr) for expr in best_expressions])\n", " symbolic_model = eqx.tree_at(lambda m: m.func.mlp, model, symbolic_fn) # noqa: F821\n", "\n", " @eqx.filter_grad\n", " def grad_loss(symbolic_model):\n", " vmap_model = jax.vmap(symbolic_model, in_axes=(None, 0))\n", " pred_ys = vmap_model(ts, ys[:, 0]) # noqa: F821\n", " return jnp.mean((ys - pred_ys) ** 2) # noqa: F821\n", "\n", " optim = optax.adam(fine_tuning_lr)\n", " opt_state = optim.init(eqx.filter(symbolic_model, eqx.is_inexact_array))\n", "\n", " @eqx.filter_jit\n", " def make_step(symbolic_model, opt_state):\n", " grads = grad_loss(symbolic_model)\n", " updates, opt_state = optim.update(grads, opt_state)\n", " symbolic_model = eqx.apply_updates(symbolic_model, updates)\n", " return symbolic_model, opt_state\n", "\n", " for _ in range(fine_tuning_steps):\n", " symbolic_model, opt_state = make_step(symbolic_model, opt_state)\n", "\n", " #\n", " # Finally we round each constant to the nearest multiple of `quantise_to`.\n", " #\n", "\n", " trained_expressions = []\n", " for symbolic_module in symbolic_model.func.mlp.modules:\n", " expression = symbolic_module.sympy()\n", " expression = quantise(expression, quantise_to)\n", " trained_expressions.append(expression)\n", "\n", " print(f\"Expressions found: {trained_expressions}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "042fd565-825a-40fb-a4da-25e3e0da106a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training neural differential equation.\n", "Step: 0, Loss: 0.16657482087612152, Computation time: 11.210124731063843\n", "Step: 100, Loss: 0.01115578692406416, Computation time: 0.002620220184326172\n", "Step: 200, Loss: 0.006481764372438192, Computation time: 0.0026247501373291016\n", "Step: 300, Loss: 0.0013819701271131635, Computation time: 0.003179311752319336\n", "Step: 400, Loss: 0.0010746140033006668, Computation time: 0.0031697750091552734\n", "Step: 499, Loss: 0.0007994902553036809, Computation time: 0.0031609535217285156\n", "Step: 0, Loss: 0.028307927772402763, Computation time: 11.210363626480103\n", "Step: 100, Loss: 0.005411561578512192, Computation time: 0.020294666290283203\n", "Step: 200, Loss: 0.004366496577858925, Computation time: 0.022084712982177734\n", "Step: 300, Loss: 0.0018046485492959619, Computation time: 0.022309064865112305\n", "Step: 400, Loss: 0.001767474808730185, Computation time: 0.021766185760498047\n", "Step: 499, Loss: 0.0011962582357227802, Computation time: 0.022264480590820312\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Symbolically regressing across the vector field.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Started!\n", "\n", "Cycles per second: 5.190e+03\n", "Head worker occupation: 3.3%\n", "Progress: 434 / 800 total iterations (54.250%)\n", "==============================\n", "Best equations for output 1\n", "Hall of Fame:\n", "-----------------------------------------\n", "Complexity Loss Score Equation\n", "1 4.883e-02 1.206e+00 x1\n", "3 2.746e-02 2.877e-01 (x1 + -0.14616892)\n", "5 6.162e-04 1.899e+00 (x1 / (x1 - -1.0118991))\n", "7 4.476e-04 1.598e-01 ((x1 / 0.92953163) / (x1 + 1.0533974))\n", "9 3.997e-04 5.664e-02 (((x1 * 1.0935224) + -0.008988203) / (x1 + 1.0716586))\n", "13 3.364e-04 4.306e-02 (x1 * ((((x0 * -0.94923264) / 11.808947) - -1.087501) / (x1 + 1.0548282)))\n", "15 3.062e-04 4.714e-02 (x1 * ((((x0 * (-1.1005011 - x1)) / 13.075972) - -1.0955853) / (x1 + 1.0604433)))\n", "\n", "==============================\n", "Best equations for output 2\n", "Hall of Fame:\n", "-----------------------------------------\n", "Complexity Loss Score Equation\n", "1 1.588e-01 -1.000e-10 -0.002322703\n", "3 2.034e-02 1.028e+00 (0.14746223 - x0)\n", "5 1.413e-03 1.333e+00 (x0 / (-1.046938 - x0))\n", "7 6.958e-04 3.543e-01 (x0 / ((x0 + 1.1405994) / -1.1647526))\n", "9 2.163e-04 5.841e-01 (((x0 + -0.026584703) / (x0 + 1.2191753)) * -1.2456053)\n", "11 2.163e-04 7.749e-06 ((x0 - 0.026545616) / (((x0 / -1.2450436) + -0.9980602) - -0.019172505))\n", "\n", "==============================\n", "Press 'q' and then to stop execution early.\n", "\n", "Optimising symbolic expression.\n", "Expressions found: [x1/(x1 + 1.0), x0/(-x0 - 1.0)]\n" ] } ], "source": [ "main()" ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }