{ "cells": [ { "cell_type": "markdown", "source": [ "# SPECTrecon deep learning use\n", "\n", "This page describes how to end-to-end train unrolled deep learning algorithms\n", "using the Julia package\n", "[`SPECTrecon`](https://github.com/JuliaImageRecon/SPECTrecon.jl)." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Setup" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Packages needed here." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "using LinearAlgebra: norm, mul!\n", "using SPECTrecon: SPECTplan, project!, backproject!, psf_gauss, mlem!\n", "using MIRTjim: jim, prompt\n", "using Plots: default; default(markerstrokecolor=:auto)\n", "using ZygoteRules: @adjoint\n", "using Flux: Chain, Conv, SamePad, relu, params, unsqueeze\n", "import Flux # apparently needed for BSON @load\n", "import NNlib\n", "using LinearMapsAA: LinearMapAA\n", "using Distributions: Poisson\n", "using BSON: @load, @save\n", "import BSON # load\n", "using InteractiveUtils: versioninfo\n", "import Downloads # download" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this example.jl file as a script;\n", "this way it will prompt user to hit a key after each figure is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() ? jim(:prompt, true) : prompt(:draw);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Overview\n", "\n", "Regularized expectation-maximization (reg-EM)\n", "is a commonly used algorithm for performing SPECT image reconstruction.\n", "This page considers regularizers of the form $β/2 * ||x - u||^2$,\n", "where $u$ is an auxiliary variable that often refers to the image denoised by a CNN.\n", "\n", "### Data generation\n", "\n", "Simulated data used in this page are identical to\n", "[`SPECTrecon ML-EM`](https://jefffessler.github.io/SPECTrecon.jl/stable/examples/4-mlem/).\n", "We repeat it again here for convenience." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nx,ny,nz = 64,64,50\n", "T = Float32\n", "xtrue = zeros(T, nx,ny,nz)\n", "xtrue[(1nx÷4):(2nx÷3), 1ny÷5:(3ny÷5), 2nz÷6:(3nz÷6)] .= 1\n", "xtrue[(2nx÷5):(3nx÷5), 1ny÷5:(2ny÷5), 4nz÷6:(5nz÷6)] .= 2\n", "\n", "average(x) = sum(x) / length(x)\n", "function mid3(x::AbstractArray{T,3}) where {T}\n", " (nx,ny,nz) = size(x)\n", " xy = x[:,:,ceil(Int, nz÷2)]\n", " xz = x[:,ceil(Int,end/2),:]\n", " zy = x[ceil(Int, nx÷2),:,:]'\n", " return [xy xz; zy fill(average(xy), nz, nz)]\n", "end\n", "jim(mid3(xtrue), \"Middle slices of xtrue\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### PSF" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Create a synthetic depth-dependent PSF for a single view" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "px = 11\n", "psf1 = psf_gauss( ; ny, px)\n", "jim(psf1, \"PSF for each of $ny planes\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "In general the PSF can vary from view to view\n", "due to non-circular detector orbits.\n", "For simplicity, here we illustrate the case\n", "where the PSF is the same for every view." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nview = 60\n", "psfs = repeat(psf1, 1, 1, 1, nview)\n", "size(psfs)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### SPECT system model using `LinearMapAA`" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "dy = 8 # transaxial pixel size in mm\n", "mumap = zeros(T, size(xtrue)) # zero μ-map just for illustration here\n", "plan = SPECTplan(mumap, psfs, dy; T)\n", "\n", "forw! = (y,x) -> project!(y, x, plan)\n", "back! = (x,y) -> backproject!(x, y, plan)\n", "idim = (nx,ny,nz)\n", "odim = (nx,nz,nview)\n", "A = LinearMapAA(forw!, back!, (prod(odim),prod(idim)); T, odim, idim)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Generate noisy data" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(ynoisy) # generate (scaled) Poisson data\n", " ytrue = A * xtrue\n", " target_mean = 20 # aim for mean of 20 counts per ray\n", " scale = target_mean / average(ytrue)\n", " scatter_fraction = 0.1 # 10% uniform scatter for illustration\n", " scatter_mean = scatter_fraction * average(ytrue) # uniform for simplicity\n", " background = scatter_mean * ones(T,nx,nz,nview)\n", " ynoisy = rand.(Poisson.(scale * (ytrue + background))) / scale\n", "end\n", "jim(ynoisy, \"$nview noisy projection views\"; ncol=10)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### ML-EM algorithm" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "x0 = ones(T, nx, ny, nz) # initial uniform image\n", "\n", "niter = 30\n", "\n", "if !@isdefined(xhat1)\n", " xhat1 = copy(x0)\n", " mlem!(xhat1, x0, ynoisy, background, A; niter)\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Define evaluation metric" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nrmse(x) = round(100 * norm(mid3(x) - mid3(xtrue)) / norm(mid3(xtrue)); digits=1)\n", "prompt()\n", "# jim(mid3(xhat1), \"MLEM NRMSE=$(nrmse(xhat1))%\") # display ML-EM reconstructed image" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Implement a 3D CNN denoiser" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "cnn = Chain(\n", " Conv((3,3,3), 1 => 4, relu; stride = 1, pad = SamePad(), bias = true),\n", " Conv((3,3,3), 4 => 4, relu; stride = 1, pad = SamePad(), bias = true),\n", " Conv((3,3,3), 4 => 1; stride = 1, pad = SamePad(), bias = true),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Show how many parameters the CNN has" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "paramCount = sum([sum(length, params(layer)) for layer in cnn])" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Custom backpropagation\n", "\n", "Forward and back-projection are linear operations\n", "so their Jacobians are very simple\n", "and there is no need to auto-differentiate through the system matrix\n", "and that would be very computationally expensive.\n", "Instead, we tell Flux.jl to use the customized Jacobian when doing backpropagation." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "projectb(x) = A * x\n", "@adjoint projectb(x) = A * x, dy -> (A' * dy, )\n", "\n", "backprojectb(y) = A' * y\n", "@adjoint backprojectb(y) = A' * y, dx -> (A * dx, )" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Backpropagatable regularized EM algorithm\n", "First define a function for unsqueezing the data\n", "because Flux CNN model expects a 5-dim tensor" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "function unsqueeze45(x)\n", " return unsqueeze(unsqueeze(x, 4), 5)\n", "end\n", "\n", "\"\"\"\n", " bregem(projectb, backprojectb, y, r, Asum, x, cnn, β; niter = 1)\n", "\n", "Backpropagatable regularized EM reconstruction with CNN regularization\n", "-`projectb`: backpropagatable forward projection\n", "-`backprojectb`: backpropagatable backward projection\n", "-`y`: projections\n", "-`r`: scatters\n", "-`Asum`: A' * 1\n", "-`x`: current iterate\n", "-`cnn`: the CNN model\n", "-`β`: regularization parameter\n", "-`niter`: number of iteration for inner EM\n", "\"\"\"\n", "function bregem(\n", " projectb::Function,\n", " backprojectb::Function,\n", " y::AbstractArray,\n", " r::AbstractArray,\n", " Asum::AbstractArray,\n", " x::AbstractArray,\n", " cnn::Union{Chain,Function},\n", " β::Real;\n", " niter::Int = 1,\n", ")\n", "\n", " u = cnn(unsqueeze45(x))[:,:,:,1,1]\n", " Asumu = Asum - β * u\n", " Asumu2 = Asumu.^2\n", " T = eltype(x)\n", " for iter = 1:niter\n", " eterm = backprojectb((y ./ (projectb(x) + r)))\n", " eterm_beta = 4 * β * (x .* eterm)\n", " x = max.(0, T(1/2β) * (-Asumu + sqrt.(Asumu2 + eterm_beta)))\n", " end\n", " return x\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Loss function\n", "We set β = 1 and train 2 outer iterations." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "β = 1\n", "Asum = A' * ones(T, nx, nz, nview)\n", "function loss(xrecon, xtrue)\n", " xiter1 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xrecon, cnn, β; niter = 1)\n", " xiter2 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xiter1, cnn, β; niter = 1)\n", " return sum(abs2, xiter2 - xtrue)\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Initial loss" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "@show loss(xhat1, xtrue)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Train the CNN\n", "Uncomment the following code to train:\n", "\n", "```\n", "using Printf\n", "nepoch = 200\n", "for e in 1:nepoch\n", " @printf(\"epoch = %d, loss = %.2f\\n\", e, loss(xhat1, xtrue))\n", " ps = Flux.params(cnn)\n", " gs = gradient(ps) do\n", " loss(xhat1, xtrue) # we start with the 30 iteration EM reconstruction\n", " end\n", " opt = ADAMW(0.002)\n", " Flux.Optimise.update!(opt, ps, gs)\n", "end\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Uncomment to save your trained model:\n", "```\n", "file = \"../data/trained-cnn-example-6-dl.bson\" # adjust path/name as needed\n", "@save file cnn\n", "```\n", "\n", "Load the pre-trained model (uncomment if you save your own model):\n", "```\n", "@load file cnn\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The code below here works fine when run via `include` from the REPL,\n", "but it fails with the error `UndefVarError: NNlib not defined`\n", "on the `BSON.load` step when run via Literate/Documenter.\n", "So for now it is just fenced off with `isinteractive()`." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if isinteractive()\n", " url = \"https://github.com/JuliaImageRecon/SPECTrecon.jl/blob/main/data/trained-cnn-example-6-dl.bson?raw=true\"\n", " tmp = tempname()\n", " Downloads.download(url, tmp)\n", " cnn = BSON.load(tmp)[:cnn]\n", "else\n", " cnn = x -> x # fake \"do-nothing CNN\" for Literate/Documenter version\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Perform recon with pre-trained model." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "xiter1 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xhat1, cnn, β; xiter1 = bregem(projectb, backprojectb, ynoisy, background,
 Asum, xhat1, cnn, β; niter = 1)
xiter2 = bregem(projectb, backprojectb, ynoisy, background,
 Asum, xiter1, cnn, β; niter = 1)

clim = (0,2)
jim(
 jim(mid3(xtrue), "xtrue"; clim),
 jim(mid3(xhat1), "EM recon, NRMSE = $(nrmse(xhat1))%"; clim),
 jim(mid3(xiter1), "Iter 1, NRMSE = $(nrmse(xiter1))%"; clim),
 jim(mid3(xiter2), "Iter 2, NRMSE = $(nrmse(xiter2))%"; clim),
)

For the web-based Documenter/Literate version,
the three NRMSE values will be the same
because of the "do-nothing" CNN above.
But if you download this file and run it locally,
then you will see that the CNN reduces the NRMSE.

A more thorough investigation
would compare the CNN approach
to a suitably optimized regularized approach;
see [https://doi.org/10.1109/EMBC46164.2021.9630985](https://doi.org/10.1109/EMBC46164.2021.9630985).