SHOW:
|
|
- or go back to the newest paste.
1 | #!/usr/bin/env python | |
2 | ||
3 | """ | |
4 | - | import scipy.linalg as sp |
4 | + | A drop in replacement for numpy.dot |
5 | - | gemm = sp.get_blas_funcs('gemm', arrays=(A,B)) |
5 | + | |
6 | Avoid temporary copies of non C-contiguous arrays | |
7 | """ | |
8 | ||
9 | import numpy as np | |
10 | import scipy.linalg as sp | |
11 | from numpy.testing import assert_equal | |
12 | - | out = np.empty((lda, ldb), dtype, order='F') |
12 | + | |
13 | ||
14 | def dot(A, B, out=None): | |
15 | - | gemm(alpha=1., a=A.T, b=B.T, trans_a=True, trans_b=True, c=out, overwrite_c=True) |
15 | + | |
16 | Computes A.B optimized using fblas call """ | |
17 | - | gemm(alpha=1., a=A.T, b=B, trans_a=True, c=out, overwrite_c=True) |
17 | + | if A.ndim != 2 or B.ndim != 2: |
18 | raise ValueError("only 2D numpy arrays are supported") | |
19 | - | gemm(alpha=1., a=A, b=B.T, trans_b=True, c=out, overwrite_c=True) |
19 | + | |
20 | gemm = sp.get_blas_funcs('gemm', arrays=(A, B)) | |
21 | - | gemm(alpha=1., a=A, b=B, c=out, overwrite_c=True) |
21 | + | |
22 | - | return out |
22 | + | |
23 | lda, x, y, ldb = A.shape + B.shape | |
24 | if x != y: | |
25 | raise ValueError("matrices are not aligned") | |
26 | dtype = np.max([x.dtype for x in (A, B)]) | |
27 | out = np.empty((lda, ldb), dtype, order='C') | |
28 | ||
29 | if A.flags.c_contiguous and B.flags.c_contiguous: | |
30 | gemm(alpha=1., a=A.T, b=B.T, | |
31 | c=out.T, overwrite_c=True) | |
32 | if A.flags.c_contiguous and B.flags.f_contiguous: | |
33 | gemm(alpha=1., a=A.T, b=B, trans_a=True, | |
34 | c=out.T, overwrite_c=True) | |
35 | if A.flags.f_contiguous and B.flags.c_contiguous: | |
36 | gemm(alpha=1., a=A, b=B.T, trans_b=True, | |
37 | c=out.T, overwrite_c=True) | |
38 | if A.flags.f_contiguous and B.flags.f_contiguous: | |
39 | gemm(alpha=1., a=A, b=B, trans_a=True, trans_b=True, | |
40 | c=out.T, overwrite_c=True) | |
41 | return out | |
42 | ||
43 | ||
44 | def test_dot(): | |
45 | A = np.random.randn(1000, 1000) | |
46 | assert_equal(A.dot(A), dot(A, A)) | |
47 | assert_equal(A.dot(A.T), dot(A, A.T)) | |
48 | assert_equal(A.T.dot(A), dot(A.T, A)) | |
49 | assert_equal(A.T.dot(A.T), dot(A.T, A.T)) | |
50 | assert(dot(A, A).flags.c_contiguous) | |
51 | ||
52 | ||
53 | def test_to_fix(): | |
54 | """ 1d array, complex and 3d """ | |
55 | v = np.random.randn(1000) | |
56 | dot(v, v) | |
57 | c = np.asarray(np.random.randn(100, 100), np.complex) | |
58 | dot(c, c) | |
59 | t = np.random.randn(2, 2, 3) | |
60 | dot(t, t) |