Guest User

Untitled

a guest
Mar 20th, 2018
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.88 KB | None | 0 0
  1. #!/usr/bin/env python
  2. import argparse
  3. import os
  4.  
  5. import nibabel as nib
  6. import numpy as np
  7. from scipy import ndimage
  8. import vtk
  9. from vtk.util.numpy_support import vtk_to_numpy
  10.  
  11. from dipy.tracking.streamline import transform_streamlines
  12.  
  13. DESCRIPTION = """
  14. Linear (and nonlinear) transformation for vtk file, useful for quick glass
  15. brain without FreeSurfer, output adapted to MI-Brain or SurfIce
  16. """
  17.  
  18. def _buildArgsParser():
  19. p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
  20. description=DESCRIPTION)
  21.  
  22. p.add_argument('surface', action='store', metavar='SURFACE',
  23. type=str, help='File that will be deformed (vtk)')
  24.  
  25. p.add_argument('matrix', action='store', metavar='MATRIX',
  26. type=str, help='4x4 transformation matrix from ants')
  27.  
  28. p.add_argument('out_name', action='store', metavar='OUT_NAME',
  29. type=str, help='Output filename of the transformed surface.')
  30.  
  31. p.add_argument('--warp', action='store', dest='warp_field',
  32. metavar='WARP_FIELD',
  33. help='Inverse deformation field from ants')
  34.  
  35. p.add_argument('--to_lps', action='store_true',
  36. help='Flip the data for SurfIce')
  37.  
  38. p.add_argument('-f', action='store_true', dest='force_overwrite',
  39. help='force (overwrite output file if present)')
  40.  
  41. return p
  42.  
  43.  
  44. def main():
  45. parser = _buildArgsParser()
  46. args = parser.parse_args()
  47.  
  48. if not os.path.isfile(args.surface):
  49. parser.error('"{0}" must be a file!'.format(args.surface))
  50.  
  51. if not os.path.isfile(args.matrix):
  52. parser.error('"{0}" must be a file!'.format(args.matrix))
  53.  
  54. if args.warp_field and not os.path.isfile(args.warp_field):
  55. parser.error('"{0}" must be a file!'.format(args.warp_field))
  56.  
  57. if os.path.isfile(args.out_name) and not args.force_overwrite:
  58. parser.error('"{0}" already exists! Use -f to overwrite it.'
  59. .format(args.out_name))
  60.  
  61. reader = vtk.vtkPolyDataReader()
  62. reader.SetFileName(args.surface)
  63. reader.Update()
  64. polydata = reader.GetOutput()
  65.  
  66. affine = np.loadtxt(args.matrix)
  67. if args.warp_field:
  68. deformation = nib.load(args.warp_field)
  69. deformation_data = np.squeeze(deformation.get_data())
  70.  
  71. affine[0:2, 2:4] *= -1
  72. affine[2:4, 0:2] *= -1
  73. matrix = vtk.vtkMatrix4x4()
  74. matrix.Identity()
  75. for i in range(4):
  76. for j in range(4):
  77. matrix.SetElement(i, j, np.linalg.inv(affine)[i, j])
  78.  
  79. transform = vtk.vtkTransform()
  80. transform.Concatenate(matrix)
  81.  
  82. # flip_matrix = vtk.vtkMatrix4x4()
  83. # flip_matrix.Identity()
  84. # flip_matrix.SetElement(0, 0, -1)
  85. # flip_matrix.SetElement(1, 1, -1)
  86. # transform.Concatenate(flip_matrix)
  87.  
  88. transform_polydata = vtk.vtkTransformPolyDataFilter()
  89. transform_polydata.SetTransform(transform)
  90. transform_polydata.SetInputData(polydata)
  91. transform_polydata.Update()
  92. polydata = transform_polydata.GetOutput()
  93. print 'Linear transform...Done'
  94.  
  95. if args.warp_field:
  96. points = polydata.GetPoints()
  97. array = points.GetData()
  98. points_list = list(vtk_to_numpy(array))
  99. warped_points = warp_vtk_points(points_list, deformation.affine,
  100. deformation_data)
  101.  
  102. points = vtk.vtkPoints()
  103. for i in warped_points:
  104. points.InsertNextPoint(i)
  105. polydata.SetPoints(points)
  106. print 'Nonlinear transform...Done'
  107.  
  108. if args.to_lps:
  109. flip_matrix = vtk.vtkMatrix4x4()
  110. flip_matrix.Identity()
  111. flip_matrix.SetElement(0, 0, -1)
  112. flip_matrix.SetElement(1, 1, -1)
  113. transform = vtk.vtkTransform()
  114. transform.Concatenate(flip_matrix)
  115.  
  116. transform_polydata = vtk.vtkTransformPolyDataFilter()
  117. transform_polydata.SetTransform(transform)
  118. transform_polydata.SetInputData(polydata)
  119. transform_polydata.Update()
  120. polydata = transform_polydata.GetOutput()
  121. print 'Back to LPS transform...Done'
  122.  
  123. # Write the output *.vtk
  124. writer = vtk.vtkPolyDataWriter()
  125. writer.SetFileName(args.out_name)
  126. writer.SetInputData(polydata)
  127. writer.Update()
  128.  
  129.  
  130. def warp_vtk_points(points, transformation, deformation_data):
  131. # VTK surface are in LPS
  132. flip_matrix = np.eye(4)
  133. flip_matrix[0, 0] = -1
  134. flip_matrix[1, 1] = -1
  135. transformation = np.dot(flip_matrix, transformation)
  136. inv_transformation = np.linalg.inv(transformation)
  137.  
  138. # Because of duplication, an iteration over chunks of points is necessary
  139. # for a big dataset
  140. nb_points = len(list(points))
  141. current_position = 0
  142. chunk_size = 1000000
  143. nb_iteration = int(np.ceil(nb_points/float(chunk_size)))
  144.  
  145. while nb_iteration > 0:
  146. max_position = min(current_position + chunk_size, nb_points)
  147. points_sublist = points[current_position:max_position]
  148. # To access the deformation information, we need to go in voxel space
  149.  
  150. points_vox = transform_streamlines(points_sublist,
  151. inv_transformation)
  152.  
  153. current_points_vox = np.array(points_vox).T
  154. current_points_vox_list = current_points_vox.tolist()
  155. x_def = ndimage.map_coordinates(deformation_data[..., 0],
  156. current_points_vox_list, order=1)
  157. y_def = ndimage.map_coordinates(deformation_data[..., 1],
  158. current_points_vox_list, order=1)
  159. z_def = ndimage.map_coordinates(deformation_data[..., 2],
  160. current_points_vox_list, order=1)
  161.  
  162. final_points = np.array([x_def, y_def, z_def])
  163.  
  164. # The deformation obtained is in worldSpace
  165. final_points += np.array(points_sublist).T
  166.  
  167. points[current_position:max_position] = final_points.T
  168. current_position = max_position
  169. nb_iteration -= 1
  170.  
  171. return points
  172.  
  173.  
  174. if __name__ == "__main__":
  175. main()
Add Comment
Please, Sign In to add comment