Advertisement
Dmitrey15

concatenate() for numpypy

Jan 26th, 2012
192
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.74 KB | None | 0 0
  1. try:
  2.     import numpypy as N
  3. except:
  4.     import numpy as N
  5.  
  6. TypePriorities = {
  7.                   N.uint8:0,
  8.                   N.uint16:1,
  9.                   N.uint32:2,
  10.                   N.unsignedinteger:3,
  11.                   N.uint64:4,
  12.                  
  13.                   N.int8:10,
  14.                   N.int16:11,
  15.                   N.int32:12,
  16.                   int:13,
  17.                   N.int64:14,
  18.                  
  19.                   N.float32:20,
  20.                   N.float64:21,
  21.                  
  22.                   # TODO: add complex types when they will be implemented
  23.                  
  24.                   object: N.inf
  25.                   }
  26.                  
  27. if 'float128' in N.__dict__:
  28.     TypePriorities[N.float128] = 22
  29.  
  30. if 'complex' in N.__dict__:
  31.     assert 0, 'you should update TypePriorities with complex types'
  32.  
  33. reversedTypePriorities = dict((v,k) for k, v in TypePriorities.items())
  34.  
  35. def concatenate(arrays, axis=0):
  36.     '''
  37.    concatenate(...)
  38.    concatenate((a1, a2, ...), axis=0)
  39.    
  40.    Join a sequence of arrays together.
  41.    
  42.    Parameters
  43.    ----------
  44.    a1, a2, ... : sequence of array_like
  45.        The arrays must have the same shape, except in the dimension
  46.        corresponding to `axis` (the first, by default).
  47.    axis : int, optional
  48.        The axis along which the arrays will be joined.  Default is 0.
  49.    
  50.    Returns
  51.    -------
  52.    res : ndarray
  53.        The concatenated array.
  54.    '''
  55.     Arrays = [(arr if isinstance(arr, N.ndarray) else N.array(arr)) for arr in arrays]
  56.     Ndims = [arr.ndim for arr in Arrays]
  57.     assert all([Ndims[i] == Ndims[0] for i in range(1, len(Ndims))]), 'arrays must be of same dimension'
  58.     ndim = Ndims[0]
  59.    
  60.     #cpython numpy-style: if ndim == 1: axis = 0
  61.     # our approach:
  62.     assert axis < ndim, 'incorrect axis: must be less than array.ndim'
  63.    
  64.     # TODO: in numpy concatenation with numbers  or zero-shape ndarrays gives error,
  65.     # maybe it's better to convert them automatically to ndarrays of correct shape (if possible)?
  66.    
  67.     TypePriority = TypePriorities.get(Arrays[0].dtype.type, N.inf)
  68.     for arr in Arrays[1:]:
  69.         if arr.ndim != ndim:
  70.             raise ValueError('arrays must have same number of dimensions')
  71.         tmp = TypePriorities.get(arr.dtype.type, N.inf)
  72.         if tmp > TypePriority: TypePriority = tmp # I think it should work faster than TypePriority = max(TypePriority, tmp)
  73.     dtype = reversedTypePriorities[TypePriority]
  74.        
  75.     Shapes = [arr.shape for arr in Arrays]
  76.     S = Shapes[0]
  77.     result_length_along_involved_axis = Arrays[0].shape[axis]
  78.     ArrConcatAxisLengths = [result_length_along_involved_axis]
  79.    
  80.     for s in Shapes[1:]:
  81.         if s[:axis] != S[:axis] or s[axis+1:] != S[axis+1:]:
  82.             raise ValueError('array dimensions must agree except for d_0')
  83.         ArrConcatAxisLengths.append(s[axis])
  84.         result_length_along_involved_axis += s[axis]
  85.  
  86.     S = list(S)
  87.     S[axis] = result_length_along_involved_axis
  88.     r = N.empty(S, dtype)
  89.     S = [slice(None)]*axis
  90.     ind = 0
  91.     for i, arr in enumerate(Arrays):
  92.         tmp = r[S + [slice(ind, ind+ArrConcatAxisLengths[i])]]
  93.         tmp[:] = N.array(arr, dtype) if arr.dtype != dtype else arr.copy()
  94.         ind += ArrConcatAxisLengths[i]
  95.  
  96.     return r
  97.    
  98.  
  99. # tests
  100. assert N.all(concatenate(((1, 2), (3, 4))) == N.array((1, 2, 3, 4)))
  101. assert N.all(concatenate((N.array((1.0, 2.0)).reshape(-1, 1), N.array((3, 4)).reshape(-1, 1)), axis=1) ==N.array(((1, 3), (2, 4))))
  102. a = N.array([[1, 2], [3, 4]])
  103. b = N.array([[5, 6]])
  104. assert N.all(concatenate((a, b), axis=0) == N.array([[1, 2], [3, 4], [5, 6]]))
  105. assert N.all(concatenate((a, b.T), axis=1) == N.array([[1, 2, 5], [3, 4, 6]]))
  106. print('passed')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement