{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb148e64-7d50-49ae-b2a3-0998f2db1c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "#%matplotlib notebook\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import minimize, Bounds\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85d4bdad-a91e-4b3f-9a64-5b1418ccdf53",
   "metadata": {},
   "outputs": [],
   "source": [
    "check_FAplot = widgets.Checkbox(value=False,description=\"show FA-plot\")\n",
    "check_TIplot = widgets.Checkbox(value=False,description=\"show TI-plot\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "22146bd8-a0b3-45e0-8ebe-d5cb836d0378",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reset(T1):\n",
    "    Mz = np.ones(len(T1))\n",
    "    return Mz\n",
    "def invert(Mz,eff=1.0):\n",
    "    Mz = Mz * -eff\n",
    "    return Mz\n",
    "\n",
    "def relax(Mz,T1,t):\n",
    "    Mz = 1 - (1-Mz) * np.exp(-t/T1);\n",
    "    return Mz\n",
    "\n",
    "def read(Mz,FA,N,T1,TR1):\n",
    "    for i in range(N):\n",
    "        sig = np.sin(FA / 180 * np.pi) * Mz;\n",
    "        Mz = Mz * np.cos(FA / 180 * np.pi);\n",
    "        Mz = relax(Mz,T1,TR1)\n",
    "    return Mz,sig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b1b201bf-889e-4dbc-8536-8b51f3c2f90b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def singleTR(Mz,param):\n",
    "    # calculate helper times\n",
    "    readDur = param['N'] * param['TR1']\n",
    "    TIfill = param['TI'] - readDur/2\n",
    "    TRfill = param['TR0'] - readDur - TIfill\n",
    "    N1 = int(param['N']/2)\n",
    "    N2 = param['N'] - N1\n",
    "    Mz = invert(Mz)\n",
    "    Mz = relax(Mz,param['T1'],TIfill)\n",
    "    Mz,sig = read(Mz,FA=param['FA'],N=N1,T1=param['T1'],TR1=param['TR1'])\n",
    "    Mz = read(Mz,FA=param['FA'],N=N2,T1=param['T1'],TR1=param['TR1'])[0]\n",
    "    Mz = relax(Mz,param['T1'],TRfill)\n",
    "    return Mz,sig\n",
    "\n",
    "# run enough times to approximate steady state\n",
    "def runParameters(param):\n",
    "    Mz = reset(param['T1'])\n",
    "    niter=7\n",
    "    for i in range(niter):\n",
    "        Mz,sig = singleTR(Mz,param)\n",
    "    return sig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "86da35d7-d640-4316-b5c9-172b08e0c157",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    .jp-OutputArea-child {\n",
       "        display: inline-block;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d7e3edec-1e00-475b-98bc-7c42b6905547",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plotFAdependence(param):\n",
    "    param2 = param.copy()\n",
    "    FAs = np.linspace(param['FA']-2,param['FA']+2,30)\n",
    "\n",
    "    cont = np.zeros(len(FAs))\n",
    "    sG   = np.zeros(len(FAs))\n",
    "    sW   = np.zeros(len(FAs))\n",
    "    for i,f in enumerate(FAs):\n",
    "        param2['FA'] = f\n",
    "        sig = runParameters(param2)\n",
    "        cont[i] = sig[1]/sig[0]\n",
    "        sG[i] = sig[0]\n",
    "        sW[i] = sig[1]\n",
    "\n",
    "    fig, ax1 = plt.subplots()\n",
    "    ax1.plot(FAs,cont,color='C0')\n",
    "    ax1.set_title(f'FA variation (FA={param[\"FA\"]:.1f} deg)')\n",
    "    ax1.set_ylabel('WM/GM contrast', color='C0')\n",
    "    ax1.vlines(param['FA'],ymin=min(cont),ymax=max(cont),color='gray')\n",
    "    ax1.set_xlabel('FA [deg]')\n",
    "    ax2 = ax1.twinx()\n",
    "    ax2.plot(FAs,sG,color='C1')\n",
    "    ax2.plot(FAs,sW,color='C2')\n",
    "    plt.grid()\n",
    "    ax2.set_ylabel('GM signal', color='C1')\n",
    "    ax2.legend(['GM','WM'],loc='right')\n",
    "    plt.show()\n",
    "def plotTIdependence(param):\n",
    "    param2 = param.copy()\n",
    "    TIs = np.linspace(param[\"TI\"]-200,param[\"TI\"]+200,30)\n",
    "\n",
    "    cont = np.zeros(len(TIs))\n",
    "    sG   = np.zeros(len(TIs))\n",
    "    sW   = np.zeros(len(TIs))\n",
    "    for i,t in enumerate(TIs):\n",
    "        param2['TI'] = t\n",
    "        sig = runParameters(param2)\n",
    "        cont[i] = sig[1]/sig[0]\n",
    "        sG[i] = sig[0]\n",
    "        sW[i] = sig[1]\n",
    "\n",
    "    fig, ax1 = plt.subplots()\n",
    "    ax1.plot(TIs,cont,color='C0')\n",
    "    ax1.set_title(f'TI variation (TI={param[\"TI\"]:.0f} ms)')\n",
    "    ax1.set_ylabel('WM/GM contrast', color='C0')\n",
    "    ax1.vlines(x=param[\"TI\"],ymin=min(cont),ymax=max(cont),color='gray')\n",
    "    ax2 = ax1.twinx()\n",
    "    ax2.plot(TIs,sG,color='C1')\n",
    "    ax2.plot(TIs,sW,color='C2')\n",
    "    plt.grid()\n",
    "    ax1.set_xlabel('TI [ms]')\n",
    "    ax2.set_ylabel('GM signal', color='C1')\n",
    "    ax2.legend(['GM','WM'],loc='right')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "82928335-a344-41d8-8af5-9c2d23418a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def optimize_values(param,target_cont,target_sig):\n",
    "    def cost(x,verbose=False):\n",
    "        param['FA'] = x[0]\n",
    "        param['TI'] = x[1]\n",
    "        sig = runParameters(param)\n",
    "        cont = sig[1]/sig[0]\n",
    "        cost_cont = abs(cont-target_cont)**2\n",
    "        cost_sig = abs(sig[0]-target_sig)**2 * 1e1\n",
    "        cost_sig2 = abs(1/sig[0])* 1e-3\n",
    "\n",
    "        cost_total = cost_cont + cost_sig\n",
    "\n",
    "        if verbose:\n",
    "            print(f'cont:    {cont:6.3f}')\n",
    "            print(f'GM-sig:  {sig[0]*100:6.3f}')\n",
    "        return cost_total\n",
    "    \n",
    "    x0 = [param['FA'],param['TI']]\n",
    "    lb = [1,np.ceil(param['TR1']*param['N']/2)]\n",
    "    ub = [90,param['TR0']-np.ceil(param['TR1']*param['N']/2)]\n",
    "    bounds = Bounds(lb,ub)\n",
    "    \n",
    "    x1 = minimize(cost,x0,method='L-BFGS-B',bounds=bounds)\n",
    "    \n",
    "    with output:\n",
    "        clear_output()\n",
    "        print(f'Duration: {param[\"N_part\"]*param[\"TR0\"]/1000/60:.2f} min')\n",
    "        # print('\\nOld:')\n",
    "        # print(f'FA:    {x0[0]:7.2f} deg')\n",
    "        # print(f'TI:    {x0[1]:7.2f} ms')\n",
    "        # cost(x0,verbose=True)\n",
    "        print('\\nNew:')\n",
    "        print(f'FA:    {x1.x[0]:7.2f} deg')\n",
    "        print(f'TI:    {x1.x[1]:7.2f} ms')\n",
    "        cost(x1.x,verbose=True)\n",
    "        if check_FAplot.value:\n",
    "            plotFAdependence(param)\n",
    "        if check_TIplot.value:\n",
    "            plotTIdependence(param)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1c665e58-e1a7-44ec-a63d-7f2d278f0988",
   "metadata": {},
   "outputs": [],
   "source": [
    "TR_slider = widgets.IntSlider(\n",
    "    value=3360,\n",
    "    min=1000,\n",
    "    max=8000,\n",
    "    step=10,\n",
    "    description='Segment-TR:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='d'\n",
    ")\n",
    "TR1_slider = widgets.FloatSlider(\n",
    "    value=6.1,\n",
    "    min=1,\n",
    "    max=20,\n",
    "    step=0.1,\n",
    "    description='Readout-TR:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='.1f'\n",
    ")\n",
    "N_slider = widgets.IntSlider(\n",
    "    value=132,\n",
    "    min=20,\n",
    "    max=500,\n",
    "    step=1,\n",
    "    description='Seg Length:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='d'\n",
    ")\n",
    "Npart_slider = widgets.IntSlider(\n",
    "    value=120,\n",
    "    min=10,\n",
    "    max=500,\n",
    "    step=1,\n",
    "    description='N_segments:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='d'\n",
    ")\n",
    "Cont_slider = widgets.FloatSlider(\n",
    "    value=1.7,\n",
    "    min=1.0,\n",
    "    max=3.0,\n",
    "    step=0.1,\n",
    "    description='contrast:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='.1f'\n",
    ")\n",
    "sig_slider = widgets.FloatSlider(\n",
    "    value=4.0,\n",
    "    min=0.0,\n",
    "    max=4.0,\n",
    "    step=0.1,\n",
    "    description='Signal target:',\n",
    "    disabled=False,\n",
    "    continuous_update=False,\n",
    "    orientation='horizontal',\n",
    "    readout=True,\n",
    "    readout_format='.1f'\n",
    ")\n",
    "output = widgets.Output()\n",
    "def update_values(change):\n",
    "    param = dict()\n",
    "    param['T1'] = np.asarray([2002,1425])\n",
    "    param['N'] = 132\n",
    "    param['TR1'] = 6.06\n",
    "    param['TR0'] = 3360\n",
    "    param['FA'] = 9.0\n",
    "    param['TI'] = 1340\n",
    "    param['N_part'] = Npart_slider.value\n",
    "    \n",
    "    param['TR0'] = TR_slider.value\n",
    "    param['TR1'] = TR1_slider.value\n",
    "    param['N'] = N_slider.value\n",
    "    target_cont = Cont_slider.value\n",
    "    target_sig = sig_slider.value/100\n",
    "    optimize_values(param,target_cont,target_sig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a46c5f6-58a6-4fa1-b72e-f7770ea3c4d8",
   "metadata": {},
   "source": [
    "# MPRAGE parameter optimization tool\n",
    "\n",
    "Assuming T1=2002 for GM and T1=1425 for WM\n",
    "\n",
    "Set sequence parameters (TR,segment length, etc) first.\n",
    "Then select target WM/GM contrast.\n",
    "\n",
    "The displayed signal is `sin(alpha) * Mz * 100`, so a 90° pulse after full relaxation would result in signal 100. This is a slightly arbitrary value and only gives an idea about the performance.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9dd2f4ca-a471-4a58-a1bd-0f3364b6ff05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "169a2706db3b4a578ee4638c0d3ef8a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntSlider(value=3360, continuous_update=False, description='Segment-TR:', max=8000, min=1000, step=10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be12b9b54abf4c6b848be3272c508152",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatSlider(value=6.1, continuous_update=False, description='Readout-TR:', max=20.0, min=1.0, readout_format='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0b3b6b9a6c164df9b5a36fdf193cde75",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntSlider(value=132, continuous_update=False, description='Seg Length:', max=500, min=20)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51c5af4b59f0484b81a85fde63f085bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntSlider(value=120, continuous_update=False, description='N_segments:', max=500, min=10)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7733c4653676488c9c6c8fb6e0308492",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatSlider(value=1.7, continuous_update=False, description='contrast:', max=3.0, min=1.0, readout_format='.1f…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "609d53bca6ae4a88a9c941228b1a928d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sig_slider.observe(update_values, 'value')\n",
    "TR_slider.observe(update_values, 'value')\n",
    "N_slider.observe(update_values, 'value')\n",
    "Cont_slider.observe(update_values, 'value')\n",
    "Npart_slider.observe(update_values,'value')\n",
    "TR1_slider.observe(update_values,'value')\n",
    "\n",
    "check_FAplot.observe(update_values,'value')\n",
    "check_TIplot.observe(update_values,'value')\n",
    "\n",
    "display(TR_slider)\n",
    "display(TR1_slider)\n",
    "display(N_slider)\n",
    "display(Npart_slider)\n",
    "display(Cont_slider)\n",
    "display(check_FAplot)\n",
    "display(check_TIplot)\n",
    "#display(sig_slider)\n",
    "display(output)\n",
    "update_values(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39598f09-11f1-4867-85a6-a752a285193c",
   "metadata": {},
   "source": [
    "```\n",
    "Old:\n",
    "FA:       9.00 deg\n",
    "TI:    1340.00 ms\n",
    "cont:     1.692\n",
    "GM-sig:   1.909\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e43f96a2-3754-4864-933d-acfc6891bda3",
   "metadata": {},
   "outputs": [],
   "source": [
    "param = dict()\n",
    "param['T1'] = np.asarray([2002,1425]) # GM,WM\n",
    "param['N'] = 132\n",
    "param['N_part'] = 120\n",
    "param['TR1'] = 6.06\n",
    "param['TR0'] = 3360\n",
    "param['FA'] = 9.0\n",
    "param['TI'] = 1340\n",
    "target_cont = 2.0\n",
    "target_sig = 0.01\n",
    "optimize_values(param,target_cont,target_sig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb20c9b3-022b-4b86-afc6-4da40d15903f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}