Not a member of Pastebin yet?
                        Sign Up,
                        it unlocks many cool features!                    
                - #!/usr/bin/env python
 - """
 - A drop in replacement for numpy.dot
 - Avoid temporary copies of non C-contiguous arrays
 - """
 - import numpy as np
 - import scipy.linalg as sp
 - from numpy.testing import assert_equal
 - def dot(A, B, out=None):
 - """ A drop in replaement for numpy.dot
 - Computes A.B optimized using fblas call """
 - if A.ndim != 2 or B.ndim != 2:
 - raise ValueError("only 2D numpy arrays are supported")
 - gemm = sp.get_blas_funcs('gemm', arrays=(A, B))
 - if out is None:
 - lda, x, y, ldb = A.shape + B.shape
 - if x != y:
 - raise ValueError("matrices are not aligned")
 - dtype = np.max([x.dtype for x in (A, B)])
 - out = np.empty((lda, ldb), dtype, order='C')
 - if A.flags.c_contiguous and B.flags.c_contiguous:
 - gemm(alpha=1., a=A.T, b=B.T,
 - c=out.T, overwrite_c=True)
 - if A.flags.c_contiguous and B.flags.f_contiguous:
 - gemm(alpha=1., a=A.T, b=B, trans_a=True,
 - c=out.T, overwrite_c=True)
 - if A.flags.f_contiguous and B.flags.c_contiguous:
 - gemm(alpha=1., a=A, b=B.T, trans_b=True,
 - c=out.T, overwrite_c=True)
 - if A.flags.f_contiguous and B.flags.f_contiguous:
 - gemm(alpha=1., a=A, b=B, trans_a=True, trans_b=True,
 - c=out.T, overwrite_c=True)
 - return out
 - def test_dot():
 - A = np.random.randn(1000, 1000)
 - assert_equal(A.dot(A), dot(A, A))
 - assert_equal(A.dot(A.T), dot(A, A.T))
 - assert_equal(A.T.dot(A), dot(A.T, A))
 - assert_equal(A.T.dot(A.T), dot(A.T, A.T))
 - assert(dot(A, A).flags.c_contiguous)
 - def test_to_fix():
 - """ 1d array, complex and 3d """
 - v = np.random.randn(1000)
 - dot(v, v)
 - c = np.asarray(np.random.randn(100, 100), np.complex)
 - dot(c, c)
 - t = np.random.randn(2, 2, 3)
 - dot(t, t)
 
Advertisement
 
                    Add Comment                
                
                        Please, Sign In to add comment