• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Matrix multiplication

a guest Oct 18th, 2016 1,074 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import java.util.ArrayList;
2. import java.util.Arrays;
3. import java.util.Collections;
4. import java.util.Random;
5. import java.util.concurrent.ForkJoinPool;
7.
8. /**
9.  * The {@code MatrixMultiplication} class implements
10.  * fast multiplication 2 matrices at each other.
11.  * The {@code MatrixMultiplication} uses Strassen algorithm and
12.  * parallelize it with the {@link java.util.concurrent.ForkJoinPool}
13.  *
14.  * @author Evgeny Usov
15.  * @author Alexey Falko
16.  */
17. public class MatrixMultiplication {
18.
19.     //******************************************************************************************
20.
21.     public static int[][] multiply(int[][] a, int[][] b) {
22.
23.         int rowsA = a.length;
24.         int columnsB = b[0].length;
25.         int columnsA_rowsB = a[0].length;
26.
27.         int[][] c = new int[rowsA][columnsB];
28.
29.         for (int i = 0; i < rowsA; i++) {
30.             for (int j = 0; j < columnsB; j++) {
31.                 int sum = 0;
32.                 for (int k = 0; k < columnsA_rowsB; k++) {
33.                     sum += a[i][k] * b[k][j];
34.                 }
35.                 c[i][j] = sum;
36.             }
37.         }
38.
39.         return c;
40.     }
41.
42.     //******************************************************************************************
43.
44.     public static int[][] multiplyTransposed(int[][] a, int[][] b) {
45.
46.         int rowsA = a.length;
47.         int columnsB = b[0].length;
48.         int columnsA_rowsB = a[0].length;
49.
50.         int columnB[] = new int[columnsA_rowsB];
51.         int[][] c = new int[rowsA][columnsB];
52.
53.
54.         for (int j = 0; j < columnsB; j++) {
55.             for (int k = 0; k < columnsA_rowsB; k++) {
56.                 columnB[k] = b[k][j];
57.             }
58.
59.             for (int i = 0; i < rowsA; i++) {
60.                 int rowA[] = a[i];
61.                 int sum = 0;
62.                 for (int k = 0; k < columnsA_rowsB; k++) {
63.                     sum += rowA[k] * columnB[k];
64.                 }
65.                 c[i][j] = sum;
66.             }
67.         }
68.
69.         return c;
70.     }
71.
72.     //******************************************************************************************
73.
74.     private static int[][] summation(int[][] a, int[][] b) {
75.
76.         int n = a.length;
77.         int m = a[0].length;
78.         int[][] c = new int[n][m];
79.
80.         for (int i = 0; i < n; i++) {
81.             for (int j = 0; j < m; j++) {
82.                 c[i][j] = a[i][j] + b[i][j];
83.             }
84.         }
85.         return c;
86.     }
87.
88.     //******************************************************************************************
89.
90.     private static int[][] subtraction(int[][] a, int[][] b) {
91.
92.         int n = a.length;
93.         int m = a[0].length;
94.         int[][] c = new int[n][m];
95.
96.         for (int i = 0; i < n; i++) {
97.             for (int j = 0; j < m; j++) {
98.                 c[i][j] = a[i][j] - b[i][j];
99.             }
100.         }
101.         return c;
102.     }
103.
104.     //******************************************************************************************
105.
106.     private static int[][] addition2SquareMatrix(int[][] a, int n) {
107.
108.         int[][] result = new int[n][n];
109.
110.         for (int i = 0; i < a.length; i++) {
111.             System.arraycopy(a[i], 0, result[i], 0, a[i].length);
112.         }
113.         return result;
114.     }
115.
116.     //******************************************************************************************
117.
118.     private static int[][] getSubmatrix(int[][] a, int n, int m) {
119.         int[][] result = new int[n][m];
120.         for (int i = 0; i < n; i++) {
121.             System.arraycopy(a[i], 0, result[i], 0, m);
122.         }
123.         return result;
124.     }
125.
126.     //******************************************************************************************
127.
128.     private static void splitMatrix(int[][] a, int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
129.
130.         int n = a.length >> 1;
131.
132.         for (int i = 0; i < n; i++) {
133.             System.arraycopy(a[i], 0, a11[i], 0, n);
134.             System.arraycopy(a[i], n, a12[i], 0, n);
135.             System.arraycopy(a[i + n], 0, a21[i], 0, n);
136.             System.arraycopy(a[i + n], n, a22[i], 0, n);
137.         }
138.     }
139.
140.     //******************************************************************************************
141.
142.     private static int[][] collectMatrix(int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
143.
144.         int n = a11.length;
145.         int[][] a = new int[n << 1][n << 1];
146.
147.         for (int i = 0; i < n; i++) {
148.             System.arraycopy(a11[i], 0, a[i], 0, n);
149.             System.arraycopy(a12[i], 0, a[i], n, n);
150.             System.arraycopy(a21[i], 0, a[i + n], 0, n);
151.             System.arraycopy(a22[i], 0, a[i + n], n, n);
152.         }
153.
154.         return a;
155.     }
156.
157.     //******************************************************************************************
158.
159.     /**
161.      * algorithm by Strassen
162.      */
164.         private static final long serialVersionUID = -433764214304695286L;
165.
166.         int n;
167.         int[][] a;
168.         int[][] b;
169.
170.         public myRecursiveTask(int[][] a, int[][] b, int n) {
171.             this.a = a;
172.             this.b = b;
173.             this.n = n;
174.         }
175.
176.         /**
177.          * @return the integer matrix by
178.          * multiplying 2 matrices at each other
179.          */
180.         @Override
181.         protected int[][] compute() {
182.             if (n <= 128) {
183.                 return multiplyTransposed(a, b);
184.             }
185.
186.             n >>= 1;
187.
188.             int[][] a11 = new int[n][n];
189.             int[][] a12 = new int[n][n];
190.             int[][] a21 = new int[n][n];
191.             int[][] a22 = new int[n][n];
192.
193.             int[][] b11 = new int[n][n];
194.             int[][] b12 = new int[n][n];
195.             int[][] b21 = new int[n][n];
196.             int[][] b22 = new int[n][n];
197.
198.             splitMatrix(a, a11, a12, a21, a22);
199.             splitMatrix(b, b11, b12, b21, b22);
200.
208.
216.
224.
225.             int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
226.             int[][] c12 = summation(p3, p5);
227.             int[][] c21 = summation(p2, p4);
228.             int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));
229.
230.             return collectMatrix(c11, c12, c21, c22);
231.         }
232.
233.     }
234.
235.     //******************************************************************************************
236.
237.     public static int[][] multiStrassenForkJoin(int[][] a, int[][] b) {
238.
239.         int nn = getNewDimension(a, b);
240.         int[][] a_n = addition2SquareMatrix(a, nn);
241.         int[][] b_n = addition2SquareMatrix(b, nn);
242.
244.         ForkJoinPool pool = new ForkJoinPool();
246.
247.         return getSubmatrix(fastFJ, a.length, b[0].length);
248.     }
249.
250.     //******************************************************************************************
251.
252.     @Deprecated
253.     /**
255.      * algorithm by Strassen
256.      * */
257.     private static int[][] multiStrassen(int[][] a, int[][] b, int n) {
258.         if (n <= 128) {
259.             return multiplyTransposed(a, b);
260.         }
261.
262.         n = n >> 1;
263.         ArrayList<Object> objects = new ArrayList<>();
264.
265.         int[][] a11 = new int[n][n];
266.         int[][] a12 = new int[n][n];
267.         int[][] a21 = new int[n][n];
268.         int[][] a22 = new int[n][n];
269.
270.         int[][] b11 = new int[n][n];
271.         int[][] b12 = new int[n][n];
272.         int[][] b21 = new int[n][n];
273.         int[][] b22 = new int[n][n];
274.
275.         splitMatrix(a, a11, a12, a21, a22);
276.         splitMatrix(b, b11, b12, b21, b22);
277.
278.         int[][] p1 = multiStrassen(summation(a11, a22), summation(b11, b22), n);
279.         int[][] p2 = multiStrassen(summation(a21, a22), b11, n);
280.         int[][] p3 = multiStrassen(a11, subtraction(b12, b22), n);
281.         int[][] p4 = multiStrassen(a22, subtraction(b21, b11), n);
282.         int[][] p5 = multiStrassen(summation(a11, a12), b22, n);
283.         int[][] p6 = multiStrassen(subtraction(a21, a11), summation(b11, b12), n);
284.         int[][] p7 = multiStrassen(subtraction(a12, a22), summation(b21, b22), n);
285.
286.         int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
287.         int[][] c12 = summation(p3, p5);
288.         int[][] c21 = summation(p2, p4);
289.         int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));
290.
291.         return collectMatrix(c11, c12, c21, c22);
292.     }
293.
294.     //******************************************************************************************
295.
296.     private static int log2(int x) {
297.         int result = 1;
298.         while ((x >>= 1) != 0) {
299.             result++;
300.         }
301.
302.         return result;
303.     }
304.
305.     //******************************************************************************************
306.
307.     private static int getNewDimension(int[][] a, int[][] b) {
308.         return 1 << log2(Collections.max(Arrays.asList(a.length, a[0].length, b[0].length)));
309.     }
310.
311.     //******************************************************************************************
312.
313.     public static int[][] randomMatrix(int m, int n) {
314.         int[][] a = new int[m][n];
315.         for (int i = 0; i < m; i++) {
316.             for (int j = 0; j < n; j++) {
317.                 a[i][j] = new Random().nextInt(100);
318.             }
319.         }
320.         return a;
321.     }
322.
323.     //******************************************************************************************
324.
325.     public static void printMatrix(int[][] a) {
326.         for (int i = 0; i < a[0].length; i++) {
327.             System.out.print("-------");
328.         }
329.         System.out.println();
330.         for (int[] anA : a) {
331.             System.out.print("|");
332.             for (int anAnA : anA) {
333.                 System.out.printf("%4d |", anAnA);
334.             }
335.
336.             System.out.println();
337.             for (int i = 0; i < a[0].length; i++) {
338.                 System.out.print("-------");
339.             }
340.             System.out.println();
341.         }
342.     }
343.
344.     //******************************************************************************************
345.
346.     public static void test(int n, int m, int l) {
347.
348.         int[][] a = randomMatrix(n, l);
349.         int[][] b = randomMatrix(l, m);
350.         long start, end;
351.
352.         //****************************************
353.         //  TEST 1
354.         start = System.currentTimeMillis();
355.         int[][] matrixByStrassenFJ = multiStrassenForkJoin(a, b);
356.         end = System.currentTimeMillis();
357.         System.out.printf("Strassen Fork-Join Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
358.         //****************************************
359.
360.         //****************************************
361.         //  TEST 2
362.         start = System.currentTimeMillis();
363.         int nn = getNewDimension(a, b);
364.
365.         int[][] a_n = addition2SquareMatrix(a, nn);
366.         int[][] b_n = addition2SquareMatrix(b, nn);
367.
368.         int[][] temp = multiStrassen(a_n, b_n, nn);
369.         int[][] matrixByStrassen = getSubmatrix(temp, n, m);
370.         end = System.currentTimeMillis();
371.         System.out.printf("Strassen Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
372.         //****************************************
373.
374.         //****************************************
375.         //  TEST 3
376.             start = System.currentTimeMillis();
377.             int[][] matrixByUsual = multiply(a, b);
378.             end = System.currentTimeMillis();
379.             System.out.printf("Usual Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
380.         //****************************************
381.
382.         //****************************************
383.         //  TEST 4
384.         start = System.currentTimeMillis();
385.         int[][] matrixByUsualTransposed = multiplyTransposed(a, b);
386.         end = System.currentTimeMillis();
387.         System.out.printf("Usual Multiply Transposed [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
388.         //****************************************
389.
390.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByStrassen));
391.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByUsual));
392.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByUsualTransposed));
393.
394.     }
395.
396.     //******************************************************************************************
397.
398.     private static class Multipliers {
399.         private final int[][] matrixA;
400.         private final int[][] matrixB;
401.
402.         public Multipliers(int[][] a, int[][] b) {
403.             matrixA = a;
404.             matrixB = b;
405.         }
406.
407.         public int[][] getMatrixB() {
408.             return matrixB;
409.         }
410.
411.         public int[][] getMatrixA() {
412.             return matrixA;
413.         }
414.     }
415.
416.
417.     //******************************************************************************************
418.     private static Multipliers validation(String[] args) {
419.         int rowsA;
420.         int columnsA;
421.         int rowsB;
422.         int columnsB;
423.
424.         if (args.length < 6) {
425.             throw new IllegalArgumentException("Too few parameters. Should be not less then 6.");
426.         }
427.
428.         /*
429.          * Note: method parseInt returns NumberFormatException if the argument String
430.          * does not contain a parsable int
431.          * */
432.
433.         rowsA = Integer.parseInt(args[0]);
434.         columnsA = Integer.parseInt(args[1]);
435.         rowsB = Integer.parseInt(args[2]);
436.         columnsB = Integer.parseInt(args[3]);
437.
438.         if (rowsA <= 0 || columnsA <= 0 || rowsB <= 0 || columnsB <= 0) {
439.             throw new IllegalArgumentException("Array dimension can't be negative or zero");
440.         }
441.
442.         if (args.length - (rowsA * columnsA + rowsB * columnsB) != 4) {
443.             throw new IllegalArgumentException("Incorrect number of values to initialize two arrays.");
444.         }
445.
446.         if (columnsA != rowsB) {
447.             throw new IllegalArgumentException("The number of columns of the matrix A is not equal to the number of rows of the matrix B.");
448.         }
449.
450.         int[][] a = new int[rowsA][columnsA];
451.         int[][] b = new int[rowsB][columnsB];
452.
453.         int k = 4;
454.
455.         //***************************************
456.
457.         for (int i = 0; i < a.length; i++) {
458.             for (int j = 0; j < a[0].length; j++) {
459.                 a[i][j] = Integer.parseInt(args[k++]);
460.             }
461.         }
462.
463.         //***************************************
464.
465.         for (int i = 0; i < b.length; i++) {
466.             for (int j = 0; j < b[0].length; j++) {
467.                 b[i][j] = Integer.parseInt(args[k++]);
468.             }
469.         }
470.
471.         //***************************************
472.
473.         return new Multipliers(a, b);
474.     }
475.
476.     //******************************************************************************************
477.
478.     /*
479.         Матрицы подаются как аргументы программы в следующем формате
480.         N M X Y A_1_1 ... A_N_M B_1_1 ... B_X_Y
481.
482.         где N и M - размерность первой матрицы A,
483.         A_1_1 ... A_N_M - элементы матрицы A,
484.         X и Y - размерность второй матрицы B,
485.         B_1_1 ... B_X_Y - элементы матрицы B.
486.
487.         Например, для умножения единичной матрицы размером 2 на 2 на вектор (-1, -1)
488.         необходимо на вход приложению пожать следующие аргументы
489.         2 2 2 1 1 0 0 1 -1 -1
490.         В консоль должен распечататься вектор:
491.         -1
492.         -1
493.     */
494.     public static void main(String[] args) {
495.         Multipliers multipliers = validation(args);
496.
497.         int[][] matrixByStrassenFJ = multiStrassenForkJoin(multipliers.getMatrixA(), multipliers.getMatrixB());
498.         int[][] matrixByUsual = multiply(multipliers.getMatrixA(), multipliers.getMatrixB());
499.
500.         printMatrix(matrixByStrassenFJ);
501.         //printMatrix(matrixByUsual);
502.
503.         //System.out.println(Arrays.deepEquals(matrixByStrassenFJ, matrixByUsual));
504.     }
505.
506. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.
Top