{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Probabilistic tractography\n",
    "\n",
    "Probabilistic fiber tracking is a way of reconstructing the white matter\n",
    "structural connectivity using diffusion MRI data. Much like deterministic fiber\n",
    "tracking, the probabilistic approach follows the trajectory of a possible\n",
    "pathway in a step-wise fashion and propagating streamlines based on the local\n",
    "orientations reconstructed at each voxel.\n",
    "\n",
    "In probabilistic tracking, however, the tracking direction at each point along\n",
    "the path is chosen at random from a distribution of possible directions, and\n",
    "thus is no longer deterministic. The distribution at each point is different and\n",
    "depends on the observed diffusion data at that point. The distribution of\n",
    "tracking directions at each point can be represented as a probability mass\n",
    "function (PMF) if the possible tracking directions are restricted to a set of\n",
    "directions distributed points on a sphere.\n",
    "\n",
    "Like their deterministic counterparts, probabilistic tracking methods start\n",
    "propagating streamlines from a *seed map*, which contains a number of\n",
    "coordinates per voxel to initiate the procedure. The higher the number of seeds\n",
    "per voxel (i.e. the seed density), the larger will be the number of potentially\n",
    "recovered long-range connections. However, this comes at the cost of a longer\n",
    "running time.\n",
    "\n",
    "This episode builds on top of the results of the CSD local orientation\n",
    "reconstruction method presented in a previous episode.\n",
    "\n",
    "We will first get the necessary diffusion data, and compute the local\n",
    "orientation information using the CSD method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "\n",
    "import bids\n",
    "from bids.layout import BIDSLayout\n",
    "\n",
    "from dipy.core.gradients import gradient_table\n",
    "from dipy.io.gradients import read_bvals_bvecs\n",
    "\n",
    "\n",
    "bids.config.set_option('extension_initial_dot', True)\n",
    "\n",
    "# Get the diffusion files\n",
    "dwi_layout = BIDSLayout(\n",
    "    '../data/ds000221/derivatives/uncorrected_topup_eddy/', validate=False)\n",
    "gradient_layout = BIDSLayout(\n",
    "    '../data/ds000221/sub-010006/ses-01/dwi/', validate=False)\n",
    "\n",
    "subj = '010006'\n",
    "\n",
    "dwi_fname = dwi_layout.get(subject=subj, suffix='dwi',\n",
    "                           extension='nii.gz', return_type='file')[0]\n",
    "bval_fname = gradient_layout.get(\n",
    "    subject=subj, suffix='dwi', extension='bval', return_type='file')[0]\n",
    "bvec_fname = dwi_layout.get(\n",
    "    subject=subj, extension='eddy_rotated_bvecs', return_type='file')[0]\n",
    "\n",
    "dwi_img = nib.load(dwi_fname)\n",
    "affine = dwi_img.affine\n",
    "\n",
    "bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname)\n",
    "gtab = gradient_table(bvals, bvecs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now create the seeding mask and the seeds using an estimate of the\n",
    "white matter tissue based on the FA values obtained from the diffusion tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dipy.reconst import dti\n",
    "from dipy.segment.mask import median_otsu\n",
    "from dipy.tracking import utils\n",
    "\n",
    "dwi_data = dwi_img.get_fdata()\n",
    "\n",
    "# Specify the volume index to the b0 volumes\n",
    "dwi_data, dwi_mask = median_otsu(dwi_data, vol_idx=[0], numpass=1)\n",
    "\n",
    "dti_model = dti.TensorModel(gtab)\n",
    "\n",
    "# This step may take a while\n",
    "dti_fit = dti_model.fit(dwi_data, mask=dwi_mask)\n",
    "\n",
    "# Create the seeding mask\n",
    "fa_img = dti_fit.fa\n",
    "seed_mask = fa_img.copy()\n",
    "seed_mask[seed_mask >= 0.2] = 1\n",
    "seed_mask[seed_mask < 0.2] = 0\n",
    "\n",
    "# Create the seeds\n",
    "seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now estimate the FRF and set the CSD model to feed the local orientation\n",
    "information to the streamline propagation object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,\n",
    "                                   auto_response_ssst)\n",
    "\n",
    "response, ratio = auto_response_ssst(gtab, dwi_data, roi_radii=10, fa_thr=0.7)\n",
    "sh_order = 2\n",
    "csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=sh_order)\n",
    "csd_fit = csd_model.fit(dwi_data, mask=seed_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tracking methods are provided with a criterion to stop propagating streamlines\n",
    "beyond non-white matter tissues. One way to do this is to use the Generalized\n",
    "Fractional Anisotropy (GFA). Much like the Fractional Anisotropy issued by the\n",
    "DTI model measures anisotropy, the GFA uses samples of the ODF to quantify the\n",
    "anisotropy of tissues, and hence, it provides an estimation of the underlying\n",
    "tissue type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import ndimage  # To rotate image for visualization purposes\n",
    "import matplotlib.pyplot as plt\n",
    "from dipy.reconst.shm import CsaOdfModel\n",
    "from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion\n",
    "\n",
    "csa_model = CsaOdfModel(gtab, sh_order=sh_order)\n",
    "gfa = csa_model.fit(dwi_data, mask=seed_mask).gfa\n",
    "stopping_criterion = ThresholdStoppingCriterion(gfa, .25)\n",
    "\n",
    "# Create the directory to save the results\n",
    "out_dir = '../data/ds000221/derivatives/dwi/tractography/sub-%s/ses-01/dwi/' % subj\n",
    "\n",
    "if not os.path.exists(out_dir):\n",
    "    os.makedirs(out_dir)\n",
    "\n",
    "# Save the GFA\n",
    "gfa_img = nib.Nifti1Image(gfa.astype(np.float32), affine)\n",
    "nib.save(gfa_img, os.path.join(out_dir, 'gfa.nii.gz'))\n",
    "\n",
    "# Plot the GFA\n",
    "%matplotlib inline\n",
    "\n",
    "fig, ax = plt.subplots(1, 3, figsize=(10, 10))\n",
    "ax[0].imshow(ndimage.rotate(gfa[:, gfa.shape[1]//2, :], 90, reshape=False))\n",
    "ax[1].imshow(ndimage.rotate(gfa[gfa.shape[0]//2, :, :], 90, reshape=False))\n",
    "ax[2].imshow(ndimage.rotate(gfa[:, :, gfa.shape[-1]//2], 90, reshape=False))\n",
    "fig.savefig(os.path.join(out_dir, \"gfa.png\"), dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The GFA threshold stopping criterion value must be adjusted to the data in\n",
    "order to avoid creating a mask that will exclude white matter areas (which\n",
    "would result in streamlines being unable to propagate to other white matter\n",
    "areas). Visually inspecting the GFA map might provide with a sufficient\n",
    "guarantee about the goodness of the value."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Fiber Orientation Distribution (FOD) of the CSD model estimates the\n",
    "distribution of small fiber bundles within each voxel. We can use this\n",
    "distribution for probabilistic fiber tracking. One way to do this is to\n",
    "represent the FOD using a discrete sphere. This discrete FOD can be used by the\n",
    "``ProbabilisticDirectionGetter`` as a PMF for sampling tracking directions. We\n",
    "need to clip the FOD to use it as a PMF because the latter cannot have negative\n",
    "values. Ideally, the FOD should be strictly positive, but because of noise\n",
    "and/or model failures sometimes it can have negative values.\n",
    "\n",
    "The set of possible directions to choose to propagate a streamline is restricted\n",
    "by a cone angle $\\theta$, named `max_angle` in `DIPY`'s\n",
    "`ProbabilisticDirectionGetter::from_pmf` method.\n",
    "\n",
    "Another relevant parameter of the propagation is the step size, which dictates\n",
    "how much the propagation will advance to the next point. Note that it is a real\n",
    "number, since the tracking procedure operates in physical coordinates.\n",
    "\n",
    "Note that the `LocalTracking` class accepts a `StoppingCriterion` class instance\n",
    "as its second argument, and thus a different criterion can be used if the GFA\n",
    "criterion does not fit into our framework, or if different data is available in\n",
    "our workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dipy.direction import ProbabilisticDirectionGetter\n",
    "from dipy.data import small_sphere\n",
    "from dipy.io.stateful_tractogram import Space, StatefulTractogram\n",
    "from dipy.io.streamline import save_tractogram\n",
    "from dipy.tracking.local_tracking import LocalTracking\n",
    "from dipy.tracking.streamline import Streamlines\n",
    "\n",
    "fod = csd_fit.odf(small_sphere)\n",
    "pmf = fod.clip(min=0)\n",
    "prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf, max_angle=30.,\n",
    "                                                sphere=small_sphere)\n",
    "streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds,\n",
    "                                     affine, step_size=.5)\n",
    "streamlines = Streamlines(streamline_generator)\n",
    "sft = StatefulTractogram(streamlines, dwi_img, Space.RASMM)\n",
    "\n",
    "# Save the tractogram\n",
    "save_tractogram(sft, os.path.join(\n",
    "    out_dir, 'tractogram_probabilistic_dg_pmf.trk'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will easily generate the anatomical views on the generated tractogram using the `generate_anatomical_volume_figure` helper function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbval-skip"
    ]
   },
   "outputs": [],
   "source": [
    "from fury import actor, colormap\n",
    "\n",
    "from utils.visualization_utils import generate_anatomical_volume_figure\n",
    "\n",
    "# Plot the tractogram\n",
    "\n",
    "# Build the representation of the data\n",
    "streamlines_actor = actor.line(streamlines, colormap.line_colors(streamlines))\n",
    "\n",
    "# Compute the slices to be shown\n",
    "slices = tuple(elem // 2 for elem in dwi_data.shape[:-1])\n",
    "\n",
    "# Generate the figure\n",
    "fig = generate_anatomical_volume_figure(streamlines_actor)\n",
    "\n",
    "fig.savefig(os.path.join(out_dir, \"tractogram_probabilistic_dg_pmf.png\"),\n",
    "            dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One disadvantage of using a discrete PMF to represent possible tracking\n",
    "directions is that it tends to take up a lot of memory (RAM). The size of the\n",
    "PMF, the FOD in this case, must be equal to the number of possible tracking\n",
    "directions on the hemisphere, and every voxel has a unique PMF. In this case\n",
    "the data is ``(81, 106, 76)`` and ``small_sphere`` has 181 directions so the\n",
    "FOD is ``(81, 106, 76, 181)``. One way to avoid sampling the PMF and holding it\n",
    "in memory is to build the direction getter directly from the spherical harmonic\n",
    "(SH) representation of the FOD. By using this approach, we can also use a\n",
    "larger sphere, like ``default_sphere`` which has 362 directions on the\n",
    "hemisphere, without having to worry about memory limitations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dipy.data import default_sphere\n",
    "\n",
    "prob_dg = ProbabilisticDirectionGetter.from_shcoeff(csd_fit.shm_coeff,\n",
    "                                                    max_angle=30.,\n",
    "                                                    sphere=default_sphere)\n",
    "streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds,\n",
    "                                     affine, step_size=.5)\n",
    "streamlines = Streamlines(streamline_generator)\n",
    "sft = StatefulTractogram(streamlines, dwi_img, Space.RASMM)\n",
    "\n",
    "# Save the tractogram\n",
    "save_tractogram(sft, os.path.join(\n",
    "    out_dir, 'tractogram_probabilistic_dg_sh.trk'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will visualize the tractogram using the three usual anatomical views:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbval-skip"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot the tractogram\n",
    "\n",
    "# Build the representation of the data\n",
    "streamlines_actor = actor.line(streamlines, colormap.line_colors(streamlines))\n",
    "\n",
    "# Generate the figure\n",
    "fig = generate_anatomical_volume_figure(streamlines_actor)\n",
    "\n",
    "fig.savefig(os.path.join(out_dir, \"tractogram_probabilistic_dg_sh.png\"),\n",
    "            dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not all model fits have the ``shm_coeff`` attribute because not all models use\n",
    "this basis to represent the data internally. However we can fit the ODF of any\n",
    "model to the spherical harmonic basis using the ``peaks_from_model`` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dipy.direction import peaks_from_model\n",
    "\n",
    "peaks = peaks_from_model(csd_model, dwi_data, default_sphere, .5, 25,\n",
    "                         mask=seed_mask, return_sh=True, parallel=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is always good practice to (save and) visualize the peaks as a check towards ensuring that the orientation information conforms to what is expected prior to the tracking process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the peaks\n",
    "from dipy.io.peaks import reshape_peaks_for_visualization\n",
    "\n",
    "nib.save(nib.Nifti1Image(reshape_peaks_for_visualization(peaks),\n",
    "                         affine), os.path.join(out_dir, 'peaks.nii.gz'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As usual, we will use `fury` to visualize the peaks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbval-skip"
    ]
   },
   "outputs": [],
   "source": [
    "from utils.visualization_utils import generate_anatomical_slice_figure\n",
    "\n",
    "# Visualize the peaks\n",
    "\n",
    "# Build the representation of the data\n",
    "peaks_actor = actor.peak_slicer(peaks.peak_dirs, peaks.peak_values)\n",
    "\n",
    "# Compute the slices to be shown\n",
    "slices = tuple(elem // 2 for elem in dwi_data.shape[:-1])\n",
    "\n",
    "# Generate the figure\n",
    "fig = generate_anatomical_slice_figure(slices, peaks_actor)\n",
    "\n",
    "fig.savefig(os.path.join(out_dir, \"peaks.png\"), dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fod_coeff = peaks.shm_coeff\n",
    "\n",
    "prob_dg = ProbabilisticDirectionGetter.from_shcoeff(fod_coeff, max_angle=30.,\n",
    "                                                    sphere=default_sphere)\n",
    "streamline_generator = LocalTracking(prob_dg, stopping_criterion, seeds,\n",
    "                                     affine, step_size=.5)\n",
    "streamlines = Streamlines(streamline_generator)\n",
    "sft = StatefulTractogram(streamlines, dwi_img, Space.RASMM)\n",
    "\n",
    "# Save the tractogram\n",
    "save_tractogram(sft, os.path.join(\n",
    "    out_dir, \"tractogram_probabilistic_dg_sh_pmf.trk\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will again visualize the tractogram using the three usual anatomical views:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbval-skip"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot the tractogram\n",
    "\n",
    "# Build the representation of the data\n",
    "streamlines_actor = actor.line(streamlines, colormap.line_colors(streamlines))\n",
    "\n",
    "# Generate the figure\n",
    "fig = generate_anatomical_volume_figure(streamlines_actor)\n",
    "\n",
    "fig.savefig(os.path.join(\n",
    "    out_dir, \"tractogram_probabilistic_dg_sh_pmf.png\"), dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tip: Making sure your tractogram is well aligned with the data\n",
    "\n",
    "If for whatever reason the anatomical and diffusion images were not correctly aligned, you may find that your tractogram is not well aligned with the anatomical data. This may also happen derived from the different formats in which a tractogram is saved/loaded, some conventions specifying the origin at the voxel corner and other specifying it at the center of the voxel. Visualizing the computed features is always recommended. There are some tools that allow to ensure that the matrices specifying the orientation and positioning of the data should be correct.\n",
    "\n",
    "`MRtrix`'s `mrinfo` command can be used to visualize the affine matrix of a `NIfTI` file as:\n",
    "\n",
    "`mrinfo dwi.nii.gz`\n",
    "\n",
    "which would output something like:\n",
    "\n",
    "```\n",
    "************************************************\n",
    "Image:               \"/data/dwi.nii.gz\"\n",
    "************************************************\n",
    "  Dimensions:        90 x 108 x 90 x 33\n",
    "  Voxel size:        2 x 2 x 2 x 1\n",
    "  Data strides:      [ -1 -2 3 4 ]\n",
    "  Format:            NIfTI-1.1 (GZip compressed)\n",
    "  Data type:         signed 16 bit integer (little endian)\n",
    "  Intensity scaling: offset = 0, multiplier = 1\n",
    "  Transform:              1          -0           0      -178\n",
    "                         -0           1           0      -214\n",
    "                         -0          -0           1        -0\n",
    "```\n",
    "\n",
    "Similarly, for your tractograms, you may use the command `track_info` from `TrackVis`' `Diffusion Toolkit` set of command-line tools:\n",
    "\n",
    "`trk_info tractogram.trk`\n",
    "\n",
    "which would output something like:\n",
    "```\n",
    "ID string:             TRACK\n",
    "Version:               2\n",
    "Dimension:             180 216 180\n",
    "Voxel size:            1 1 1\n",
    "Voxel order:           LPS\n",
    "Voxel order original:  LPS\n",
    "Voxel to RAS matrix:\n",
    "     -1.0000     0.0000     0.0000     0.5000\n",
    "      0.0000    -1.0000     0.0000     0.5000\n",
    "      0.0000     0.0000     1.0000    -0.5000\n",
    "      0.0000     0.0000     0.0000     1.0000\n",
    "\n",
    "Image Orientation:  1.0000/0.0000/0.0000/0.0000/1.0000/0.0000\n",
    "Orientation patches:   none\n",
    "Number of scalars:  0\n",
    "Number of properties:  0\n",
    "Number of tracks:  200433\n",
    "```\n",
    "\n",
    "Note that, a `TRK` file contains orientational and positional information. If you choose to store your tractograms using the `TCK` format, this informationwill not be contained in the file. To see the file header information you may use the `MRtrix` `tckinfo` command:\n",
    "\n",
    "`tckinfo tractogram.tck`\n",
    "\n",
    "which would output something like:\n",
    "\n",
    "```\n",
    "***********************************\n",
    " Tracks file: \"/data/tractogram.tck\"\n",
    "   count:                0000200433\n",
    "   dimensions:           (180, 216, 180)\n",
    "   voxel_order:          LPS\n",
    "   voxel_sizes:          (1.0, 1.0, 1.0)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}