Advertisement
Dmitrey15

getShape

Jan 26th, 2012
141
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.54 KB | None | 0 0
  1. import numpy as np
  2.  
  3. if 'matrix' in np.__dict__:
  4.     matrix = np.matrix  
  5. else:
  6.     class matrix(np.ndarray):
  7.         pass
  8.    
  9.    
  10.  
  11. def getShape(arrLikeItem):
  12.     # doesn't work with sparse matrices, but it doesn't matter -
  13.     # asarray().shape will yield also incorrect result in the case,
  14.     # treating sparse matrices as objects
  15.     shape = []
  16.     tmp = arrLikeItem
  17.     while 1:
  18.         if isinstance(tmp, np.ndarray):
  19.             if tmp.shape == ():
  20.                 break
  21.             elif tmp.shape[0] == 0:
  22.                 shape += tmp.shape
  23.                 break
  24.             elif isinstance(tmp, matrix):
  25.                 shape += tmp.shape
  26.                 break
  27.             else:
  28.                 shape.append(tmp.shape[0])
  29.                 tmp = tmp[0]
  30.         elif type(tmp) in (tuple, list):
  31.             shape.append(len(tmp))
  32.             tmp = tmp[0]
  33.         else:
  34.             # a number or another item which should terminate
  35.             break
  36.     return tuple(shape)
  37.    
  38. # version that should work with RPython:
  39. def getShapeRPython(arrLikeItem):
  40.     # doesn't work with sparse matrices, but it doesn't matter -
  41.     # asarray().shape will yield also incorrect result in the case,
  42.     # treating sparse matrices as objects
  43.     shape = []
  44.     Tmp = [arrLikeItem]
  45.    
  46.     while 1:
  47.         tmp = Tmp[-1]
  48.         if isinstance(tmp, np.ndarray):
  49.             if tmp.shape == ():
  50.                 break
  51.             elif tmp.shape[0] == 0:
  52.                 shape += tmp.shape
  53.                 break
  54.             elif isinstance(tmp, matrix):
  55.                 shape += tmp.shape
  56.                 break            
  57.             else:
  58.                 shape.append(tmp.shape[0])
  59.                 Tmp.append(tmp[0])
  60.         elif type(tmp) in (tuple, list):
  61.             shape.append(len(tmp))
  62.             Tmp.append(tmp[0])
  63.         else:
  64.             # a number or another item which should terminate
  65.             break
  66.     return tuple(shape)
  67.  
  68. for a in [
  69.           1,
  70.           [1],
  71.           [1, 2, 3],
  72.           np.array((1, 2, 3)),
  73.           np.array((1, 2, 3), object),
  74.           np.matrix(np.array((1, 2, 3))),
  75.           np.ones((2, 3, 4)),
  76.           np.ones((0, 3, 4)),
  77.           (((1, 2, 3), (4, 5, 6)), ((7, 8, 9), (10, 11, 12))),
  78.           np.array((((1, 2, 3), (4, 5, 6)), ((7, 8, 9), (10, 11, 12)))),
  79.           ((np.array((1, 2, 3)), np.array((4, 5, 6))), (np.array((7, 8, 9)), np.array((10, 11, 12))))
  80.           ]:
  81.     assert np.asarray(a).shape == getShape(a) == getShapeRPython(a)
  82. print('passed')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement