{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Conjugate gradient method: applied to 1-D BVP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Consider solving\n", "\n", "$$\n", "-u_{xx} = f(x), \\qquad x \\in [0,1]\n", "$$\n", "\n", "with boundary condition\n", "\n", "$$\n", "u(0) = u(1) = 0\n", "$$\n", "\n", "Choose\n", "\n", "$$\n", "f(x) = 1\n", "$$\n", "\n", "The exact solution is\n", "\n", "$$\n", "u(x) = \\frac{1}{2}x(1-x)\n", "$$\n", "\n", "Make a partition of $n$ intervals with spacing and grid points\n", "\n", "$$\n", "h = \\frac{1}{n}, \\qquad x_i = i h, \\qquad i=0,1,\\ldots,n\n", "$$\n", "\n", "The finite difference scheme is\n", "\n", "\\begin{eqnarray}\n", "u_0 &=& 0 \\\\\n", "- \\frac{u_{i-1} - 2 u_i + u_{i+1}}{h^2} &=& f_i, \\qquad i=1,2,\\ldots,n-1 \\\\\n", "u_n &=& 0\n", "\\end{eqnarray}\n", "\n", "We have a matrix equation\n", "\n", "$$\n", "Au = f\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Set initial guess $u_0 = 0$, $r_0 = f - A u_0$, $p_0 = 0$\n", "* For $k=0,1,\\ldots$\n", " * If $\\| r_k \\| < TOL \\cdot \\|r_0\\|$, then stop\n", " * If $k=0$, $\\beta_1 = 0$\n", " * If $k > 0$, $\\beta_{k+1} = \\frac{r_k^\\top r_k}{r_{k-1}^\\top r_{k-1}}$\n", " * $p_{k+1} = r_k + \\beta_{k+1} p_k$\n", " * $\\alpha_{k+1} = \\frac{r_k^\\top r_k}{p_{k+1}^\\top A p_{k+1}}$\n", " * $u_{k+1} = u_k + \\alpha_{k+1} p_{k+1}$\n", " * $r_{k+1} = r_k - \\alpha_{k+1} A p_{k+1}$\n", " \n", "The above algorithm is not useful for computer implementation, since it is written in terms of indices, but we do not need to store all the values. We modify it below, so that we store only the quantities needed for the iterations.\n", " \n", "* Set initial guess $u = 0$\n", "* $r = f - A u$\n", "* $p = 0$\n", "* $r_{norm} = \\| r \\|$, $r_{old} = r_{new} = 0$\n", "* For $k=0,1,\\ldots,$ itmax\n", " * $r_{new} = \\| r \\|$\n", " * If $r_{new} < TOL \\cdot r_{norm}$, then stop\n", " * If $k=0$, $\\beta = 0$\n", " * If $k > 0$, $\\beta = r_{new}^2 / r_{old}^2$\n", " * $p = r + \\beta p$\n", " * $A_p = A p$\n", " * $\\alpha = r_{new}^2 / (p^\\top A_p)$\n", " * $u = u + \\alpha p$\n", " * $r = r - \\alpha A_p$\n", " * $r_{old} = r_{new}$\n", " \n", "Written like this, each step of the algorithm requires one matrix-vector product, two dot products and three saxpy operations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Code" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "import numpy as np\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function computes the matrix-vector product." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Note that first and last components of r are always zero.\n", "# This keeps boundary values of the solution unchanged.\n", "def ax(h,u):\n", " n = len(u) - 1\n", " r = np.zeros(n+1)\n", " for i in range(1,n):\n", " r[i] = -(u[i-1]-2*u[i]+u[i+1])/h**2\n", " return r" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We compute this without forming the matrix. Next, we define the problem parameters" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "xmin, xmax = 0.0, 1.0\n", "n = 100\n", "\n", "h = (xmax - xmin)/n\n", "x = np.linspace(xmin, xmax, n+1) # Grid\n", "f = np.ones(n+1) # rhs\n", "ue= 0.5*x*(1.0-x) # exact solution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and implement the CG method as a separate function." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "def solve_cg(res, u):\n", " r_norm = np.linalg.norm(res)\n", " res_old, res_new, res_data = 0.0, 0.0, []\n", " p = np.zeros(n+1)\n", " for it in range(itmax):\n", " res_new = np.linalg.norm(res); res_data.append(res_new)\n", " print('iter,res =', it, res_new)\n", " if res_new < TOL * r_norm:\n", " break\n", " if it == 0:\n", " beta = 0.0\n", " else:\n", " beta = res_new**2 / res_old**2\n", " p = res + beta * p\n", " ap = ax(h,p)\n", " alpha = res_new**2 / p.dot(ap)\n", " u += alpha * p\n", " res -= alpha * ap\n", " res_old = res_new\n", " \n", " print(\"Number of iterations = %d\" % it)\n", " print(\"Final residual norm = %e\" % np.linalg.norm(res))\n", " return u, res_data" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,res = 0 9.9498743710662\n", "iter,res = 1 69.29285677470659\n", "iter,res = 2 67.87856804618082\n", "iter,res = 3 66.46427611882942\n", "iter,res = 4 65.04998078400943\n", "iter,res = 5 63.63568181452914\n", "iter,res = 6 62.22137896253984\n", "iter,res = 7 60.80707195713337\n", "iter,res = 8 59.39276050159645\n", "iter,res = 9 57.97844427026304\n", "iter,res = 10 56.56412290489438\n", "iter,res = 11 55.14979601050214\n", "iter,res = 12 53.73546315051168\n", "iter,res = 13 52.32112384114085\n", "iter,res = 14 50.90677754484169\n", "iter,res = 15 49.49242366261726\n", "iter,res = 16 48.078061524982445\n", "iter,res = 17 46.66369038128036\n", "iter,res = 18 45.24930938699504\n", "iter,res = 19 43.83491758860735\n", "iter,res = 20 42.42051390542079\n", "iter,res = 21 41.00609710762534\n", "iter,res = 22 39.59166578965829\n", "iter,res = 23 38.17721833764213\n", "iter,res = 24 36.762752889303584\n", "iter,res = 25 35.34826728426726\n", "iter,res = 26 33.9337590019143\n", "iter,res = 27 32.51922508301819\n", "iter,res = 28 31.104662029991573\n", "iter,res = 29 29.69006567860704\n", "iter,res = 30 28.275431031197378\n", "iter,res = 31 26.86075203712658\n", "iter,res = 32 25.44602129999894\n", "iter,res = 33 24.03122968139583\n", "iter,res = 34 22.616365755797283\n", "iter,res = 35 21.20141504711419\n", "iter,res = 36 19.786358937409382\n", "iter,res = 37 18.371173070873837\n", "iter,res = 38 16.95582495781317\n", "iter,res = 39 15.54027026792005\n", "iter,res = 40 14.124446891825535\n", "iter,res = 41 12.708265027138836\n", "iter,res = 42 11.291589790636213\n", "iter,res = 43 9.874208829065747\n", "iter,res = 44 8.45576726264388\n", "iter,res = 45 7.035623639735143\n", "iter,res = 46 5.612486080160912\n", "iter,res = 47 4.183300132670378\n", "iter,res = 48 2.7386127875258306\n", "iter,res = 49 1.2247448713915892\n", "iter,res = 50 6.587568684622976e-15\n", "Number of iterations = 50\n", "Final residual norm = 6.587569e-15\n" ] } ], "source": [ "TOL = 1.0e-6\n", "itmax = 100\n", "\n", "u = np.zeros(n+1)\n", "res = np.array(f)\n", "\n", "# First and last grid point, solution is fixed to zero.\n", "# Hence we make residual zero, in which case solution\n", "# will not change at these points. These two will \n", "# remain zero throughout the following code.\n", "res[0] = 0.0\n", "res[n] = 0.0\n", "\n", "u, res_data = solve_cg(res, u)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-05-16T14:19:00.712862\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(12,5))\n", "plt.subplot(121)\n", "plt.plot(x,ue,label=\"Exact\")\n", "plt.plot(x,u,label=\"Numerical\")\n", "plt.xlabel(\"x\"); plt.ylabel(\"u\")\n", "plt.legend()\n", "\n", "plt.subplot(122)\n", "plt.semilogy(res_data)\n", "plt.xlabel(\"No of iterations\")\n", "plt.ylabel(\"Residual norm\");" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 4 }