{ "cells": [ { "cell_type": "markdown", "source": [ "# SPECTrecon deep learning use\n", "\n", "This page describes how to end-to-end train unrolled deep learning algorithms\n", "using the Julia package\n", "[`SPECTrecon`](https://github.com/JuliaImageRecon/SPECTrecon.jl)." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Setup" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Packages needed here." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "using LinearAlgebra: norm, mul!\n", "using SPECTrecon: SPECTplan, project!, backproject!, psf_gauss, mlem!\n", "using MIRTjim: jim, prompt\n", "using Plots: default; default(markerstrokecolor=:auto)\n", "using ZygoteRules: @adjoint\n", "using Flux: Chain, Conv, SamePad, relu, params, unsqueeze\n", "import Flux # apparently needed for BSON @load\n", "import NNlib\n", "using LinearMapsAA: LinearMapAA\n", "using Distributions: Poisson\n", "using BSON: @load, @save\n", "import BSON # load\n", "using InteractiveUtils: versioninfo\n", "import Downloads # download" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this example.jl file as a script;\n", "this way it will prompt user to hit a key after each figure is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() ? jim(:prompt, true) : prompt(:draw);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Overview\n", "\n", "Regularized expectation-maximization (reg-EM)\n", "is a commonly used algorithm for performing SPECT image reconstruction.\n", "This page considers regularizers of the form $β/2 * ||x - u||^2$,\n", "where $u$ is an auxiliary variable that often refers to the image denoised by a CNN.\n", "\n", "### Data generation\n", "\n", "Simulated data used in this page are identical to\n", "[`SPECTrecon ML-EM`](https://jefffessler.github.io/SPECTrecon.jl/stable/examples/4-mlem/).\n", "We repeat it again here for convenience." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nx,ny,nz = 64,64,50\n", "T = Float32\n", "xtrue = zeros(T, nx,ny,nz)\n", "xtrue[(1nx÷4):(2nx÷3), 1ny÷5:(3ny÷5), 2nz÷6:(3nz÷6)] .= 1\n", "xtrue[(2nx÷5):(3nx÷5), 1ny÷5:(2ny÷5), 4nz÷6:(5nz÷6)] .= 2\n", "\n", "average(x) = sum(x) / length(x)\n", "function mid3(x::AbstractArray{T,3}) where {T}\n", " (nx,ny,nz) = size(x)\n", " xy = x[:,:,ceil(Int, nz÷2)]\n", " xz = x[:,ceil(Int,end/2),:]\n", " zy = x[ceil(Int, nx÷2),:,:]'\n", " return [xy xz; zy fill(average(xy), nz, nz)]\n", "end\n", "jim(mid3(xtrue), \"Middle slices of xtrue\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### PSF" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Create a synthetic depth-dependent PSF for a single view" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "px = 11\n", "psf1 = psf_gauss( ; ny, px)\n", "jim(psf1, \"PSF for each of $ny planes\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "In general the PSF can vary from view to view\n", "due to non-circular detector orbits.\n", "For simplicity, here we illustrate the case\n", "where the PSF is the same for every view." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nview = 60\n", "psfs = repeat(psf1, 1, 1, 1, nview)\n", "size(psfs)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### SPECT system model using `LinearMapAA`" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "dy = 8 # transaxial pixel size in mm\n", "mumap = zeros(T, size(xtrue)) # zero μ-map just for illustration here\n", "plan = SPECTplan(mumap, psfs, dy; T)\n", "\n", "forw! = (y,x) -> project!(y, x, plan)\n", "back! = (x,y) -> backproject!(x, y, plan)\n", "idim = (nx,ny,nz)\n", "odim = (nx,nz,nview)\n", "A = LinearMapAA(forw!, back!, (prod(odim),prod(idim)); T, odim, idim)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Generate noisy data" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(ynoisy) # generate (scaled) Poisson data\n", " ytrue = A * xtrue\n", " target_mean = 20 # aim for mean of 20 counts per ray\n", " scale = target_mean / average(ytrue)\n", " scatter_fraction = 0.1 # 10% uniform scatter for illustration\n", " scatter_mean = scatter_fraction * average(ytrue) # uniform for simplicity\n", " background = scatter_mean * ones(T,nx,nz,nview)\n", " ynoisy = rand.(Poisson.(scale * (ytrue + background))) / scale\n", "end\n", "jim(ynoisy, \"$nview noisy projection views\"; ncol=10)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### ML-EM algorithm" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "x0 = ones(T, nx, ny, nz) # initial uniform image\n", "\n", "niter = 30\n", "\n", "if !@isdefined(xhat1)\n", " xhat1 = copy(x0)\n", " mlem!(xhat1, x0, ynoisy, background, A; niter)\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Define evaluation metric" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "nrmse(x) = round(100 * norm(mid3(x) - mid3(xtrue)) / norm(mid3(xtrue)); digits=1)\n", "prompt()\n", "# jim(mid3(xhat1), \"MLEM NRMSE=$(nrmse(xhat1))%\") # display ML-EM reconstructed image" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Implement a 3D CNN denoiser" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "cnn = Chain(\n", " Conv((3,3,3), 1 => 4, relu; stride = 1, pad = SamePad(), bias = true),\n", " Conv((3,3,3), 4 => 4, relu; stride = 1, pad = SamePad(), bias = true),\n", " Conv((3,3,3), 4 => 1; stride = 1, pad = SamePad(), bias = true),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Show how many parameters the CNN has" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "paramCount = sum([sum(length, params(layer)) for layer in cnn])" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Custom backpropagation\n", "\n", "Forward and back-projection are linear operations\n", "so their Jacobians are very simple\n", "and there is no need to auto-differentiate through the system matrix\n", "and that would be very computationally expensive.\n", "Instead, we tell Flux.jl to use the customized Jacobian when doing backpropagation." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "projectb(x) = A * x\n", "@adjoint projectb(x) = A * x, dy -> (A' * dy, )\n", "\n", "backprojectb(y) = A' * y\n", "@adjoint backprojectb(y) = A' * y, dx -> (A * dx, )" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Backpropagatable regularized EM algorithm\n", "First define a function for unsqueezing the data\n", "because Flux CNN model expects a 5-dim tensor" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "function unsqueeze45(x)\n", " return unsqueeze(unsqueeze(x, 4), 5)\n", "end\n", "\n", "\"\"\"\n", " bregem(projectb, backprojectb, y, r, Asum, x, cnn, β; niter = 1)\n", "\n", "Backpropagatable regularized EM reconstruction with CNN regularization\n", "-`projectb`: backpropagatable forward projection\n", "-`backprojectb`: backpropagatable backward projection\n", "-`y`: projections\n", "-`r`: scatters\n", "-`Asum`: A' * 1\n", "-`x`: current iterate\n", "-`cnn`: the CNN model\n", "-`β`: regularization parameter\n", "-`niter`: number of iteration for inner EM\n", "\"\"\"\n", "function bregem(\n", " projectb::Function,\n", " backprojectb::Function,\n", " y::AbstractArray,\n", " r::AbstractArray,\n", " Asum::AbstractArray,\n", " x::AbstractArray,\n", " cnn::Union{Chain,Function},\n", " β::Real;\n", " niter::Int = 1,\n", ")\n", "\n", " u = cnn(unsqueeze45(x))[:,:,:,1,1]\n", " Asumu = Asum - β * u\n", " Asumu2 = Asumu.^2\n", " T = eltype(x)\n", " for iter = 1:niter\n", " eterm = backprojectb((y ./ (projectb(x) + r)))\n", " eterm_beta = 4 * β * (x .* eterm)\n", " x = max.(0, T(1/2β) * (-Asumu + sqrt.(Asumu2 + eterm_beta)))\n", " end\n", " return x\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Loss function\n", "We set β = 1 and train 2 outer iterations." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "β = 1\n", "Asum = A' * ones(T, nx, nz, nview)\n", "function loss(xrecon, xtrue)\n", " xiter1 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xrecon, cnn, β; niter = 1)\n", " xiter2 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xiter1, cnn, β; niter = 1)\n", " return sum(abs2, xiter2 - xtrue)\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Initial loss" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "@show loss(xhat1, xtrue)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "### Train the CNN\n", "Uncomment the following code to train:\n", "\n", "```\n", "using Printf\n", "nepoch = 200\n", "for e in 1:nepoch\n", " @printf(\"epoch = %d, loss = %.2f\\n\", e, loss(xhat1, xtrue))\n", " ps = Flux.params(cnn)\n", " gs = gradient(ps) do\n", " loss(xhat1, xtrue) # we start with the 30 iteration EM reconstruction\n", " end\n", " opt = ADAMW(0.002)\n", " Flux.Optimise.update!(opt, ps, gs)\n", "end\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Uncomment to save your trained model:\n", "```\n", "file = \"../data/trained-cnn-example-6-dl.bson\" # adjust path/name as needed\n", "@save file cnn\n", "```\n", "\n", "Load the pre-trained model (uncomment if you save your own model):\n", "```\n", "@load file cnn\n", "```" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "The code below here works fine when run via `include` from the REPL,\n", "but it fails with the error `UndefVarError: NNlib not defined`\n", "on the `BSON.load` step when run via Literate/Documenter.\n", "So for now it is just fenced off with `isinteractive()`." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if isinteractive()\n", " url = \"https://github.com/JuliaImageRecon/SPECTrecon.jl/blob/main/data/trained-cnn-example-6-dl.bson?raw=true\"\n", " tmp = tempname()\n", " Downloads.download(url, tmp)\n", " cnn = BSON.load(tmp)[:cnn]\n", "else\n", " cnn = x -> x # fake \"do-nothing CNN\" for Literate/Documenter version\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Perform recon with pre-trained model." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "xiter1 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xhat1, cnn, β; niter = 1)\n", "xiter2 = bregem(projectb, backprojectb, ynoisy, background,\n", " Asum, xiter1, cnn, β; niter = 1)\n", "\n", "clim = (0,2)\n", "jim(\n", " jim(mid3(xtrue), \"xtrue\"; clim),\n", " jim(mid3(xhat1), \"EM recon, NRMSE = $(nrmse(xhat1))%\"; clim),\n", " jim(mid3(xiter1), \"Iter 1, NRMSE = $(nrmse(xiter1))%\"; clim),\n", " jim(mid3(xiter2), \"Iter 2, NRMSE = $(nrmse(xiter2))%\"; clim),\n", ")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "For the web-based Documenter/Literate version,\n", "the three NRMSE values will be the same\n", "because of the \"do-nothing\" CNN above.\n", "But if you download this file and run it locally,\n", "then you will see that the CNN reduces the NRMSE.\n", "\n", "A more thorough investigation\n", "would compare the CNN approach\n", "to a suitably optimized regularized approach;\n", "see [https://doi.org/10.1109/EMBC46164.2021.9630985](https://doi.org/10.1109/EMBC46164.2021.9630985)." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "---\n", "\n", "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*" ], "metadata": {} } ], "nbformat_minor": 3, "metadata": { "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.11.1" }, "kernelspec": { "name": "julia-1.11", "display_name": "Julia 1.11.1", "language": "julia" } }, "nbformat": 4 }