Advertisement
Guest User

scilpy's dti script

a guest
Mar 17th, 2014
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.70 KB | None | 0 0
  1. #! /usr/bin/env python
  2. from __future__ import division, print_function
  3.  
  4. import nibabel as nib
  5. import numpy as np
  6. import argparse
  7.  
  8. from dipy.segment.mask import median_otsu
  9. from dipy.core.gradients import gradient_table_from_bvals_bvecs
  10. from dipy.io.gradients import read_bvals_bvecs
  11. from dipy.reconst.dti import (TensorModel, color_fa, fractional_anisotropy,
  12.                               mean_diffusivity, axial_diffusivity, radial_diffusivity,
  13.                               lower_triangular)
  14.  
  15. DESCRIPTION = """
  16.    Convenient script to compute all of the dti metrics on a dwi.
  17.    """
  18.  
  19.  
  20. def buildArgsParser():
  21.  
  22.     p = argparse.ArgumentParser(description=DESCRIPTION)
  23.  
  24.     p.add_argument('input', action='store', metavar='input', type=str,
  25.                    help='Path of the input diffusion volume.')
  26.  
  27.     p.add_argument('bvals', action='store', metavar='bvals',
  28.                    help='Path of the bvals file, in FSL format.')
  29.  
  30.     p.add_argument('bvecs', action='store', metavar='bvecs',
  31.                    help='Path of the bvecs file, in FSL format.')
  32.  
  33.     p.add_argument('-o', action='store', dest='savename',
  34.                    metavar='savename', required=False, default=None, type=str,
  35.                    help='Path and prefix for the saved metrics files. The name is always appended \
  36.                   with _(metric_name).nii.gz, where (metric_name) if the name of the computed metric.')
  37.  
  38.     p.add_argument('-mask', action='store', dest='mask',
  39.                    metavar='mask', required=False, default=None, type=str,
  40.                    help='Path to a binary mask. Only data inside the mask will be used \
  41.                   for computations and reconstruction.')
  42.  
  43.     return p
  44.  
  45.  
  46. def isotropic(qform):
  47.     tr_A = qform[..., 0, 0] = qform[..., 1, 1] + qform[..., 2, 2]
  48.     n_dims = len(qform.shape)
  49.     add_dims = n_dims - 2
  50.     my_I = np.eye(3).reshape(add_dims * (1,) + (3, 3))
  51.     tr_AI = (tr_A.reshape(tr_A.shape + (1, 1)) * my_I)
  52.     return (1 / 3.0) * tr_AI
  53.  
  54.  
  55. def deviatoric(qform):
  56.     a_squiggle = qform - isotropic(qform)
  57.     return a_squiggle
  58.  
  59.  
  60. def tensor_norm(qform):
  61.     return np.sqrt(np.sum(np.sum(np.abs(qform ** 2), -1), -1))
  62.  
  63.  
  64. def tensor_determinant(qform):
  65.     aei = qform[..., 0, 0] * qform[..., 1, 1] * qform[..., 2, 2]
  66.     bfg = qform[..., 0, 1] * qform[..., 1, 2] * qform[..., 2, 0]
  67.     cdh = qform[..., 0, 2] * qform[..., 1, 0] * qform[..., 2, 1]
  68.     ceg = qform[..., 0, 2] * qform[..., 1, 1] * qform[..., 2, 0]
  69.     bdi = qform[..., 0, 1] * qform[..., 1, 0] * qform[..., 2, 2]
  70.     afh = qform[..., 0, 0] * qform[..., 1, 2] * qform[..., 2, 1]
  71.     return aei + bfg + cdh - ceg - bdi - afh
  72.  
  73.  
  74. def tensor_mode(qform):
  75.     a_squiggle = deviatoric(qform)
  76.     a_s_norm = tensor_norm(a_squiggle)
  77.     a_s_norm = a_s_norm.reshape(a_s_norm.shape + (1, 1))
  78.     return 3 * np.sqrt(6) * tensor_determinant((a_squiggle / a_s_norm))
  79.  
  80.  
  81. def main():
  82.     parser = buildArgsParser()
  83.     args = parser.parse_args()
  84.  
  85.     # Load data
  86.     img = nib.load(args.input)
  87.     data = img.get_data()
  88.     affine = img.get_affine()
  89.  
  90.     # Setting suffix savename
  91.     if args.savename is None:
  92.         filename = ""
  93.     else:
  94.         filename = args.savename + "_"
  95.  
  96.     if args.mask is not None:
  97.         mask = nib.load(args.mask).get_data()
  98.     else:
  99.         print("No mask specified. Computing mask with median_otsu.")
  100.         data, mask = median_otsu(data)
  101.         mask_img = nib.Nifti1Image(mask.astype(np.float32), affine)
  102.         nib.save(mask_img, filename + 'mask.nii.gz')
  103.  
  104.     # Get tensors
  105.     print('Tensor estimation...')
  106.     b_vals, b_vecs = read_bvals_bvecs(args.bvals, args.bvecs)
  107.     gtab = gradient_table_from_bvals_bvecs(b_vals, b_vecs)
  108.     tenmodel = TensorModel(gtab)
  109.     tenfit = tenmodel.fit(data, mask)
  110.  
  111.     # FA
  112.     print('Computing FA...')
  113.     FA = fractional_anisotropy(tenfit.evals)
  114.     FA[np.isnan(FA)] = 0
  115.  
  116.     # RGB
  117.     print('Computing RGB...')
  118.     FA = np.clip(FA, 0, 1)
  119.     RGB = color_fa(FA, tenfit.evecs)
  120.  
  121.     print('Computing Diffusivities...')
  122.     # diffusivities
  123.     MD = mean_diffusivity(tenfit.evals)
  124.     AD = axial_diffusivity(tenfit.evals)
  125.     RD = radial_diffusivity(tenfit.evals)
  126.  
  127.     print('Computing Mode...')
  128.     # Compute tensor mode
  129.     inter_mode = tensor_mode(tenfit.quadratic_form)
  130.  
  131.     # Since the mode computation is not masked, we need to remove nans.
  132.     non_nan_indices = zip(np.where(np.isnan(inter_mode) is False))
  133.     mode = np.zeros(inter_mode.shape)
  134.     mode[non_nan_indices] = inter_mode[non_nan_indices]
  135.  
  136.     print('Saving tensor coefficients and metrics...')
  137.     # Get the Tensor values and format them for visualisation in the Fibernavigator.
  138.     tensor_vals = lower_triangular(tenfit.quadratic_form)
  139.     correct_order = [0, 1, 3, 2, 4, 5]
  140.     tensor_vals_reordered = tensor_vals[..., correct_order]
  141.     fiber_tensors = nib.Nifti1Image(tensor_vals_reordered.astype(np.float32), affine)
  142.     nib.save(fiber_tensors, filename + 'tensors.nii.gz')
  143.  
  144.     # Save - for some reason this is not read properly by the FiberNav
  145.     fa_img = nib.Nifti1Image(FA.astype(np.float32), affine)
  146.     nib.save(fa_img, filename + 'fa.nii.gz')
  147.     rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine)
  148.     nib.save(rgb_img, filename + 'rgb.nii.gz')
  149.     md_img = nib.Nifti1Image(MD.astype(np.float32), affine)
  150.     nib.save(md_img, filename + 'md.nii.gz')
  151.     ad_img = nib.Nifti1Image(AD.astype(np.float32), affine)
  152.     nib.save(ad_img, filename + 'ad.nii.gz')
  153.     rd_img = nib.Nifti1Image(RD.astype(np.float32), affine)
  154.     nib.save(rd_img, filename + 'rd.nii.gz')
  155.     mode_img = nib.Nifti1Image(mode.astype(np.float32), affine)
  156.     nib.save(mode_img, filename + 'mode.nii.gz')
  157.  
  158. if __name__ == "__main__":
  159.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement