Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy
- import sys
- def getMatrixFromFile(fileName):
- inputFile = open(fileName, 'r')
- return numpy.loadtxt(inputFile)
- def Strassen(X,Y):
- n = len(X)
- if n == 2:
- return numpy.matrix([[X[0][0]*Y[0][0] + X[0][1]*Y[1][0], X[0][0]*Y[0][1] + X[0][1]*Y[1][1]],
- [X[1][0]*Y[0][0] + X[1][1]*Y[1][0], X[1][0]*Y[0][1] + X[1][1]*Y[1][1]]])
- A = X[0:n/2, 0:n/2]
- B = X[0:n/2, n/2:n]
- C = X[n/2:n, 0:n/2]
- D = X[n/2:n,n/2:n]
- E = Y[0:n/2, 0:n/2]
- F = Y[0:n/2, n/2:n]
- G = Y[n/2:n, 0:n/2]
- H = Y[n/2:n,n/2:n]
- P1 = Strassen(A,F-H)
- P2 = Strassen(A+B,H)
- P3 = Strassen(C+D,E)
- P4 = Strassen(D,G-E)
- P5 = Strassen(A+D,E+H)
- P6 = Strassen(B-D,G+H)
- P7 = Strassen(A-C,E+F)
- R = numpy.zeros((n,n))
- R[0:n/2, 0:n/2] = P5 + P4 - P2 + P6
- R[0:n/2, n/2:n] = P1 + P2
- R[n/2:n, 0:n/2] = P3 + P4
- R[n/2:n,n/2:n] = P1 + P5 - P3 - P7
- return R
- def main():
- X = getMatrixFromFile(sys.argv[1])
- Y = getMatrixFromFile(sys.argv[2])
- print (Strassen(X,Y))
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement