Advertisement
Guest User

apply_along_axis named arguments

a guest
Mar 6th, 2014
152
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.47 KB | None | 0 0
  1. def apply_along_axis(func1d,axis,arr,*args,**moreargs):
  2.     """
  3.    Apply a function to 1-D slices along the given axis.
  4.  
  5.    Execute `func1d(a, *args)` where `func1d` operates on 1-D arrays and `a`
  6.    is a 1-D slice of `arr` along `axis`.
  7.  
  8.    Parameters
  9.    ----------
  10.    func1d : function
  11.        This function should accept 1-D arrays. It is applied to 1-D
  12.        slices of `arr` along the specified axis.
  13.    axis : integer
  14.        Axis along which `arr` is sliced.
  15.    arr : ndarray
  16.        Input array.
  17.    args : any
  18.        Additional arguments to `func1d`.
  19.  
  20.    Returns
  21.    -------
  22.    apply_along_axis : ndarray
  23.        The output array. The shape of `outarr` is identical to the shape of
  24.        `arr`, except along the `axis` dimension, where the length of `outarr`
  25.        is equal to the size of the return value of `func1d`.  If `func1d`
  26.        returns a scalar `outarr` will have one fewer dimensions than `arr`.
  27.  
  28.    See Also
  29.    --------
  30.    apply_over_axes : Apply a function repeatedly over multiple axes.
  31.  
  32.    Examples
  33.    --------
  34.    >>> def my_func(a):
  35.    ...     \"\"\"Average first and last element of a 1-D array\"\"\"
  36.    ...     return (a[0] + a[-1]) * 0.5
  37.    >>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
  38.    >>> np.apply_along_axis(my_func, 0, b)
  39.    array([ 4.,  5.,  6.])
  40.    >>> np.apply_along_axis(my_func, 1, b)
  41.    array([ 2.,  5.,  8.])
  42.  
  43.    For a function that doesn't return a scalar, the number of dimensions in
  44.    `outarr` is the same as `arr`.
  45.  
  46.    >>> b = np.array([[8,1,7], [4,3,9], [5,2,6]])
  47.    >>> np.apply_along_axis(sorted, 1, b)
  48.    array([[1, 7, 8],
  49.           [3, 4, 9],
  50.           [2, 5, 6]])
  51.  
  52.    """
  53.     arr = asarray(arr)
  54.     nd = arr.ndim
  55.     if axis < 0:
  56.         axis += nd
  57.     if (axis >= nd):
  58.         raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d."
  59.             % (axis, nd))
  60.     ind = [0]*(nd-1)
  61.     i = zeros(nd, 'O')
  62.     indlist = list(range(nd))
  63.     indlist.remove(axis)
  64.     i[axis] = slice(None, None)
  65.     outshape = asarray(arr.shape).take(indlist)
  66.     i.put(indlist, ind)
  67.     print
  68.     res = func1d(arr[tuple(i.tolist())],*args,**moreargs)
  69.     #  if res is a number, then we have a smaller output array
  70.     if isscalar(res):
  71.         outarr = zeros(outshape, asarray(res).dtype)
  72.         outarr[tuple(ind)] = res
  73.         Ntot = product(outshape)
  74.         k = 1
  75.         while k < Ntot:
  76.             # increment the index
  77.             ind[-1] += 1
  78.             n = -1
  79.             while (ind[n] >= outshape[n]) and (n > (1-nd)):
  80.                 ind[n-1] += 1
  81.                 ind[n] = 0
  82.                 n -= 1
  83.             i.put(indlist, ind)
  84.             res = func1d(arr[tuple(i.tolist())],*args,**moreargs)
  85.             outarr[tuple(ind)] = res
  86.             k += 1
  87.         return outarr
  88.     else:
  89.         Ntot = product(outshape)
  90.         holdshape = outshape
  91.         outshape = list(arr.shape)
  92.         outshape[axis] = len(res)
  93.         outarr = zeros(outshape, asarray(res).dtype)
  94.         outarr[tuple(i.tolist())] = res
  95.         k = 1
  96.         while k < Ntot:
  97.             # increment the index
  98.             ind[-1] += 1
  99.             n = -1
  100.             while (ind[n] >= holdshape[n]) and (n > (1-nd)):
  101.                 ind[n-1] += 1
  102.                 ind[n] = 0
  103.                 n -= 1
  104.             i.put(indlist, ind)
  105.             res = func1d(arr[tuple(i.tolist())],*args,**moreargs)
  106.             outarr[tuple(i.tolist())] = res
  107.             k += 1
  108.         return outarr
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement