{ "cells": [ { "cell_type": "markdown", "source": [ "# L+S 2D dynamic recon\n", "\n", "This page illustrates dynamic parallel MRI image reconstruction\n", "using a low-rank plus sparse (L+S) model\n", "optimized by a fast algorithm\n", "described in the paper\n", "by Claire Lin and Jeff Fessler\n", "[Efficient Dynamic Parallel MRI Reconstruction for the Low-Rank Plus Sparse Model](https://doi.org/10.1109/TCI.2018.2882089),\n", "IEEE Trans. on Computational Imaging, 5(1):17-26, 2019,\n", "by Claire Lin and Jeff Fessler,\n", "EECS Department, University of Michigan.\n", "\n", "The Julia code here is a translation\n", "of part of the\n", "[Matlab code](https://github.com/JeffFessler/reproduce-l-s-dynamic-mri)\n", "used in the original paper.\n", "\n", "If you use this code,\n", "please cite that paper." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Setup" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Packages needed here." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "# using Unitful: s\n", "using Plots; cgrad, default(markerstrokecolor=:auto, label=\"\")\n", "using MIRT: Afft, Asense, embed\n", "using MIRT: pogm_restart, poweriter\n", "using MIRTjim: jim, prompt\n", "using FFTW: fft!, bfft!, fftshift!\n", "using LinearMapsAA: LinearMapAA, block_diag, redim, undim\n", "using MAT: matread\n", "import Downloads # todo: use Fetch or DataDeps?\n", "using LinearAlgebra: dot, norm, svd, svdvals, Diagonal, I\n", "using Random: seed!\n", "using StatsBase: mean\n", "using LaTeXStrings" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this file as a script;\n", "this way it will prompt user to hit a key after each figure is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "jif(args...; kwargs...) = jim(args...; prompt=false, kwargs...)\n", "isinteractive() ? jim(:prompt, true) : prompt(:draw);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Overview\n", "\n", "Dynamic image reconstruction\n", "using a \"low-rank plus sparse\"\n", "or \"L+S\" approach\n", "was proposed by\n", "[Otazo et al.](https://doi.org/10.1002/mrm.25240)\n", "and uses the following cost function:\n", "\n", "$$\n", "X = \\hat{L} + \\hat{S}\n", ",\\qquad\n", "(\\hat{L}, \\hat{S})\n", "= \\arg \\min_{L,S} \\frac{1}{2} \\| E (L + S) - d \\|_2^2\n", " + λ_L \\| L \\|_*\n", " + λ_S \\| vec(T S) \\|_1\n", "$$\n", "where $T$ is a temporal unitary FFT,\n", "$E$ is an encoding operator (system matrix),\n", "and $d$\n", "is Cartesian undersampled multicoil k-space data.\n", "\n", "The Otazo paper used an\n", "iterative soft thresholding algorithm (ISTA)\n", "to solve this optimization problem.\n", "Using FISTA is faster,\n", "but using\n", "the\n", "[proximal optimized gradient method (POGM)](https://doi.org/10.1137/16m108104x)\n", "with\n", "[adaptive restart](https://doi.org/10.1007/s10957-018-1287-4)\n", "is even faster.\n", "\n", "This example reproduces part of Figures 1 & 2 in\n", "[Claire Lin's paper](https://doi.org/10.1109/TCI.2018.2882089),\n", "based on the\n", "[cardiac perfusion example](https://github.com/JeffFessler/reproduce-l-s-dynamic-mri/blob/master/examples/example_cardiac_perf.m)." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Read data" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(data)\n", " url = \"https://github.com/JeffFessler/MIRTdata/raw/main/mri/lin-19-edp/\"\n", " dataurl = url * \"cardiac_perf_R8.mat\"\n", " data = matread(Downloads.download(dataurl))\n", " xinfurl = url * \"Xinf.mat\"\n", " Xinf = matread(Downloads.download(xinfurl))[\"Xinf\"][\"perf\"] # (128,128,40)\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Show converged image as a preview:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "pinf = jim(Xinf, L\"\\mathrm{Converged\\ image\\ sequence } X_∞\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Organize k-space data:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(ydata0)\n", " ydata0 = data[\"kdata\"] # k-space data full of zeros\n", " ydata0 = permutedims(ydata0, [1, 2, 4, 3]) # (nx,ny,nc,nt)\n", " ydata0 = ComplexF32.(ydata0)\n", "end\n", "(nx, ny, nc, nt) = size(ydata0)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Extract sampling pattern from zeros of k-space data:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(samp)\n", " samp = ydata0[:,:,1,:] .!= 0\n", " for ic in 2:nc # verify it is same for all coils\n", " @assert samp == (ydata0[:,:,ic,:] .!= 0)\n", " end\n", " kx = -(nx÷2):(nx÷2-1)\n", " ky = -(ny÷2):(ny÷2-1)\n", " psamp = jim(kx, ky, samp, \"Sampling patterns for $nt frames\";\n", " xlabel=L\"k_x\", ylabel=L\"k_y\")\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Are all k-space rows are sampled in one of the 40 frames?\n", "Sadly no.\n", "The 10 blue rows shown below are never sampled.\n", "A better sampling pattern design\n", "could have avoided this issue." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "samp_sum = sum(samp, dims=3)\n", "color = cgrad([:blue, :black, :white], [0, 1/2nt, 1])\n", "pssum = jim(kx, ky, samp_sum; xlabel=\"kx\", ylabel=\"ky\",\n", " color, clim=(0,nt), title=\"Number of sampled frames out of $nt\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Prepare coil sensitivity maps" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(smaps)\n", " smaps_raw = data[\"b1\"] # raw coil sensitivity maps\n", " jim(smaps_raw, \"Raw |coil maps| for $nc coils\")\n", " sum_last = (f, x) -> selectdim(sum(f, x; dims=ndims(x)), ndims(x), 1)\n", " ssos_fun = smap -> sqrt.(sum_last(abs2, smap)) # SSoS\n", " ssos_raw = ssos_fun(smaps_raw)\n", " smaps = smaps_raw ./ ssos_raw\n", " ssos = ssos_fun(smaps)\n", " @assert all(≈(1), ssos)\n", " pmap = jim(smaps, \"Normalized |coil maps| for $nc coils\")\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Temporal unitary FFT sparsifying transform\n", "for image sequence of size `(nx, ny, nt)`:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "TF = Afft((nx,ny,nt), 3; unitary=true) # unitary FFT along 3rd (time) dimension\n", "if false # verify adjoint\n", " tmp1 = randn(ComplexF32, nx, ny, nt)\n", " tmp2 = randn(ComplexF32, nx, ny, nt)\n", " @assert dot(tmp2, TF * tmp1) ≈ dot(TF' * tmp2, tmp1)\n", " @assert TF' * (TF * tmp1) ≈ tmp1\n", " (size(TF), TF._odim, TF._idim)\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Examine temporal Fourier sparsity of Xinf.\n", "The low temporal frequencies dominate,\n", "as expected,\n", "because Xinf was reconstructed\n", "using this regularizer!" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "tmp = TF * Xinf\n", "ptfft = jim(tmp, \"|Temporal FFT of Xinf|\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## System matrix\n", "Construct dynamic parallel MRI system model.\n", "It is block diagonal\n", "where each frame has its own sampling pattern.\n", "The input (image) here has size `(nx=128, ny=128, nt=40)`.\n", "The output (data) has size `(nsamp=2048, nc=12, nt=40)`\n", "because every frame\n", "has 16 phase-encode lines of 128 samples.\n", "\n", "todo: precompute (i)fft along readout direction to save time\n", "\n", "The code in the original Otazo et al. paper\n", "used an `ifft` in the forward model\n", "and an `fft` in the adjoint,\n", "so we must use a flag here to match that model." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "Aotazo = (samp, smaps) -> Asense(samp, smaps; unitary=true, fft_forward=false) # Otazo style\n", "A = block_diag([Aotazo(s, smaps) for s in eachslice(samp, dims=3)]...)\n", "#A = ComplexF32(1/sqrt(nx*ny)) * A # match Otazo's scaling\n", "(size(A), A._odim, A._idim)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Reshape data to match the system model" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(ydata)\n", " tmp = reshape(ydata0, :, nc, nt)\n", " tmp = [tmp[vec(s),:,it] for (it,s) in enumerate(eachslice(samp, dims=3))]\n", " ydata = cat(tmp..., dims=3) # (nsamp,nc,nt) = (2048,12,40) no \"zeros\"\n", "end\n", "size(ydata)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Final encoding operator `E` for L+S because we stack `X = [L;S]`" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "tmp = LinearMapAA(I(nx*ny*nt);\n", " odim=(nx,ny,nt), idim=(nx,ny,nt), T=ComplexF32, prop=(;name=\"I\"))\n", "tmp = kron([1 1], tmp)\n", "AII = redim(tmp; odim=(nx,ny,nt), idim=(nx,ny,nt,2)) # \"squeeze\" odim\n", "E = A * AII;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Run power iteration to verify that `opnorm(E) = √2`" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if false\n", " (_, σ1E) = poweriter(undim(E)) # 1.413 ≈ √2\n", "else\n", " σ1E = √2\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Check scale factor of Xinf. (It should be ≈1.)" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "tmp = A * Xinf\n", "scale0 = dot(tmp, ydata) / norm(tmp)^2 # 1.009 ≈ 1" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Crude initial image" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "L0 = A' * ydata # adjoint (zero-filled)\n", "S0 = zeros(ComplexF32, nx, ny, nt)\n", "X0 = cat(L0, S0, dims=ndims(L0)+1) # (nx, ny, nt, 2) = (128, 128, 40, 2)\n", "M0 = AII * X0 # L0 + S0\n", "pm0 = jim(M0, \"|Initial L+S via zero-filled recon|\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## L+S reconstruction\n", "Prepare for proximal gradient methods" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Scalars to match Otazo's results" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "scaleL = 130 / 1.2775 # Otazo's stopping St(1) / b1 constant squared\n", "scaleS = 1 / 1.2775; # 1 / b1 constant squared" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "L+S regularizer" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "lambda_L = 0.01 # regularization parameter\n", "lambda_S = 0.01 * scaleS\n", "Lpart = X -> selectdim(X, ndims(X), 1) # extract \"L\" from X\n", "Spart = X -> selectdim(X, ndims(X), 2) # extract \"S\" from X\n", "nucnorm(L::AbstractMatrix) = sum(svdvals(L)) # nuclear norm\n", "nucnorm(L::AbstractArray) = nucnorm(reshape(L, :, nt)); # (nx*ny, nt) for L" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Optimization overall composite cost function" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "Fcost = X -> 0.5 * norm(E * X - ydata)^2 +\n", " lambda_L * scaleL * nucnorm(Lpart(X)) + # note scaleL !\n", " lambda_S * norm(TF * Spart(X), 1);\n", "\n", "f_grad = X -> E' * (E * X - ydata); # gradient of data-fit term" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Lipschitz constant of data-fit term is 2\n", "because A is unitary and AII is like ones(2,2)." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "f_L = 2; # σ1E^2" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Proximal operator for scaled nuclear norm $β | X |_*$:\n", "singular value soft thresholding (SVST)." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "function SVST(X::AbstractArray, β)\n", " dims = size(X)\n", " X = reshape(X, :, dims[end]) # assume time frame is the last dimension\n", " U,s,V = svd(X)\n", " sthresh = @. max(s - β, 0)\n", " keep = findall(>(0), sthresh)\n", " X = U[:,keep] * Diagonal(sthresh[keep]) * V[:,keep]'\n", " X = reshape(X, dims)\n", " return X\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Combine proximal operators for L and S parts to make overall prox for `X`" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "soft = (v,c) -> sign(v) * max(abs(v) - c, 0) # soft threshold function\n", "S_prox = (S, β) -> TF' * soft.(TF * S, β) # 1-norm proximal mapping for unitary TF\n", "g_prox = (X, c) -> cat(dims=ndims(X),\n", " SVST(Lpart(X), c * lambda_L * scaleL),\n", " S_prox(Spart(X), c * lambda_S),\n", ");\n", "\n", "if false # check functions\n", " @assert Fcost(X0) isa Real\n", " tmp = f_grad(X0)\n", " @assert size(tmp) == size(X0)\n", " tmp = SVST(Lpart(X0), 1)\n", " @assert size(tmp) == size(L0)\n", " tmp = S_prox(S0, 1)\n", " @assert size(tmp) == size(S0)\n", " tmp = g_prox(X0, 1)\n", " @assert size(tmp) == size(X0)\n", "end\n", "\n", "\n", "niter = 10\n", "fun = (iter, xk, yk, is_restart) -> (Fcost(xk), xk); # logger" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Run PGM" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(Mpgm)\n", " f_mu = 2/0.99 - f_L # trick to match 0.99 step size in Lin 1999\n", " f_mu = 0\n", " xpgm, out_pgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;\n", " f_mu, mom = :pgm, niter, g_prox, fun)\n", " Mpgm = AII * xpgm\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Run FPGM (FISTA)" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(Mfpgm)\n", " xfpgm, out_fpgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;\n", " mom = :fpgm, niter, g_prox, fun)\n", " Mfpgm = AII * xfpgm\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Run POGM" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(Mpogm)\n", " xpogm, out_pogm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;\n", " mom = :pogm, niter, g_prox, fun)\n", " Mpogm = AII * xpogm\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Look at final POGM image components" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "px = jim(\n", " jif(Lpart(xpogm), \"L\"),\n", " jif(Spart(xpogm), \"S\"),\n", " jif(Mpogm, \"M=L+S\"),\n", " jif(Xinf, \"Minf\"),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Plot cost function" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "costs = out -> [o[1] for o in out]\n", "nrmsd = out -> [norm(AII*o[2]-Xinf)/norm(Xinf) for o in out]\n", "cost_pgm = costs(out_pgm)\n", "cost_fpgm = costs(out_fpgm)\n", "cost_pogm = costs(out_pogm)\n", "pc = plot(xlabel = \"iteration\", ylabel = \"cost\")\n", "plot!(0:niter, cost_pgm, marker=:circle, label=\"PGM (ISTA)\")\n", "plot!(0:niter, cost_fpgm, marker=:square, label=\"FPGM (FISTA)\")\n", "plot!(0:niter, cost_pogm, marker=:star, label=\"POGM\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Plot NRMSD vs Matlab Xinf" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nrmsd_pgm = nrmsd(out_pgm)\n", "nrmsd_fpgm = nrmsd(out_fpgm)\n", "nrmsd_pogm = nrmsd(out_pogm)\n", "pd = plot(xlabel = \"iteration\", ylabel = \"NRMSD vs Matlab Xinf\")\n", "plot!(0:niter, nrmsd_pgm, marker=:circle, label=\"PGM (ISTA)\")\n", "plot!(0:niter, nrmsd_fpgm, marker=:square, label=\"FPGM (FISTA)\")\n", "plot!(0:niter, nrmsd_pogm, marker=:star, label=\"POGM\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Discussion\n", "\n", "todo" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "---\n", "\n", "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*" ], "metadata": {} } ], "nbformat_minor": 3, "metadata": { "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.10.3" }, "kernelspec": { "name": "julia-1.10", "display_name": "Julia 1.10.3", "language": "julia" } }, "nbformat": 4 }