{ "cells": [ { "cell_type": "markdown", "id": "2522791a", "metadata": {}, "source": [ "Non-Negative Basis Pursuit DeNoising (ADMM)\n", "===========================================\n", "\n", "This example demonstrates the solution of a non-negative sparse coding\n", "problem\n", "\n", " $$\\mathrm{argmin}_{\\mathbf{x}} \\; (1/2) \\| \\mathbf{y} - D \\mathbf{x} \\|_2^2\n", " + \\lambda \\| \\mathbf{x} \\|_1 + I(\\mathbf{x} \\geq 0) \\;,$$\n", "\n", "where $D$ the dictionary, $\\mathbf{y}$ the signal to be represented,\n", "$\\mathbf{x}$ is the sparse representation, and $I(\\mathbf{x} \\geq 0)$\n", "is the non-negative indicator." ] }, { "cell_type": "code", "execution_count": 1, "id": "7b28a87a", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-20T19:54:14.459537Z", "iopub.status.busy": "2021-12-20T19:54:14.458987Z", "iopub.status.idle": "2021-12-20T19:54:16.972003Z", "shell.execute_reply": "2021-12-20T19:54:16.972517Z" } }, "outputs": [], "source": [ "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install 'scico[examples] @ git+https://github.com/lanl/scico'\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install 'scico[examples] @ git+https://github.com/lanl/scico'\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install 'scico[examples] @ git+https://github.com/lanl/scico\n'", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install git+https://github.com/lanl/scico\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install git+https://github.com/lanl/scico\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install git+https://github.com/lanl/scico\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install git+https://github.com/lanl/scico\n", "\n", "# This scico project Jupyter notebook has been automatically modified\n", "# to install the dependencies required for running it on Google Colab.\n", "# If you encounter any problems in running it, please open an issue at\n", "# https://github.com/lanl/scico-data/issues\n", "\n", "!pip install git+https://github.com/lanl/scico\n", "\n", "import numpy as np\n", "\n", "import jax\n", "\n", "from scico import functional, linop, loss, plot\n", "from scico.optimize.admm import ADMM, MatrixSubproblemSolver\n", "from scico.util import device_info\n", "plot.config_notebook_plotting()" ] }, { "cell_type": "markdown", "id": "a5b824f6", "metadata": {}, "source": [ "Create random dictionary, reference random sparse representation, and\n", "test signal consisting of the synthesis of the reference sparse\n", "representation." ] }, { "cell_type": "code", "execution_count": 2, "id": "0b2a4327", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-20T19:54:16.978559Z", "iopub.status.busy": "2021-12-20T19:54:16.978115Z", "iopub.status.idle": "2021-12-20T19:54:18.254884Z", "shell.execute_reply": "2021-12-20T19:54:18.253629Z" } }, "outputs": [], "source": [ "m = 32 # signal size\n", "n = 128 # dictionary size\n", "s = 10 # sparsity level\n", "\n", "np.random.seed(1)\n", "D = np.random.randn(m, n)\n", "D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary\n", "\n", "xt = np.zeros(n) # true signal\n", "idx = np.random.randint(low=0, high=n, size=s) # support of xt\n", "xt[idx] = np.random.rand(s)\n", "y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal\n", "\n", "xt = jax.device_put(xt) # convert to jax array, push to GPU\n", "y = jax.device_put(y) # convert to jax array, push to GPU" ] }, { "cell_type": "markdown", "id": "6b93e839", "metadata": {}, "source": [ "Set up the forward operator and ADMM solver object." ] }, { "cell_type": "code", "execution_count": 3, "id": "49d9f0b8", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-20T19:54:18.258831Z", "iopub.status.busy": "2021-12-20T19:54:18.258087Z", "iopub.status.idle": "2021-12-20T19:54:18.826467Z", "shell.execute_reply": "2021-12-20T19:54:18.827280Z" } }, "outputs": [], "source": [ "lmbda = 1e-1\n", "A = linop.MatrixOperator(D)\n", "f = loss.SquaredL2Loss(y=y, A=A)\n", "g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]\n", "C_list = [linop.Identity((n)), linop.Identity((n))]\n", "rho_list = [1.0, 1.0]\n", "maxiter = 100 # number of ADMM iterations\n", "\n", "solver = ADMM(\n", " f=f,\n", " g_list=g_list,\n", " C_list=C_list,\n", " rho_list=rho_list,\n", " x0=A.adj(y),\n", " maxiter=maxiter,\n", " subproblem_solver=MatrixSubproblemSolver(),\n", " itstat_options={\"display\": True, \"period\": 10},\n", ")" ] }, { "cell_type": "markdown", "id": "27d797b8", "metadata": {}, "source": [ "Run the solver." ] }, { "cell_type": "code", "execution_count": 4, "id": "213cbf3f", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-20T19:54:18.831329Z", "iopub.status.busy": "2021-12-20T19:54:18.829918Z", "iopub.status.idle": "2021-12-20T19:54:23.475714Z", "shell.execute_reply": "2021-12-20T19:54:23.475162Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Solving on GPU (NVIDIA GeForce RTX 2080 Ti)\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res \n", "-----------------------------------------------------------------\n", " 0 1.47e+00 2.810e+00 1.435e+00 4.750e+00 7 5.959e-05\n", " 1 2.21e+00 6.810e-01 4.430e-01 8.297e-01 9 8.225e-05\r", " 2 2.24e+00 6.931e-01 3.791e-01 2.579e-01 7 7.293e-05\r", " 3 2.27e+00 6.060e-01 2.025e-01 1.891e-01 7 7.780e-05\r", " 4 2.30e+00 5.527e-01 1.014e-01 1.657e-01 7 2.822e-05\r", " 5 2.33e+00 5.219e-01 7.434e-02 1.383e-01 6 7.671e-05\r", " 6 2.36e+00 5.079e-01 7.709e-02 1.088e-01 6 5.791e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 7 2.39e+00 4.989e-01 5.198e-02 9.590e-02 6 5.759e-05\r", " 8 2.42e+00 4.928e-01 4.370e-02 8.704e-02 6 4.320e-05\r", " 9 2.45e+00 4.904e-01 5.082e-02 7.129e-02 6 3.517e-05\r", " 10 2.48e+00 4.879e-01 3.590e-02 6.160e-02 6 3.352e-05\n", " 11 2.51e+00 4.858e-01 3.085e-02 5.368e-02 5 9.916e-05\r", " 12 2.53e+00 4.847e-01 3.253e-02 4.530e-02 5 7.902e-05\r", " 13 2.55e+00 4.829e-01 2.272e-02 3.870e-02 6 2.912e-05\r", " 14 2.58e+00 4.811e-01 1.595e-02 3.524e-02 5 7.934e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 15 2.61e+00 4.797e-01 1.319e-02 3.197e-02 5 6.346e-05\r", " 16 2.64e+00 4.785e-01 1.168e-02 2.921e-02 5 5.498e-05\r", " 17 2.67e+00 4.775e-01 1.092e-02 2.672e-02 5 5.073e-05\r", " 18 2.69e+00 4.765e-01 1.012e-02 2.467e-02 5 4.331e-05\r", " 19 2.72e+00 4.757e-01 9.883e-03 2.269e-02 5 4.049e-05\r", " 20 2.75e+00 4.753e-01 1.017e-02 2.036e-02 5 3.832e-05\n", " 21 2.78e+00 4.750e-01 9.117e-03 1.848e-02 5 3.448e-05\r", " 22 2.80e+00 4.749e-01 8.487e-03 1.640e-02 5 3.111e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 23 2.82e+00 4.750e-01 8.096e-03 1.461e-02 4 9.554e-05\r", " 24 2.85e+00 4.748e-01 6.643e-03 1.307e-02 4 9.588e-05\r", " 25 2.88e+00 4.746e-01 5.796e-03 1.176e-02 4 8.330e-05\r", " 26 2.90e+00 4.744e-01 5.057e-03 1.065e-02 4 7.405e-05\r", " 27 2.93e+00 4.742e-01 4.410e-03 9.722e-03 4 6.696e-05\r", " 28 2.96e+00 4.740e-01 3.857e-03 8.936e-03 4 5.982e-05\r", " 29 2.99e+00 4.738e-01 3.391e-03 8.266e-03 4 5.352e-05\r", " 30 3.02e+00 4.737e-01 3.005e-03 7.686e-03 4 4.805e-05\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 31 3.04e+00 4.736e-01 2.688e-03 7.175e-03 4 4.409e-05\r", " 32 3.07e+00 4.735e-01 2.428e-03 6.718e-03 4 4.162e-05\r", " 33 3.10e+00 4.734e-01 2.215e-03 6.303e-03 4 3.903e-05\r", " 34 3.13e+00 4.733e-01 2.038e-03 5.922e-03 4 3.636e-05\r", " 35 3.15e+00 4.733e-01 1.865e-03 5.554e-03 3 9.948e-05\r", " 36 3.18e+00 4.732e-01 1.754e-03 5.243e-03 3 6.122e-05\r", " 37 3.20e+00 4.732e-01 1.630e-03 4.929e-03 3 8.706e-05\r", " 38 3.23e+00 4.731e-01 1.540e-03 4.651e-03 3 5.910e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 39 3.26e+00 4.731e-01 1.435e-03 4.374e-03 3 7.613e-05\r", " 40 3.28e+00 4.732e-01 1.731e-03 3.707e-03 3 5.594e-05\n", " 41 3.31e+00 4.732e-01 1.758e-03 3.223e-03 3 6.391e-05\r", " 42 3.33e+00 4.732e-01 1.495e-03 2.781e-03 3 7.975e-05\r", " 43 3.36e+00 4.732e-01 1.411e-03 2.401e-03 3 6.982e-05\r", " 44 3.38e+00 4.732e-01 1.296e-03 2.075e-03 3 6.177e-05\r", " 45 3.41e+00 4.732e-01 1.166e-03 1.800e-03 3 5.345e-05\r", " 46 3.43e+00 4.731e-01 1.036e-03 1.566e-03 3 4.550e-05\r", " 47 3.45e+00 4.731e-01 9.215e-04 1.343e-03 3 3.840e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 48 3.48e+00 4.731e-01 7.847e-04 1.084e-03 2 9.732e-05\r", " 49 3.50e+00 4.731e-01 6.643e-04 9.493e-04 2 9.293e-05\r", " 50 3.53e+00 4.731e-01 5.812e-04 8.044e-04 2 7.877e-05\n", " 51 3.55e+00 4.731e-01 5.317e-04 7.038e-04 2 6.443e-05\r", " 52 3.57e+00 4.730e-01 4.626e-04 5.887e-04 2 5.985e-05\r", " 53 3.60e+00 4.730e-01 4.065e-04 5.070e-04 2 5.060e-05\r", " 54 3.62e+00 4.730e-01 3.520e-04 4.284e-04 2 4.389e-05\r", " 55 3.64e+00 4.730e-01 2.558e-04 1.825e-04 1 8.865e-05\r", " 56 3.66e+00 4.730e-01 2.483e-04 1.764e-04 1 8.417e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 57 3.68e+00 4.730e-01 1.851e-04 1.027e-04 1 6.973e-05\r", " 58 3.70e+00 4.730e-01 1.851e-04 0.000e+00 0 9.308e-05\r", " 59 3.73e+00 4.730e-01 9.734e-05 1.078e-04 1 5.527e-05\r", " 60 3.75e+00 4.730e-01 9.734e-05 0.000e+00 0 7.698e-05\n", " 61 3.77e+00 4.730e-01 9.734e-05 0.000e+00 0 8.566e-05\r", " 62 3.79e+00 4.730e-01 9.734e-05 0.000e+00 0 9.952e-05\r", " 63 3.82e+00 4.730e-01 3.825e-05 6.954e-05 1 4.892e-05\r", " 64 3.84e+00 4.730e-01 3.825e-05 0.000e+00 0 6.494e-05\r", " 65 3.87e+00 4.730e-01 3.825e-05 0.000e+00 0 6.899e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 66 3.89e+00 4.730e-01 3.825e-05 0.000e+00 0 7.402e-05\r", " 67 3.92e+00 4.730e-01 3.825e-05 0.000e+00 0 7.983e-05\r", " 68 3.95e+00 4.730e-01 3.825e-05 0.000e+00 0 8.627e-05\r", " 69 3.97e+00 4.730e-01 3.825e-05 0.000e+00 0 9.321e-05\r", " 70 4.00e+00 4.730e-01 5.895e-05 6.213e-05 1 4.818e-05\n", " 71 4.02e+00 4.730e-01 5.895e-05 0.000e+00 0 6.711e-05\r", " 72 4.04e+00 4.730e-01 5.895e-05 0.000e+00 0 7.636e-05\r", " 73 4.06e+00 4.730e-01 5.895e-05 0.000e+00 0 8.711e-05\r", " 74 4.08e+00 4.730e-01 5.895e-05 0.000e+00 0 9.886e-05\r", " 75 4.10e+00 4.730e-01 3.793e-05 5.370e-05 1 4.677e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 76 4.12e+00 4.730e-01 3.793e-05 0.000e+00 0 5.983e-05\r", " 77 4.14e+00 4.730e-01 3.793e-05 0.000e+00 0 6.447e-05\r", " 78 4.15e+00 4.730e-01 3.793e-05 0.000e+00 0 7.008e-05\r", " 79 4.17e+00 4.730e-01 3.793e-05 0.000e+00 0 7.644e-05\r", " 80 4.19e+00 4.730e-01 3.793e-05 0.000e+00 0 8.337e-05\n", " 81 4.21e+00 4.730e-01 3.793e-05 0.000e+00 0 9.076e-05\r", " 82 4.24e+00 4.730e-01 3.793e-05 0.000e+00 0 9.849e-05\r", " 83 4.27e+00 4.730e-01 5.727e-05 5.079e-05 1 4.646e-05\r", " 84 4.29e+00 4.730e-01 5.727e-05 0.000e+00 0 6.218e-05\r", " 85 4.31e+00 4.730e-01 5.727e-05 0.000e+00 0 7.058e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 86 4.33e+00 4.730e-01 5.727e-05 0.000e+00 0 8.064e-05\r", " 87 4.35e+00 4.730e-01 5.727e-05 0.000e+00 0 9.179e-05\r", " 88 4.38e+00 4.730e-01 2.943e-05 5.005e-05 1 4.083e-05\r", " 89 4.40e+00 4.730e-01 2.943e-05 0.000e+00 0 5.311e-05\r", " 90 4.42e+00 4.730e-01 2.943e-05 0.000e+00 0 5.663e-05\n", " 91 4.44e+00 4.730e-01 2.943e-05 0.000e+00 0 6.083e-05\r", " 92 4.47e+00 4.730e-01 2.943e-05 0.000e+00 0 6.558e-05\r", " 93 4.49e+00 4.730e-01 2.943e-05 0.000e+00 0 7.077e-05\r", " 94 4.51e+00 4.730e-01 2.943e-05 0.000e+00 0 7.630e-05\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 95 4.53e+00 4.730e-01 2.943e-05 0.000e+00 0 8.211e-05\r", " 96 4.56e+00 4.730e-01 2.943e-05 0.000e+00 0 8.815e-05\r", " 97 4.58e+00 4.730e-01 2.928e-05 1.322e-06 0 9.436e-05\r", " 98 4.60e+00 4.730e-01 6.007e-05 4.695e-05 1 4.417e-05\r", " 99 4.62e+00 4.730e-01 6.007e-05 0.000e+00 0 5.957e-05\r\n" ] } ], "source": [ "print(f\"Solving on {device_info()}\\n\")\n", "x = solver.solve()" ] }, { "cell_type": "markdown", "id": "e3f03fd6", "metadata": {}, "source": [ "Plot the recovered coefficients and signal." ] }, { "cell_type": "code", "execution_count": 5, "id": "c2e8b2c4", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2021-12-20T19:54:23.492471Z", "iopub.status.busy": "2021-12-20T19:54:23.482292Z", "iopub.status.idle": "2021-12-20T19:54:23.952319Z", "shell.execute_reply": "2021-12-20T19:54:23.952764Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))\n", "plot.plot(\n", " np.vstack((xt, solver.x)).T,\n", " title=\"Coefficients\",\n", " lgnd=(\"Ground Truth\", \"Recovered\"),\n", " fig=fig,\n", " ax=ax[0],\n", ")\n", "plot.plot(\n", " np.vstack((D @ xt, y, D @ solver.x)).T,\n", " title=\"Signal\",\n", " lgnd=(\"Ground Truth\", \"Noisy\", \"Recovered\"),\n", " fig=fig,\n", " ax=ax[1],\n", ")\n", "fig.show()" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }