Advertisement
Guest User

Matrix multiplication

a guest
Oct 18th, 2016
2,454
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 17.29 KB | None | 0 0
  1. import java.util.ArrayList;
  2. import java.util.Arrays;
  3. import java.util.Collections;
  4. import java.util.Random;
  5. import java.util.concurrent.ForkJoinPool;
  6. import java.util.concurrent.RecursiveTask;
  7.  
  8. /**
  9.  * The {@code MatrixMultiplication} class implements
  10.  * fast multiplication 2 matrices at each other.
  11.  * The {@code MatrixMultiplication} uses Strassen algorithm and
  12.  * parallelize it with the {@link java.util.concurrent.ForkJoinPool}
  13.  *
  14.  * @author Evgeny Usov
  15.  * @author Alexey Falko
  16.  */
  17. public class MatrixMultiplication {
  18.  
  19.     //******************************************************************************************
  20.  
  21.     public static int[][] multiply(int[][] a, int[][] b) {
  22.  
  23.         int rowsA = a.length;
  24.         int columnsB = b[0].length;
  25.         int columnsA_rowsB = a[0].length;
  26.  
  27.         int[][] c = new int[rowsA][columnsB];
  28.  
  29.         for (int i = 0; i < rowsA; i++) {
  30.             for (int j = 0; j < columnsB; j++) {
  31.                 int sum = 0;
  32.                 for (int k = 0; k < columnsA_rowsB; k++) {
  33.                     sum += a[i][k] * b[k][j];
  34.                 }
  35.                 c[i][j] = sum;
  36.             }
  37.         }
  38.  
  39.         return c;
  40.     }
  41.  
  42.     //******************************************************************************************
  43.  
  44.     public static int[][] multiplyTransposed(int[][] a, int[][] b) {
  45.  
  46.         int rowsA = a.length;
  47.         int columnsB = b[0].length;
  48.         int columnsA_rowsB = a[0].length;
  49.  
  50.         int columnB[] = new int[columnsA_rowsB];
  51.         int[][] c = new int[rowsA][columnsB];
  52.  
  53.  
  54.         for (int j = 0; j < columnsB; j++) {
  55.             for (int k = 0; k < columnsA_rowsB; k++) {
  56.                 columnB[k] = b[k][j];
  57.             }
  58.  
  59.             for (int i = 0; i < rowsA; i++) {
  60.                 int rowA[] = a[i];
  61.                 int sum = 0;
  62.                 for (int k = 0; k < columnsA_rowsB; k++) {
  63.                     sum += rowA[k] * columnB[k];
  64.                 }
  65.                 c[i][j] = sum;
  66.             }
  67.         }
  68.  
  69.         return c;
  70.     }
  71.  
  72.     //******************************************************************************************
  73.  
  74.     private static int[][] summation(int[][] a, int[][] b) {
  75.  
  76.         int n = a.length;
  77.         int m = a[0].length;
  78.         int[][] c = new int[n][m];
  79.  
  80.         for (int i = 0; i < n; i++) {
  81.             for (int j = 0; j < m; j++) {
  82.                 c[i][j] = a[i][j] + b[i][j];
  83.             }
  84.         }
  85.         return c;
  86.     }
  87.  
  88.     //******************************************************************************************
  89.  
  90.     private static int[][] subtraction(int[][] a, int[][] b) {
  91.  
  92.         int n = a.length;
  93.         int m = a[0].length;
  94.         int[][] c = new int[n][m];
  95.  
  96.         for (int i = 0; i < n; i++) {
  97.             for (int j = 0; j < m; j++) {
  98.                 c[i][j] = a[i][j] - b[i][j];
  99.             }
  100.         }
  101.         return c;
  102.     }
  103.  
  104.     //******************************************************************************************
  105.  
  106.     private static int[][] addition2SquareMatrix(int[][] a, int n) {
  107.  
  108.         int[][] result = new int[n][n];
  109.  
  110.         for (int i = 0; i < a.length; i++) {
  111.             System.arraycopy(a[i], 0, result[i], 0, a[i].length);
  112.         }
  113.         return result;
  114.     }
  115.  
  116.     //******************************************************************************************
  117.  
  118.     private static int[][] getSubmatrix(int[][] a, int n, int m) {
  119.         int[][] result = new int[n][m];
  120.         for (int i = 0; i < n; i++) {
  121.             System.arraycopy(a[i], 0, result[i], 0, m);
  122.         }
  123.         return result;
  124.     }
  125.  
  126.     //******************************************************************************************
  127.  
  128.     private static void splitMatrix(int[][] a, int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
  129.  
  130.         int n = a.length >> 1;
  131.  
  132.         for (int i = 0; i < n; i++) {
  133.             System.arraycopy(a[i], 0, a11[i], 0, n);
  134.             System.arraycopy(a[i], n, a12[i], 0, n);
  135.             System.arraycopy(a[i + n], 0, a21[i], 0, n);
  136.             System.arraycopy(a[i + n], n, a22[i], 0, n);
  137.         }
  138.     }
  139.  
  140.     //******************************************************************************************
  141.  
  142.     private static int[][] collectMatrix(int[][] a11, int[][] a12, int[][] a21, int[][] a22) {
  143.  
  144.         int n = a11.length;
  145.         int[][] a = new int[n << 1][n << 1];
  146.  
  147.         for (int i = 0; i < n; i++) {
  148.             System.arraycopy(a11[i], 0, a[i], 0, n);
  149.             System.arraycopy(a12[i], 0, a[i], n, n);
  150.             System.arraycopy(a21[i], 0, a[i + n], 0, n);
  151.             System.arraycopy(a22[i], 0, a[i + n], n, n);
  152.         }
  153.  
  154.         return a;
  155.     }
  156.  
  157.     //******************************************************************************************
  158.  
  159.     /**
  160.      * Multi-threaded matrix multiplication
  161.      * algorithm by Strassen
  162.      */
  163.     private static class myRecursiveTask extends RecursiveTask<int[][]> {
  164.         private static final long serialVersionUID = -433764214304695286L;
  165.  
  166.         int n;
  167.         int[][] a;
  168.         int[][] b;
  169.  
  170.         public myRecursiveTask(int[][] a, int[][] b, int n) {
  171.             this.a = a;
  172.             this.b = b;
  173.             this.n = n;
  174.         }
  175.  
  176.         /**
  177.          * @return the integer matrix by
  178.          * multiplying 2 matrices at each other
  179.          */
  180.         @Override
  181.         protected int[][] compute() {
  182.             if (n <= 128) {
  183.                 return multiplyTransposed(a, b);
  184.             }
  185.  
  186.             n >>= 1;
  187.  
  188.             int[][] a11 = new int[n][n];
  189.             int[][] a12 = new int[n][n];
  190.             int[][] a21 = new int[n][n];
  191.             int[][] a22 = new int[n][n];
  192.  
  193.             int[][] b11 = new int[n][n];
  194.             int[][] b12 = new int[n][n];
  195.             int[][] b21 = new int[n][n];
  196.             int[][] b22 = new int[n][n];
  197.  
  198.             splitMatrix(a, a11, a12, a21, a22);
  199.             splitMatrix(b, b11, b12, b21, b22);
  200.  
  201.             myRecursiveTask task_p1 = new myRecursiveTask(summation(a11, a22), summation(b11, b22), n);
  202.             myRecursiveTask task_p2 = new myRecursiveTask(summation(a21, a22), b11, n);
  203.             myRecursiveTask task_p3 = new myRecursiveTask(a11, subtraction(b12, b22), n);
  204.             myRecursiveTask task_p4 = new myRecursiveTask(a22, subtraction(b21, b11), n);
  205.             myRecursiveTask task_p5 = new myRecursiveTask(summation(a11, a12), b22, n);
  206.             myRecursiveTask task_p6 = new myRecursiveTask(subtraction(a21, a11), summation(b11, b12), n);
  207.             myRecursiveTask task_p7 = new myRecursiveTask(subtraction(a12, a22), summation(b21, b22), n);
  208.  
  209.             task_p1.fork();
  210.             task_p2.fork();
  211.             task_p3.fork();
  212.             task_p4.fork();
  213.             task_p5.fork();
  214.             task_p6.fork();
  215.             task_p7.fork();
  216.  
  217.             int[][] p1 = task_p1.join();
  218.             int[][] p2 = task_p2.join();
  219.             int[][] p3 = task_p3.join();
  220.             int[][] p4 = task_p4.join();
  221.             int[][] p5 = task_p5.join();
  222.             int[][] p6 = task_p6.join();
  223.             int[][] p7 = task_p7.join();
  224.  
  225.             int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
  226.             int[][] c12 = summation(p3, p5);
  227.             int[][] c21 = summation(p2, p4);
  228.             int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));
  229.  
  230.             return collectMatrix(c11, c12, c21, c22);
  231.         }
  232.  
  233.     }
  234.  
  235.     //******************************************************************************************
  236.  
  237.     public static int[][] multiStrassenForkJoin(int[][] a, int[][] b) {
  238.  
  239.         int nn = getNewDimension(a, b);
  240.         int[][] a_n = addition2SquareMatrix(a, nn);
  241.         int[][] b_n = addition2SquareMatrix(b, nn);
  242.  
  243.         myRecursiveTask task = new myRecursiveTask(a_n, b_n, nn);
  244.         ForkJoinPool pool = new ForkJoinPool();
  245.         int[][] fastFJ = pool.invoke(task);
  246.  
  247.         return getSubmatrix(fastFJ, a.length, b[0].length);
  248.     }
  249.  
  250.     //******************************************************************************************
  251.  
  252.     @Deprecated
  253.     /**
  254.      * Single-threaded matrix multiplication
  255.      * algorithm by Strassen
  256.      * */
  257.     private static int[][] multiStrassen(int[][] a, int[][] b, int n) {
  258.         if (n <= 128) {
  259.             return multiplyTransposed(a, b);
  260.         }
  261.  
  262.         n = n >> 1;
  263.         ArrayList<Object> objects = new ArrayList<>();
  264.  
  265.         int[][] a11 = new int[n][n];
  266.         int[][] a12 = new int[n][n];
  267.         int[][] a21 = new int[n][n];
  268.         int[][] a22 = new int[n][n];
  269.  
  270.         int[][] b11 = new int[n][n];
  271.         int[][] b12 = new int[n][n];
  272.         int[][] b21 = new int[n][n];
  273.         int[][] b22 = new int[n][n];
  274.  
  275.         splitMatrix(a, a11, a12, a21, a22);
  276.         splitMatrix(b, b11, b12, b21, b22);
  277.  
  278.         int[][] p1 = multiStrassen(summation(a11, a22), summation(b11, b22), n);
  279.         int[][] p2 = multiStrassen(summation(a21, a22), b11, n);
  280.         int[][] p3 = multiStrassen(a11, subtraction(b12, b22), n);
  281.         int[][] p4 = multiStrassen(a22, subtraction(b21, b11), n);
  282.         int[][] p5 = multiStrassen(summation(a11, a12), b22, n);
  283.         int[][] p6 = multiStrassen(subtraction(a21, a11), summation(b11, b12), n);
  284.         int[][] p7 = multiStrassen(subtraction(a12, a22), summation(b21, b22), n);
  285.  
  286.         int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5));
  287.         int[][] c12 = summation(p3, p5);
  288.         int[][] c21 = summation(p2, p4);
  289.         int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6));
  290.  
  291.         return collectMatrix(c11, c12, c21, c22);
  292.     }
  293.  
  294.     //******************************************************************************************
  295.  
  296.     private static int log2(int x) {
  297.         int result = 1;
  298.         while ((x >>= 1) != 0) {
  299.             result++;
  300.         }
  301.  
  302.         return result;
  303.     }
  304.  
  305.     //******************************************************************************************
  306.  
  307.     private static int getNewDimension(int[][] a, int[][] b) {
  308.         return 1 << log2(Collections.max(Arrays.asList(a.length, a[0].length, b[0].length)));
  309.     }
  310.  
  311.     //******************************************************************************************
  312.  
  313.     public static int[][] randomMatrix(int m, int n) {
  314.         int[][] a = new int[m][n];
  315.         for (int i = 0; i < m; i++) {
  316.             for (int j = 0; j < n; j++) {
  317.                 a[i][j] = new Random().nextInt(100);
  318.             }
  319.         }
  320.         return a;
  321.     }
  322.  
  323.     //******************************************************************************************
  324.  
  325.     public static void printMatrix(int[][] a) {
  326.         for (int i = 0; i < a[0].length; i++) {
  327.             System.out.print("-------");
  328.         }
  329.         System.out.println();
  330.         for (int[] anA : a) {
  331.             System.out.print("|");
  332.             for (int anAnA : anA) {
  333.                 System.out.printf("%4d |", anAnA);
  334.             }
  335.  
  336.             System.out.println();
  337.             for (int i = 0; i < a[0].length; i++) {
  338.                 System.out.print("-------");
  339.             }
  340.             System.out.println();
  341.         }
  342.     }
  343.  
  344.     //******************************************************************************************
  345.  
  346.     public static void test(int n, int m, int l) {
  347.  
  348.         int[][] a = randomMatrix(n, l);
  349.         int[][] b = randomMatrix(l, m);
  350.         long start, end;
  351.  
  352.         //****************************************
  353.         //  TEST 1
  354.         start = System.currentTimeMillis();
  355.         int[][] matrixByStrassenFJ = multiStrassenForkJoin(a, b);
  356.         end = System.currentTimeMillis();
  357.         System.out.printf("Strassen Fork-Join Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
  358.         //****************************************
  359.  
  360.         //****************************************
  361.         //  TEST 2
  362.         start = System.currentTimeMillis();
  363.         int nn = getNewDimension(a, b);
  364.  
  365.         int[][] a_n = addition2SquareMatrix(a, nn);
  366.         int[][] b_n = addition2SquareMatrix(b, nn);
  367.  
  368.         int[][] temp = multiStrassen(a_n, b_n, nn);
  369.         int[][] matrixByStrassen = getSubmatrix(temp, n, m);
  370.         end = System.currentTimeMillis();
  371.         System.out.printf("Strassen Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
  372.         //****************************************
  373.  
  374.         //****************************************
  375.         //  TEST 3
  376.             start = System.currentTimeMillis();
  377.             int[][] matrixByUsual = multiply(a, b);
  378.             end = System.currentTimeMillis();
  379.             System.out.printf("Usual Multiply [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
  380.         //****************************************
  381.  
  382.         //****************************************
  383.         //  TEST 4
  384.         start = System.currentTimeMillis();
  385.         int[][] matrixByUsualTransposed = multiplyTransposed(a, b);
  386.         end = System.currentTimeMillis();
  387.         System.out.printf("Usual Multiply Transposed [A:%dx%d; B:%dx%d]: \tElapsed: %dms\n", n, l, l, m, end - start);
  388.         //****************************************
  389.  
  390.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByStrassen));
  391.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByUsual));
  392.         System.out.println("Matrices are equal: " + Arrays.deepEquals(matrixByStrassenFJ, matrixByUsualTransposed));
  393.  
  394.     }
  395.  
  396.     //******************************************************************************************
  397.  
  398.     private static class Multipliers {
  399.         private final int[][] matrixA;
  400.         private final int[][] matrixB;
  401.  
  402.         public Multipliers(int[][] a, int[][] b) {
  403.             matrixA = a;
  404.             matrixB = b;
  405.         }
  406.  
  407.         public int[][] getMatrixB() {
  408.             return matrixB;
  409.         }
  410.  
  411.         public int[][] getMatrixA() {
  412.             return matrixA;
  413.         }
  414.     }
  415.  
  416.  
  417.     //******************************************************************************************
  418.     private static Multipliers validation(String[] args) {
  419.         int rowsA;
  420.         int columnsA;
  421.         int rowsB;
  422.         int columnsB;
  423.  
  424.         if (args.length < 6) {
  425.             throw new IllegalArgumentException("Too few parameters. Should be not less then 6.");
  426.         }
  427.  
  428.         /*
  429.          * Note: method parseInt returns NumberFormatException if the argument String
  430.          * does not contain a parsable int
  431.          * */
  432.  
  433.         rowsA = Integer.parseInt(args[0]);
  434.         columnsA = Integer.parseInt(args[1]);
  435.         rowsB = Integer.parseInt(args[2]);
  436.         columnsB = Integer.parseInt(args[3]);
  437.  
  438.         if (rowsA <= 0 || columnsA <= 0 || rowsB <= 0 || columnsB <= 0) {
  439.             throw new IllegalArgumentException("Array dimension can't be negative or zero");
  440.         }
  441.  
  442.         if (args.length - (rowsA * columnsA + rowsB * columnsB) != 4) {
  443.             throw new IllegalArgumentException("Incorrect number of values to initialize two arrays.");
  444.         }
  445.  
  446.         if (columnsA != rowsB) {
  447.             throw new IllegalArgumentException("The number of columns of the matrix A is not equal to the number of rows of the matrix B.");
  448.         }
  449.  
  450.         int[][] a = new int[rowsA][columnsA];
  451.         int[][] b = new int[rowsB][columnsB];
  452.  
  453.         int k = 4;
  454.  
  455.         //***************************************
  456.  
  457.         for (int i = 0; i < a.length; i++) {
  458.             for (int j = 0; j < a[0].length; j++) {
  459.                 a[i][j] = Integer.parseInt(args[k++]);
  460.             }
  461.         }
  462.  
  463.         //***************************************
  464.  
  465.         for (int i = 0; i < b.length; i++) {
  466.             for (int j = 0; j < b[0].length; j++) {
  467.                 b[i][j] = Integer.parseInt(args[k++]);
  468.             }
  469.         }
  470.  
  471.         //***************************************
  472.  
  473.         return new Multipliers(a, b);
  474.     }
  475.  
  476.     //******************************************************************************************
  477.  
  478.     /*
  479.         Матрицы подаются как аргументы программы в следующем формате
  480.         N M X Y A_1_1 ... A_N_M B_1_1 ... B_X_Y
  481.  
  482.         где N и M - размерность первой матрицы A,
  483.         A_1_1 ... A_N_M - элементы матрицы A,
  484.         X и Y - размерность второй матрицы B,
  485.         B_1_1 ... B_X_Y - элементы матрицы B.
  486.  
  487.         Например, для умножения единичной матрицы размером 2 на 2 на вектор (-1, -1)
  488.         необходимо на вход приложению пожать следующие аргументы
  489.         2 2 2 1 1 0 0 1 -1 -1
  490.         В консоль должен распечататься вектор:
  491.         -1
  492.         -1
  493.     */
  494.     public static void main(String[] args) {
  495.         Multipliers multipliers = validation(args);
  496.  
  497.         int[][] matrixByStrassenFJ = multiStrassenForkJoin(multipliers.getMatrixA(), multipliers.getMatrixB());
  498.         int[][] matrixByUsual = multiply(multipliers.getMatrixA(), multipliers.getMatrixB());
  499.  
  500.         printMatrix(matrixByStrassenFJ);
  501.         //printMatrix(matrixByUsual);
  502.  
  503.         //System.out.println(Arrays.deepEquals(matrixByStrassenFJ, matrixByUsual));
  504.     }
  505.  
  506. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement