Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import operator
- import string
- import copy
- import itertools
- import numpy as np
- import networkx as nx
- import scipy.ndimage as nd
- import image_ops
- from from_fits import create_image_from_fits_file
- from skimage.morphology import medial_axis
- from scipy import nanmean
- import matplotlib.pyplot as plt
- # Create 4 to 8-connected elements to use with binary hit-or-miss
- struct1 = np.array([[1, 0, 0],
- [0, 1, 1],
- [0, 0, 0]])
- struct2 = np.array([[0, 0, 1],
- [1, 1, 0],
- [0, 0, 0]])
- # Next check the three elements which will be double counted
- check1 = np.array([[1, 1, 0, 0],
- [0, 0, 1, 1]])
- check2 = np.array([[0, 0, 1, 1],
- [1, 1, 0, 0]])
- check3 = np.array([[1, 1, 0],
- [0, 0, 1],
- [0, 0, 1]])
- def eight_con():
- return np.ones((3, 3))
- def _fix_small_holes(mask_array, rel_size=0.1):
- '''
- Helper function to remove only small holes within a masked region.
- Parameters
- ----------
- mask_array : numpy.ndarray
- Array containing the masked region.
- rel_size : float, optional
- If < 1.0, sets the minimum size a hole must be relative to the area
- of the mask. Otherwise, this is the maximum number of pixels the hole
- must have to be deleted.
- Returns
- -------
- mask_array : numpy.ndarray
- Altered array.
- '''
- if rel_size <= 0.0:
- raise ValueError("rel_size must be positive.")
- elif rel_size > 1.0:
- pixel_flag = True
- else:
- pixel_flag = False
- # Find the region area
- reg_area = len(np.where(mask_array == 1)[0])
- # Label the holes
- holes = np.logical_not(mask_array).astype(float)
- lab_holes, n_holes = nd.label(holes, eight_con())
- # If no holes, return
- if n_holes == 1:
- return mask_array
- # Ignore area outside of the region.
- out_label = lab_holes[0, 0]
- # Set size to be just larger than the region. Thus it can never be
- # deleted.
- holes[np.where(lab_holes == out_label)] = reg_area + 1.
- # Sum up the regions and find holes smaller than the threshold.
- sums = nd.sum(holes, lab_holes, range(1, n_holes + 1))
- if pixel_flag: # Use number of pixels
- delete_holes = np.where(sums < rel_size)[0]
- else: # Use relative size of holes.
- delete_holes = np.where(sums / reg_area < rel_size)[0]
- # Return if there is nothing to delete.
- if delete_holes == []:
- return mask_array
- # Add one to take into account 0 in list if object label 1.
- delete_holes += 1
- for label in delete_holes:
- mask_array[np.where(lab_holes == label)] = 1
- return mask_array
- def isolateregions(binary_array, size_threshold=0, pad_size=5,
- fill_hole=False, rel_size=0.1, morph_smooth=False):
- '''
- Labels regions in a boolean array and returns individual arrays for each
- region. Regions below a threshold can optionlly be removed. Small holes
- may also be filled in.
- Parameters
- ----------
- binary_array : numpy.ndarray
- A binary array of regions.
- size_threshold : int, optional
- Sets the pixel size on the size of regions.
- pad_size : int, optional
- Padding to be added to the individual arrays.
- fill_hole : int, optional
- Enables hole filling.
- rel_size : float or int, optional
- If < 1.0, sets the minimum size a hole must be relative to the area
- of the mask. Otherwise, this is the maximum number of pixels the hole
- must have to be deleted.
- morph_smooth : bool, optional
- Morphologically smooth the image using a binar opening and closing.
- Returns
- -------
- output_arrays : list
- Regions separated into individual arrays.
- num : int
- Number of filaments
- corners : list
- Contains the indices where each skeleton array was taken from
- the original.
- '''
- output_arrays = []
- corners = []
- # Label skeletons
- labels, num = nd.label(binary_array, eight_con())
- # Remove skeletons which have fewer pixels than the threshold.
- if size_threshold != 0:
- sums = nd.sum(binary_array, labels, range(1, num + 1))
- remove_fils = np.where(sums <= size_threshold)[0]
- for lab in remove_fils:
- binary_array[np.where(labels == lab + 1)] = 0
- # Relabel after deleting short skeletons.
- labels, num = nd.label(binary_array, eight_con())
- # Split each skeleton into its own array.
- for n in range(1, num + 1):
- x, y = np.where(labels == n)
- # Make an array shaped to the skeletons size and padded on each edge
- # the +1 is because, e.g., range(0, 5) only has 5 elements, but the
- # indices we're using are range(0, 6)
- shapes = (x.max() - x.min() + 2 * pad_size,
- y.max() - y.min() + 2 * pad_size)
- eachfil = np.zeros(shapes)
- eachfil[x - x.min() + pad_size, y - y.min() + pad_size] = 1
- # Fill in small holes
- if fill_hole:
- eachfil = _fix_small_holes(eachfil, rel_size=rel_size)
- if morph_smooth:
- eachfil = nd.binary_opening(eachfil, np.ones((3, 3)))
- eachfil = nd.binary_closing(eachfil, np.ones((3, 3)))
- output_arrays.append(eachfil)
- # Keep the coordinates from the original image
- lower = (x.min() - pad_size, y.min() - pad_size)
- upper = (x.max() + pad_size + 1, y.max() + pad_size + 1)
- corners.append([lower, upper])
- return output_arrays, num, corners
- def shifter(l, n):
- return l[n:] + l[:n]
- def distance(x, x1, y, y1):
- return np.sqrt((x - x1) ** 2.0 + (y - y1) ** 2.0)
- def find_filpix(branches, labelfil, final=True):
- '''
- Identifies the types of pixels in the given skeletons. Identification is
- based on the connectivity of the pixel.
- Parameters
- ----------
- branches : list
- Contains the number of branches in each skeleton.
- labelfil : list
- Contains the arrays of each skeleton.
- final : bool, optional
- If true, corner points, intersections, and body points are all
- labeled as a body point for use when the skeletons have already
- been cleaned.
- Returns
- -------
- fila_pts : list
- All points on the body of each skeleton.
- inters : list
- All points associated with an intersection in each skeleton.
- labelfil : list
- Contains the arrays of each skeleton where all intersections
- have been removed.
- endpts_return : list
- The end points of each branch of each skeleton.
- '''
- initslices = []
- initlist = []
- shiftlist = []
- sublist = []
- endpts = []
- blockpts = []
- bodypts = []
- slices = []
- vallist = []
- shiftvallist = []
- cornerpts = []
- subvallist = []
- subslist = []
- pix = []
- filpix = []
- intertemps = []
- fila_pts = []
- inters = []
- repeat = []
- temp_group = []
- all_pts = []
- pairs = []
- endpts_return = []
- for k in range(1, branches + 1):
- x, y = np.where(labelfil == k)
- # pixel_slices = np.empty((len(x)+1,8))
- for i in range(len(x)):
- if x[i] < labelfil.shape[0] - 1 and y[i] < labelfil.shape[1] - 1:
- pix.append((x[i], y[i]))
- initslices.append(np.array([[labelfil[x[i] - 1, y[i] + 1],
- labelfil[x[i], y[i] + 1],
- labelfil[x[i] + 1, y[i] + 1]],
- [labelfil[x[i] - 1, y[i]], 0,
- labelfil[x[i] + 1, y[i]]],
- [labelfil[x[i] - 1, y[i] - 1],
- labelfil[x[i], y[i] - 1],
- labelfil[x[i] + 1, y[i] - 1]]]))
- filpix.append(pix)
- slices.append(initslices)
- initslices = []
- pix = []
- for i in range(len(slices)):
- for k in range(len(slices[i])):
- initlist.append([slices[i][k][0, 0],
- slices[i][k][0, 1],
- slices[i][k][0, 2],
- slices[i][k][1, 2],
- slices[i][k][2, 2],
- slices[i][k][2, 1],
- slices[i][k][2, 0],
- slices[i][k][1, 0]])
- vallist.append(initlist)
- initlist = []
- for i in range(len(slices)):
- for k in range(len(slices[i])):
- shiftlist.append(shifter(vallist[i][k], 1))
- shiftvallist.append(shiftlist)
- shiftlist = []
- for k in range(len(slices)):
- for i in range(len(vallist[k])):
- for j in range(8):
- sublist.append(
- int(vallist[k][i][j]) - int(shiftvallist[k][i][j]))
- subslist.append(sublist)
- sublist = []
- subvallist.append(subslist)
- subslist = []
- # x represents the subtracted list (step-ups) and y is the values of the
- # surrounding pixels. The categories of pixels are ENDPTS (x<=1),
- # BODYPTS (x=2,y=2),CORNERPTS (x=2,y=3),BLOCKPTS (x=3,y>=4), and
- # INTERPTS (x>=3).
- # A cornerpt is [*,0,0] (*s) associated with an intersection,
- # but their exclusion from
- # [1,*,0] the intersection keeps eight-connectivity, they are included
- # [0,1,0] intersections for this reason.
- # A blockpt is [1,0,1] They are typically found in a group of four,
- # where all four
- # [0,*,*] constitute a single intersection.
- # [1,*,*]
- # The "final" designation is used when finding the final branch lengths.
- # At this point, blockpts and cornerpts should be eliminated.
- for k in range(branches):
- for l in range(len(filpix[k])):
- x = [j for j, y in enumerate(subvallist[k][l]) if y == k + 1]
- y = [j for j, z in enumerate(vallist[k][l]) if z == k + 1]
- if len(x) <= 1:
- endpts.append(filpix[k][l])
- endpts_return.append(filpix[k][l])
- elif len(x) == 2:
- if final:
- bodypts.append(filpix[k][l])
- else:
- if len(y) == 2:
- bodypts.append(filpix[k][l])
- elif len(y) == 3:
- cornerpts.append(filpix[k][l])
- elif len(y) >= 4:
- blockpts.append(filpix[k][l])
- elif len(x) >= 3:
- intertemps.append(filpix[k][l])
- endpts = list(set(endpts))
- bodypts = list(set(bodypts))
- dups = set(endpts) & set(bodypts)
- if len(dups) > 0:
- for i in dups:
- bodypts.remove(i)
- # Cornerpts without a partner diagonally attached can be included as a
- # bodypt.
- if len(cornerpts) > 0:
- deleted_cornerpts = []
- for i, j in zip(cornerpts, cornerpts):
- if i != j:
- if distance(i[0], j[0], i[1], j[1]) == np.sqrt(2.0):
- proximity = [(i[0], i[1] - 1),
- (i[0], i[1] + 1),
- (i[0] - 1, i[1]),
- (i[0] + 1, i[1]),
- (i[0] - 1, i[1] + 1),
- (i[0] + 1, i[1] + 1),
- (i[0] - 1, i[1] - 1),
- (i[0] + 1, i[1] - 1)]
- match = set(intertemps) & set(proximity)
- if len(match) == 1:
- pairs.append([i, j])
- deleted_cornerpts.append(i)
- deleted_cornerpts.append(j)
- cornerpts = list(set(cornerpts).difference(set(deleted_cornerpts)))
- if len(cornerpts) > 0:
- for l in cornerpts:
- proximity = [(l[0], l[1] - 1),
- (l[0], l[1] + 1),
- (l[0] - 1, l[1]),
- (l[0] + 1, l[1]),
- (l[0] - 1, l[1] + 1),
- (l[0] + 1, l[1] + 1),
- (l[0] - 1, l[1] - 1),
- (l[0] + 1, l[1] - 1)]
- match = set(intertemps) & set(proximity)
- if len(match) == 1:
- intertemps.append(l)
- fila_pts.append(endpts + bodypts)
- else:
- fila_pts.append(endpts + bodypts + [l])
- # cornerpts.remove(l)
- else:
- fila_pts.append(endpts + bodypts)
- # Reset lists
- cornerpts = []
- endpts = []
- bodypts = []
- if len(pairs) > 0:
- for i in range(len(pairs)):
- for j in pairs[i]:
- all_pts.append(j)
- if len(blockpts) > 0:
- for i in blockpts:
- all_pts.append(i)
- if len(intertemps) > 0:
- for i in intertemps:
- all_pts.append(i)
- # Pairs of cornerpts, blockpts, and interpts are combined into an
- # array. If there is eight connectivity between them, they are labelled
- # as a single intersection.
- arr = np.zeros((labelfil.shape))
- for z in all_pts:
- labelfil[z[0], z[1]] = 0
- arr[z[0], z[1]] = 1
- lab, nums = nd.label(arr, eight_con())
- for k in range(1, nums + 1):
- objs_pix = np.where(lab == k)
- for l in range(len(objs_pix[0])):
- temp_group.append((objs_pix[0][l], objs_pix[1][l]))
- inters.append(temp_group)
- temp_group = []
- for i in range(len(inters) - 1):
- if inters[i] == inters[i + 1]:
- repeat.append(inters[i])
- for i in repeat:
- inters.remove(i)
- return fila_pts, inters, labelfil, endpts_return
- def pix_identify(isolatefilarr, num):
- '''
- This function is essentially a wrapper on find_filpix. It returns the
- outputs of find_filpix in the form that are used during the analysis.
- Parameters
- ----------
- isolatefilarr : list
- Contains individual arrays of each skeleton.
- num : int
- The number of skeletons.
- Returns
- -------
- interpts : list
- Contains lists of all intersections points in each skeleton.
- hubs : list
- Contains the number of intersections in each filament. This is
- useful for identifying those with no intersections as their analysis
- is straight-forward.
- ends : list
- Contains the positions of all end points in each skeleton.
- filbranches : list
- Contains the number of branches in each skeleton.
- labelisofil : list
- Contains individual arrays for each skeleton where the
- branches are labeled and the intersections have been removed.
- '''
- interpts = []
- hubs = []
- ends = []
- filbranches = []
- labelisofil = []
- for n in range(num):
- funcreturn = find_filpix(1, isolatefilarr[n], final=False)
- interpts.append(funcreturn[1])
- hubs.append(len(funcreturn[1]))
- isolatefilarr.pop(n)
- isolatefilarr.insert(n, funcreturn[2])
- ends.append(funcreturn[3])
- label_branch, num_branch = nd.label(isolatefilarr[n], eight_con())
- filbranches.append(num_branch)
- labelisofil.append(label_branch)
- return interpts, hubs, ends, filbranches, labelisofil
- def skeleton_length(skeleton):
- '''
- Length finding via morphological operators. We use the differences in
- connectivity between 4 and 8-connected to split regions. Connections
- between 4 and 8-connected regions are found using a series of hit-miss
- operators.
- The inputted skeleton MUST have no intersections otherwise the returned
- length will not be correct!
- Parameters
- ----------
- skeleton : numpy.ndarray
- Array containing the skeleton.
- Returns
- -------
- length : float
- Length of the skeleton.
- '''
- # 4-connected labels
- four_labels = nd.label(skeleton)[0]
- four_sizes = nd.sum(skeleton, four_labels, range(np.max(four_labels) + 1))
- # Lengths is the number of pixels minus number of objects with more
- # than 1 pixel.
- four_length = np.sum(
- four_sizes[four_sizes > 1]) - len(four_sizes[four_sizes > 1])
- # Find pixels which a 4-connected and subtract them off the skeleton
- four_objects = np.where(four_sizes > 1)[0]
- skel_copy = copy.copy(skeleton)
- for val in four_objects:
- skel_copy[np.where(four_labels == val)] = 0
- # Remaining pixels are only 8-connected
- # Lengths is same as before, multiplied by sqrt(2)
- eight_labels = nd.label(skel_copy, eight_con())[0]
- eight_sizes = nd.sum(
- skel_copy, eight_labels, range(np.max(eight_labels) + 1))
- eight_length = (
- np.sum(eight_sizes) - np.max(eight_labels)) * np.sqrt(2)
- # If there are no 4-connected pixels, we don't need the hit-miss portion.
- if four_length == 0.0:
- conn_length = 0.0
- else:
- store = np.zeros(skeleton.shape)
- # Loop through the 4 rotations of the structuring elements
- for k in range(0, 4):
- hm1 = nd.binary_hit_or_miss(
- skeleton, structure1=np.rot90(struct1, k=k))
- store += hm1
- hm2 = nd.binary_hit_or_miss(
- skeleton, structure1=np.rot90(struct2, k=k))
- store += hm2
- hm_check3 = nd.binary_hit_or_miss(
- skeleton, structure1=np.rot90(check3, k=k))
- store -= hm_check3
- if k <= 1:
- hm_check1 = nd.binary_hit_or_miss(
- skeleton, structure1=np.rot90(check1, k=k))
- store -= hm_check1
- hm_check2 = nd.binary_hit_or_miss(
- skeleton, structure1=np.rot90(check2, k=k))
- store -= hm_check2
- conn_length = np.sqrt(2) * \
- np.sum(np.sum(store, axis=1), axis=0) # hits
- return conn_length + eight_length + four_length
- def init_lengths(labelisofil, filbranches, array_offsets, img):
- '''
- This is a wrapper on fil_length for running on the branches of the
- skeletons.
- Parameters
- ----------
- labelisofil : list
- Contains individual arrays for each skeleton where the
- branches are labeled and the intersections have been removed.
- filbranches : list
- Contains the number of branches in each skeleton.
- array_offsets : List
- The indices of where each filament array fits in the
- original image.
- img : numpy.ndarray
- Original image.
- Returns
- -------
- branch_properties: dict
- Contains the lengths and intensities of the branches.
- Keys are *length* and *intensity*.
- '''
- num = len(labelisofil)
- # Initialize Lists
- lengths = []
- av_branch_intensity = []
- for n in range(num):
- leng = []
- av_intensity = []
- label_copy = copy.copy(labelisofil[n])
- objects = nd.find_objects(label_copy)
- for obj in objects:
- # Scale the branch array to the branch size
- branch_array = label_copy[obj]
- # Find the skeleton points and set those to 1
- branch_pts = np.where(branch_array > 0)
- branch_array[branch_pts] = 1
- # Now find the length on the branch
- branch_length = skeleton_length(branch_array)
- if branch_length == 0.0:
- # For use in longest path algorithm, will be set to zero for
- # final analysis
- branch_length = 0.5
- leng.append(branch_length)
- # Now let's find the average intensity along each branch
- # Get the offsets from the original array and
- # add on the offset the branch array introduces.
- x_offset = obj[0].start + array_offsets[n][0][0]
- y_offset = obj[1].start + array_offsets[n][0][1]
- av_intensity.append(
- nanmean([img[x + x_offset, y + y_offset]
- for x, y in zip(*branch_pts)
- if np.isfinite(img[x + x_offset, y + y_offset]) and
- not img[x + x_offset, y + y_offset] < 0.0]))
- lengths.append(leng)
- av_branch_intensity.append(av_intensity)
- branch_properties = {
- "length": lengths, "intensity": av_branch_intensity}
- return branch_properties
- def product_gen(n):
- for r in itertools.count(1):
- for i in itertools.product(n, repeat=r):
- yield "".join(i)
- def pre_graph(labelisofil, branch_properties, interpts, ends):
- '''
- This function converts the skeletons into a graph object compatible with
- networkx. The graphs have nodes corresponding to end and
- intersection points and edges defining the connectivity as the branches
- with the weights set to the branch length.
- Parameters
- ----------
- labelisofil : list
- Contains individual arrays for each skeleton where the
- branches are labeled and the intersections have been removed.
- branch_properties : dict
- Contains the lengths and intensities of all branches.
- interpts : list
- Contains the pixels which belong to each intersection.
- ends : list
- Contains the end pixels for each skeleton.
- Returns
- -------
- end_nodes : list
- Contains the nodes corresponding to end points.
- inter_nodes : list
- Contains the nodes corresponding to intersection points.
- edge_list : list
- Contains the connectivity information for the graphs.
- nodes : list
- A complete list of all of the nodes. The other nodes lists have
- been separated as they are labeled differently.
- '''
- num = len(labelisofil)
- end_nodes = []
- inter_nodes = []
- nodes = []
- edge_list = []
- def path_weighting(idx, length, intensity, w=0.5):
- '''
- Relative weighting for the shortest path algorithm using the branch
- lengths and the average intensity along the branch.
- '''
- if w > 1.0 or w < 0.0:
- raise ValueError(
- "Relative weighting w must be between 0.0 and 1.0.")
- return (1 - w) * (length[idx] / np.sum(length)) + \
- w * (intensity[idx] / np.sum(intensity))
- lengths = branch_properties["length"]
- branch_intensity = branch_properties["intensity"]
- for n in range(num):
- inter_nodes_temp = []
- # Create end_nodes, which contains lengths, and nodes, which we will
- # later add in the intersections
- end_nodes.append([(labelisofil[n][i[0], i[1]],
- path_weighting(int(labelisofil[n][i[0], i[1]] - 1),
- lengths[n],
- branch_intensity[n]),
- lengths[n][int(labelisofil[n][i[0], i[1]] - 1)],
- branch_intensity[n][int(labelisofil[n][i[0], i[1]] - 1)])
- for i in ends[n]])
- nodes.append([labelisofil[n][i[0], i[1]] for i in ends[n]])
- # Intersection nodes are given by the intersections points of the filament.
- # They are labeled alphabetically (if len(interpts[n])>26,
- # subsequent labels are AA,AB,...).
- # The branch labels attached to each intersection are included for future
- # use.
- for intersec in interpts[n]:
- uniqs = []
- for i in intersec: # Intersections can contain multiple pixels
- int_arr = np.array([[labelisofil[n][i[0] - 1, i[1] + 1],
- labelisofil[n][i[0], i[1] + 1],
- labelisofil[n][i[0] + 1, i[1] + 1]],
- [labelisofil[n][i[0] - 1, i[1]], 0,
- labelisofil[n][i[0] + 1, i[1]]],
- [labelisofil[n][i[0] - 1, i[1] - 1],
- labelisofil[n][i[0], i[1] - 1],
- labelisofil[n][i[0] + 1, i[1] - 1]]]).astype(int)
- for x in np.unique(int_arr[np.nonzero(int_arr)]):
- uniqs.append((x,
- path_weighting(x - 1, lengths[n],
- branch_intensity[n]),
- lengths[n][x - 1],
- branch_intensity[n][x - 1]))
- # Intersections with multiple pixels can give the same branches.
- # Get rid of duplicates
- uniqs = list(set(uniqs))
- inter_nodes_temp.append(uniqs)
- # Add the intersection labels. Also append those to nodes
- inter_nodes.append(
- zip(product_gen(string.ascii_uppercase), inter_nodes_temp))
- for alpha, node in zip(product_gen(string.ascii_uppercase),
- inter_nodes_temp):
- nodes[n].append(alpha)
- # Edges are created from the information contained in the nodes.
- edge_list_temp = []
- for i, inters in enumerate(inter_nodes[n]):
- end_match = list(set(inters[1]) & set(end_nodes[n]))
- for k in end_match:
- edge_list_temp.append((inters[0], k[0], k))
- for j, inters_2 in enumerate(inter_nodes[n]):
- if i != j:
- match = list(set(inters[1]) & set(inters_2[1]))
- new_edge = None
- if len(match) == 1:
- new_edge = (inters[0], inters_2[0], match[0])
- elif len(match) > 1:
- multi = [match[l][1] for l in range(len(match))]
- keep = multi.index(min(multi))
- new_edge = (inters[0], inters_2[0], match[keep])
- if new_edge is not None:
- if not (new_edge[1], new_edge[0], new_edge[2]) in edge_list_temp \
- and new_edge not in edge_list_temp:
- edge_list_temp.append(new_edge)
- # Remove duplicated edges between intersections
- edge_list.append(edge_list_temp)
- return edge_list, nodes
- def try_mkdir(name):
- '''
- Checks if a folder exists, and makes it if it doesn't
- '''
- if not os.path.isdir(os.path.join(os.getcwd(), name)):
- os.mkdir(os.path.join(os.getcwd(), name))
- def longest_path(edge_list, nodes, verbose=False,
- skeleton_arrays=None, save_png=False, save_name=None):
- '''
- Takes the output of pre_graph and runs the shortest path algorithm.
- Parameters
- ----------
- edge_list : list
- Contains the connectivity information for the graphs.
- nodes : list
- A complete list of all of the nodes. The other nodes lists have
- been separated as they are labeled differently.
- verbose : bool, optional
- If True, enables the plotting of the graph.
- skeleton_arrays : list, optional
- List of the skeleton arrays. Required when verbose=True.
- save_png : bool, optional
- Saves the plot made in verbose mode. Disabled by default.
- save_name : str, optional
- For use when ``save_png`` is enabled.
- **MUST be specified when ``save_png`` is enabled.**
- Returns
- -------
- max_path : list
- Contains the paths corresponding to the longest lengths for
- each skeleton.
- extremum : list
- Contains the starting and ending points of max_path
- '''
- num = len(nodes)
- # Initialize lists
- max_path = []
- extremum = []
- graphs = []
- for n in range(num):
- G = nx.Graph()
- G.add_nodes_from(nodes[n])
- for i in edge_list[n]:
- G.add_edge(i[0], i[1], weight=i[2][1])
- paths = nx.shortest_path_length(G, weight='weight')
- # Fix new API
- new_paths = dict()
- for path in paths:
- new_paths[path[0]] = path[1]
- values = []
- node_extrema = []
- for i in new_paths.keys():
- j = max(new_paths[i].items(), key=operator.itemgetter(1))
- node_extrema.append((j[0], i))
- values.append(j[1])
- start, finish = node_extrema[values.index(max(values))]
- extremum.append([start, finish])
- max_path.append(nx.shortest_path(G, start, finish))
- graphs.append(G)
- if verbose or save_png:
- if not skeleton_arrays:
- Warning("Must input skeleton arrays if verbose or save_png is"
- " enabled. No plots will be created.")
- elif save_png and save_name is None:
- Warning("Must give a save_name when save_png is enabled. No"
- " plots will be created.")
- else:
- # Check if skeleton_arrays is a list
- assert isinstance(skeleton_arrays, list)
- import matplotlib.pyplot as p
- if verbose:
- print("Filament: %s / %s" % (n + 1, num))
- p.subplot(1, 2, 1)
- p.imshow(skeleton_arrays[n], interpolation="nearest",
- origin="lower")
- p.subplot(1, 2, 2)
- elist = [(u, v) for (u, v, d) in G.edges(data=True)]
- pos = nx.spring_layout(G)
- nx.draw_networkx_nodes(G, pos, node_size=200)
- nx.draw_networkx_edges(G, pos, edgelist=elist, width=2)
- nx.draw_networkx_labels(
- G, pos, font_size=10, font_family='sans-serif')
- p.axis('off')
- if save_png:
- try_mkdir(save_name)
- p.savefig(os.path.join(save_name,
- save_name + "_longest_path_" + str(n) + ".png"))
- if verbose:
- p.show()
- p.clf()
- return max_path, extremum, graphs
- def prune_graph(G, nodes, edge_list, max_path, labelisofil, branch_properties,
- length_thresh, relintens_thresh=0.2):
- '''
- Function to remove unnecessary branches, while maintaining connectivity
- in the graph. Also updates edge_list, nodes, branch_lengths and
- filbranches.
- Parameters
- ----------
- G : list
- Contains the networkx Graph objects.
- nodes : list
- A complete list of all of the nodes. The other nodes lists have
- been separated as they are labeled differently.
- edge_list : list
- Contains the connectivity information for the graphs.
- max_path : list
- Contains the paths corresponding to the longest lengths for
- each skeleton.
- labelisofil : list
- Contains individual arrays for each skeleton where the
- branches are labeled and the intersections have been removed.
- branch_properties : dict
- Contains the lengths and intensities of all branches.
- length_thresh : int or float
- Minimum length a branch must be to be kept. Can be overridden if the
- branch is bright relative to the entire skeleton.
- relintens_thresh : float between 0 and 1, optional.
- Threshold for how bright the branch must be relative to the entire
- skeleton. Can be overridden by length.
- Returns
- -------
- labelisofil : list
- Updated from input.
- edge_list : list
- Updated from input.
- nodes : list
- Updated from input.
- branch_properties : dict
- Updated from input.
- '''
- num = len(labelisofil)
- for n in range(num):
- degree = G[n].degree()
- # To make compatible with new iterator API of networkx
- new_degree = dict()
- for d in degree:
- new_degree[d[0]] = d[1]
- single_connect = [key for key in new_degree.keys() if new_degree[key] == 1]
- delete_candidate = list(
- (set(nodes[n]) - set(max_path[n])) & set(single_connect))
- if not delete_candidate: # Nothing to delete!
- continue
- edge_candidates = [edge for edge in edge_list[n] if edge[
- 0] in delete_candidate or edge[1] in delete_candidate]
- intensities = [edge[2][3] for edge in edge_list[n]]
- for edge in edge_candidates:
- # In the odd case where a loop meets at the same intersection,
- # ensure that edge is kept.
- if isinstance(edge[0], str) & isinstance(edge[1], str):
- continue
- # If its too short and relatively not as intense, delete it
- length = edge[2][2]
- av_intensity = edge[2][3]
- if length < length_thresh \
- and (av_intensity / np.sum(intensities)) < relintens_thresh:
- edge_pts = np.where(labelisofil[n] == edge[2][0])
- labelisofil[n][edge_pts] = 0
- edge_list[n].remove(edge)
- nodes[n].remove(edge[1])
- branch_properties["length"][n].remove(length)
- branch_properties["intensity"][n].remove(av_intensity)
- branch_properties["number"][n] -= 1
- return labelisofil, edge_list, nodes, branch_properties
- def extremum_pts(labelisofil, extremum, ends):
- '''
- This function returns the the farthest extents of each filament. This
- is useful for determining how well the shortest path algorithm has worked.
- Parameters
- ----------
- labelisofil : list
- Contains individual arrays for each skeleton.
- extremum : list
- Contains the extents as determined by the shortest
- path algorithm.
- ends : list
- Contains the positions of each end point in eahch filament.
- Returns
- -------
- extrem_pts : list
- Contains the indices of the extremum points.
- '''
- num = len(labelisofil)
- extrem_pts = []
- for n in range(num):
- per_fil = []
- for i, j in ends[n]:
- if labelisofil[n][i, j] == extremum[n][0] or labelisofil[n][i, j] == extremum[n][1]:
- per_fil.append([i, j])
- extrem_pts.append(per_fil)
- return extrem_pts
- def main_length(max_path, edge_list, labelisofil, interpts, branch_lengths,
- img_scale, verbose=False, save_png=False, save_name=None):
- '''
- Wraps previous functionality together for all of the skeletons in the
- image. To find the overall length for each skeleton, intersections are
- added back in, and any extraneous pixels they bring with them are deleted.
- Parameters
- ----------
- max_path : list
- Contains the paths corresponding to the longest lengths for
- each skeleton.
- edge_list : list
- Contains the connectivity information for the graphs.
- labelisofil : list
- Contains individual arrays for each skeleton where the
- branches are labeled and the intersections have been removed.
- interpts : list
- Contains the pixels which belong to each intersection.
- branch_lengths : list
- Lengths of individual branches in each skeleton.
- img_scale : float
- Conversion from pixel to physical units.
- verbose : bool, optional
- Returns plots of the longest path skeletons.
- save_png : bool, optional
- Saves the plot made in verbose mode. Disabled by default.
- save_name : str, optional
- For use when ``save_png`` is enabled.
- **MUST be specified when ``save_png`` is enabled.**
- Returns
- -------
- main_lengths : list
- Lengths of the skeletons.
- longpath_arrays : list
- Arrays of the longest paths in the skeletons.
- '''
- main_lengths = []
- longpath_arrays = []
- for num, (path, edges, inters, skel_arr, lengths) in \
- enumerate(zip(max_path, edge_list, interpts, labelisofil,
- branch_lengths)):
- if len(path) == 1:
- main_lengths.append(lengths[0] * img_scale)
- skeleton = skel_arr # for viewing purposes when verbose
- else:
- skeleton = np.zeros(skel_arr.shape)
- # Add edges along longest path
- good_edge_list = [(path[i], path[i + 1])
- for i in range(len(path) - 1)]
- # Find the branches along the longest path.
- for i in good_edge_list:
- for j in edges:
- if (i[0] == j[0] and i[1] == j[1]) or \
- (i[0] == j[1] and i[1] == j[0]):
- label = j[2][0]
- skeleton[np.where(skel_arr == label)] = 1
- # Add intersections along longest path
- intersec_pts = []
- for label in path:
- try:
- label = int(label)
- except ValueError:
- pass
- if not isinstance(label, int):
- k = 1
- while zip(product_gen(string.ascii_uppercase),
- [1] * k)[-1][0] != label:
- k += 1
- intersec_pts.extend(inters[k - 1])
- skeleton[zip(*inters[k - 1])] = 2
- # Remove unnecessary pixels
- count = 0
- while True:
- for pt in intersec_pts:
- # If we have already eliminated the point, continue
- if skeleton[pt] == 0:
- continue
- skeleton[pt] = 0
- lab_try, n = nd.label(skeleton, eight_con())
- if n > 1:
- skeleton[pt] = 1
- else:
- count += 1
- if count == 0:
- break
- count = 0
- main_lengths.append(skeleton_length(skeleton) * img_scale)
- longpath_arrays.append(skeleton.astype(int))
- if verbose or save_png:
- if save_png and save_name is None:
- Warning("Must give a save_name when save_png is enabled. No"
- " plots will be created.")
- import matplotlib.pyplot as p
- if verbose:
- print("Filament: %s / %s" % (num + 1, len(labelisofil)))
- p.subplot(121)
- p.imshow(skeleton, origin='lower', interpolation="nearest")
- p.subplot(122)
- p.imshow(labelisofil[num], origin='lower',
- interpolation="nearest")
- if save_png:
- try_mkdir(save_name)
- p.savefig(os.path.join(save_name,
- save_name + "_main_length_" + str(num) + ".png"))
- if verbose:
- p.show()
- p.clf()
- return main_lengths, longpath_arrays
- def find_extran(branches, labelfil):
- '''
- Identify pixels that are not necessary to keep the connectivity of the
- skeleton. It uses the same labeling process as find_filpix. Extraneous
- pixels tend to be those from former intersections, whose attached branch
- was eliminated in the cleaning process.
- Parameters
- ----------
- branches : list
- Contains the number of branches in each skeleton.
- labelfil : list
- Contains arrays of the labeled versions of each skeleton.
- Returns
- -------
- labelfil : list
- Contains the updated labeled arrays with extraneous pieces
- removed.
- '''
- initslices = []
- initlist = []
- shiftlist = []
- sublist = []
- extran = []
- slices = []
- vallist = []
- shiftvallist = []
- subvallist = []
- subslist = []
- pix = []
- filpix = []
- for k in range(1, branches + 1):
- x, y = np.where(labelfil == k)
- for i in range(len(x)):
- if x[i] < labelfil.shape[0] - 1 and y[i] < labelfil.shape[1] - 1:
- pix.append((x[i], y[i]))
- initslices.append(np.array([[labelfil[x[i] - 1, y[i] + 1],
- labelfil[x[i], y[i] + 1],
- labelfil[x[i] + 1, y[i] + 1]],
- [labelfil[x[i] - 1, y[i]], 0,
- labelfil[x[i] + 1, y[i]]],
- [labelfil[x[i] - 1, y[i] - 1],
- labelfil[x[i], y[i] - 1],
- labelfil[x[i] + 1, y[i] - 1]]]))
- filpix.append(pix)
- slices.append(initslices)
- initslices = []
- pix = []
- for i in range(len(slices)):
- for k in range(len(slices[i])):
- initlist.append([slices[i][k][0, 0],
- slices[i][k][0, 1],
- slices[i][k][0, 2],
- slices[i][k][1, 2],
- slices[i][k][2, 2],
- slices[i][k][2, 1],
- slices[i][k][2, 0],
- slices[i][k][1, 0]])
- vallist.append(initlist)
- initlist = []
- for i in range(len(slices)):
- for k in range(len(slices[i])):
- shiftlist.append(shifter(vallist[i][k], 1))
- shiftvallist.append(shiftlist)
- shiftlist = []
- for k in range(len(slices)):
- for i in range(len(vallist[k])):
- for j in range(8):
- sublist.append(
- int(vallist[k][i][j]) - int(shiftvallist[k][i][j]))
- subslist.append(sublist)
- sublist = []
- subvallist.append(subslist)
- subslist = []
- for k in range(len(slices)):
- for l in range(len(filpix[k])):
- x = [j for j, y in enumerate(subvallist[k][l]) if y == k + 1]
- y = [j for j, z in enumerate(vallist[k][l]) if z == k + 1]
- if len(x) == 0:
- labelfil[filpix[k][l][0], filpix[k][l][1]] = 0
- if len(x) == 1:
- if len(y) >= 2:
- extran.append(filpix[k][l])
- labelfil[filpix[k][l][0], filpix[k][l][1]] = 0
- # if len(extran) >= 2:
- # for i in extran:
- # for j in extran:
- # if i != j:
- # if distance(i[0], j[0], i[1], j[1]) == np.sqrt(2.0):
- # proximity = [(i[0], i[1] - 1),
- # (i[0], i[1] + 1),
- # (i[0] - 1, i[1]),
- # (i[0] + 1, i[1]),
- # (i[0] - 1, i[1] + 1),
- # (i[0] + 1, i[1] + 1),
- # (i[0] - 1, i[1] - 1),
- # (i[0] + 1, i[1] - 1)]
- # match = set(filpix[k]) & set(proximity)
- # if len(match) > 0:
- # for z in match:
- # labelfil[z[0], z[1]] = 0
- return labelfil
- def in_ipynb():
- try:
- cfg = get_ipython().config
- if cfg['IPKernelApp']['parent_appname'] == 'ipython-notebook':
- return True
- else:
- return False
- except NameError:
- return False
- def make_final_skeletons(labelisofil, inters, verbose=False, save_png=False,
- save_name=None):
- '''
- Creates the final skeletons outputted by the algorithm.
- Parameters
- ----------
- labelisofil : list
- List of labeled skeletons.
- inters : list
- Positions of the intersections in each skeleton.
- verbose : bool, optional
- Enables plotting of the final skeleton.
- save_png : bool, optional
- Saves the plot made in verbose mode. Disabled by default.
- save_name : str, optional
- For use when ``save_png`` is enabled.
- **MUST be specified when ``save_png`` is enabled.**
- Returns
- -------
- filament_arrays : list
- List of the final skeletons.
- '''
- filament_arrays = []
- for n, (skel_array, intersec) in enumerate(zip(labelisofil, inters)):
- copy_array = np.zeros(skel_array.shape, dtype=int)
- for inter in intersec:
- for pts in inter:
- x, y = pts
- copy_array[x, y] = 1
- copy_array[np.where(skel_array >= 1)] = 1
- cleaned_array = find_extran(1, copy_array)
- filament_arrays.append(cleaned_array)
- if verbose or save_png:
- if save_png and save_name is None:
- Warning("Must give a save_name when save_png is enabled. No"
- " plots will be created.")
- plt.clf()
- plt.imshow(cleaned_array, origin='lower', interpolation='nearest')
- if save_png:
- try_mkdir(save_name)
- plt.savefig(os.path.join(save_name,
- save_name+"_final_skeleton_"+str(n)+".png"))
- if verbose:
- plt.show()
- if in_ipynb():
- plt.clf()
- return filament_arrays
- def recombine_skeletons(skeletons, offsets, orig_size, pad_size,
- verbose=False):
- '''
- Takes a list of skeleton arrays and combines them back into
- the original array.
- Parameters
- ----------
- skeletons : list
- Arrays of each skeleton.
- offsets : list
- Coordinates where the skeleton arrays have been sliced from the
- image.
- orig_size : tuple
- Size of the image.
- pad_size : int
- Size of the array padding.
- verbose : bool, optional
- Enables printing when a skeleton array needs to be resized to fit
- into the image.
- Returns
- -------
- master_array : numpy.ndarray
- Contains all skeletons placed in their original positions in the image.
- '''
- num = len(skeletons)
- master_array = np.zeros(orig_size)
- for n in range(num):
- x_off, y_off = offsets[n][0] # These are the coordinates of the bottom
- # left in the master array.
- x_top, y_top = offsets[n][1]
- # Now check if padding will put the array outside of the original array
- # size
- excess_x_top = x_top - orig_size[0]
- excess_y_top = y_top - orig_size[1]
- copy_skeleton = copy.copy(skeletons[n])
- size_change_flag = False
- if excess_x_top > 0:
- copy_skeleton = copy_skeleton[:-excess_x_top, :]
- size_change_flag = True
- if excess_y_top > 0:
- copy_skeleton = copy_skeleton[:, :-excess_y_top]
- size_change_flag = True
- if x_off < 0:
- copy_skeleton = copy_skeleton[-x_off:, :]
- x_off = 0
- size_change_flag = True
- if y_off < 0:
- copy_skeleton = copy_skeleton[:, -y_off:]
- y_off = 0
- size_change_flag = True
- if verbose & size_change_flag:
- print("REDUCED FILAMENT %s/%s TO FIT IN ORIGINAL ARRAY" % (n, num))
- x, y = np.where(copy_skeleton >= 1)
- for i in range(len(x)):
- master_array[x[i] + x_off, y[i] + y_off] = 1
- return master_array
- if __name__ == "__main__":
- # 2D numpy array with image
- image = None
- # rms of the image
- rms = None
- data = image.copy()
- from scipy.ndimage.filters import gaussian_filter
- data = gaussian_filter(data, 5)
- mask = data < 3. * rms
- data[mask] = 0
- data[~mask] = 1
- skel, distance = medial_axis(data, return_distance=True)
- dist_on_skel = distance * skel
- # Plot area and skeleton
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True,
- subplot_kw={'adjustable': 'box-forced'})
- ax1.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
- ax1.axis('off')
- ax2.imshow(dist_on_skel, cmap=plt.cm.spectral, interpolation='nearest')
- ax2.contour(data, [0.5], colors='w')
- ax2.axis('off')
- fig.tight_layout()
- plt.show()
- fig.savefig('skeleton_orig.png')
- plt.close()
- isolated_filaments, num, offsets = isolateregions(skel)
- interpts, hubs, ends, filbranches, labeled_fil_arrays =\
- pix_identify(isolated_filaments, num)
- branch_properties = init_lengths(labeled_fil_arrays, filbranches, offsets, data)
- branch_properties["number"] = filbranches
- edge_list, nodes = pre_graph(labeled_fil_arrays, branch_properties, interpts,
- ends)
- max_path, extremum, G = longest_path(edge_list, nodes, verbose=True,
- save_png=False,
- skeleton_arrays=labeled_fil_arrays)
- updated_lists = prune_graph(G, nodes, edge_list, max_path, labeled_fil_arrays,
- branch_properties, length_thresh=20,
- relintens_thresh=0.1)
- labeled_fil_arrays, edge_list, nodes, branch_properties = updated_lists
- filament_extents = extremum_pts(labeled_fil_arrays, extremum, ends)
- length_output = main_length(max_path, edge_list, labeled_fil_arrays, interpts,
- branch_properties["length"], 1, verbose=True)
- filament_arrays = {}
- lengths, filament_arrays["long path"] = length_output
- lengths = np.asarray(lengths)
- filament_arrays["final"] = make_final_skeletons(labeled_fil_arrays, interpts,
- verbose=True)
- skeleton = recombine_skeletons(filament_arrays["final"], offsets, data.shape,
- 0, verbose=True)
- skeleton_longpath = recombine_skeletons(filament_arrays["long path"], offsets,
- data.shape, 1)
- skeleton_longpath_dist = skeleton_longpath * distance
- # Plot area and skeleton
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True,
- subplot_kw={'adjustable': 'box-forced'})
- ax1.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
- ax1.axis('off')
- ax2.imshow(skeleton_longpath_dist, cmap=plt.cm.spectral,
- interpolation='nearest')
- ax2.contour(data, [0.5], colors='w')
- ax2.axis('off')
- fig.tight_layout()
- plt.savefig('skeleton.png')
- plt.show()
- plt.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement