{ "cells": [ { "cell_type": "markdown", "id": "06b92af3", "metadata": {}, "source": [ "## SunODE\n", "https://sunode.readthedocs.io/en/latest/without_pymc.html" ] }, { "cell_type": "code", "execution_count": 1, "id": "dd396fb0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sunode 0.4.0\n" ] } ], "source": [ "%matplotlib inline\n", "import sunode\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "print(\"sunode\", sunode.__version__)" ] }, { "cell_type": "markdown", "id": "1b384ff8", "metadata": {}, "source": [ "## ODE model" ] }, { "cell_type": "code", "execution_count": 2, "id": "0553e453", "metadata": {}, "outputs": [], "source": [ "params = {\n", " 'α': (),\n", " 'β': (),\n", " 'γ': (),\n", " 'δ': (),\n", "}\n", "\n", "states = {\n", " 'hares': (),\n", " 'lynxes': (),\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "id": "716b4af5", "metadata": {}, "outputs": [], "source": [ "def lotka_volterra(t, y, p):\n", " \"\"\"Right hand side of Lotka-Volterra equation.\n", "\n", " All inputs are dataclasses of sympy variables, or in the case\n", " of non-scalar variables numpy arrays of sympy variables.\n", " \"\"\"\n", " return {\n", " 'hares': p.α * y.hares - p.β * y.lynxes * y.hares,\n", " 'lynxes': p.δ * y.hares * y.lynxes - p.γ * y.lynxes,\n", " }" ] }, { "cell_type": "code", "execution_count": 4, "id": "4e2da0b3", "metadata": {}, "outputs": [], "source": [ "problem = sunode.SympyProblem(\n", " params=params,\n", " states=states,\n", " rhs_sympy=lotka_volterra,\n", " derivative_params=[('α',), ('β',), ('γ',), ('δ',)]\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "fa61dfbc", "metadata": {}, "outputs": [], "source": [ "# solver = sunode.solver.Solver(problem, sens_mode=\"simultaneous\")\n", "solver = sunode.solver.AdjointSolver(problem)" ] }, { "cell_type": "code", "execution_count": 6, "id": "8543d464", "metadata": {}, "outputs": [], "source": [ "y0 = np.zeros((), dtype=problem.state_dtype)\n", "y0['hares'] = 1\n", "y0['lynxes'] = 0.1\n", "\n", "# At which time points do we want to evalue the solution\n", "t = np.linspace(0, 10)" ] }, { "cell_type": "code", "execution_count": 7, "id": "2204fa35", "metadata": {}, "outputs": [], "source": [ "α, β, γ, δ = 0.1, 0.2, 0.3, 0.4\n", "θ = α, β, γ, δ\n", "solver.set_params_dict({\n", " 'α': α,\n", " 'β': β,\n", " 'γ': γ,\n", " 'δ': δ,\n", "})" ] }, { "cell_type": "code", "execution_count": 8, "id": "8b028b79", "metadata": {}, "outputs": [], "source": [ "#y, sens = solver.make_output_buffers(tvals)\n", "y, grad, lam = solver.make_output_buffers(t)\n", "# solver.solve(t0=0, tvals=t, y0=y0, y_out=y, sens0=np.zeros_like(sens[0]), sens_out=sens)\n", "solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y)" ] }, { "cell_type": "code", "execution_count": 9, "id": "90f6640b", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 432, "width": 576 } }, "output_type": "display_data" } ], "source": [ "plt.plot(t, y)\n", "plt.xlabel('t')\n", "plt.ylabel('y')\n", "plt.xlim(t[0], t[-1])\n", "plt.legend(['hares', 'lynx']);" ] }, { "cell_type": "code", "execution_count": 10, "id": "0078b93b", "metadata": {}, "outputs": [], "source": [ "solver.solve_backward(t0=t[-1], tend=t[0], tvals=t, \n", " grads=np.ones((len(t), y.shape[-1])),\n", " grad_out=grad, lamda_out=lam)" ] }, { "cell_type": "code", "execution_count": 11, "id": "5a87288c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([-82.13485284, 25.18562792]),\n", " array([ 465.25049911, -103.50680243, -25.88314792, 38.67576041]))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lam, grad" ] }, { "cell_type": "markdown", "id": "7a256883", "metadata": {}, "source": [ "## Loss and gradient" ] }, { "cell_type": "code", "execution_count": 15, "id": "4f0665dc", "metadata": {}, "outputs": [], "source": [ "def predict(θ):\n", " α, β, γ, δ = θ\n", " solver.set_params_dict({\n", " 'α': α,\n", " 'β': β,\n", " 'γ': γ,\n", " 'δ': δ,\n", " })\n", " y_out, grad, lam = solver.make_output_buffers(t)\n", " solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y_out)\n", " return y_out\n", "\n", "θ_guess = 0.1, 0.1, 0.1, 0.1\n", "yhat = predict(θ_guess)" ] }, { "cell_type": "code", "execution_count": 16, "id": "0b965493", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "14.487093439626495" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def loss(y, θ): # mse\n", " yhat = predict(θ)\n", " resid = y - yhat\n", " return (resid * resid).sum()\n", "\n", "loss(y, θ_guess)" ] }, { "cell_type": "code", "execution_count": 17, "id": "2a05a6d3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([433.82952474, -50.6250329 , 52.96938998, -75.86014686])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def gradient(θ):\n", " α, β, γ, δ = θ\n", " solver.set_params_dict({\n", " 'α': α,\n", " 'β': β,\n", " 'γ': γ,\n", " 'δ': δ,\n", " })\n", " y_out, grad_out, lam_out = solver.make_output_buffers(t)\n", " solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y_out)\n", " res = (y - y_out)\n", " solver.solve_backward(t0=t[-1], tend=t[0], tvals=t, \n", " grads=-2*res,\n", " grad_out=grad_out, lamda_out=lam_out)\n", " return grad_out\n", "\n", "θ_guess = 0.1, 0.1, 0.1, 0.1\n", "grad = gradient(θ_guess)\n", "grad" ] }, { "cell_type": "markdown", "id": "6cbde33a", "metadata": {}, "source": [ "## Gradient checking" ] }, { "cell_type": "code", "execution_count": 18, "id": "3191c190", "metadata": {}, "outputs": [], "source": [ "α, β, γ, δ = θ_guess\n", "Δ = 1e-6" ] }, { "cell_type": "code", "execution_count": 20, "id": "12c4b677", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(433.82952247483786, 433.8295247366786)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# yhat1 = predict((α+Δα/2, β, γ, δ))\n", "# yhat2 = predict((α-Δα/2, β, γ, δ))\n", "# (yhat1.sum() - yhat2.sum())/Δα, grad[0]\n", "\n", "L1 = loss(y, (α+Δ/2, β, γ, δ))\n", "L2 = loss(y, (α-Δ/2, β, γ, δ))\n", "(L1-L2)/Δ, grad[0]" ] }, { "cell_type": "code", "execution_count": 21, "id": "33db0a2e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-50.62503427133436, -50.62503290111857)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# yhat1 = predict((α, β+Δβ/2, γ, δ))\n", "# yhat2 = predict((α, β-Δβ/2, γ, δ))\n", "# (yhat1.sum() - yhat2.sum())/Δβ, grad[1]\n", "\n", "L1 = loss(y, (α, β+Δ/2, γ, δ))\n", "L2 = loss(y, (α, β-Δ/2, γ, δ))\n", "(L1-L2)/Δ, grad[1]" ] }, { "cell_type": "code", "execution_count": 22, "id": "040e69c8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(52.96939328225392, 52.969389983518944)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L1 = loss(y, (α, β, γ+Δ/2, δ))\n", "L2 = loss(y, (α, β, γ-Δ/2, δ))\n", "(L1-L2)/Δ, grad[2]" ] }, { "cell_type": "code", "execution_count": 23, "id": "40a3446a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-75.86015293803428, -75.860146862688)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L1 = loss(y, (α, β, γ, δ+Δ/2))\n", "L2 = loss(y, (α, β, γ, δ-Δ/2))\n", "(L1-L2)/Δ, grad[3]" ] }, { "cell_type": "markdown", "id": "3e16036c", "metadata": {}, "source": [ "## Gredienct descent" ] }, { "cell_type": "code", "execution_count": 41, "id": "388c110b", "metadata": {}, "outputs": [], "source": [ "def average(p, c, β):\n", " return β * p + (1 - β) * c \n", " \n", "class AdamOptimizer:\n", " def __init__(self, α=0.001, β1=0.9, β2=0.999, ϵ=1e-8):\n", " self.α = α\n", " self.β1 = β1\n", " self.β2 = β2\n", " self.ϵ = ϵ\n", " self.m = None\n", " self.v = None\n", " self.t = 0\n", "\n", " def send(self, grad):\n", " if self.m is None:\n", " self.m = 0\n", " if self.v is None:\n", " self.v = 0\n", "\n", " self.t += 1\n", " αt = self.α * np.sqrt(1 - self.β2**self.t) / (1 - self.β1**self.t)\n", " self.m = average(self.m, grad, self.β1) \n", " self.v = average(self.v, (grad*grad), self.β2)\n", "\n", " updates = -αt * self.m / (np.sqrt(self.v) + self.ϵ)\n", " assert np.isfinite(updates).all()\n", " return updates" ] }, { "cell_type": "code", "execution_count": 46, "id": "7362b84d", "metadata": {}, "outputs": [], "source": [ "θ_hat = np.array([0.1, 0.1, 0.1, 0.1])\n", "η = np.array([0.0001, 0.0001, 0.0001, 0.0001])\n", "imax = 10000\n", "opt = AdamOptimizer()" ] }, { "cell_type": "code", "execution_count": 47, "id": "671de937", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "L = 0, θ = ((0.1000, 0.2000, 0.3000, 0.4000))\n", "L = 14, θ_hat = (0.0990, 0.1010, 0.0990, 0.1010) \n", "L = 0.097, θ_hat = (0.0857, 0.1328, 0.0120, 0.1865) \n", "L = 0.054, θ_hat = (0.0925, 0.1573, 0.0141, 0.1865) \n", "L = 0.034, θ_hat = (0.0977, 0.1765, 0.0164, 0.1869) \n", "L = 0.028, θ_hat = (0.1008, 0.1884, 0.0187, 0.1878) \n", "L = 0.026, θ_hat = (0.1024, 0.1943, 0.0211, 0.1891) \n", "L = 0.026, θ_hat = (0.1030, 0.1967, 0.0237, 0.1909) \n", "L = 0.025, θ_hat = (0.1032, 0.1975, 0.0265, 0.1930) \n", "L = 0.025, θ_hat = (0.1032, 0.1978, 0.0297, 0.1953) \n", "L = 0.024, θ_hat = (0.1032, 0.1978, 0.0332, 0.1980) \n", "L = 0.023, θ_hat = (0.1031, 0.1978, 0.0371, 0.2010) \n", "L = 0.022, θ_hat = (0.1030, 0.1978, 0.0415, 0.2042) \n", "L = 0.022, θ_hat = (0.1030, 0.1978, 0.0462, 0.2078) \n", "L = 0.021, θ_hat = (0.1029, 0.1978, 0.0514, 0.2118) \n", "L = 0.02, θ_hat = (0.1028, 0.1978, 0.0570, 0.2160) \n", "L = 0.019, θ_hat = (0.1027, 0.1978, 0.0632, 0.2207) \n", "L = 0.018, θ_hat = (0.1027, 0.1978, 0.0698, 0.2257) \n", "L = 0.017, θ_hat = (0.1026, 0.1978, 0.0769, 0.2310) \n", "L = 0.015, θ_hat = (0.1025, 0.1978, 0.0845, 0.2368) \n", "L = 0.014, θ_hat = (0.1024, 0.1979, 0.0926, 0.2429) \n", "L = 0.013, θ_hat = (0.1022, 0.1979, 0.1012, 0.2495) \n", "L = 0.012, θ_hat = (0.1021, 0.1979, 0.1103, 0.2564) \n", "L = 0.011, θ_hat = (0.1020, 0.1980, 0.1199, 0.2636) \n", "L = 0.0096, θ_hat = (0.1019, 0.1980, 0.1299, 0.2711) \n", "L = 0.0084, θ_hat = (0.1018, 0.1981, 0.1402, 0.2790) \n", "L = 0.0073, θ_hat = (0.1016, 0.1982, 0.1510, 0.2871) \n", "L = 0.0062, θ_hat = (0.1015, 0.1983, 0.1619, 0.2954) \n", "L = 0.0053, θ_hat = (0.1014, 0.1984, 0.1731, 0.3039) \n", "L = 0.0044, θ_hat = (0.1012, 0.1985, 0.1843, 0.3124) \n", "L = 0.0035, θ_hat = (0.1011, 0.1986, 0.1956, 0.3209) \n", "L = 0.0028, θ_hat = (0.1010, 0.1987, 0.2067, 0.3293) \n", "L = 0.0022, θ_hat = (0.1009, 0.1988, 0.2176, 0.3376) \n", "L = 0.0017, θ_hat = (0.1008, 0.1989, 0.2281, 0.3455) \n", "L = 0.0012, θ_hat = (0.1006, 0.1991, 0.2382, 0.3532) \n", "L = 0.00088, θ_hat = (0.1005, 0.1992, 0.2476, 0.3603) \n", "L = 0.00061, θ_hat = (0.1005, 0.1993, 0.2564, 0.3669) \n", "L = 0.00041, θ_hat = (0.1004, 0.1994, 0.2643, 0.3730) \n", "L = 0.00026, θ_hat = (0.1003, 0.1995, 0.2715, 0.3784) \n", "L = 0.00016, θ_hat = (0.1002, 0.1996, 0.2777, 0.3831) \n", "L = 9.2e-05, θ_hat = (0.1002, 0.1997, 0.2830, 0.3871) \n", "L = 5.1e-05, θ_hat = (0.1001, 0.1998, 0.2874, 0.3904) \n", "L = 2.6e-05, θ_hat = (0.1001, 0.1999, 0.2909, 0.3931) \n", "L = 1.3e-05, θ_hat = (0.1001, 0.1999, 0.2937, 0.3952) \n", "L = 5.6e-06, θ_hat = (0.1000, 0.1999, 0.2958, 0.3968) \n", "L = 2.3e-06, θ_hat = (0.1000, 0.2000, 0.2973, 0.3980) \n", "L = 8.4e-07, θ_hat = (0.1000, 0.2000, 0.2984, 0.3988) \n", "L = 2.8e-07, θ_hat = (0.1000, 0.2000, 0.2991, 0.3993) \n", "L = 8.3e-08, θ_hat = (0.1000, 0.2000, 0.2995, 0.3996) \n", "L = 2.2e-08, θ_hat = (0.1000, 0.2000, 0.2997, 0.3998) \n", "L = 7.1e-09, θ_hat = (0.1000, 0.2000, 0.2999, 0.3999) \n" ] } ], "source": [ "print(\"L = {:.2g}, θ = (({:.4f}, {:.4f}, {:.4f}, {:.4f}))\".format(loss(y, θ), *θ))\n", "for i in range(imax):\n", " grad = gradient(θ_hat)\n", " Δθ = opt.send(grad)\n", " θ_hat += Δθ\n", " if i % (imax // 50) == 0:\n", " print(\"L = {:.2g}, θ_hat = ({:.4f}, {:.4f}, {:.4f}, {:.4f}) \".format(loss(y, θ_hat), *θ_hat))" ] }, { "cell_type": "code", "execution_count": null, "id": "2c79f2c1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "be1b5554", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:scipy]", "language": "python", "name": "conda-env-scipy-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }