Advertisement
ali_m

norm_xcorr

Sep 3rd, 2012
12,285
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.47 KB | None | 0 0
  1. import numpy as np
  2. import scipy as sp
  3. from scipy.ndimage import convolve
  4.  
  5. # Try and use the faster Fourier transform functions from the anfft module if
  6. # available
  7. try:
  8.     import anfft as _anfft
  9.     # measure == True for self-optimisation of repeat Fourier transforms of
  10.     # similarly-shaped arrays
  11.     def fftn(A,shape=None):
  12.         if shape != None:
  13.             A = _checkffttype(A)
  14.             A = procrustes(A,target=shape,side='after',padval=0)
  15.         return _anfft.fftn(A,measure=True)
  16.     def ifftn(B,shape=None):
  17.         if shape != None:
  18.             B = _checkffttype(B)
  19.             B = procrustes(B,target=shape,side='after',padval=0)
  20.         return _anfft.ifftn(B,measure=True)
  21.     def _checkffttype(C):
  22.         # make sure input arrays are typed correctly for FFTW
  23.         if C.dtype == 'complex256':
  24.             # the only incompatible complex type --> complex64
  25.             C = np.complex128(C)
  26.         elif C.dtype not in ['float32','float64','complex64','complex128']:
  27.             # any other incompatible type --> float64
  28.             C = np.float64(C)
  29.         return C
  30. # Otherwise use the normal scipy fftpack ones instead (~2-3x slower!)
  31. except ImportError:
  32.     print \
  33.     "Module 'anfft' (FFTW Python bindings) could not be imported.\n"\
  34.     "To install it, try running 'easy_install anfft' from the terminal.\n"\
  35.     "Falling back on the slower 'fftpack' module for ND Fourier transforms."
  36.     from scipy.fftpack import fftn, ifftn
  37.  
  38. class TemplateMatch(object):
  39.     """
  40.     N-dimensional template search by normalized cross-correlation or sum of
  41.     squared differences.
  42.  
  43.     Arguments:
  44.     ------------------------
  45.         template    The template to search for
  46.         method      The search method. Can be "ncc", "ssd" or
  47.                 "both". See documentation for norm_xcorr and
  48.                 fast_ssd for more details.
  49.  
  50.     Example use:
  51.     ------------------------
  52.     from scipy.misc import lena
  53.     from matplotlib.pyplot import subplots
  54.  
  55.     image = lena()
  56.     template = image[240:281,240:281]
  57.     TM = TemplateMatch(template,method='both')
  58.     ncc,ssd = TM(image)
  59.     nccloc = np.nonzero(ncc == ncc.max())
  60.     ssdloc = np.nonzero(ssd == ssd.min())
  61.  
  62.     fig,[[ax1,ax2],[ax3,ax4]] = subplots(2,2,num='ND Template Search')
  63.     ax1.imshow(image,interpolation='nearest')
  64.     ax1.set_title('Search image')
  65.     ax2.imshow(template,interpolation='nearest')
  66.     ax2.set_title('Template')
  67.     ax3.hold(True)
  68.     ax3.imshow(ncc,interpolation='nearest')
  69.     ax3.plot(nccloc[1],nccloc[0],'w+')
  70.     ax3.set_title('Normalized cross-correlation')
  71.     ax4.hold(True)
  72.     ax4.imshow(ssd,interpolation='nearest')
  73.     ax4.plot(ssdloc[1],ssdloc[0],'w+')
  74.     ax4.set_title('Sum of squared differences')
  75.     fig.tight_layout()
  76.     fig.canvas.draw()
  77.     """
  78.     def __init__(self,template,method='ssd'):
  79.  
  80.         if method not in ['ncc','ssd','both']:
  81.             raise Exception('Invalid method "%s". '\
  82.                     'Valid methods are "ncc", "ssd" or "both"'
  83.                     %method)
  84.  
  85.         self.template = template
  86.         self.method = method
  87.  
  88.     def __call__(self,a):
  89.  
  90.         if a.ndim != self.template.ndim:
  91.             raise Exception('Input array must have the same number '\
  92.                     'of dimensions as the template (%i)'
  93.                     %self.template.ndim)
  94.  
  95.         if self.method == 'ssd':
  96.             return self.fast_ssd(self.template,a,trim=True)
  97.         elif self.method == 'ncc':
  98.             return norm_xcorr(self.template,a,trim=True)
  99.         elif self.method == 'both':
  100.             return norm_xcorr(self.template,a,trim=True,do_ssd=True)
  101.  
  102. def norm_xcorr(t,a,method=None,trim=True,do_ssd=False):
  103.     """
  104.     Fast normalized cross-correlation for n-dimensional arrays
  105.  
  106.     Inputs:
  107.     ----------------
  108.         t   The template. Must have at least 2 elements, which
  109.             cannot all be equal.
  110.  
  111.         a   The search space. Its dimensionality must match that of
  112.             the template.
  113.  
  114.         method  The convolution method to use when computing the
  115.             cross-correlation. Can be either 'direct', 'fourier' or
  116.             None. If method == None (default), the convolution time
  117.             is estimated for both methods and the best one is chosen
  118.             for the given input array sizes.
  119.  
  120.         trim    If True (default), the output array is trimmed down to  
  121.             the size of the search space. Otherwise, its size will  
  122.             be (f.shape[dd] + t.shape[dd] -1) for dimension dd.
  123.  
  124.         do_ssd  If True, the sum of squared differences between the
  125.             template and the search image will also be calculated.
  126.             It is very efficient to calculate normalized
  127.             cross-correlation and the SSD simultaneously, since they
  128.             require many of the same quantities.
  129.  
  130.     Output:
  131.     ----------------
  132.         nxcorr  An array of cross-correlation coefficients, which may  
  133.             vary from -1.0 to 1.0.
  134.         [ssd]   [Returned if do_ssd == True. See fast_ssd for details.]
  135.  
  136.     Wherever the search space has zero  variance under the template,
  137.     normalized  cross-correlation is undefined. In such regions, the
  138.     correlation coefficients are set to zero.
  139.  
  140.     References:
  141.         Hermosillo et al 2002: Variational Methods for Multimodal Image
  142.         Matching, International Journal of Computer Vision 50(3),
  143.         329-343, 2002
  144.         <http://www.springerlink.com/content/u4007p8871w10645/>
  145.  
  146.         Lewis 1995: Fast Template Matching, Vision Interface,
  147.         p.120-123, 1995
  148.         <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  149.  
  150.         <http://en.wikipedia.org/wiki/Cross-correlation#Normalized_cross-correlation>
  151.  
  152.     Alistair Muldal
  153.     Department of Pharmacology
  154.     University of Oxford
  155.     <alistair.muldal@pharm.ox.ac.uk>
  156.  
  157.     Sept 2012
  158.  
  159.     """
  160.  
  161.     if t.size < 2:
  162.         raise Exception('Invalid template')
  163.     if t.size > a.size:
  164.         raise Exception('The input array must be smaller than the template')
  165.  
  166.     std_t,mean_t = np.std(t),np.mean(t)
  167.  
  168.     if std_t == 0:
  169.         raise Exception('The values of the template must not all be equal')
  170.  
  171.     t = np.float64(t)
  172.     a = np.float64(a)
  173.  
  174.     # output dimensions of xcorr need to match those of local_sum
  175.     outdims = np.array([a.shape[dd]+t.shape[dd]-1 for dd in xrange(a.ndim)])
  176.  
  177.     # would it be quicker to convolve in the spatial or frequency domain? NB
  178.     # this is not very accurate since the speed of the Fourier transform
  179.     # varies quite a lot with the output dimensions (e.g. 2-radix case)
  180.     if method == None:
  181.         spatialtime, ffttime = get_times(t,a,outdims)
  182.         if spatialtime < ffttime:
  183.             method = 'spatial'
  184.         else:
  185.             method = 'fourier'
  186.  
  187.     if method == 'fourier':
  188.         # # in many cases, padding the dimensions to a power of 2
  189.         # # *dramatically* improves the speed of the Fourier transforms
  190.         # # since it allows using radix-2 FFTs
  191.         # fftshape = [nextpow2(ss) for ss in a.shape]
  192.  
  193.         # Fourier transform of the input array and the inverted template
  194.  
  195.         # af = fftn(a,shape=fftshape)
  196.         # tf = fftn(ndflip(t),shape=fftshape)
  197.  
  198.         af = fftn(a,shape=outdims)
  199.         tf = fftn(ndflip(t),shape=outdims)
  200.  
  201.         # 'non-normalized' cross-correlation
  202.         xcorr = np.real(ifftn(tf*af))
  203.  
  204.     else:
  205.         xcorr = convolve(a,t,mode='constant',cval=0)
  206.  
  207.     # local linear and quadratic sums of input array in the region of the
  208.     # template
  209.     ls_a = local_sum(a,t.shape)
  210.     ls2_a = local_sum(a**2,t.shape)
  211.  
  212.     # now we need to make sure xcorr is the same size as ls_a
  213.     xcorr = procrustes(xcorr,ls_a.shape,side='both')
  214.  
  215.     # local standard deviation of the input array
  216.     ls_diff = ls2_a - (ls_a**2)/t.size
  217.     ls_diff = np.where(ls_diff < 0,0,ls_diff)
  218.     sigma_a = np.sqrt(ls_diff)
  219.  
  220.     # standard deviation of the template
  221.     sigma_t = np.sqrt(t.size-1.)*std_t
  222.  
  223.     # denominator: product of standard deviations
  224.     denom = sigma_t*sigma_a
  225.  
  226.     # numerator: local mean corrected cross-correlation
  227.     numer = (xcorr - ls_a*mean_t)
  228.  
  229.     # sigma_t cannot be zero, so wherever the denominator is zero, this must
  230.     # be because sigma_a is zero (and therefore the normalized cross-
  231.     # correlation is undefined), so set nxcorr to zero in these regions
  232.     tol = np.sqrt(np.finfo(denom.dtype).eps)
  233.     nxcorr = np.where(denom < tol,0,numer/denom)
  234.  
  235.     # if any of the coefficients are outside the range [-1 1], they will be
  236.     # unstable to small variance in a or t, so set them to zero to reflect
  237.     # the undefined 0/0 condition
  238.     nxcorr = np.where(np.abs(nxcorr-1.) > np.sqrt(np.finfo(nxcorr.dtype).eps),nxcorr,0)
  239.  
  240.     # calculate the SSD if requested
  241.     if do_ssd:
  242.         # quadratic sum of the template
  243.         tsum2 = np.sum(t**2.)
  244.  
  245.         # SSD between template and image
  246.         ssd = ls2_a + tsum2 - 2.*xcorr
  247.  
  248.         # normalise to between 0 and 1
  249.         ssd -= ssd.min()
  250.         ssd /= ssd.max()
  251.  
  252.         if trim:
  253.             nxcorr = procrustes(nxcorr,a.shape,side='both')
  254.             ssd = procrustes(ssd,a.shape,side='both')
  255.         return nxcorr,ssd
  256.  
  257.     else:
  258.         if trim:
  259.             nxcorr = procrustes(nxcorr,a.shape,side='both')
  260.         return nxcorr
  261.  
  262. def fast_ssd(t,a,method=None,trim=True):
  263.     """
  264.  
  265.     Fast sum of squared differences (SSD block matching) for n-dimensional
  266.     arrays
  267.  
  268.     Inputs:
  269.     ----------------
  270.         t   The template. Must have at least 2 elements, which
  271.             cannot all be equal.
  272.  
  273.         a   The search space. Its dimensionality must match that of
  274.             the template.
  275.  
  276.         method  The convolution method to use when computing the
  277.             cross-correlation. Can be either 'direct', 'fourier' or
  278.             None. If method == None (default), the convolution time
  279.             is estimated for both methods and the best one is chosen
  280.             for the given input array sizes.
  281.  
  282.         trim    If True (default), the output array is trimmed down to  
  283.             the size of the search space. Otherwise, its size will  
  284.             be (f.shape[dd] + t.shape[dd] -1) for dimension dd.
  285.  
  286.     Output:
  287.     ----------------
  288.         ssd     An array containing the sum of squared differences
  289.             between the image and the template, with the values
  290.             normalized in the range -1.0 to 1.0.
  291.  
  292.     Wherever the search space has zero  variance under the template,
  293.     normalized  cross-correlation is undefined. In such regions, the
  294.     correlation coefficients are set to zero.
  295.  
  296.     References:
  297.         Hermosillo et al 2002: Variational Methods for Multimodal Image
  298.         Matching, International Journal of Computer Vision 50(3),
  299.         329-343, 2002
  300.         <http://www.springerlink.com/content/u4007p8871w10645/>
  301.  
  302.         Lewis 1995: Fast Template Matching, Vision Interface,
  303.         p.120-123, 1995
  304.         <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  305.  
  306.  
  307.     Alistair Muldal
  308.     Department of Pharmacology
  309.     University of Oxford
  310.     <alistair.muldal@pharm.ox.ac.uk>
  311.  
  312.     Sept 2012
  313.  
  314.     """
  315.  
  316.     if t.size < 2:
  317.         raise Exception('Invalid template')
  318.     if t.size > a.size:
  319.         raise Exception('The input array must be smaller than the template')
  320.  
  321.     std_t,mean_t = np.std(t),np.mean(t)
  322.  
  323.     if std_t == 0:
  324.         raise Exception('The values of the template must not all be equal')
  325.  
  326.     # output dimensions of xcorr need to match those of local_sum
  327.     outdims = np.array([a.shape[dd]+t.shape[dd]-1 for dd in xrange(a.ndim)])
  328.  
  329.     # would it be quicker to convolve in the spatial or frequency domain? NB
  330.     # this is not very accurate since the speed of the Fourier transform
  331.     # varies quite a lot with the output dimensions (e.g. 2-radix case)
  332.     if method == None:
  333.         spatialtime, ffttime = get_times(t,a,outdims)
  334.         if spatialtime < ffttime:
  335.             method = 'spatial'
  336.         else:
  337.             method = 'fourier'
  338.  
  339.     if method == 'fourier':
  340.         # # in many cases, padding the dimensions to a power of 2
  341.         # # *dramatically* improves the speed of the Fourier transforms
  342.         # # since it allows using radix-2 FFTs
  343.         # fftshape = [nextpow2(ss) for ss in a.shape]
  344.  
  345.         # Fourier transform of the input array and the inverted template
  346.  
  347.         # af = fftn(a,shape=fftshape)
  348.         # tf = fftn(ndflip(t),shape=fftshape)
  349.  
  350.         af = fftn(a,shape=outdims)
  351.         tf = fftn(ndflip(t),shape=outdims)
  352.  
  353.         # 'non-normalized' cross-correlation
  354.         xcorr = np.real(ifftn(tf*af))
  355.  
  356.     else:
  357.         xcorr = convolve(a,t,mode='constant',cval=0)
  358.  
  359.     # quadratic sum of the template
  360.     tsum2 = np.sum(t**2.)
  361.  
  362.     # local quadratic sum of input array in the region of the template
  363.     ls2_a = local_sum(a**2,t.shape)
  364.  
  365.     # now we need to make sure xcorr is the same size as ls2_a
  366.     xcorr = procrustes(xcorr,ls2_a.shape,side='both')
  367.  
  368.     # SSD between template and image
  369.     ssd = ls2_a + tsum2 - 2.*xcorr
  370.  
  371.     # normalise to between 0 and 1
  372.     ssd -= ssd.min()
  373.     ssd /= ssd.max()
  374.  
  375.     if trim:
  376.         ssd = procrustes(ssd,a.shape,side='both')
  377.  
  378.     return ssd
  379.  
  380.  
  381. def local_sum(a,tshape):
  382.     """For each element in an n-dimensional input array, calculate
  383.     the sum of the elements within a surrounding region the size of
  384.     the template"""
  385.  
  386.     # zero-padding
  387.     a = ndpad(a,tshape)
  388.  
  389.     # difference between shifted copies of an array along a given dimension
  390.     def shiftdiff(a,tshape,shiftdim):
  391.         ind1 = [slice(None,None),]*a.ndim
  392.         ind2 = [slice(None,None),]*a.ndim
  393.         ind1[shiftdim] = slice(tshape[shiftdim],a.shape[shiftdim]-1)
  394.         ind2[shiftdim] = slice(0,a.shape[shiftdim]-tshape[shiftdim]-1)
  395.         return a[ind1] - a[ind2]
  396.  
  397.     # take the cumsum along each dimension and subtracting a shifted version
  398.     # from itself. this reduces the number of computations to 2*N additions
  399.     # and 2*N subtractions for an N-dimensional array, independent of its
  400.     # size.
  401.     #
  402.     # See:
  403.     # <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  404.     for dd in xrange(a.ndim):
  405.         a = np.cumsum(a,dd)
  406.         a = shiftdiff(a,tshape,dd)
  407.     return a
  408.  
  409. # # for debugging purposes, ~10x slower than local_sum for a (512,512) array
  410. # def slow_2D_local_sum(a,tshape):
  411. #   out = np.zeros_like(a)
  412. #   for ii in xrange(a.shape[0]):
  413. #       istart = np.max((0,ii-tshape[0]//2))
  414. #       istop = np.min((a.shape[0],ii+tshape[0]//2+1))
  415. #       for jj in xrange(a.shape[1]):
  416. #           jstart = np.max((0,jj-tshape[1]//2))
  417. #           jstop = np.min((a.shape[1],jj+tshape[0]//2+1))
  418. #           out[ii,jj] = np.sum(a[istart:istop,jstart:jstop])
  419. #   return out
  420.  
  421. def get_times(t,a,outdims):
  422.  
  423.     k_conv = 1.21667E-09
  424.     k_fft = 2.65125E-08
  425.  
  426.     # # uncomment these lines to measure timing constants
  427.     # k_conv,k_fft,convreps,fftreps = benchmark(t,a,outdims,maxtime=60)
  428.     # print "-------------------------------------"
  429.     # print "Template size:\t\t%s" %str(t.shape)
  430.     # print "Search space size:\t%s" %str(a.shape)
  431.     # print "k_conv:\t%.6G\treps:\t%s" %(k_conv,str(convreps))
  432.     # print "k_fft:\t%.6G\treps:\t%s" %(k_fft,str(fftreps))
  433.     # print "-------------------------------------"
  434.  
  435.     # spatial convolution time scales with the total number of elements
  436.     convtime = k_conv*(t.size*a.size)
  437.  
  438.     # Fourier convolution time scales with N*log(N), cross-correlation
  439.     # requires 2x FFTs and 1x iFFT. ND FFT time scales with
  440.     # prod(dimensions)*log(prod(dimensions))
  441.     ffttime = 3*k_fft*(np.prod(outdims)*np.log(np.prod(outdims)))
  442.  
  443.     # print     "Predicted spatial:\t%.6G\nPredicted fourier:\t%.6G" %(convtime,ffttime)
  444.     return convtime,ffttime
  445.  
  446. def benchmark(t,a,outdims,maxtime=60):
  447.     import resource
  448.  
  449.     # benchmark spatial convolutions
  450.     # ---------------------------------
  451.     convreps = 0
  452.     tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  453.     toc = tic
  454.     while (toc-tic) < maxtime:
  455.         convolve(a,t,mode='constant',cval=0)
  456.         # xcorr = convolve(a,t,mode='full')
  457.         convreps += 1
  458.         toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  459.     convtime = (toc-tic)/convreps
  460.  
  461.     # convtime == k(N1+N2)
  462.     N = t.size*a.size
  463.     k_conv = convtime/N
  464.  
  465.     # benchmark 1D Fourier transforms
  466.     # ---------------------------------
  467.     veclist = [np.random.randn(ss) for ss in outdims]
  468.     fft1times = []
  469.     fftreps = []
  470.     for vec in veclist:
  471.         reps = 0
  472.         tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  473.         toc = tic
  474.         while (toc-tic) < maxtime:
  475.             fftn(vec)
  476.             toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  477.             reps += 1
  478.         fft1times.append((toc-tic)/reps)
  479.         fftreps.append(reps)
  480.     fft1times = np.asarray(fft1times)
  481.  
  482.     # fft1_time == k*N*log(N)
  483.     N = np.asarray([vec.size for vec in veclist])
  484.     k_fft = np.mean(fft1times/(N*np.log(N)))
  485.  
  486.     # # benchmark ND Fourier transforms
  487.     # # ---------------------------------
  488.     # arraylist = [t,a]
  489.     # fftntimes = []
  490.     # fftreps = []
  491.     # for array in arraylist:
  492.     #   reps = 0
  493.     #   tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  494.     #   toc = tic
  495.     #   while (toc-tic) < maxtime:
  496.     #       fftn(array,shape=a.shape)
  497.     #       reps += 1
  498.     #       toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  499.     #   fftntimes.append((toc-tic)/reps)
  500.     #   fftreps.append(reps)
  501.     # fftntimes = np.asarray(fftntimes)
  502.  
  503.     # # fftn_time == k*prod(dimensions)*log(prod(dimensions)) for an M-dimensional array
  504.     # nlogn = np.array([aa.size*np.log(aa.size) for aa in arraylist])
  505.     # k_fft = np.mean(fftntimes/nlogn)
  506.  
  507.     return k_conv,k_fft,convreps,fftreps
  508.     # return k_conv,k_fft1,k_fftn
  509.  
  510.  
  511. def ndpad(a,npad=None,padval=0):
  512.     """
  513.     Pads the edges of an n-dimensional input array with a constant value
  514.     across all of its dimensions.
  515.  
  516.     Inputs:
  517.     ----------------
  518.         a   The array to pad
  519.  
  520.         npad*   The pad width. Can either be array-like, with one
  521.             element per dimension, or a scalar, in which case the
  522.             same pad width is applied to all dimensions.
  523.  
  524.         padval  The value to pad with. Must be a scalar (default is 0).
  525.  
  526.     Output:
  527.     ----------------
  528.         b   The padded array
  529.  
  530.     *If npad is not a whole number, padding will be applied so that the
  531.     'left' edge of the output is padded less than the 'right', e.g.:
  532.  
  533.         a       == np.array([1,2,3,4,5,6])
  534.         ndpad(a,1.5)    == np.array([0,1,2,3,4,5,6,0,0])
  535.  
  536.     In this case, the average pad width is equal to npad (but if npad was
  537.     not a multiple of 0.5 this would not still hold). This is so that ndpad
  538.     can be used to pad an array out to odd final dimensions.
  539.     """
  540.  
  541.     if npad == None:
  542.         npad = np.ones(a.ndim)
  543.     elif np.isscalar(npad):
  544.         npad = (npad,)*a.ndim
  545.     elif len(npad) != a.ndim:
  546.         raise Exception('Length of npad (%i) does not match the '\
  547.                 'dimensionality of the input array (%i)'
  548.                 %(len(npad),a.ndim))
  549.  
  550.     # initialise padded output
  551.     padsize = [a.shape[dd]+2*npad[dd] for dd in xrange(a.ndim)]
  552.     b = np.ones(padsize,a.dtype)*padval
  553.  
  554.     # construct an N-dimensional list of slice objects
  555.     ind = [slice(np.floor(npad[dd]),a.shape[dd]+np.floor(npad[dd])) for dd in xrange(a.ndim)]
  556.  
  557.     # fill in the non-pad part of the array
  558.     b[ind] = a
  559.     return b
  560.  
  561. # def ndunpad(b,npad=None):
  562. #   """
  563. #   Removes padding from each dimension of an n-dimensional array (the
  564. #   reverse of ndpad)
  565.  
  566. #   Inputs:
  567. #   ----------------
  568. #       b   The array to unpad
  569.  
  570. #       npad*   The pad width. Can either be array-like, with one
  571. #           element per dimension, or a scalar, in which case the
  572. #           same pad width is applied to all dimensions.
  573.  
  574. #   Output:
  575. #   ----------------
  576. #       a   The unpadded array
  577.  
  578. #         *If npad is not a whole number, padding will be removed assuming that
  579. #   the 'left' edge of the output is padded less than the 'right', e.g.:
  580.  
  581. #       b       == np.array([0,1,2,3,4,5,6,0,0])
  582. #       ndpad(b,1.5)    == np.array([1,2,3,4,5,6])
  583.  
  584. #   This is consistent with the behaviour of ndpad.
  585. #   """
  586. #   if npad == None:
  587. #       npad = np.ones(b.ndim)
  588. #   elif np.isscalar(npad):
  589. #       npad = (npad,)*b.ndim
  590. #   elif len(npad) != b.ndim:
  591. #       raise Exception('Length of npad (%i) does not match the '\
  592. #               'dimensionality of the input array (%i)'
  593. #               %(len(npad),b.ndim))
  594. #   ind = [slice(np.floor(npad[dd]),b.shape[dd]-np.ceil(npad[dd])) for dd in xrange(b.ndim)]
  595. #   return b[ind]
  596.  
  597. def procrustes(a,target,side='both',padval=0):
  598.     """
  599.     Forces an array to a target size by either padding it with a constant or
  600.     truncating it
  601.  
  602.     Arguments:
  603.         a   Input array of any type or shape
  604.         target  Dimensions to pad/trim to, must be a list or tuple
  605.     """
  606.  
  607.     try:
  608.         if len(target) != a.ndim:
  609.             raise TypeError('Target shape must have the same number of dimensions as the input')
  610.     except TypeError:
  611.         raise TypeError('Target must be array-like')
  612.  
  613.     try:
  614.         b = np.ones(target,a.dtype)*padval
  615.     except TypeError:
  616.         raise TypeError('Pad value must be numeric')
  617.     except ValueError:
  618.         raise ValueError('Pad value must be scalar')
  619.  
  620.     aind = [slice(None,None)]*a.ndim
  621.     bind = [slice(None,None)]*a.ndim
  622.  
  623.     # pad/trim comes after the array in each dimension
  624.     if side == 'after':
  625.         for dd in xrange(a.ndim):
  626.             if a.shape[dd] > target[dd]:
  627.                 aind[dd] = slice(None,target[dd])
  628.             elif a.shape[dd] < target[dd]:
  629.                 bind[dd] = slice(None,a.shape[dd])
  630.  
  631.     # pad/trim comes before the array in each dimension
  632.     elif side == 'before':
  633.         for dd in xrange(a.ndim):
  634.             if a.shape[dd] > target[dd]:
  635.                 aind[dd] = slice(a.shape[dd]-target[dd],None)
  636.             elif a.shape[dd] < target[dd]:
  637.                 bind[dd] = slice(target[dd]-a.shape[dd],None)
  638.  
  639.     # pad/trim both sides of the array in each dimension
  640.     elif side == 'both':
  641.         for dd in xrange(a.ndim):
  642.             if a.shape[dd] > target[dd]:
  643.                 diff = (a.shape[dd]-target[dd])/2.
  644.                 aind[dd] = slice(np.floor(diff),a.shape[dd]-np.ceil(diff))
  645.             elif a.shape[dd] < target[dd]:
  646.                 diff = (target[dd]-a.shape[dd])/2.
  647.                 bind[dd] = slice(np.floor(diff),target[dd]-np.ceil(diff))
  648.    
  649.     else:
  650.         raise Exception('Invalid choice of pad type: %s' %side)
  651.  
  652.     b[bind] = a[aind]
  653.  
  654.     return b
  655.  
  656. def ndflip(a):
  657.     """Inverts an n-dimensional array along each of its axes"""
  658.     ind = (slice(None,None,-1),)*a.ndim
  659.     return a[ind]
  660.  
  661. # def nextpow2(n):
  662. #   """get the next power of 2 that's greater than n"""
  663. #   m_f = np.log2(n)
  664. #   m_i = np.ceil(m_f)
  665. #   return 2**m_i
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement