{ "cells": [ { "cell_type": "markdown", "id": "a1cc726f", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Physics Informed Neural Networks\n", "\n", "**Presenter:** Filippo Maria Bianchi\n", "\n", "**Repository:** [github.com/FilippoMB/Physics-Informed-Neural-Networks-tutorial](https://github.com/FilippoMB/Physics-Informed-Neural-Networks-tutorial)" ] }, { "cell_type": "markdown", "id": "e363db51", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Introduction\n", "\n", "What are PINNs?\n", "\n", "- PINNs are Neural Networks used to learn a generic function $f$.\n", "- Like standard NNs, PINNs account for observation data $\\{ x_i \\}_{i=1}^N$ in learning $f$.\n", "- In addition, the optimization of $f$ is guided by a regularization term, which encourages $f$ to be the solution of a Partial Differential Equation (PDE)." ] }, { "cell_type": "markdown", "id": "cffd09a7", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Traditional PDE solvers\n", "\n", "- Simple problems can be solved analytically.\n", "- E.g., consider the velocity:\n", "\n", "$$v(t) = \\frac{d x}{d t} = \\lim_{h \\rightarrow 0} \\frac{x(t+h) - x(t)}{h}$$\n", "\n", "\n", "\n", "- Solution: \n", "\n", "$$\n", "v(t) = \n", "\\begin{cases}\n", "3/2 & \\text{if}\\; t \\in \\{ 0, 2 \\} \\\\\n", "0 & \\text{if}\\; t \\in \\{ 2, 4 \\} \\\\\n", "-1/3 & \\text{if}\\; t \\in \\{ 4, 7 \\}\n", "\\end{cases}\n", "$$" ] }, { "attachments": {}, "cell_type": "markdown", "id": "08b1c65b", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "- In most real-world problems solutions cannot be found analytically.\n", "- Differential equations are solved numerically.\n", "- E.g., they apply the definition of derivative for *all* the point of the time domain." ] }, { "cell_type": "markdown", "id": "45aec619", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Limitations of PDE solvers**\n", "\n", "- Computationally expensive and scale bad to big data.\n", "- Integrating external data sources (e.g., from sensors) is problematic." ] }, { "attachments": {}, "cell_type": "markdown", "id": "6ee24fff", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Neural Networks\n", "\n", "\n", "\n", "- Universal function approximators.\n", "- Can consume any kind of data $\\boldsymbol{X}$.\n", "- Are trained to minimize a loss, e.g., the error between the predictions $\\boldsymbol{\\hat{y}}$ and the desired outputs $\\boldsymbol{y}$." ] }, { "cell_type": "code", "execution_count": 1, "id": "bf86ecd0", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# Imports\n", "import torch\n", "from torch import nn\n", "import numpy as np\n", "from scipy.integrate import solve_ivp\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm" ] }, { "attachments": {}, "cell_type": "markdown", "id": "82760add", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Let's start by creating a simple neural network in PyTorch." ] }, { "cell_type": "code", "execution_count": 2, "id": "980ee1b5", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "# Define a simple neural network for regression\n", "class simple_NN(nn.Module):\n", " def __init__(self):\n", " super(simple_NN, self).__init__()\n", " self.linear_tanh_stack = nn.Sequential(\n", " nn.Linear(1, 16),\n", " nn.Tanh(),\n", " nn.Linear(16, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 16),\n", " nn.Tanh(),\n", " nn.Linear(16, 1),\n", " )\n", "\n", " def forward(self, x):\n", " out = self.linear_tanh_stack(x)\n", " return out" ] }, { "cell_type": "markdown", "id": "a908a777", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Then,\n", "- Create a small dataset $\\{x_i, y_i\\}_{i=1, \\dots 5}$.\n", "- Use the NN to make predictions: $\\hat{y}_i = \\rm{NN}(x_i)$.\n", "- Train the NN by minimizing $\\rm{MSE}(\\boldsymbol{y}, \\boldsymbol{\\hat{y}})$." ] }, { "cell_type": "code", "execution_count": 3, "id": "3a554192", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "# Define dataset\n", "x_train = torch.tensor([[1.1437e-04],\n", " [1.4676e-01],\n", " [3.0233e-01],\n", " [4.1702e-01],\n", " [7.2032e-01]], dtype=torch.float32)\n", "y_train = torch.tensor([[1.0000],\n", " [1.0141],\n", " [1.0456],\n", " [1.0753],\n", " [1.1565]], dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 4, "id": "6d76616a", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, loss: 0.865811\n", "epoch: 200, loss: 0.000253\n", "epoch: 400, loss: 0.000151\n", "epoch: 600, loss: 0.000068\n", "epoch: 800, loss: 0.000017\n" ] } ], "source": [ "# Initialize the model\n", "model = simple_NN()\n", "\n", "# define loss and optimizer\n", "loss_fn = nn.MSELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n", "\n", "# Train\n", "for ep in range(1000):\n", "\n", " # Compute prediction error\n", " pred = model(x_train)\n", " loss = loss_fn(pred, y_train)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if ep % 200 == 0:\n", " print(f\"epoch: {ep}, loss: {loss.item():>7f}\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "aa4644a4", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# evaluate the model on all data points in the domain\n", "domain = [0.0, 1.5]\n", "x_eval = torch.linspace(domain[0], domain[1], steps=100).reshape(-1, 1)\n", "f_eval = model(x_eval)\n", "\n", "# plotting\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.scatter(x_train.detach().numpy(), y_train.detach().numpy(), label=\"Training data\", color=\"blue\")\n", "ax.plot(x_eval.detach().numpy(), f_eval.detach().numpy(), label=\"NN approximation\", color=\"black\")\n", "ax.set(title=\"Neural Network Regression\", xlabel=\"$x$\", ylabel=\"$y$\")\n", "ax.legend();" ] }, { "cell_type": "markdown", "id": "ce277406", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- The NN does a good job in fitting the data samples.\n", "- However, it has no information on what function should learn when $x>0.8$. " ] }, { "attachments": {}, "cell_type": "markdown", "id": "cd3a8bf4", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Physics Informed NNs\n", "\n", "- Use PDEs to adjust the NN output.\n", "- Train the model with an additional loss that penalizes the violation of the PDE.\n", "\n", "$$ \\mathcal{L}_{\\text{tot}} = \\mathcal{L}_{\\text{data}} + \\mathcal{L}_{\\text{PDE}}$$\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "id": "4ad6a57a", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Advantages**\n", "\n", "Combine information from both data and from physical models.\n", "- Compared to traditional NNs, $\\mathcal{L}_{\\text{PDE}}$ regularizes the model limiting overfitting and improving generalization.\n", "- Compared to traiditional PDE solvers, PINNs are more scalable and can consume any kind of data." ] }, { "cell_type": "markdown", "id": "a7379c57", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Example I: population growth\n", "\n", "Logistic equation for modeling the population growth: \n", "\n", "$$ \\frac{d f(t)}{d t} = Rt(1-t)$$\n", "\n", "- $f(t)$ is the population growth over time $t$\n", "- $R$ is the max growth rate\n", "- To identify a solution, a boundary condition must be imposed, e.g., at $t=0$:\n", "\n", "$$f(t=0)=1$$" ] }, { "cell_type": "code", "execution_count": 6, "id": "b8b0ee7a", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "R = 1.0\n", "ft0 = 1.0" ] }, { "cell_type": "markdown", "id": "7c6782f4", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Use the NN to model $f(t)$, i.e., $f(t) = \\rm{NN}(t)$\n", "- We can easily compute the derivative $\\frac{d\\rm{NN}(t)}{dt}$ thanks to automatic differentiation provided by deep learning libraries\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "308d79c8", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def df(f: simple_NN, x: torch.Tensor = None, order: int = 1) -> torch.Tensor:\n", " \"\"\"Compute neural network derivative with respect to input features using PyTorch autograd engine\"\"\"\n", " df_value = f(x)\n", " for _ in range(order):\n", " df_value = torch.autograd.grad(\n", " df_value,\n", " x,\n", " grad_outputs=torch.ones_like(x),\n", " create_graph=True,\n", " retain_graph=True,\n", " )[0]\n", "\n", " return df_value " ] }, { "cell_type": "markdown", "id": "46db65df", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- We want our NN to satisfy the following equation:\n", "\n", "$$ \\frac{d\\rm{NN}(t)}{dt} - Rt(1-t) = 0 $$\n", "\n", "- To do that, we add the following physics-informed regularization term to the loss:\n", "\n", "$$ \\mathcal{L}_\\rm{PDE} = \\frac{1}{N} \\sum_{i=1}^N \\left( \\frac{d\\rm{NN}}{dt} \\bigg\\rvert_{t_i} - R t_i (1-t_i) \\right)^2 $$\n", "\n", "- where $t_i$ are **collocation points**, i.e., a set of points from the domain where we evaluate the differential equation." ] }, { "cell_type": "code", "execution_count": 8, "id": "f7b2db5f", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# Generate 10 evenly distributed collocation points\n", "t = torch.linspace(domain[0], domain[1], steps=10, requires_grad=True).reshape(-1, 1)" ] }, { "cell_type": "markdown", "id": "0babb74f", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Only minimizing $\\mathcal{L}_\\rm{PDE}$ does not ensure a unique solution.\n", "- We must include the boundary condition by adding the following loss:\n", "\n", "$$ \\mathcal{L}_\\rm{BC} = \\left( \\rm{NN}(t_0) - 1 \\right)^2 $$\n", "\n", "- This lets the NN converge to the desired solution among the infinite possible ones." ] }, { "cell_type": "markdown", "id": "3de1ccf0", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "The final loss is given by:\n", "\n", "$$ \\mathcal{L}_\\rm{PDE} + \\mathcal{L}_\\rm{BC} + \\mathcal{L}_\\rm{data} $$" ] }, { "cell_type": "code", "execution_count": 9, "id": "66cb667d", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Wrap everything into a function\n", "def compute_loss(nn: simple_NN, \n", " t: torch.Tensor = None, \n", " x: torch.Tensor = None,\n", " y: torch.Tensor = None,\n", " ) -> torch.float:\n", " \"\"\"Compute the full loss function as pde loss + boundary loss\n", " This custom loss function is fully defined with differentiable tensors therefore\n", " the .backward() method can be applied to it\n", " \"\"\"\n", "\n", " pde_loss = df(nn, t) - R * t * (1 - t)\n", " pde_loss = pde_loss.pow(2).mean()\n", "\n", " boundary = torch.Tensor([0.0])\n", " boundary.requires_grad = True\n", " bc_loss = nn(boundary) - ft0\n", " bc_loss = bc_loss.pow(2)\n", " \n", " mse_loss = torch.nn.MSELoss()(nn(x), y)\n", " \n", " tot_loss = pde_loss + bc_loss + mse_loss\n", " \n", " return tot_loss" ] }, { "cell_type": "code", "execution_count": 10, "id": "e4655d8b", "metadata": { "run_control": { "marked": false }, "scrolled": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, loss: 3.063274\n", "epoch: 200, loss: 0.000948\n", "epoch: 400, loss: 0.000451\n", "epoch: 600, loss: 0.000110\n", "epoch: 800, loss: 0.000090\n", "epoch: 1000, loss: 0.000087\n", "epoch: 1200, loss: 0.000085\n", "epoch: 1400, loss: 0.000084\n", "epoch: 1600, loss: 0.000083\n", "epoch: 1800, loss: 0.000082\n" ] } ], "source": [ "model = simple_NN()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n", "\n", "# Train\n", "for ep in range(2000):\n", "\n", " loss = compute_loss(model, t, x_train, y_train)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if ep % 200 == 0:\n", " print(f\"epoch: {ep}, loss: {loss.item():>7f}\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "c4ed19a2", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# numeric solution\n", "def logistic_eq_fn(x, y):\n", " return R * x * (1 - x)\n", "\n", "numeric_solution = solve_ivp(\n", " logistic_eq_fn, domain, [ft0], t_eval=x_eval.squeeze().detach().numpy()\n", ")\n", "\n", "f_colloc = solve_ivp(\n", " logistic_eq_fn, domain, [ft0], t_eval=t.squeeze().detach().numpy()\n", ").y.T" ] }, { "cell_type": "code", "execution_count": 12, "id": "705c3b81", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# evaluation on the domain [0, 1.5]\n", "f_eval = model(x_eval)\n", "\n", "# plotting\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.scatter(t.detach().numpy(), f_colloc, label=\"Collocation points\", color=\"magenta\", alpha=0.75)\n", "ax.scatter(x_train.detach().numpy(), y_train.detach().numpy(), label=\"Observation data\", color=\"blue\")\n", "ax.plot(x_eval.detach().numpy(), f_eval.detach().numpy(), label=\"NN solution\", color=\"black\")\n", "ax.plot(x_eval.detach().numpy(), numeric_solution.y.T,\n", " label=\"Analytic solution\", color=\"magenta\", alpha=0.75)\n", "ax.set(title=\"Logistic equation solved with NNs\", xlabel=\"t\", ylabel=\"f(t)\")\n", "ax.legend();" ] }, { "cell_type": "markdown", "id": "e7ce2baa", "metadata": { "run_control": { "marked": false }, "slideshow": { "slide_type": "slide" } }, "source": [ "### Example II: 1d wave\n", "\n", "- Now, we want our NN to learn a function $f(x,t)$ that satisfies the following $2^\\rm{nd}$ order PDE:\n", "\n", "$$\\frac{\\partial^2 f}{\\partial x^2} = \\frac{1}{C} \\frac{\\partial^2 f}{\\partial t^2}$$\n", "\n", "- where $C$ is a positive constant." ] }, { "cell_type": "markdown", "id": "980f2ca0", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Differently from before, $f$ depends on two variables: space ($x$) and time ($t$). \n", "- We modify our neural network to accept to input variables." ] }, { "cell_type": "code", "execution_count": 13, "id": "2ba361a9", "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "class simple_NN2(nn.Module):\n", " def __init__(self):\n", " super(simple_NN2, self).__init__()\n", " self.linear_tanh_stack = nn.Sequential(\n", " nn.Linear(2, 16), # <--- 2 input variables\n", " nn.Tanh(),\n", " nn.Linear(16, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 16),\n", " nn.Tanh(),\n", " nn.Linear(16, 1),\n", " )\n", "\n", " def forward(self, x, t):\n", " x_stack = torch.cat([x, t], dim=1) # <--- concatenate x and t\n", " out = self.linear_tanh_stack(x_stack)\n", " return out" ] }, { "cell_type": "markdown", "id": "473d690a", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- The function we defined before, `df()`, computes a derivatives of any order w.r.t. only one input variable.\n", "- We need to modify it slightly to differentiate w.r.t. both $x$ and $t$." ] }, { "cell_type": "code", "execution_count": 14, "id": "ba9d5e25", "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "def df(output: torch.Tensor, input_var: torch.Tensor, order: int = 1) -> torch.Tensor:\n", " \"\"\"Compute neural network derivative with respect to input features using PyTorch autograd engine\"\"\"\n", " df_value = output # <-- we directly take the output of the NN\n", " for _ in range(order):\n", " df_value = torch.autograd.grad(\n", " df_value,\n", " input_var,\n", " grad_outputs=torch.ones_like(input_var),\n", " create_graph=True,\n", " retain_graph=True,\n", " )[0]\n", " return df_value\n", "\n", "def dfdt(model: simple_NN2, x: torch.Tensor, t: torch.Tensor, order: int = 1):\n", " \"\"\"Derivative with respect to the time variable of arbitrary order\"\"\"\n", " f_value = model(x, t)\n", " return df(f_value, t, order=order)\n", "\n", "def dfdx(model: simple_NN2, x: torch.Tensor, t: torch.Tensor, order: int = 1):\n", " \"\"\"Derivative with respect to the spatial variable of arbitrary order\"\"\"\n", " f_value = model(x, t)\n", " return df(f_value, x, order=order)" ] }, { "cell_type": "markdown", "id": "6013750a", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Loss definition\n", "\n", "- For this example, we do not consider measurement data.\n", "- We train the NN with a loss that only accounts for physical equations.\n", "- The first term of the loss encourages respecting the 1-dimensional wave equation:\n", "\n", "$$\\mathcal{L}_\\rm{PDE} = \\left( \\frac{\\partial^2 f}{\\partial x^2} - \\frac{1}{C} \\frac{\\partial^2 f}{\\partial t^2} \\right)^2 $$" ] }, { "cell_type": "markdown", "id": "80aa22c1", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- As before, there are infinite solutions satisfying this equation.\n", "- We need to restrict the possible solutions by:\n", " 1. imposing periodic boundary conditions at the domain extrema.\n", " 2. imposing an initial condition on $f(x, t_0)$.\n", " 3. imposing an initial condition on $\\frac{\\partial f(x, t)}{\\partial t} \\bigg\\rvert_{t=0}$." ] }, { "attachments": {}, "cell_type": "markdown", "id": "a3d7795d", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- We define the domain of $x$ as $[x_0, x_1]$.\n", "- In this example, $x_0 = 0$ and $x_1 = 1$, but they could be different values.\n", "\n", "\n", "\n", "- The following loss penalizes the violation of the boundary conditions:\n", "\n", "$$\\mathcal{L}_\\rm{BC} = f(x_0, t)^2 + f(x_1, t)^2$$" ] }, { "attachments": {}, "cell_type": "markdown", "id": "fde287b2", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Next, we must define the initial condition on $f(x, t_0)$.\n", "\n", "\n", "\n", "- The following loss penalizes departure from the desired initial condition:\n", "\n", "$$\\mathcal{L}_\\rm{initF} = \\left( f(x, t_0) - \\frac{1}{2} \\rm{sin}(2\\pi x) \\right)^2 $$" ] }, { "attachments": {}, "cell_type": "markdown", "id": "3b477992", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Finally, we must specify the initial condition on $\\frac{\\partial f(x, t)}{\\partial t} \\bigg\\rvert_{t=0}$.\n", "\n", "\n", "\n", "The following loss penalizes departure from the desired initial condition of the 1st order derivative:\n", "\n", "$$\\mathcal{L}_\\rm{initDF} = \\left( \\frac{\\partial f}{\\partial t} \\bigg\\rvert_{t=0} \\right)^2 $$" ] }, { "cell_type": "markdown", "id": "a631559d", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "The total loss is given by:\n", "\n", "$$\\mathcal{L}_\\rm{PDE} + \\mathcal{L}_\\rm{BC} + \\mathcal{L}_\\rm{initF} + \\mathcal{L}_\\rm{initDF}$$" ] }, { "cell_type": "code", "execution_count": 15, "id": "f2934df9", "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def initial_condition(x) -> torch.Tensor:\n", " res = torch.sin( 2*np.pi * x).reshape(-1, 1) * 0.5\n", " return res" ] }, { "cell_type": "code", "execution_count": 16, "id": "b1e50e13", "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def compute_loss(\n", " model: simple_NN2,\n", " x: torch.Tensor = None, \n", " t: torch.Tensor = None,\n", " x_idx: torch.Tensor = None, \n", " t_idx: torch.Tensor = None, \n", " C: float = 1.0,\n", " device: str = None,\n", " ) -> torch.float:\n", "\n", " # PDE\n", " pde_loss = dfdx(model, x, t, order=2) - (1/C**2) * dfdt(model, x, t, order=2)\n", "\n", " # boundary conditions\n", " boundary_x0 = torch.ones_like(t_idx, requires_grad=True).to(device) * x[0] \n", " boundary_loss_x0 = model(boundary_x0, t_idx) # f(x0, t)\n", " boundary_x1 = torch.ones_like(t_idx, requires_grad=True).to(device) * x[-1] \n", " boundary_loss_x1 = model(boundary_x1, t_idx) # f(x1, t)\n", " \n", " # initial conditions\n", " f_initial = initial_condition(x_idx) # 0.5*sin(2*pi*x)\n", " t_initial = torch.zeros_like(x_idx) # t0\n", " t_initial.requires_grad = True\n", " initial_loss_f = model(x_idx, t_initial) - f_initial # L_initF\n", " initial_loss_df = dfdt(model, x_idx, t_initial, order=1) # L_initDF\n", " \n", " # obtain the final loss by averaging each term and summing them up\n", " final_loss = \\\n", " pde_loss.pow(2).mean() + \\\n", " boundary_loss_x0.pow(2).mean() + \\\n", " boundary_loss_x1.pow(2).mean() + \\\n", " initial_loss_f.pow(2).mean() + \\\n", " initial_loss_df.pow(2).mean()\n", "\n", " return final_loss" ] }, { "cell_type": "code", "execution_count": 17, "id": "5a39ad2a", "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# generate the time-space meshgrid\n", "x_domain = [0.0, 1.0]; n_points_x = 100\n", "t_domain = [0.0, 1.0]; n_points_t = 150\n", "x_idx = torch.linspace(x_domain[0], x_domain[1], steps=n_points_x, requires_grad=True)\n", "t_idx = torch.linspace(t_domain[0], t_domain[1], steps=n_points_t, requires_grad=True)\n", "grids = torch.meshgrid(x_idx, t_idx, indexing=\"ij\")\n", "x_idx, t_idx = x_idx.reshape(-1, 1).to(device), t_idx.reshape(-1, 1).to(device)\n", "x, t = grids[0].flatten().reshape(-1, 1).to(device), grids[1].flatten().reshape(-1, 1).to(device)\n", "\n", "# initialize the neural network model\n", "model = simple_NN2().to(device)" ] }, { "cell_type": "code", "execution_count": 18, "id": "36cc375c", "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, loss: 0.133278\n", "epoch: 300, loss: 0.058259\n", "epoch: 600, loss: 0.033736\n", "epoch: 900, loss: 0.026548\n", "epoch: 1200, loss: 0.023956\n", "epoch: 1500, loss: 0.076558\n", "epoch: 1800, loss: 0.014723\n", "epoch: 2100, loss: 0.012283\n", "epoch: 2400, loss: 0.003159\n", "epoch: 2700, loss: 0.001655\n" ] } ], "source": [ "# Train\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n", "for ep in range(3000):\n", "\n", " loss = compute_loss(model, x=x, t=t, x_idx=x_idx, t_idx=t_idx, device=device)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if ep % 300 == 0:\n", " print(f\"epoch: {ep}, loss: {loss.item():>7f}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "65b3ec36", "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "application/javascript": "/* Put everything inside the global mpl namespace */\n/* global mpl */\nwindow.mpl = {};\n\nmpl.get_websocket_type = function () {\n if (typeof WebSocket !== 'undefined') {\n return WebSocket;\n } else if (typeof MozWebSocket !== 'undefined') {\n return MozWebSocket;\n } else {\n alert(\n 'Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.'\n );\n }\n};\n\nmpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = this.ws.binaryType !== undefined;\n\n if (!this.supports_binary) {\n var warnings = document.getElementById('mpl-warnings');\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent =\n 'This browser does not support binary websocket messages. ' +\n 'Performance may be slow.';\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = document.createElement('div');\n this.root.setAttribute('style', 'display: inline-block');\n this._root_extra_style(this.root);\n\n parent_element.appendChild(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message('supports_binary', { value: fig.supports_binary });\n fig.send_message('send_image_mode', {});\n if (fig.ratio !== 1) {\n fig.send_message('set_device_pixel_ratio', {\n device_pixel_ratio: fig.ratio,\n });\n }\n fig.send_message('refresh', {});\n };\n\n this.imageObj.onload = function () {\n if (fig.image_mode === 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function () {\n fig.ws.close();\n };\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n};\n\nmpl.figure.prototype._init_header = function () {\n var titlebar = document.createElement('div');\n titlebar.classList =\n 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n var titletext = document.createElement('div');\n titletext.classList = 'ui-dialog-title';\n titletext.setAttribute(\n 'style',\n 'width: 100%; text-align: center; padding: 3px;'\n );\n titlebar.appendChild(titletext);\n this.root.appendChild(titlebar);\n this.header = titletext;\n};\n\nmpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._init_canvas = function () {\n var fig = this;\n\n var canvas_div = (this.canvas_div = document.createElement('div'));\n canvas_div.setAttribute('tabindex', '0');\n canvas_div.setAttribute(\n 'style',\n 'border: 1px solid #ddd;' +\n 'box-sizing: content-box;' +\n 'clear: both;' +\n 'min-height: 1px;' +\n 'min-width: 1px;' +\n 'outline: 0;' +\n 'overflow: hidden;' +\n 'position: relative;' +\n 'resize: both;' +\n 'z-index: 2;'\n );\n\n function on_keyboard_event_closure(name) {\n return function (event) {\n return fig.key_event(event, name);\n };\n }\n\n canvas_div.addEventListener(\n 'keydown',\n on_keyboard_event_closure('key_press')\n );\n canvas_div.addEventListener(\n 'keyup',\n on_keyboard_event_closure('key_release')\n );\n\n this._canvas_extra_style(canvas_div);\n this.root.appendChild(canvas_div);\n\n var canvas = (this.canvas = document.createElement('canvas'));\n canvas.classList.add('mpl-canvas');\n canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'pointer-events: none;' +\n 'position: relative;' +\n 'z-index: 0;'\n );\n\n this.context = canvas.getContext('2d');\n\n var backingStore =\n this.context.backingStorePixelRatio ||\n this.context.webkitBackingStorePixelRatio ||\n this.context.mozBackingStorePixelRatio ||\n this.context.msBackingStorePixelRatio ||\n this.context.oBackingStorePixelRatio ||\n this.context.backingStorePixelRatio ||\n 1;\n\n this.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n 'canvas'\n ));\n rubberband_canvas.setAttribute(\n 'style',\n 'box-sizing: content-box;' +\n 'left: 0;' +\n 'pointer-events: none;' +\n 'position: absolute;' +\n 'top: 0;' +\n 'z-index: 1;'\n );\n\n // Apply a ponyfill if ResizeObserver is not implemented by browser.\n if (this.ResizeObserver === undefined) {\n if (window.ResizeObserver !== undefined) {\n this.ResizeObserver = window.ResizeObserver;\n } else {\n var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n this.ResizeObserver = obs.ResizeObserver;\n }\n }\n\n this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n var nentries = entries.length;\n for (var i = 0; i < nentries; i++) {\n var entry = entries[i];\n var width, height;\n if (entry.contentBoxSize) {\n if (entry.contentBoxSize instanceof Array) {\n // Chrome 84 implements new version of spec.\n width = entry.contentBoxSize[0].inlineSize;\n height = entry.contentBoxSize[0].blockSize;\n } else {\n // Firefox implements old version of spec.\n width = entry.contentBoxSize.inlineSize;\n height = entry.contentBoxSize.blockSize;\n }\n } else {\n // Chrome <84 implements even older version of spec.\n width = entry.contentRect.width;\n height = entry.contentRect.height;\n }\n\n // Keep the size of the canvas and rubber band canvas in sync with\n // the canvas container.\n if (entry.devicePixelContentBoxSize) {\n // Chrome 84 implements new version of spec.\n canvas.setAttribute(\n 'width',\n entry.devicePixelContentBoxSize[0].inlineSize\n );\n canvas.setAttribute(\n 'height',\n entry.devicePixelContentBoxSize[0].blockSize\n );\n } else {\n canvas.setAttribute('width', width * fig.ratio);\n canvas.setAttribute('height', height * fig.ratio);\n }\n /* This rescales the canvas back to display pixels, so that it\n * appears correct on HiDPI screens. */\n canvas.style.width = width + 'px';\n canvas.style.height = height + 'px';\n\n rubberband_canvas.setAttribute('width', width);\n rubberband_canvas.setAttribute('height', height);\n\n // And update the size in Python. We ignore the initial 0/0 size\n // that occurs as the element is placed into the DOM, which should\n // otherwise not happen due to the minimum size styling.\n if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n fig.request_resize(width, height);\n }\n }\n });\n this.resizeObserverInstance.observe(canvas_div);\n\n function on_mouse_event_closure(name) {\n /* User Agent sniffing is bad, but WebKit is busted:\n * https://bugs.webkit.org/show_bug.cgi?id=144526\n * https://bugs.webkit.org/show_bug.cgi?id=181818\n * The worst that happens here is that they get an extra browser\n * selection when dragging, if this check fails to catch them.\n */\n var UA = navigator.userAgent;\n var isWebKit = /AppleWebKit/.test(UA) && !/Chrome/.test(UA);\n if(isWebKit) {\n return function (event) {\n /* This prevents the web browser from automatically changing to\n * the text insertion cursor when the button is pressed. We\n * want to control all of the cursor setting manually through\n * the 'cursor' event from matplotlib */\n event.preventDefault()\n return fig.mouse_event(event, name);\n };\n } else {\n return function (event) {\n return fig.mouse_event(event, name);\n };\n }\n }\n\n canvas_div.addEventListener(\n 'mousedown',\n on_mouse_event_closure('button_press')\n );\n canvas_div.addEventListener(\n 'mouseup',\n on_mouse_event_closure('button_release')\n );\n canvas_div.addEventListener(\n 'dblclick',\n on_mouse_event_closure('dblclick')\n );\n // Throttle sequential mouse events to 1 every 20ms.\n canvas_div.addEventListener(\n 'mousemove',\n on_mouse_event_closure('motion_notify')\n );\n\n canvas_div.addEventListener(\n 'mouseenter',\n on_mouse_event_closure('figure_enter')\n );\n canvas_div.addEventListener(\n 'mouseleave',\n on_mouse_event_closure('figure_leave')\n );\n\n canvas_div.addEventListener('wheel', function (event) {\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n on_mouse_event_closure('scroll')(event);\n });\n\n canvas_div.appendChild(canvas);\n canvas_div.appendChild(rubberband_canvas);\n\n this.rubberband_context = rubberband_canvas.getContext('2d');\n this.rubberband_context.strokeStyle = '#000000';\n\n this._resize_canvas = function (width, height, forward) {\n if (forward) {\n canvas_div.style.width = width + 'px';\n canvas_div.style.height = height + 'px';\n }\n };\n\n // Disable right mouse context menu.\n canvas_div.addEventListener('contextmenu', function (_e) {\n event.preventDefault();\n return false;\n });\n\n function set_focus() {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'mpl-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n continue;\n }\n\n var button = (fig.buttons[name] = document.createElement('button'));\n button.classList = 'mpl-widget';\n button.setAttribute('role', 'button');\n button.setAttribute('aria-disabled', 'false');\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n\n var icon_img = document.createElement('img');\n icon_img.src = '_images/' + image + '.png';\n icon_img.srcset = '_images/' + image + '_large.png 2x';\n icon_img.alt = tooltip;\n button.appendChild(icon_img);\n\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n var fmt_picker = document.createElement('select');\n fmt_picker.classList = 'mpl-widget';\n toolbar.appendChild(fmt_picker);\n this.format_dropdown = fmt_picker;\n\n for (var ind in mpl.extensions) {\n var fmt = mpl.extensions[ind];\n var option = document.createElement('option');\n option.selected = fmt === mpl.default_extension;\n option.innerHTML = fmt;\n fmt_picker.appendChild(option);\n }\n\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n};\n\nmpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n // which will in turn request a refresh of the image.\n this.send_message('resize', { width: x_pixels, height: y_pixels });\n};\n\nmpl.figure.prototype.send_message = function (type, properties) {\n properties['type'] = type;\n properties['figure_id'] = this.id;\n this.ws.send(JSON.stringify(properties));\n};\n\nmpl.figure.prototype.send_draw_message = function () {\n if (!this.waiting) {\n this.waiting = true;\n this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n var format_dropdown = fig.format_dropdown;\n var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n fig.ondownload(fig, format);\n};\n\nmpl.figure.prototype.handle_resize = function (fig, msg) {\n var size = msg['size'];\n if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n fig._resize_canvas(size[0], size[1], msg['forward']);\n fig.send_message('refresh', {});\n }\n};\n\nmpl.figure.prototype.handle_rubberband = function (fig, msg) {\n var x0 = msg['x0'] / fig.ratio;\n var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n var x1 = msg['x1'] / fig.ratio;\n var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n x0 = Math.floor(x0) + 0.5;\n y0 = Math.floor(y0) + 0.5;\n x1 = Math.floor(x1) + 0.5;\n y1 = Math.floor(y1) + 0.5;\n var min_x = Math.min(x0, x1);\n var min_y = Math.min(y0, y1);\n var width = Math.abs(x1 - x0);\n var height = Math.abs(y1 - y0);\n\n fig.rubberband_context.clearRect(\n 0,\n 0,\n fig.canvas.width / fig.ratio,\n fig.canvas.height / fig.ratio\n );\n\n fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n};\n\nmpl.figure.prototype.handle_figure_label = function (fig, msg) {\n // Updates the figure title.\n fig.header.textContent = msg['label'];\n};\n\nmpl.figure.prototype.handle_cursor = function (fig, msg) {\n fig.canvas_div.style.cursor = msg['cursor'];\n};\n\nmpl.figure.prototype.handle_message = function (fig, msg) {\n fig.message.textContent = msg['message'];\n};\n\nmpl.figure.prototype.handle_draw = function (fig, _msg) {\n // Request the server to send over a new figure.\n fig.send_draw_message();\n};\n\nmpl.figure.prototype.handle_image_mode = function (fig, msg) {\n fig.image_mode = msg['mode'];\n};\n\nmpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n for (var key in msg) {\n if (!(key in fig.buttons)) {\n continue;\n }\n fig.buttons[key].disabled = !msg[key];\n fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n }\n};\n\nmpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n if (msg['mode'] === 'PAN') {\n fig.buttons['Pan'].classList.add('active');\n fig.buttons['Zoom'].classList.remove('active');\n } else if (msg['mode'] === 'ZOOM') {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.add('active');\n } else {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.remove('active');\n }\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Called whenever the canvas gets updated.\n this.send_message('ack', {});\n};\n\n// A function to construct a web socket function for onmessage handling.\n// Called in the figure constructor.\nmpl.figure.prototype._make_on_message_function = function (fig) {\n return function socket_on_message(evt) {\n if (evt.data instanceof Blob) {\n var img = evt.data;\n if (img.type !== 'image/png') {\n /* FIXME: We get \"Resource interpreted as Image but\n * transferred with MIME type text/plain:\" errors on\n * Chrome. But how to set the MIME type? It doesn't seem\n * to be part of the websocket stream */\n img.type = 'image/png';\n }\n\n /* Free the memory for the previous frames */\n if (fig.imageObj.src) {\n (window.URL || window.webkitURL).revokeObjectURL(\n fig.imageObj.src\n );\n }\n\n fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n img\n );\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n } else if (\n typeof evt.data === 'string' &&\n evt.data.slice(0, 21) === 'data:image/png;base64'\n ) {\n fig.imageObj.src = evt.data;\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n }\n\n var msg = JSON.parse(evt.data);\n var msg_type = msg['type'];\n\n // Call the \"handle_{type}\" callback, which takes\n // the figure and JSON message as its only arguments.\n try {\n var callback = fig['handle_' + msg_type];\n } catch (e) {\n console.log(\n \"No handler for the '\" + msg_type + \"' message type: \",\n msg\n );\n return;\n }\n\n if (callback) {\n try {\n // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n callback(fig, msg);\n } catch (e) {\n console.log(\n \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n e,\n e.stack,\n msg\n );\n }\n }\n };\n};\n\nfunction getModifiers(event) {\n var mods = [];\n if (event.ctrlKey) {\n mods.push('ctrl');\n }\n if (event.altKey) {\n mods.push('alt');\n }\n if (event.shiftKey) {\n mods.push('shift');\n }\n if (event.metaKey) {\n mods.push('meta');\n }\n return mods;\n}\n\n/*\n * return a copy of an object with only non-object keys\n * we need this to avoid circular references\n * https://stackoverflow.com/a/24161582/3208463\n */\nfunction simpleKeys(original) {\n return Object.keys(original).reduce(function (obj, key) {\n if (typeof original[key] !== 'object') {\n obj[key] = original[key];\n }\n return obj;\n }, {});\n}\n\nmpl.figure.prototype.mouse_event = function (event, name) {\n if (name === 'button_press') {\n this.canvas.focus();\n this.canvas_div.focus();\n }\n\n // from https://stackoverflow.com/q/1114465\n var boundingRect = this.canvas.getBoundingClientRect();\n var x = (event.clientX - boundingRect.left) * this.ratio;\n var y = (event.clientY - boundingRect.top) * this.ratio;\n\n this.send_message(name, {\n x: x,\n y: y,\n button: event.button,\n step: event.step,\n modifiers: getModifiers(event),\n guiEvent: simpleKeys(event),\n });\n\n return false;\n};\n\nmpl.figure.prototype._key_event_extra = function (_event, _name) {\n // Handle any extra behaviour associated with a key event\n};\n\nmpl.figure.prototype.key_event = function (event, name) {\n // Prevent repeat events\n if (name === 'key_press') {\n if (event.key === this._key) {\n return;\n } else {\n this._key = event.key;\n }\n }\n if (name === 'key_release') {\n this._key = null;\n }\n\n var value = '';\n if (event.ctrlKey && event.key !== 'Control') {\n value += 'ctrl+';\n }\n else if (event.altKey && event.key !== 'Alt') {\n value += 'alt+';\n }\n else if (event.shiftKey && event.key !== 'Shift') {\n value += 'shift+';\n }\n\n value += 'k' + event.key;\n\n this._key_event_extra(event, name);\n\n this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n return false;\n};\n\nmpl.figure.prototype.toolbar_button_onclick = function (name) {\n if (name === 'download') {\n this.handle_save(this, null);\n } else {\n this.send_message('toolbar_button', { name: name });\n }\n};\n\nmpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n this.message.textContent = tooltip;\n};\n\n///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n// prettier-ignore\nvar _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\nmpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o\", \"download\"]];\n\nmpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\", \"webp\"];\n\nmpl.default_extension = \"png\";/* global mpl */\n\nvar comm_websocket_adapter = function (comm) {\n // Create a \"websocket\"-like object which calls the given IPython comm\n // object with the appropriate methods. Currently this is a non binary\n // socket, so there is still some room for performance tuning.\n var ws = {};\n\n ws.binaryType = comm.kernel.ws.binaryType;\n ws.readyState = comm.kernel.ws.readyState;\n function updateReadyState(_event) {\n if (comm.kernel.ws) {\n ws.readyState = comm.kernel.ws.readyState;\n } else {\n ws.readyState = 3; // Closed state.\n }\n }\n comm.kernel.ws.addEventListener('open', updateReadyState);\n comm.kernel.ws.addEventListener('close', updateReadyState);\n comm.kernel.ws.addEventListener('error', updateReadyState);\n\n ws.close = function () {\n comm.close();\n };\n ws.send = function (m) {\n //console.log('sending', m);\n comm.send(m);\n };\n // Register the callback with on_msg.\n comm.on_msg(function (msg) {\n //console.log('receiving', msg['content']['data'], msg);\n var data = msg['content']['data'];\n if (data['blob'] !== undefined) {\n data = {\n data: new Blob(msg['buffers'], { type: data['blob'] }),\n };\n }\n // Pass the mpl event to the overridden (by mpl) onmessage function.\n ws.onmessage(data);\n });\n return ws;\n};\n\nmpl.mpl_figure_comm = function (comm, msg) {\n // This is the function which gets called when the mpl process\n // starts-up an IPython Comm through the \"matplotlib\" channel.\n\n var id = msg.content.data.id;\n // Get hold of the div created by the display call when the Comm\n // socket was opened in Python.\n var element = document.getElementById(id);\n var ws_proxy = comm_websocket_adapter(comm);\n\n function ondownload(figure, _format) {\n window.open(figure.canvas.toDataURL());\n }\n\n var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n\n // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n // web socket which is closed, not our websocket->open comm proxy.\n ws_proxy.onopen();\n\n fig.parent_element = element;\n fig.cell_info = mpl.find_output_cell(\"
\");\n if (!fig.cell_info) {\n console.error('Failed to find cell for figure', id, fig);\n return;\n }\n fig.cell_info[0].output_area.element.on(\n 'cleared',\n { fig: fig },\n fig._remove_fig_handler\n );\n};\n\nmpl.figure.prototype.handle_close = function (fig, msg) {\n var width = fig.canvas.width / fig.ratio;\n fig.cell_info[0].output_area.element.off(\n 'cleared',\n fig._remove_fig_handler\n );\n fig.resizeObserverInstance.unobserve(fig.canvas_div);\n\n // Update the output cell to use the data from the current canvas.\n fig.push_to_output();\n var dataURL = fig.canvas.toDataURL();\n // Re-enable the keyboard manager in IPython - without this line, in FF,\n // the notebook keyboard shortcuts fail.\n IPython.keyboard_manager.enable();\n fig.parent_element.innerHTML =\n '';\n fig.close_ws(fig, msg);\n};\n\nmpl.figure.prototype.close_ws = function (fig, msg) {\n fig.send_message('closing', msg);\n // fig.ws.close()\n};\n\nmpl.figure.prototype.push_to_output = function (_remove_interactive) {\n // Turn the data on the canvas into data in the output cell.\n var width = this.canvas.width / this.ratio;\n var dataURL = this.canvas.toDataURL();\n this.cell_info[1]['text/html'] =\n '';\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Tell IPython that the notebook contents must change.\n IPython.notebook.set_dirty(true);\n this.send_message('ack', {});\n var fig = this;\n // Wait a second, then push the new image to the DOM so\n // that it is saved nicely (might be nice to debounce this).\n setTimeout(function () {\n fig.push_to_output();\n }, 1000);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'btn-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n var button;\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n continue;\n }\n\n button = fig.buttons[name] = document.createElement('button');\n button.classList = 'btn btn-default';\n button.href = '#';\n button.title = name;\n button.innerHTML = '';\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n // Add the status bar.\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message pull-right';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n\n // Add the close button to the window.\n var buttongrp = document.createElement('div');\n buttongrp.classList = 'btn-group inline pull-right';\n button = document.createElement('button');\n button.classList = 'btn btn-mini btn-primary';\n button.href = '#';\n button.title = 'Stop Interaction';\n button.innerHTML = '';\n button.addEventListener('click', function (_evt) {\n fig.handle_close(fig, {});\n });\n button.addEventListener(\n 'mouseover',\n on_mouseover_closure('Stop Interaction')\n );\n buttongrp.appendChild(button);\n var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n titlebar.insertBefore(buttongrp, titlebar.firstChild);\n};\n\nmpl.figure.prototype._remove_fig_handler = function (event) {\n var fig = event.data.fig;\n if (event.target !== this) {\n // Ignore bubbled events from children.\n return;\n }\n fig.close_ws(fig, {});\n};\n\nmpl.figure.prototype._root_extra_style = function (el) {\n el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n};\n\nmpl.figure.prototype._canvas_extra_style = function (el) {\n // this is important to make the div 'focusable\n el.setAttribute('tabindex', 0);\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n } else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n};\n\nmpl.figure.prototype._key_event_extra = function (event, _name) {\n // Check for shift+enter\n if (event.shiftKey && event.which === 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n fig.ondownload(fig, null);\n};\n\nmpl.find_output_cell = function (html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i = 0; i < ncells; i++) {\n var cell = cells[i];\n if (cell.cell_type === 'code') {\n for (var j = 0; j < cell.output_area.outputs.length; j++) {\n var data = cell.output_area.outputs[j];\n if (data.data) {\n // IPython >= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] === html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n};\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel !== null) {\n IPython.notebook.kernel.comm_manager.register_target(\n 'matplotlib',\n mpl.mpl_figure_comm\n );\n}\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Prediction\n", "y = model(x, t)\n", "y_np = y.reshape([100,-1]).to(\"cpu\").detach().numpy()\n", "\n", "# Plot\n", "X, Y = np.meshgrid(np.linspace(0, 1, 150), np.linspace(0, 1, 100))\n", "fig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n", "ax.plot_surface(X, Y, y_np, linewidth=0, antialiased=False, cmap=cm.coolwarm,)\n", "ax.set_xlabel(\"t\"), ax.set_ylabel(\"x\"), ax.set_zlabel(\"f\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "bd07f2dd", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Conclusions\n", "\n", "**Growth rate**\n", "\n", "- We saw the difference between:\n", " - Fitting a NN only on observations.\n", " - Adding a regularization term from a 1st order PDE." ] }, { "cell_type": "markdown", "id": "132a6a45", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**1d wave**\n", "\n", "- We saw how to include:\n", " - A 2nd order PDE.\n", " - Multiple constraints on the initial conditions." ] }, { "cell_type": "markdown", "id": "c0e47416", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Next steps**\n", "\n", "- With more complex equations, convergence is not achieved so easily.\n", "- Also, For time-dependent problems, many useful tricks have been devised over the past years such as:\n", " - Decomposing the solution domain in different parts solved using different neural networks.\n", " - Smart weighting of different loss contributions to avoid converging to trivial solutions." ] }, { "cell_type": "markdown", "id": "49a91deb", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## References\n", "\n", "[[1](https://www.sciencedirect.com/science/article/pii/S0021999118307125)] Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. \"Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations.\" Journal of Computational physics 378 (2019): 686-707.\n", "\n", "[[2](https://maziarraissi.github.io/PINNs/)] Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. \"Physics Informed Deep Learning\".\n", "\n", "[[3](https://www.sciencedirect.com/science/article/pii/S095219762030292X)] Nascimento, R. G., Fricke, K., & Viana, F. A. (2020). A tutorial on solving ordinary differential equations using Python and hybrid physics-informed neural network. Engineering Applications of Artificial Intelligence, 96, 103996.\n", "\n", "[[4](https://towardsdatascience.com/solving-differential-equations-with-neural-networks-afdcf7b8bcc4)] Dagrada, Dario. \"Introduction to Physics-informed Neural Networks\" ([code](https://github.com/madagra/basic-pinn)).\n", "\n", "[[5](https://towardsdatascience.com/physics-and-artificial-intelligence-introduction-to-physics-informed-neural-networks-24548438f2d5)] Paialunga Piero. \"Physics and Artificial Intelligence: Introduction to Physics Informed Neural Networks\".\n", "\n", "[[6](https://github.com/omniscientoctopus/Physics-Informed-Neural-Networks)] \"Physics-Informed-Neural-Networks (PINNs)\" - implementation of PINNs in TensorFlow 2 and PyTorch for the Burgers' and Helmholtz PDE." ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }