Advertisement
ali_m

norm_xcorr

Sep 3rd, 2012
13,085
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.47 KB | None | 0 0
  1. import numpy as np
  2. import scipy as sp
  3. from scipy.ndimage import convolve
  4.  
  5. # Try and use the faster Fourier transform functions from the anfft module if
  6. # available
  7. try:
  8.     import anfft as _anfft
  9.     # measure == True for self-optimisation of repeat Fourier transforms of
  10.     # similarly-shaped arrays
  11.     def fftn(A,shape=None):
  12.         if shape != None:
  13.             A = _checkffttype(A)
  14.             A = procrustes(A,target=shape,side='after',padval=0)
  15.         return _anfft.fftn(A,measure=True)
  16.     def ifftn(B,shape=None):
  17.         if shape != None:
  18.             B = _checkffttype(B)
  19.             B = procrustes(B,target=shape,side='after',padval=0)
  20.         return _anfft.ifftn(B,measure=True)
  21.     def _checkffttype(C):
  22.         # make sure input arrays are typed correctly for FFTW
  23.         if C.dtype == 'complex256':
  24.             # the only incompatible complex type --> complex64
  25.             C = np.complex128(C)
  26.         elif C.dtype not in ['float32','float64','complex64','complex128']:
  27.             # any other incompatible type --> float64
  28.             C = np.float64(C)
  29.         return C
  30. # Otherwise use the normal scipy fftpack ones instead (~2-3x slower!)
  31. except ImportError:
  32.     print \
  33.     "Module 'anfft' (FFTW Python bindings) could not be imported.\n"\
  34.     "To install it, try running 'easy_install anfft' from the terminal.\n"\
  35.     "Falling back on the slower 'fftpack' module for ND Fourier transforms."
  36.     from scipy.fftpack import fftn, ifftn
  37.  
  38. class TemplateMatch(object):
  39.     """
  40.     N-dimensional template search by normalized cross-correlation or sum of
  41.     squared differences.
  42.  
  43.     Arguments:
  44.     ------------------------
  45.         template    The template to search for
  46.         method      The search method. Can be "ncc", "ssd" or
  47.                 "both". See documentation for norm_xcorr and
  48.                 fast_ssd for more details.
  49.  
  50.     Example use:
  51.     ------------------------
  52.     from scipy.misc import lena
  53.     from matplotlib.pyplot import subplots
  54.  
  55.     image = lena()
  56.     template = image[240:281,240:281]
  57.     TM = TemplateMatch(template,method='both')
  58.     ncc,ssd = TM(image)
  59.     nccloc = np.nonzero(ncc == ncc.max())
  60.     ssdloc = np.nonzero(ssd == ssd.min())
  61.  
  62.     fig,[[ax1,ax2],[ax3,ax4]] = subplots(2,2,num='ND Template Search')
  63.     ax1.imshow(image,interpolation='nearest')
  64.     ax1.set_title('Search image')
  65.     ax2.imshow(template,interpolation='nearest')
  66.     ax2.set_title('Template')
  67.     ax3.hold(True)
  68.     ax3.imshow(ncc,interpolation='nearest')
  69.     ax3.plot(nccloc[1],nccloc[0],'w+')
  70.     ax3.set_title('Normalized cross-correlation')
  71.     ax4.hold(True)
  72.     ax4.imshow(ssd,interpolation='nearest')
  73.     ax4.plot(ssdloc[1],ssdloc[0],'w+')
  74.     ax4.set_title('Sum of squared differences')
  75.     fig.tight_layout()
  76.     fig.canvas.draw()
  77.     """
  78.     def __init__(self,template,method='ssd'):
  79.  
  80.         if method not in ['ncc','ssd','both']:
  81.             raise Exception('Invalid method "%s". '\
  82.                     'Valid methods are "ncc", "ssd" or "both"'
  83.                     %method)
  84.  
  85.         self.template = template
  86.         self.method = method
  87.  
  88.     def __call__(self,a):
  89.  
  90.         if a.ndim != self.template.ndim:
  91.             raise Exception('Input array must have the same number '\
  92.                     'of dimensions as the template (%i)'
  93.                     %self.template.ndim)
  94.  
  95.         if self.method == 'ssd':
  96.             return self.fast_ssd(self.template,a,trim=True)
  97.         elif self.method == 'ncc':
  98.             return norm_xcorr(self.template,a,trim=True)
  99.         elif self.method == 'both':
  100.             return norm_xcorr(self.template,a,trim=True,do_ssd=True)
  101.  
  102. def norm_xcorr(t,a,method=None,trim=True,do_ssd=False):
  103.     """
  104.     Fast normalized cross-correlation for n-dimensional arrays
  105.  
  106.     Inputs:
  107.     ----------------
  108.         t   The template. Must have at least 2 elements, which
  109.             cannot all be equal.
  110.  
  111.         a   The search space. Its dimensionality must match that of
  112.             the template.
  113.  
  114.         method  The convolution method to use when computing the
  115.             cross-correlation. Can be either 'direct', 'fourier' or
  116.             None. If method == None (default), the convolution time
  117.             is estimated for both methods and the best one is chosen
  118.             for the given input array sizes.
  119.  
  120.         trim    If True (default), the output array is trimmed down to  
  121.             the size of the search space. Otherwise, its size will  
  122.             be (f.shape[dd] + t.shape[dd] -1) for dimension dd.
  123.  
  124.         do_ssd  If True, the sum of squared differences between the
  125.             template and the search image will also be calculated.
  126.             It is very efficient to calculate normalized
  127.             cross-correlation and the SSD simultaneously, since they
  128.             require many of the same quantities.
  129.  
  130.     Output:
  131.     ----------------
  132.         nxcorr  An array of cross-correlation coefficients, which may  
  133.             vary from -1.0 to 1.0.
  134.         [ssd]   [Returned if do_ssd == True. See fast_ssd for details.]
  135.  
  136.     Wherever the search space has zero  variance under the template,
  137.     normalized  cross-correlation is undefined. In such regions, the
  138.     correlation coefficients are set to zero.
  139.  
  140.     References:
  141.         Hermosillo et al 2002: Variational Methods for Multimodal Image
  142.         Matching, International Journal of Computer Vision 50(3),
  143.         329-343, 2002
  144.         <http://www.springerlink.com/content/u4007p8871w10645/>
  145.  
  146.         Lewis 1995: Fast Template Matching, Vision Interface,
  147.         p.120-123, 1995
  148.         <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  149.  
  150.         <http://en.wikipedia.org/wiki/Cross-correlation#Normalized_cross-correlation>
  151.  
  152.     Alistair Muldal
  153.     Department of Pharmacology
  154.     University of Oxford
  155.  
  156.     Sept 2012
  157.  
  158.     """
  159.  
  160.     if t.size < 2:
  161.         raise Exception('Invalid template')
  162.     if t.size > a.size:
  163.         raise Exception('The input array must be smaller than the template')
  164.  
  165.     std_t,mean_t = np.std(t),np.mean(t)
  166.  
  167.     if std_t == 0:
  168.         raise Exception('The values of the template must not all be equal')
  169.  
  170.     t = np.float64(t)
  171.     a = np.float64(a)
  172.  
  173.     # output dimensions of xcorr need to match those of local_sum
  174.     outdims = np.array([a.shape[dd]+t.shape[dd]-1 for dd in xrange(a.ndim)])
  175.  
  176.     # would it be quicker to convolve in the spatial or frequency domain? NB
  177.     # this is not very accurate since the speed of the Fourier transform
  178.     # varies quite a lot with the output dimensions (e.g. 2-radix case)
  179.     if method == None:
  180.         spatialtime, ffttime = get_times(t,a,outdims)
  181.         if spatialtime < ffttime:
  182.             method = 'spatial'
  183.         else:
  184.             method = 'fourier'
  185.  
  186.     if method == 'fourier':
  187.         # # in many cases, padding the dimensions to a power of 2
  188.         # # *dramatically* improves the speed of the Fourier transforms
  189.         # # since it allows using radix-2 FFTs
  190.         # fftshape = [nextpow2(ss) for ss in a.shape]
  191.  
  192.         # Fourier transform of the input array and the inverted template
  193.  
  194.         # af = fftn(a,shape=fftshape)
  195.         # tf = fftn(ndflip(t),shape=fftshape)
  196.  
  197.         af = fftn(a,shape=outdims)
  198.         tf = fftn(ndflip(t),shape=outdims)
  199.  
  200.         # 'non-normalized' cross-correlation
  201.         xcorr = np.real(ifftn(tf*af))
  202.  
  203.     else:
  204.         xcorr = convolve(a,t,mode='constant',cval=0)
  205.  
  206.     # local linear and quadratic sums of input array in the region of the
  207.     # template
  208.     ls_a = local_sum(a,t.shape)
  209.     ls2_a = local_sum(a**2,t.shape)
  210.  
  211.     # now we need to make sure xcorr is the same size as ls_a
  212.     xcorr = procrustes(xcorr,ls_a.shape,side='both')
  213.  
  214.     # local standard deviation of the input array
  215.     ls_diff = ls2_a - (ls_a**2)/t.size
  216.     ls_diff = np.where(ls_diff < 0,0,ls_diff)
  217.     sigma_a = np.sqrt(ls_diff)
  218.  
  219.     # standard deviation of the template
  220.     sigma_t = np.sqrt(t.size-1.)*std_t
  221.  
  222.     # denominator: product of standard deviations
  223.     denom = sigma_t*sigma_a
  224.  
  225.     # numerator: local mean corrected cross-correlation
  226.     numer = (xcorr - ls_a*mean_t)
  227.  
  228.     # sigma_t cannot be zero, so wherever the denominator is zero, this must
  229.     # be because sigma_a is zero (and therefore the normalized cross-
  230.     # correlation is undefined), so set nxcorr to zero in these regions
  231.     tol = np.sqrt(np.finfo(denom.dtype).eps)
  232.     nxcorr = np.where(denom < tol,0,numer/denom)
  233.  
  234.     # if any of the coefficients are outside the range [-1 1], they will be
  235.     # unstable to small variance in a or t, so set them to zero to reflect
  236.     # the undefined 0/0 condition
  237.     nxcorr = np.where(np.abs(nxcorr-1.) > np.sqrt(np.finfo(nxcorr.dtype).eps),nxcorr,0)
  238.  
  239.     # calculate the SSD if requested
  240.     if do_ssd:
  241.         # quadratic sum of the template
  242.         tsum2 = np.sum(t**2.)
  243.  
  244.         # SSD between template and image
  245.         ssd = ls2_a + tsum2 - 2.*xcorr
  246.  
  247.         # normalise to between 0 and 1
  248.         ssd -= ssd.min()
  249.         ssd /= ssd.max()
  250.  
  251.         if trim:
  252.             nxcorr = procrustes(nxcorr,a.shape,side='both')
  253.             ssd = procrustes(ssd,a.shape,side='both')
  254.         return nxcorr,ssd
  255.  
  256.     else:
  257.         if trim:
  258.             nxcorr = procrustes(nxcorr,a.shape,side='both')
  259.         return nxcorr
  260.  
  261. def fast_ssd(t,a,method=None,trim=True):
  262.     """
  263.  
  264.     Fast sum of squared differences (SSD block matching) for n-dimensional
  265.     arrays
  266.  
  267.     Inputs:
  268.     ----------------
  269.         t   The template. Must have at least 2 elements, which
  270.             cannot all be equal.
  271.  
  272.         a   The search space. Its dimensionality must match that of
  273.             the template.
  274.  
  275.         method  The convolution method to use when computing the
  276.             cross-correlation. Can be either 'direct', 'fourier' or
  277.             None. If method == None (default), the convolution time
  278.             is estimated for both methods and the best one is chosen
  279.             for the given input array sizes.
  280.  
  281.         trim    If True (default), the output array is trimmed down to  
  282.             the size of the search space. Otherwise, its size will  
  283.             be (f.shape[dd] + t.shape[dd] -1) for dimension dd.
  284.  
  285.     Output:
  286.     ----------------
  287.         ssd     An array containing the sum of squared differences
  288.             between the image and the template, with the values
  289.             normalized in the range -1.0 to 1.0.
  290.  
  291.     Wherever the search space has zero  variance under the template,
  292.     normalized  cross-correlation is undefined. In such regions, the
  293.     correlation coefficients are set to zero.
  294.  
  295.     References:
  296.         Hermosillo et al 2002: Variational Methods for Multimodal Image
  297.         Matching, International Journal of Computer Vision 50(3),
  298.         329-343, 2002
  299.         <http://www.springerlink.com/content/u4007p8871w10645/>
  300.  
  301.         Lewis 1995: Fast Template Matching, Vision Interface,
  302.         p.120-123, 1995
  303.         <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  304.  
  305.  
  306.     Alistair Muldal
  307.     Department of Pharmacology
  308.     University of Oxford
  309.  
  310.     Sept 2012
  311.  
  312.     """
  313.  
  314.     if t.size < 2:
  315.         raise Exception('Invalid template')
  316.     if t.size > a.size:
  317.         raise Exception('The input array must be smaller than the template')
  318.  
  319.     std_t,mean_t = np.std(t),np.mean(t)
  320.  
  321.     if std_t == 0:
  322.         raise Exception('The values of the template must not all be equal')
  323.  
  324.     # output dimensions of xcorr need to match those of local_sum
  325.     outdims = np.array([a.shape[dd]+t.shape[dd]-1 for dd in xrange(a.ndim)])
  326.  
  327.     # would it be quicker to convolve in the spatial or frequency domain? NB
  328.     # this is not very accurate since the speed of the Fourier transform
  329.     # varies quite a lot with the output dimensions (e.g. 2-radix case)
  330.     if method == None:
  331.         spatialtime, ffttime = get_times(t,a,outdims)
  332.         if spatialtime < ffttime:
  333.             method = 'spatial'
  334.         else:
  335.             method = 'fourier'
  336.  
  337.     if method == 'fourier':
  338.         # # in many cases, padding the dimensions to a power of 2
  339.         # # *dramatically* improves the speed of the Fourier transforms
  340.         # # since it allows using radix-2 FFTs
  341.         # fftshape = [nextpow2(ss) for ss in a.shape]
  342.  
  343.         # Fourier transform of the input array and the inverted template
  344.  
  345.         # af = fftn(a,shape=fftshape)
  346.         # tf = fftn(ndflip(t),shape=fftshape)
  347.  
  348.         af = fftn(a,shape=outdims)
  349.         tf = fftn(ndflip(t),shape=outdims)
  350.  
  351.         # 'non-normalized' cross-correlation
  352.         xcorr = np.real(ifftn(tf*af))
  353.  
  354.     else:
  355.         xcorr = convolve(a,t,mode='constant',cval=0)
  356.  
  357.     # quadratic sum of the template
  358.     tsum2 = np.sum(t**2.)
  359.  
  360.     # local quadratic sum of input array in the region of the template
  361.     ls2_a = local_sum(a**2,t.shape)
  362.  
  363.     # now we need to make sure xcorr is the same size as ls2_a
  364.     xcorr = procrustes(xcorr,ls2_a.shape,side='both')
  365.  
  366.     # SSD between template and image
  367.     ssd = ls2_a + tsum2 - 2.*xcorr
  368.  
  369.     # normalise to between 0 and 1
  370.     ssd -= ssd.min()
  371.     ssd /= ssd.max()
  372.  
  373.     if trim:
  374.         ssd = procrustes(ssd,a.shape,side='both')
  375.  
  376.     return ssd
  377.  
  378.  
  379. def local_sum(a,tshape):
  380.     """For each element in an n-dimensional input array, calculate
  381.     the sum of the elements within a surrounding region the size of
  382.     the template"""
  383.  
  384.     # zero-padding
  385.     a = ndpad(a,tshape)
  386.  
  387.     # difference between shifted copies of an array along a given dimension
  388.     def shiftdiff(a,tshape,shiftdim):
  389.         ind1 = [slice(None,None),]*a.ndim
  390.         ind2 = [slice(None,None),]*a.ndim
  391.         ind1[shiftdim] = slice(tshape[shiftdim],a.shape[shiftdim]-1)
  392.         ind2[shiftdim] = slice(0,a.shape[shiftdim]-tshape[shiftdim]-1)
  393.         return a[ind1] - a[ind2]
  394.  
  395.     # take the cumsum along each dimension and subtracting a shifted version
  396.     # from itself. this reduces the number of computations to 2*N additions
  397.     # and 2*N subtractions for an N-dimensional array, independent of its
  398.     # size.
  399.     #
  400.     # See:
  401.     # <http://www.idiom.com/~zilla/Papers/nvisionInterface/nip.html>
  402.     for dd in xrange(a.ndim):
  403.         a = np.cumsum(a,dd)
  404.         a = shiftdiff(a,tshape,dd)
  405.     return a
  406.  
  407. # # for debugging purposes, ~10x slower than local_sum for a (512,512) array
  408. # def slow_2D_local_sum(a,tshape):
  409. #   out = np.zeros_like(a)
  410. #   for ii in xrange(a.shape[0]):
  411. #       istart = np.max((0,ii-tshape[0]//2))
  412. #       istop = np.min((a.shape[0],ii+tshape[0]//2+1))
  413. #       for jj in xrange(a.shape[1]):
  414. #           jstart = np.max((0,jj-tshape[1]//2))
  415. #           jstop = np.min((a.shape[1],jj+tshape[0]//2+1))
  416. #           out[ii,jj] = np.sum(a[istart:istop,jstart:jstop])
  417. #   return out
  418.  
  419. def get_times(t,a,outdims):
  420.  
  421.     k_conv = 1.21667E-09
  422.     k_fft = 2.65125E-08
  423.  
  424.     # # uncomment these lines to measure timing constants
  425.     # k_conv,k_fft,convreps,fftreps = benchmark(t,a,outdims,maxtime=60)
  426.     # print "-------------------------------------"
  427.     # print "Template size:\t\t%s" %str(t.shape)
  428.     # print "Search space size:\t%s" %str(a.shape)
  429.     # print "k_conv:\t%.6G\treps:\t%s" %(k_conv,str(convreps))
  430.     # print "k_fft:\t%.6G\treps:\t%s" %(k_fft,str(fftreps))
  431.     # print "-------------------------------------"
  432.  
  433.     # spatial convolution time scales with the total number of elements
  434.     convtime = k_conv*(t.size*a.size)
  435.  
  436.     # Fourier convolution time scales with N*log(N), cross-correlation
  437.     # requires 2x FFTs and 1x iFFT. ND FFT time scales with
  438.     # prod(dimensions)*log(prod(dimensions))
  439.     ffttime = 3*k_fft*(np.prod(outdims)*np.log(np.prod(outdims)))
  440.  
  441.     # print     "Predicted spatial:\t%.6G\nPredicted fourier:\t%.6G" %(convtime,ffttime)
  442.     return convtime,ffttime
  443.  
  444. def benchmark(t,a,outdims,maxtime=60):
  445.     import resource
  446.  
  447.     # benchmark spatial convolutions
  448.     # ---------------------------------
  449.     convreps = 0
  450.     tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  451.     toc = tic
  452.     while (toc-tic) < maxtime:
  453.         convolve(a,t,mode='constant',cval=0)
  454.         # xcorr = convolve(a,t,mode='full')
  455.         convreps += 1
  456.         toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  457.     convtime = (toc-tic)/convreps
  458.  
  459.     # convtime == k(N1+N2)
  460.     N = t.size*a.size
  461.     k_conv = convtime/N
  462.  
  463.     # benchmark 1D Fourier transforms
  464.     # ---------------------------------
  465.     veclist = [np.random.randn(ss) for ss in outdims]
  466.     fft1times = []
  467.     fftreps = []
  468.     for vec in veclist:
  469.         reps = 0
  470.         tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  471.         toc = tic
  472.         while (toc-tic) < maxtime:
  473.             fftn(vec)
  474.             toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  475.             reps += 1
  476.         fft1times.append((toc-tic)/reps)
  477.         fftreps.append(reps)
  478.     fft1times = np.asarray(fft1times)
  479.  
  480.     # fft1_time == k*N*log(N)
  481.     N = np.asarray([vec.size for vec in veclist])
  482.     k_fft = np.mean(fft1times/(N*np.log(N)))
  483.  
  484.     # # benchmark ND Fourier transforms
  485.     # # ---------------------------------
  486.     # arraylist = [t,a]
  487.     # fftntimes = []
  488.     # fftreps = []
  489.     # for array in arraylist:
  490.     #   reps = 0
  491.     #   tic = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  492.     #   toc = tic
  493.     #   while (toc-tic) < maxtime:
  494.     #       fftn(array,shape=a.shape)
  495.     #       reps += 1
  496.     #       toc = resource.getrusage(resource.RUSAGE_SELF).ru_utime
  497.     #   fftntimes.append((toc-tic)/reps)
  498.     #   fftreps.append(reps)
  499.     # fftntimes = np.asarray(fftntimes)
  500.  
  501.     # # fftn_time == k*prod(dimensions)*log(prod(dimensions)) for an M-dimensional array
  502.     # nlogn = np.array([aa.size*np.log(aa.size) for aa in arraylist])
  503.     # k_fft = np.mean(fftntimes/nlogn)
  504.  
  505.     return k_conv,k_fft,convreps,fftreps
  506.     # return k_conv,k_fft1,k_fftn
  507.  
  508.  
  509. def ndpad(a,npad=None,padval=0):
  510.     """
  511.     Pads the edges of an n-dimensional input array with a constant value
  512.     across all of its dimensions.
  513.  
  514.     Inputs:
  515.     ----------------
  516.         a   The array to pad
  517.  
  518.         npad*   The pad width. Can either be array-like, with one
  519.             element per dimension, or a scalar, in which case the
  520.             same pad width is applied to all dimensions.
  521.  
  522.         padval  The value to pad with. Must be a scalar (default is 0).
  523.  
  524.     Output:
  525.     ----------------
  526.         b   The padded array
  527.  
  528.     *If npad is not a whole number, padding will be applied so that the
  529.     'left' edge of the output is padded less than the 'right', e.g.:
  530.  
  531.         a       == np.array([1,2,3,4,5,6])
  532.         ndpad(a,1.5)    == np.array([0,1,2,3,4,5,6,0,0])
  533.  
  534.     In this case, the average pad width is equal to npad (but if npad was
  535.     not a multiple of 0.5 this would not still hold). This is so that ndpad
  536.     can be used to pad an array out to odd final dimensions.
  537.     """
  538.  
  539.     if npad == None:
  540.         npad = np.ones(a.ndim)
  541.     elif np.isscalar(npad):
  542.         npad = (npad,)*a.ndim
  543.     elif len(npad) != a.ndim:
  544.         raise Exception('Length of npad (%i) does not match the '\
  545.                 'dimensionality of the input array (%i)'
  546.                 %(len(npad),a.ndim))
  547.  
  548.     # initialise padded output
  549.     padsize = [a.shape[dd]+2*npad[dd] for dd in xrange(a.ndim)]
  550.     b = np.ones(padsize,a.dtype)*padval
  551.  
  552.     # construct an N-dimensional list of slice objects
  553.     ind = [slice(np.floor(npad[dd]),a.shape[dd]+np.floor(npad[dd])) for dd in xrange(a.ndim)]
  554.  
  555.     # fill in the non-pad part of the array
  556.     b[ind] = a
  557.     return b
  558.  
  559. # def ndunpad(b,npad=None):
  560. #   """
  561. #   Removes padding from each dimension of an n-dimensional array (the
  562. #   reverse of ndpad)
  563.  
  564. #   Inputs:
  565. #   ----------------
  566. #       b   The array to unpad
  567.  
  568. #       npad*   The pad width. Can either be array-like, with one
  569. #           element per dimension, or a scalar, in which case the
  570. #           same pad width is applied to all dimensions.
  571.  
  572. #   Output:
  573. #   ----------------
  574. #       a   The unpadded array
  575.  
  576. #         *If npad is not a whole number, padding will be removed assuming that
  577. #   the 'left' edge of the output is padded less than the 'right', e.g.:
  578.  
  579. #       b       == np.array([0,1,2,3,4,5,6,0,0])
  580. #       ndpad(b,1.5)    == np.array([1,2,3,4,5,6])
  581.  
  582. #   This is consistent with the behaviour of ndpad.
  583. #   """
  584. #   if npad == None:
  585. #       npad = np.ones(b.ndim)
  586. #   elif np.isscalar(npad):
  587. #       npad = (npad,)*b.ndim
  588. #   elif len(npad) != b.ndim:
  589. #       raise Exception('Length of npad (%i) does not match the '\
  590. #               'dimensionality of the input array (%i)'
  591. #               %(len(npad),b.ndim))
  592. #   ind = [slice(np.floor(npad[dd]),b.shape[dd]-np.ceil(npad[dd])) for dd in xrange(b.ndim)]
  593. #   return b[ind]
  594.  
  595. def procrustes(a,target,side='both',padval=0):
  596.     """
  597.     Forces an array to a target size by either padding it with a constant or
  598.     truncating it
  599.  
  600.     Arguments:
  601.         a   Input array of any type or shape
  602.         target  Dimensions to pad/trim to, must be a list or tuple
  603.     """
  604.  
  605.     try:
  606.         if len(target) != a.ndim:
  607.             raise TypeError('Target shape must have the same number of dimensions as the input')
  608.     except TypeError:
  609.         raise TypeError('Target must be array-like')
  610.  
  611.     try:
  612.         b = np.ones(target,a.dtype)*padval
  613.     except TypeError:
  614.         raise TypeError('Pad value must be numeric')
  615.     except ValueError:
  616.         raise ValueError('Pad value must be scalar')
  617.  
  618.     aind = [slice(None,None)]*a.ndim
  619.     bind = [slice(None,None)]*a.ndim
  620.  
  621.     # pad/trim comes after the array in each dimension
  622.     if side == 'after':
  623.         for dd in xrange(a.ndim):
  624.             if a.shape[dd] > target[dd]:
  625.                 aind[dd] = slice(None,target[dd])
  626.             elif a.shape[dd] < target[dd]:
  627.                 bind[dd] = slice(None,a.shape[dd])
  628.  
  629.     # pad/trim comes before the array in each dimension
  630.     elif side == 'before':
  631.         for dd in xrange(a.ndim):
  632.             if a.shape[dd] > target[dd]:
  633.                 aind[dd] = slice(a.shape[dd]-target[dd],None)
  634.             elif a.shape[dd] < target[dd]:
  635.                 bind[dd] = slice(target[dd]-a.shape[dd],None)
  636.  
  637.     # pad/trim both sides of the array in each dimension
  638.     elif side == 'both':
  639.         for dd in xrange(a.ndim):
  640.             if a.shape[dd] > target[dd]:
  641.                 diff = (a.shape[dd]-target[dd])/2.
  642.                 aind[dd] = slice(np.floor(diff),a.shape[dd]-np.ceil(diff))
  643.             elif a.shape[dd] < target[dd]:
  644.                 diff = (target[dd]-a.shape[dd])/2.
  645.                 bind[dd] = slice(np.floor(diff),target[dd]-np.ceil(diff))
  646.    
  647.     else:
  648.         raise Exception('Invalid choice of pad type: %s' %side)
  649.  
  650.     b[bind] = a[aind]
  651.  
  652.     return b
  653.  
  654. def ndflip(a):
  655.     """Inverts an n-dimensional array along each of its axes"""
  656.     ind = (slice(None,None,-1),)*a.ndim
  657.     return a[ind]
  658.  
  659. # def nextpow2(n):
  660. #   """get the next power of 2 that's greater than n"""
  661. #   m_f = np.log2(n)
  662. #   m_i = np.ceil(m_f)
  663. #   return 2**m_i
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement