Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- import argparse
- import os
- import nibabel as nib
- import numpy as np
- from scipy import ndimage
- import vtk
- from vtk.util.numpy_support import vtk_to_numpy
- from dipy.tracking.streamline import transform_streamlines
- DESCRIPTION = """
- Linear (and nonlinear) transformation for vtk file, useful for quick glass
- brain without FreeSurfer, output adapted to MI-Brain or SurfIce
- """
- def _buildArgsParser():
- p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
- description=DESCRIPTION)
- p.add_argument('surface', action='store', metavar='SURFACE',
- type=str, help='File that will be deformed (vtk)')
- p.add_argument('matrix', action='store', metavar='MATRIX',
- type=str, help='4x4 transformation matrix from ants')
- p.add_argument('out_name', action='store', metavar='OUT_NAME',
- type=str, help='Output filename of the transformed surface.')
- p.add_argument('--warp', action='store', dest='warp_field',
- metavar='WARP_FIELD',
- help='Inverse deformation field from ants')
- p.add_argument('--to_lps', action='store_true',
- help='Flip the data for SurfIce')
- p.add_argument('-f', action='store_true', dest='force_overwrite',
- help='force (overwrite output file if present)')
- return p
- def main():
- parser = _buildArgsParser()
- args = parser.parse_args()
- if not os.path.isfile(args.surface):
- parser.error('"{0}" must be a file!'.format(args.surface))
- if not os.path.isfile(args.matrix):
- parser.error('"{0}" must be a file!'.format(args.matrix))
- if args.warp_field and not os.path.isfile(args.warp_field):
- parser.error('"{0}" must be a file!'.format(args.warp_field))
- if os.path.isfile(args.out_name) and not args.force_overwrite:
- parser.error('"{0}" already exists! Use -f to overwrite it.'
- .format(args.out_name))
- reader = vtk.vtkPolyDataReader()
- reader.SetFileName(args.surface)
- reader.Update()
- polydata = reader.GetOutput()
- affine = np.loadtxt(args.matrix)
- if args.warp_field:
- deformation = nib.load(args.warp_field)
- deformation_data = np.squeeze(deformation.get_data())
- affine[0:2, 2:4] *= -1
- affine[2:4, 0:2] *= -1
- matrix = vtk.vtkMatrix4x4()
- matrix.Identity()
- for i in range(4):
- for j in range(4):
- matrix.SetElement(i, j, np.linalg.inv(affine)[i, j])
- transform = vtk.vtkTransform()
- transform.Concatenate(matrix)
- # flip_matrix = vtk.vtkMatrix4x4()
- # flip_matrix.Identity()
- # flip_matrix.SetElement(0, 0, -1)
- # flip_matrix.SetElement(1, 1, -1)
- # transform.Concatenate(flip_matrix)
- transform_polydata = vtk.vtkTransformPolyDataFilter()
- transform_polydata.SetTransform(transform)
- transform_polydata.SetInputData(polydata)
- transform_polydata.Update()
- polydata = transform_polydata.GetOutput()
- print 'Linear transform...Done'
- if args.warp_field:
- points = polydata.GetPoints()
- array = points.GetData()
- points_list = list(vtk_to_numpy(array))
- warped_points = warp_vtk_points(points_list, deformation.affine,
- deformation_data)
- points = vtk.vtkPoints()
- for i in warped_points:
- points.InsertNextPoint(i)
- polydata.SetPoints(points)
- print 'Nonlinear transform...Done'
- if args.to_lps:
- flip_matrix = vtk.vtkMatrix4x4()
- flip_matrix.Identity()
- flip_matrix.SetElement(0, 0, -1)
- flip_matrix.SetElement(1, 1, -1)
- transform = vtk.vtkTransform()
- transform.Concatenate(flip_matrix)
- transform_polydata = vtk.vtkTransformPolyDataFilter()
- transform_polydata.SetTransform(transform)
- transform_polydata.SetInputData(polydata)
- transform_polydata.Update()
- polydata = transform_polydata.GetOutput()
- print 'Back to LPS transform...Done'
- # Write the output *.vtk
- writer = vtk.vtkPolyDataWriter()
- writer.SetFileName(args.out_name)
- writer.SetInputData(polydata)
- writer.Update()
- def warp_vtk_points(points, transformation, deformation_data):
- # VTK surface are in LPS
- flip_matrix = np.eye(4)
- flip_matrix[0, 0] = -1
- flip_matrix[1, 1] = -1
- transformation = np.dot(flip_matrix, transformation)
- inv_transformation = np.linalg.inv(transformation)
- # Because of duplication, an iteration over chunks of points is necessary
- # for a big dataset
- nb_points = len(list(points))
- current_position = 0
- chunk_size = 1000000
- nb_iteration = int(np.ceil(nb_points/float(chunk_size)))
- while nb_iteration > 0:
- max_position = min(current_position + chunk_size, nb_points)
- points_sublist = points[current_position:max_position]
- # To access the deformation information, we need to go in voxel space
- points_vox = transform_streamlines(points_sublist,
- inv_transformation)
- current_points_vox = np.array(points_vox).T
- current_points_vox_list = current_points_vox.tolist()
- x_def = ndimage.map_coordinates(deformation_data[..., 0],
- current_points_vox_list, order=1)
- y_def = ndimage.map_coordinates(deformation_data[..., 1],
- current_points_vox_list, order=1)
- z_def = ndimage.map_coordinates(deformation_data[..., 2],
- current_points_vox_list, order=1)
- final_points = np.array([x_def, y_def, z_def])
- # The deformation obtained is in worldSpace
- final_points += np.array(points_sublist).T
- points[current_position:max_position] = final_points.T
- current_position = max_position
- nb_iteration -= 1
- return points
- if __name__ == "__main__":
- main()
Add Comment
Please, Sign In to add comment