{ "cells": [ { "cell_type": "markdown", "source": [ "# Compressed Sensing 2D\n", "\n", "This example illustrates how to perform\n", "2D compressed sensing image reconstruction\n", "from Cartesian sampled MRI data\n", "with 1-norm regularization of orthogonal wavelet coefficients,\n", "using the Julia language." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "This demo is somewhat similar to Fig. 3 in the survey paper\n", "\"[Optimization methods for MR image reconstruction](https://doi.org/10.1109/MSP.2019.2943645),\"\n", "in Jan 2020 IEEE Signal Processing Magazine,\n", "except that the sampling is 1D phase encoding instead of 2D." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Packages used in this demo (run `Pkg.add` as needed):" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "using ImagePhantoms: shepp_logan, SheppLoganEmis, spectrum, phantom\n", "using MIRT: embed, Afft, Aodwt\n", "using MIRTjim: jim, prompt\n", "using MIRT: pogm_restart\n", "using LinearAlgebra: norm\n", "using Plots; default(markerstrokecolor=:auto, label=\"\")\n", "using FFTW: fft\n", "using Random: seed!\n", "using InteractiveUtils: versioninfo" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this jl-file as a script;\n", "this way it will prompt user to hit a key after each image is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() && jim(:prompt, true);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Create (synthetic) data" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Shepp-Logan phantom (unrealistic because real-valued):" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nx,ny = 192,256\n", "object = shepp_logan(SheppLoganEmis(); fovs=(ny,ny))\n", "Xtrue = phantom(-(nx÷2):(nx÷2-1), -(ny÷2):(ny÷2-1), object, 2)\n", "Xtrue = reverse(Xtrue, dims=2)\n", "clim = (0,9)\n", "jim(Xtrue, \"true image\"; clim)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Somewhat random 1D phase-encode sampling:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "seed!(0); sampfrac = 0.2; samp = rand(ny) .< sampfrac; sig = 1\n", "mod2 = (N) -> mod.((0:N-1) .+ Int(N/2), N) .- Int(N/2)\n", "samp .|= (abs.(mod2(ny)) .< Int(ny/8)) # fully sampled center rows\n", "samp = trues(nx) * samp'\n", "jim(samp, fft0=true, title=\"k-space sampling ($(100count(samp)/(nx*ny))%)\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Generate noisy, under-sampled k-space data (inverse-crime simulation):" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "ytrue = fft(Xtrue)[samp]\n", "y = ytrue + sig * √(2) * randn(ComplexF32, size(ytrue)) # complex noise!\n", "y = ComplexF32.(y) # save memory\n", "ysnr = 20 * log10(norm(ytrue) / norm(y-ytrue)) # data SNR in dB" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Display zero-filled data:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "logger = (x; min=-6) -> log10.(max.(abs.(x) / maximum(abs, x), (10.)^min))\n", "jim(:abswarn, false) # suppress warnings about showing magnitude\n", "jim(logger(embed(ytrue,samp)), fft0=true, title=\"k-space |data| (zero-filled)\",\n", " xlabel=\"kx\", ylabel=\"ky\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Prepare to reconstruct\n", "Creating a system matrix (encoding matrix) and an initial image\n", "The system matrix is a `LinearMapAA` object, akin to a `fatrix` in Matlab MIRT." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "System model (\"encoding matrix\") from MIRT:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "F = Afft(samp) # operator!" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Initial image based on zero-filled reconstruction:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nrmse = (x) -> round(norm(x - Xtrue) / norm(Xtrue) * 100, digits=1)\n", "X0 = 1.0f0/(nx*ny) * (F' * y)\n", "jim(X0, \"|X0|: initial image; NRMSE $(nrmse(X0))%\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Wavelet sparsity in synthesis form\n", "\n", "The image reconstruction optimization problem here is\n", "$$\n", "\\arg \\min_{x}\n", "\\frac{1}{2} \\| A x - y \\|_2^2 + \\beta \\; \\| W x \\|_1\n", "$$\n", "where\n", "$y$ is the k-space data,\n", "$A$ is the system model (simply Fourier encoding `F` here),\n", "$W$ is an orthogonal discrete (Haar) wavelet transform,\n", "again implemented as a `LinearMapAA` object.\n", "Because $W$ is unitary,\n", "we make the change of variables\n", "$z = W x$\n", "and solve for $z$\n", "and then let $x = W' z$\n", "at the end.\n", "In fact we use a weighted 1-norm\n", "where only the detail wavelet coefficients are regularized,\n", "not the approximation coefficients." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Orthogonal discrete wavelet transform operator (`LinearMapAO`):" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "W, scales, _ = Aodwt((nx,ny) ; T = eltype(F))\n", "isdetail = scales .> 0\n", "jim(\n", " jim(scales, \"wavelet scales\"),\n", " jim(real(W * Xtrue) .* isdetail, \"wavelet detail coefficients\"),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Inputs needed for proximal gradient methods:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "Az = F * W' # another operator!\n", "Fnullz = (z) -> 0 # cost function in `z` not needed\n", "f_gradz = (z) -> Az' * (Az * z - y)\n", "f_Lz = nx*ny # Lipschitz constant for single coil Cartesian DFT\n", "regz = 0.03 * nx * ny # oracle from Xtrue wavelet coefficients!\n", "costz = (z) -> 1/2 * norm(Az * z - y)^2 + regz * norm(z,1) # 1-norm regularizer\n", "soft = (z,c) -> sign(z) * max(abs(z) - c, 0) # soft thresholding\n", "g_prox = (z,c) -> soft.(z, isdetail .* (regz * c)) # proximal operator (shrink details only)\n", "z0 = W * X0\n", "jim(z0, \"Initial wavelet coefficients\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Iterate\n", "\n", "Run ISTA=PGM and FISTA=FPGM and POGM, the latter two with adaptive restart\n", "See [Kim & Fessler, 2018](https://doi.org/10.1007/s10957-018-1287-4)\n", "for adaptive restart algorithm details." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Functions for tracking progress:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "function fun_ista(iter, xk_z, yk, is_restart)\n", " xh = W' * xk_z\n", " return (costz(xk_z), nrmse(xh), is_restart) # , psnr(xh)) # time()\n", "end\n", "\n", "function fun_fista(iter, xk, yk_z, is_restart)\n", " xh = W' * yk_z\n", " return (costz(yk_z), nrmse(xh), is_restart) # , psnr(xh)) # time()\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Run and compare three proximal gradient methods:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "niter = 50\n", "z_ista, out_ista = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:pgm, niter=niter,\n", " restart=:none, restart_cutoff=0., g_prox=g_prox, fun=fun_ista)\n", "Xista = W'*z_ista\n", "@show nrmse(Xista)\n", "\n", "z_fista, out_fista = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:fpgm, niter=niter,\n", " restart=:gr, restart_cutoff=0., g_prox=g_prox, fun=fun_fista)\n", "Xfista = W'*z_fista\n", "@show nrmse(Xfista)\n", "\n", "z_pogm, out_pogm = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:pogm, niter=niter,\n", " restart=:gr, restart_cutoff=0., g_prox=g_prox, fun=fun_fista)\n", "Xpogm = W'*z_pogm\n", "@show nrmse(Xpogm)\n", "\n", "jim(\n", " jim(Xfista, \"FISTA/FPGM\"),\n", " jim(Xpogm, \"POGM with ODWT\"),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## POGM is fastest" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Plot cost function vs iteration:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "cost_ista = [out_ista[k][1] for k=1:niter+1]\n", "cost_fista = [out_fista[k][1] for k=1:niter+1]\n", "cost_pogm = [out_pogm[k][1] for k=1:niter+1]\n", "cost_min = min(minimum(cost_ista), minimum(cost_pogm))\n", "plot(xlabel=\"iteration k\", ylabel=\"Relative cost\")\n", "scatter!(0:niter, cost_ista .- cost_min, label=\"Cost ISTA\")\n", "scatter!(0:niter, cost_fista .- cost_min, markershape=:square, label=\"Cost FISTA\")\n", "scatter!(0:niter, cost_pogm .- cost_min, markershape=:utriangle, label=\"Cost POGM\")" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() && prompt();" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Plot NRMSE vs iteration:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nrmse_ista = [out_ista[k][2] for k=1:niter+1]\n", "nrmse_fista = [out_fista[k][2] for k=1:niter+1]\n", "nrmse_pogm = [out_pogm[k][2] for k=1:niter+1]\n", "pn = plot(xlabel=\"iteration k\", ylabel=\"NRMSE %\", ylims=(3,6.5))\n", "scatter!(0:niter, nrmse_ista, label=\"NRMSE ISTA\")\n", "scatter!(0:niter, nrmse_fista, markershape=:square, label=\"NRMSE FISTA\")\n", "scatter!(0:niter, nrmse_pogm, markershape=:utriangle, label=\"NRMSE POGM\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Show error images:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "p1 = jim(Xtrue, \"true\")\n", "p2 = jim(X0, \"X0: initial\")\n", "p3 = jim(Xpogm, \"POGM recon\")\n", "p5 = jim(X0 - Xtrue, \"X0 error\", clim=(0,2))\n", "p6 = jim(Xpogm - Xtrue, \"Xpogm error\", clim=(0,2))\n", "pe = jim(p2, p3, p5, p6)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "---\n", "\n", "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*" ], "metadata": {} } ], "nbformat_minor": 3, "metadata": { "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.9.1" }, "kernelspec": { "name": "julia-1.9", "display_name": "Julia 1.9.1", "language": "julia" } }, "nbformat": 4 }