diff --git a/workflows/relperm/dist_func_utils.py b/workflows/relperm/dist_func_utils.py new file mode 100644 index 00000000..9d8ad113 --- /dev/null +++ b/workflows/relperm/dist_func_utils.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python + +import numpy as np +from scipy import ndimage +from scipy import spatial +import scipy.ndimage.filters as filters +import scipy.ndimage.morphology as morphology +import scipy.stats as stats +from skimage.morphology import medial_axis,skeletonize_3d + +def detect_intersection_point_with_cluster_filter(input_data,lx,ly,lz=0): + # Input is a signed distance data file such as 'SignDist.xxxxx' + if lz > 0: # i.e. a 3D input + # Calculate the medial axis of the segmented image + dist = input_data.copy() + dist.shape = (lz,ly,lx) + skel = skeletonize_3d(dist>0.0) + dist_on_skel = skel*dist + [z_skel,y_skel,x_skel]=np.where(skel) + + # Building two search trees is potentially good for 3D image + grid = np.indices(dist.shape) + grid_points = zip(grid[0].ravel(),grid[1].ravel(),grid[2].ravel()) + tree_grid = spatial.cKDTree(grid_points) + points_for_search = zip(z_skel,y_skel,x_skel) + tree_points_for_search = spatial.cKDTree(points_for_search) + neighbor_all = tree_points_for_search.query_ball_tree(tree_grid,np.sqrt(3.0)) + + idx_glb_table = np.ravel_multi_index([z_skel,y_skel,x_skel],skel.shape) + idx_glb_candidate = np.empty((0,),dtype=int) + for k,idx_glb_neighbor in enumerate(neighbor_all): + if k%4000==0: + print 'Search for intersection points: '+str(k)+'/'+str(len(neighbor_all)-1) + #end if + mask = np.in1d(idx_glb_neighbor,idx_glb_table,assume_unique=True) + idx_glb_candidate = np.hstack((idx_glb_candidate,mask.sum())) + #end for + + # Statistics: number of neighbors for each voxel on the medial axis + connectivity_stats = stats.itemfreq(idx_glb_candidate) + # 'connectivity_stats' has the following format: + # number_of_neighbors_for_each_voxel :: corresponding_number_of_voxels + # array([[ 1, 41], + # [ 2, 143], + # [ 3, 185], + # [ 4, 5]]) + # + # If a voxel is indentified as an intersection point, the number of its + # neighboring voxels should be greater than 'benchmark' + benchmark = connectivity_stats[np.argmax(connectivity_stats[:,1]),0] + # Update the medial axis and 'dist_on_skel' + skel[:] = False + skel[np.unravel_index(idx_glb_table[idx_glb_candidate > benchmark],skel.shape)] = True + else: # i.e. 2D input + # Calculate the medial axis of the segmented image + dist = input_data.copy() + dist.shape = (ly,lx) + skel = medial_axis(dist>0.0) + dist_on_skel = skel*dist + [y_skel,x_skel]=np.where(skel) + + # Building two search trees is potentially good for 3D image + grid = np.indices(dist.shape) + grid_points = zip(grid[0].ravel(),grid[1].ravel()) + tree_grid = spatial.cKDTree(grid_points) + points_for_search = zip(y_skel,x_skel) + tree_points_for_search = spatial.cKDTree(points_for_search) + neighbor_all = tree_points_for_search.query_ball_tree(tree_grid,np.sqrt(2.0)) + + idx_glb_table = np.ravel_multi_index([y_skel,x_skel],skel.shape) + idx_glb_candidate = np.empty((0,),dtype=int) + for k,idx_glb_neighbor in enumerate(neighbor_all): + if k%4000==0: + print 'Search for intersection points: '+str(k)+'/'+str(len(neighbor_all)-1) + #end if + mask = np.in1d(idx_glb_neighbor,idx_glb_table,assume_unique=True) + idx_glb_candidate = np.hstack((idx_glb_candidate,mask.sum())) + #end for + + # Statistics: number of neighbors for each voxel on the medial axis + connectivity_stats = stats.itemfreq(idx_glb_candidate) + # 'connectivity_stats' has the following format: + # number_of_neighbors_for_each_voxel :: corresponding_number_of_voxels + # array([[ 1, 41], + # [ 2, 143], + # [ 3, 185], + # [ 4, 5]]) + # + # If a voxel is indentified as an intersection point, the number of its + # neighboring voxels should be greater than 'benchmark' + benchmark = connectivity_stats[np.argmax(connectivity_stats[:,1]),0] + # Update the medial axis and 'dist_on_skel' + skel[:] = False + skel[np.unravel_index(idx_glb_table[idx_glb_candidate > benchmark],skel.shape)] = True + #end if + return (_Filter_cluster_close_points(skel*dist),connectivity_stats) +#end def + +def _Filter_cluster_close_points(arr): + # 'arr' is a 2D/3D signed distance data file + # Return an 'arr' with nearest neighboring points clustered + if arr.ndim == 2: + [y_idx,x_idx] = np.where(arr>0.0) + idx_glb = np.ravel_multi_index([y_idx,x_idx],arr.shape) + grid = np.indices(arr.shape) + + dist = arr.ravel()[idx_glb] + candidate = np.empty((0,),dtype=int) + # TODO: use search tree to do this ! + for k in np.arange(idx_glb.size): + if k%200==0: + print 'Clustering close points: '+str(k)+'/'+str(idx_glb.size-1) + #end if + mask_circle = (grid[0]-y_idx[k])**2 + (grid[1]-x_idx[k])**2<=dist[k]**2 + idx_glb_circle =np.ravel_multi_index(np.where(mask_circle),mask_circle.shape) + mask = np.in1d(idx_glb,idx_glb_circle,assume_unique=True) + candidate = np.hstack((candidate,idx_glb[mask][np.argmax(dist[mask])])) + #end for + candidate = np.unique(candidate) + mask = np.in1d(idx_glb,candidate,assume_unique=True) + output_glb_idx = idx_glb[mask] + output = np.zeros_like(arr.ravel()) + output[output_glb_idx] = arr.ravel()[output_glb_idx] + output.shape = arr.shape + else: + [z_idx,y_idx,x_idx] = np.where(arr>0.0) + idx_glb = np.ravel_multi_index([z_idx,y_idx,x_idx],arr.shape) + grid = np.indices(arr.shape) + + dist = arr.ravel()[idx_glb] + candidate = np.empty((0,),dtype=int) + # TODO: use search tree to do this ! + for k in np.arange(idx_glb.size): + if k%200==0: + print 'Clustering close points: '+str(k)+'/'+str(idx_glb.size-1) + #end if + mask_circle = (grid[0]-z_idx[k])**2 + (grid[1]-y_idx[k])**2 + (grid[2]-x_idx[k])**2<=dist[k]**2 + idx_glb_circle =np.ravel_multi_index(np.where(mask_circle),mask_circle.shape) + mask = np.in1d(idx_glb,idx_glb_circle,assume_unique=True) + candidate = np.hstack((candidate,idx_glb[mask][np.argmax(dist[mask])])) + #end for + candidate = np.unique(candidate) + mask = np.in1d(idx_glb,candidate,assume_unique=True) + output_glb_idx = idx_glb[mask] + output = np.zeros_like(arr.ravel()) + output[output_glb_idx] = arr.ravel()[output_glb_idx] + output.shape = arr.shape + #end if + return output +#end def + +def get_Dist(f): + return ndimage.distance_transform_edt(f) +#end def + +def detect_local_maxima(arr,medial_axis_arr,patch_size=3): + local_max = _find_local_maxima(arr,patch_size) + background_value = 1e5 # for denoising in finding local minima + arr2 = arr.copy() + arr2[np.logical_not(medial_axis_arr)]=background_value + ocal_min = _find_local_minima(arr2,background_val=background_value) + local_min = np.logical_and(medial_axis_arr,local_min) #Correct min_indices with Medial axis + local_max_reduced = np.logical_xor(local_max,local_min) + local_max = np.logical_and(local_max,local_max_reduced) + local_max = np.logical_and(medial_axis_arr,local_max)#Correct max_indices with Medial axis + return local_max +#end def + +def detect_local_maxima_with_cluster_filter(arr,medial_axis_arr,patch_size=3): + # apply the cluster filetering only once + local_max = detect_local_maxima(arr,medial_axis_arr,patch_size=patch_size) + arr2 = arr.copy() + arr2[np.logical_not(local_max)]=0.0 + return _Filter_cluster_close_points(arr2) +#end def + +def detect_local_maxima_with_cluster_filter_loop(arr,medial_axis_arr,patch_start=3,patch_stop=9): + # Implement multiple search for local maxima (with cluster filtering) + arr1 = detect_local_maxima_with_cluster_filter(arr,medial_axis_arr,patch_size=patch_start) + [y_idx_1,x_idx_1] = np.where(arr1>0.0) + arr2 = detect_local_maxima_with_cluster_filter(arr1,medial_axis_arr,patch_size=patch_start+2) + [y_idx_2,x_idx_2] = np.where(arr2>0.0) + if (y_idx_1.size>y_idx_2.size): + counter = 2 # Record how many times is the filtering applied + patch = np.arange((patch_start+2)+2,patch_stop+2,2) + for k in patch: + + y_idx_1 = y_idx_2.copy() + x_idx_1 = x_idx_2.copy() + + arr1 = detect_local_maxima_with_cluster_filter(arr2,medial_axis_arr,patch_size=k) + [y_idx_2,x_idx_2] = np.where(arr1>0.0) + + arr2 = arr1.copy() + + counter = counter + 1 + if not (y_idx_1.size > y_idx_2.size): + print 'Note: Local maxima are already found before the patch_stop is reached.' + print ' All local maxima are found at the patch size of '+str(k)+' !' + break; + #end if + #end for + else: + print 'Note: All local maxima are found at the patch size of '+str(patch_start)+' !' + return arr2 + #end if + return arr2 +#end def + +#def Filter_cluster(arr): +# return _Filter_cluster_close_points(arr) +##end def +def _find_local_maxima(arr,patch_size): + # http://stackoverflow.com/questions/3684484/peak-detection-in-a-2d-array/3689710#3689710 + """ + Takes an array and detects the troughs using the local maximum filter. + Returns a boolean mask of the troughs (i.e. 1 when + the pixel's value is the neighborhood maximum, 0 otherwise) + """ + # define an connected neighborhood + # http://www.scipy.org/doc/api_docs/SciPy.ndimage.morphology.html#generate_binary_structure + #neighborhood = morphology.generate_binary_structure(len(arr.shape),2) + #neighborhood = np.ones((patch_size,patch_size),dtype=bool) + # apply the local maximum filter; all locations of maximum value + # in their neighborhood are set to 1 + local_max = (filters.maximum_filter(arr, size=patch_size,mode='constant',cval=0)==arr) + # local_min is a mask that contains the peaks we are + # looking for, but also the background. + # In order to isolate the peaks we must remove the background from the mask. + # + # we create the mask of the background + background = (arr==0) + # + # a little technicality: we must erode the background in order to + # successfully subtract it from local_min, otherwise a line will + # appear along the background border (artifact of the local minimum filter) + # http://www.scipy.org/doc/api_docs/SciPy.ndimage.morphology.html#binary_erosion + if arr.ndim == 3: + neighborhood = np.ones((patch_size,patch_size,patch_size),dtype=bool) + elif arr.ndim==2: + neighborhood = np.ones((patch_size,patch_size),dtype=bool) + else: + neighborhood = np.ones((patch_size,),dtype=bool) + #end if + eroded_background = morphology.binary_erosion( + background, structure=neighborhood, border_value=1) + # + # we obtain the final mask, containing only peaks, + # by removing the background from the local_min mask + detected_maxima = local_max - eroded_background + return detected_maxima + +def _find_local_minima(arr,patch_size=3,background_val=1e5): + local_min = (filters.minimum_filter(arr, size=patch_size,mode='constant',cval=background_val)==arr) + background = (arr==background_val) + if arr.ndim == 3: + neighborhood = np.ones((patch_size,patch_size,patch_size),dtype=bool) + elif arr.ndim==2: + neighborhood = np.ones((patch_size,patch_size),dtype=bool) + else: + neighborhood = np.ones((patch_size,),dtype=bool) + #end if + eroded_background = morphology.binary_erosion( + background, structure=neighborhood, border_value=1) + detected_minima = local_min - eroded_background + return detected_minima +#end def + + +