{ "cells": [ { "cell_type": "markdown", "id": "9a515258", "metadata": {}, "source": [ "Basis Pursuit DeNoising (APGM)\n", "==============================\n", "\n", "This example demonstrates the solution of the the sparse coding problem\n", "\n", " $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - D \\mathbf{x}\n", " \\|_2^2 + \\lambda \\| \\mathbf{x} \\|_1\\;,$$\n", "\n", "where $D$ the dictionary, $\\mathbf{y}$ the signal to be represented,\n", "and $\\mathbf{x}$ is the sparse representation." ] }, { "cell_type": "code", "execution_count": 1, "id": "46e9afac", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-16T21:53:44.804160Z", "iopub.status.busy": "2021-12-16T21:53:44.803272Z", "iopub.status.idle": "2021-12-16T21:53:47.399774Z", "shell.execute_reply": "2021-12-16T21:53:47.398649Z" } }, "outputs": [], "source": [ "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install 'scico[examples] @ git+https://github.com/lanl/scico'\n", "\n", "import numpy as np\n", "\n", "import scico.numpy as snp\n", "from scico import functional, linop, loss, plot\n", "from scico.optimize.pgm import AcceleratedPGM\n", "from scico.util import device_info\n", "plot.config_notebook_plotting()" ] }, { "cell_type": "markdown", "id": "d4811c44", "metadata": {}, "source": [ "Construct a random dictionary, a reference random sparse\n", "representation, and a test signal consisting of the synthesis of the\n", "reference sparse representation." ] }, { "cell_type": "code", "execution_count": 2, "id": "09d25dd4", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-16T21:53:47.407013Z", "iopub.status.busy": "2021-12-16T21:53:47.406537Z", "iopub.status.idle": "2021-12-16T21:53:49.436430Z", "shell.execute_reply": "2021-12-16T21:53:49.435063Z" } }, "outputs": [], "source": [ "m = 512 # Signal size\n", "n = 4 * m # Dictionary size\n", "s = 32 # Sparsity level (number of non-zeros)\n", "σ = 0.5 # Noise level\n", "\n", "np.random.seed(12345)\n", "D = np.random.randn(m, n).astype(np.float32)\n", "L0 = np.linalg.norm(D, 2) ** 2\n", "\n", "x_gt = np.zeros(n, dtype=np.float32) # true signal\n", "idx = np.random.permutation(list(range(0, n - 1)))\n", "x_gt[idx[0:s]] = np.random.randn(s)\n", "y = D @ x_gt + σ * np.random.randn(m) # synthetic signal\n", "\n", "x_gt = snp.array(x_gt) # convert to jax array\n", "y = snp.array(y) # convert to jax array" ] }, { "cell_type": "markdown", "id": "3ca055bd", "metadata": {}, "source": [ "Set up the forward operator and `AcceleratedPGM` solver object." ] }, { "cell_type": "code", "execution_count": 3, "id": "211e1c3a", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-16T21:53:49.443362Z", "iopub.status.busy": "2021-12-16T21:53:49.442761Z", "iopub.status.idle": "2021-12-16T21:53:49.900380Z", "shell.execute_reply": "2021-12-16T21:53:49.901258Z" } }, "outputs": [], "source": [ "maxiter = 100\n", "λ = 2.98e1\n", "A = linop.MatrixOperator(D)\n", "f = loss.SquaredL2Loss(y=y, A=A)\n", "g = λ * functional.L1Norm()\n", "solver = AcceleratedPGM(\n", " f=f, g=g, L0=L0, x0=A.adj(y), maxiter=maxiter, itstat_options={\"display\": True, \"period\": 10}\n", ")" ] }, { "cell_type": "markdown", "id": "493ca9ed", "metadata": {}, "source": [ "Run the solver." ] }, { "cell_type": "code", "execution_count": 4, "id": "30c49e8b", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-16T21:53:49.957466Z", "iopub.status.busy": "2021-12-16T21:53:49.947684Z", "iopub.status.idle": "2021-12-16T21:53:51.931612Z", "shell.execute_reply": "2021-12-16T21:53:51.931975Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Solving on GPU (NVIDIA GeForce RTX 2080 Ti)\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iter Time Objective L Residual\n", "---------------------------------------------\n", " 0 9.71e-01 7.795e+09 4.611e+03 4.126e+03\n", " 1 1.80e+00 2.097e+09 4.611e+03 1.310e+03\r", " 2 1.80e+00 5.399e+08 4.611e+03 4.524e+02\r", " 3 1.81e+00 1.406e+08 4.611e+03 1.845e+02\r", " 4 1.81e+00 4.021e+07 4.611e+03 1.002e+02\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 5 1.81e+00 1.382e+07 4.611e+03 6.608e+01\r", " 6 1.81e+00 5.955e+06 4.611e+03 4.542e+01\r", " 7 1.81e+00 3.351e+06 4.611e+03 3.125e+01\r", " 8 1.82e+00 2.373e+06 4.611e+03 2.261e+01\r", " 9 1.82e+00 1.771e+06 4.611e+03 1.751e+01\r", " 10 1.82e+00 1.211e+06 4.611e+03 1.369e+01\n", " 11 1.82e+00 7.238e+05 4.611e+03 1.023e+01\r", " 12 1.82e+00 3.867e+05 4.611e+03 7.277e+00\r", " 13 1.83e+00 2.045e+05 4.611e+03 5.249e+00\r", " 14 1.83e+00 1.251e+05 4.611e+03 4.212e+00\r", " 15 1.83e+00 9.184e+04 4.611e+03 3.683e+00\r", " 16 1.83e+00 7.175e+04 4.611e+03 3.203e+00\r", " 17 1.83e+00 5.452e+04 4.611e+03 2.665e+00\r", " 18 1.84e+00 4.005e+04 4.611e+03 2.153e+00\r", " 19 1.84e+00 2.904e+04 4.611e+03 1.756e+00\r", " 20 1.84e+00 2.102e+04 4.611e+03 1.476e+00\n", " 21 1.84e+00 1.503e+04 4.611e+03 1.254e+00\r", " 22 1.85e+00 1.064e+04 4.611e+03 1.066e+00\r", " 23 1.85e+00 7.613e+03 4.611e+03 9.034e-01\r", " 24 1.85e+00 5.804e+03 4.611e+03 7.621e-01\r", " 25 1.85e+00 4.800e+03 4.611e+03 6.525e-01\r", " 26 1.86e+00 4.155e+03 4.611e+03 5.744e-01\r", " 27 1.86e+00 3.596e+03 4.611e+03 5.036e-01\r", " 28 1.86e+00 3.034e+03 4.611e+03 4.274e-01\r", " 29 1.87e+00 2.518e+03 4.611e+03 3.595e-01\r", " 30 1.87e+00 2.126e+03 4.611e+03 3.078e-01\n", " 31 1.87e+00 1.849e+03 4.611e+03 2.687e-01\r", " 32 1.87e+00 1.642e+03 4.611e+03 2.391e-01\r", " 33 1.88e+00 1.466e+03 4.611e+03 2.222e-01\r", " 34 1.88e+00 1.312e+03 4.611e+03 2.004e-01\r", " 35 1.88e+00 1.186e+03 4.611e+03 1.782e-01\r", " 36 1.88e+00 1.083e+03 4.611e+03 1.549e-01\r", " 37 1.88e+00 1.006e+03 4.611e+03 1.361e-01\r", " 38 1.89e+00 9.468e+02 4.611e+03 1.197e-01\r", " 39 1.89e+00 9.051e+02 4.611e+03 1.049e-01\r", " 40 1.89e+00 8.819e+02 4.611e+03 9.747e-02\n", " 41 1.89e+00 8.693e+02 4.611e+03 8.916e-02\r", " 42 1.90e+00 8.638e+02 4.611e+03 8.379e-02\r", " 43 1.90e+00 8.617e+02 4.611e+03 7.405e-02\r", " 44 1.90e+00 8.629e+02 4.611e+03 6.725e-02\r", " 45 1.90e+00 8.617e+02 4.611e+03 5.961e-02\r", " 46 1.90e+00 8.588e+02 4.611e+03 5.390e-02\r", " 47 1.91e+00 8.519e+02 4.611e+03 4.903e-02\r", " 48 1.91e+00 8.442e+02 4.611e+03 4.319e-02\r", " 49 1.91e+00 8.367e+02 4.611e+03 3.636e-02\r", " 50 1.91e+00 8.306e+02 4.611e+03 2.501e-02\n", " 51 1.92e+00 8.270e+02 4.611e+03 2.013e-02\r", " 52 1.92e+00 8.251e+02 4.611e+03 1.787e-02\r", " 53 1.92e+00 8.243e+02 4.611e+03 1.639e-02\r", " 54 1.92e+00 8.244e+02 4.611e+03 1.716e-02\r", " 55 1.93e+00 8.246e+02 4.611e+03 1.604e-02\r", " 56 1.93e+00 8.249e+02 4.611e+03 1.536e-02\r", " 57 1.93e+00 8.249e+02 4.611e+03 1.426e-02\r", " 58 1.93e+00 8.247e+02 4.611e+03 1.538e-02\r", " 59 1.93e+00 8.244e+02 4.611e+03 1.266e-02\r", " 60 1.94e+00 8.240e+02 4.611e+03 1.229e-02\n", " 61 1.94e+00 8.234e+02 4.611e+03 1.125e-02\r", " 62 1.94e+00 8.228e+02 4.611e+03 1.075e-02\r", " 63 1.94e+00 8.221e+02 4.611e+03 9.987e-03\r", " 64 1.95e+00 8.215e+02 4.611e+03 9.180e-03\r", " 65 1.95e+00 8.210e+02 4.611e+03 1.101e-02\r", " 66 1.95e+00 8.206e+02 4.611e+03 7.595e-03\r", " 67 1.95e+00 8.203e+02 4.611e+03 8.156e-03\r", " 68 1.95e+00 8.202e+02 4.611e+03 6.299e-03\r", " 69 1.96e+00 8.201e+02 4.611e+03 7.833e-03\r", " 70 1.96e+00 8.201e+02 4.611e+03 6.280e-03\n", " 71 1.96e+00 8.201e+02 4.611e+03 7.016e-03\r", " 72 1.96e+00 8.201e+02 4.611e+03 5.411e-03\r", " 73 1.96e+00 8.201e+02 4.611e+03 4.842e-03\r", " 74 1.97e+00 8.200e+02 4.611e+03 4.669e-03\r", " 75 1.97e+00 8.200e+02 4.611e+03 4.320e-03\r", " 76 1.97e+00 8.200e+02 4.611e+03 4.308e-03\r", " 77 1.97e+00 8.199e+02 4.611e+03 3.889e-03\r", " 78 1.98e+00 8.198e+02 4.611e+03 3.467e-03\r", " 79 1.98e+00 8.197e+02 4.611e+03 3.102e-03\r", " 80 1.98e+00 8.197e+02 4.611e+03 2.978e-03\n", " 81 1.98e+00 8.196e+02 4.611e+03 2.504e-03\r", " 82 1.99e+00 8.196e+02 4.611e+03 2.324e-03\r", " 83 1.99e+00 8.196e+02 4.611e+03 2.238e-03\r", " 84 1.99e+00 8.196e+02 4.611e+03 2.207e-03\r", " 85 1.99e+00 8.196e+02 4.611e+03 2.227e-03\r", " 86 1.99e+00 8.196e+02 4.611e+03 2.541e-03\r", " 87 2.00e+00 8.196e+02 4.611e+03 2.320e-03\r", " 88 2.00e+00 8.196e+02 4.611e+03 2.407e-03\r", " 89 2.00e+00 8.196e+02 4.611e+03 2.051e-03\r", " 90 2.00e+00 8.195e+02 4.611e+03 1.815e-03\n", " 91 2.00e+00 8.195e+02 4.611e+03 1.688e-03\r", " 92 2.01e+00 8.195e+02 4.611e+03 1.948e-03\r", " 93 2.01e+00 8.195e+02 4.611e+03 1.494e-03\r", " 94 2.01e+00 8.195e+02 4.611e+03 1.362e-03\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 95 2.01e+00 8.195e+02 4.611e+03 1.294e-03\r", " 96 2.02e+00 8.195e+02 4.611e+03 1.317e-03\r", " 97 2.02e+00 8.195e+02 4.611e+03 1.384e-03\r", " 98 2.02e+00 8.195e+02 4.611e+03 1.132e-03\r", " 99 2.02e+00 8.195e+02 4.611e+03 1.059e-03\r\n" ] } ], "source": [ "print(f\"Solving on {device_info()}\\n\")\n", "x = solver.solve()\n", "hist = solver.itstat_object.history(transpose=True)" ] }, { "cell_type": "markdown", "id": "901e3dcd", "metadata": {}, "source": [ "Plot the recovered coefficients and convergence statistics." ] }, { "cell_type": "code", "execution_count": 5, "id": "e942c39d", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-16T21:53:51.955842Z", "iopub.status.busy": "2021-12-16T21:53:51.951901Z", "iopub.status.idle": "2021-12-16T21:53:52.715080Z", "shell.execute_reply": "2021-12-16T21:53:52.716233Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n", "plot.plot(\n", " np.vstack((x_gt, x)).T,\n", " title=\"Coefficients\",\n", " lgnd=(\"Ground Truth\", \"Recovered\"),\n", " fig=fig,\n", " ax=ax[0],\n", ")\n", "plot.plot(\n", " np.vstack((hist.Objective, hist.Residual)).T,\n", " ptyp=\"semilogy\",\n", " title=\"Convergence\",\n", " xlbl=\"Iteration\",\n", " lgnd=(\"Objective\", \"Residual\"),\n", " fig=fig,\n", " ax=ax[1],\n", ")\n", "fig.show()" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }