Schef

scipy dot v2

Nov 8th, 2012
570
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/bin/env python
  2.  
  3. """
  4. A drop in replacement for numpy.dot
  5.  
  6. Avoid temporary copies of non C-contiguous arrays
  7. """
  8.  
  9. import numpy as np
  10. import scipy.linalg as sp
  11. from numpy.testing import assert_equal
  12.  
  13.  
  14. def dot(A, B, out=None):
  15.     """ A drop in replaement for numpy.dot
  16.    Computes A.B optimized using fblas call """
  17.     if A.ndim != 2 or B.ndim != 2:
  18.         raise ValueError("only 2D numpy arrays are supported")
  19.  
  20.     gemm = sp.get_blas_funcs('gemm', arrays=(A, B))
  21.  
  22.     if out is None:
  23.         lda, x, y, ldb = A.shape + B.shape
  24.         if x != y:
  25.             raise ValueError("matrices are not aligned")
  26.         dtype = np.max([x.dtype for x in (A, B)])
  27.         out = np.empty((lda, ldb), dtype, order='C')
  28.  
  29.     if A.flags.c_contiguous and B.flags.c_contiguous:
  30.         gemm(alpha=1., a=A.T, b=B.T,
  31.                 c=out.T, overwrite_c=True)
  32.     if A.flags.c_contiguous and B.flags.f_contiguous:
  33.         gemm(alpha=1., a=A.T, b=B, trans_a=True,
  34.                 c=out.T, overwrite_c=True)
  35.     if A.flags.f_contiguous and B.flags.c_contiguous:
  36.         gemm(alpha=1., a=A, b=B.T, trans_b=True,
  37.                 c=out.T, overwrite_c=True)
  38.     if A.flags.f_contiguous and B.flags.f_contiguous:
  39.         gemm(alpha=1., a=A, b=B, trans_a=True, trans_b=True,
  40.                 c=out.T, overwrite_c=True)
  41.     return out
  42.  
  43.  
  44. def test_dot():
  45.     A = np.random.randn(1000, 1000)
  46.     assert_equal(A.dot(A), dot(A, A))
  47.     assert_equal(A.dot(A.T), dot(A, A.T))
  48.     assert_equal(A.T.dot(A), dot(A.T, A))
  49.     assert_equal(A.T.dot(A.T), dot(A.T, A.T))
  50.     assert(dot(A, A).flags.c_contiguous)
  51.  
  52.  
  53. def test_to_fix():
  54.     """ 1d array, complex and 3d """
  55.     v = np.random.randn(1000)
  56.     dot(v, v)
  57.     c = np.asarray(np.random.randn(100, 100), np.complex)
  58.     dot(c, c)
  59.     t = np.random.randn(2, 2, 3)
  60.     dot(t, t)
RAW Paste Data