View difference between Paste ID: M8TfbURi and QrRk0kEf
SHOW: | | - or go back to the newest paste.
1
#!/usr/bin/env python
2
3
"""
4-
    import scipy.linalg as sp
4+
A drop in replacement for numpy.dot
5-
    gemm = sp.get_blas_funcs('gemm', arrays=(A,B))
5+
6
Avoid temporary copies of non C-contiguous arrays
7
"""
8
9
import numpy as np
10
import scipy.linalg as sp
11
from numpy.testing import assert_equal
12-
        out = np.empty((lda, ldb), dtype, order='F')
12+
13
14
def dot(A, B, out=None):
15-
        gemm(alpha=1., a=A.T, b=B.T, trans_a=True, trans_b=True, c=out, overwrite_c=True)
15+
16
    Computes A.B optimized using fblas call """
17-
        gemm(alpha=1., a=A.T, b=B, trans_a=True, c=out, overwrite_c=True)
17+
    if A.ndim != 2 or B.ndim != 2:
18
        raise ValueError("only 2D numpy arrays are supported")
19-
        gemm(alpha=1., a=A, b=B.T, trans_b=True, c=out, overwrite_c=True)
19+
20
    gemm = sp.get_blas_funcs('gemm', arrays=(A, B))
21-
        gemm(alpha=1., a=A, b=B, c=out, overwrite_c=True)
21+
22-
    return out
22+
23
        lda, x, y, ldb = A.shape + B.shape
24
        if x != y:
25
            raise ValueError("matrices are not aligned")
26
        dtype = np.max([x.dtype for x in (A, B)])
27
        out = np.empty((lda, ldb), dtype, order='C')
28
29
    if A.flags.c_contiguous and B.flags.c_contiguous:
30
        gemm(alpha=1., a=A.T, b=B.T,
31
                c=out.T, overwrite_c=True)
32
    if A.flags.c_contiguous and B.flags.f_contiguous:
33
        gemm(alpha=1., a=A.T, b=B, trans_a=True,
34
                c=out.T, overwrite_c=True)
35
    if A.flags.f_contiguous and B.flags.c_contiguous:
36
        gemm(alpha=1., a=A, b=B.T, trans_b=True,
37
                c=out.T, overwrite_c=True)
38
    if A.flags.f_contiguous and B.flags.f_contiguous:
39
        gemm(alpha=1., a=A, b=B, trans_a=True, trans_b=True,
40
                c=out.T, overwrite_c=True)
41
    return out
42
43
44
def test_dot():
45
    A = np.random.randn(1000, 1000)
46
    assert_equal(A.dot(A), dot(A, A))
47
    assert_equal(A.dot(A.T), dot(A, A.T))
48
    assert_equal(A.T.dot(A), dot(A.T, A))
49
    assert_equal(A.T.dot(A.T), dot(A.T, A.T))
50
    assert(dot(A, A).flags.c_contiguous)
51
52
53
def test_to_fix():
54
    """ 1d array, complex and 3d """
55
    v = np.random.randn(1000)
56
    dot(v, v)
57
    c = np.asarray(np.random.randn(100, 100), np.complex)
58
    dot(c, c)
59
    t = np.random.randn(2, 2, 3)
60
    dot(t, t)