Advertisement
DulcetAirman

Matrix Multiplication with Fork/Join

Aug 23rd, 2019
454
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. package com.example.foo;
  2.  
  3. import java.util.Arrays;
  4. import java.util.concurrent.ForkJoinPool;
  5. import java.util.concurrent.ForkJoinTask;
  6.  
  7. /**
  8.  * https://humanoidreadable.wordpress.com/2014/12/31/forkjoin-nonrecursive-task/
  9.  *
  10.  * @author Claude Martin
  11.  *
  12.  */
  13. public final class SomeClass {
  14.     /** This is used to fork/join once per entry of the resulting matrix. */
  15.     private static class MatrixProduct extends ForkJoinTask<int[][]> {
  16.         private static final long serialVersionUID = 1527639909439152978L;
  17.         private final int[][] A;
  18.         private final int[][] B;
  19.         private int[][] result = null;
  20.  
  21.         public MatrixProduct(final int[][] A, final int[][] B) {
  22.             this.A = A;
  23.             this.B = B;
  24.         }
  25.  
  26.         @Override
  27.         public int[][] getRawResult() {
  28.             return result;
  29.         }
  30.  
  31.         @Override
  32.         protected void setRawResult(final int[][] value) {
  33.             this.result = value;
  34.         }
  35.  
  36.         // see https://en.wikipedia.org/wiki/Matrix_multiplication#Definition
  37.         @Override
  38.         protected boolean exec() {
  39.             final int n = A.length;
  40.             final int p = B[0].length;
  41.             final int m = B.length; // == A[0].length
  42.             this.result = new int[n][p];
  43.             @SuppressWarnings("unchecked")
  44.             final ForkJoinTask<Integer>[][] tasks = new ForkJoinTask[n][p];
  45.             if (m != A[0].length) {
  46.                 throw new IllegalArgumentException("Can't calculate matrix product for given input");
  47.             }
  48.             if (n == 0 || p == 0)
  49.                 return true;
  50.             for (int i = 0; i < n; i++)
  51.                 for (int j = 0; j < p; j++)
  52.                     tasks[i][j] = new CalculateEntry(A, B, i, j).fork();
  53.             for (int i = 0; i < n; i++)
  54.                 for (int j = 0; j < p; j++)
  55.                     result[i][j] = tasks[i][j].join();
  56.             return true;
  57.         }
  58.     }
  59.  
  60.     /** This is used to calculate each entry. It won't create more subtasks. */
  61.     private static class CalculateEntry extends ForkJoinTask<Integer> {
  62.         private static final long serialVersionUID = 2781862671220715328L;
  63.         private final int[][] A;
  64.         private final int[][] B;
  65.         private final int i;
  66.         private final int j;
  67.         private Integer result = null;
  68.  
  69.         public CalculateEntry(final int[][] A, final int[][] B, final int i, final int j) {
  70.             this.A = A;
  71.             this.B = B;
  72.             this.i = i;
  73.             this.j = j;
  74.         }
  75.  
  76.         @Override
  77.         public Integer getRawResult() {
  78.             return this.result;
  79.         }
  80.  
  81.         @Override
  82.         protected void setRawResult(final Integer value) {
  83.             this.result = value;
  84.         }
  85.  
  86.         @Override
  87.         protected boolean exec() {
  88.             int sum = 0;
  89.             for (int k = 0; k < B.length; k++) {
  90.                 sum += A[i][k] * B[k][j];
  91.             }
  92.             this.setRawResult(sum);
  93.             return true;
  94.         }
  95.     }
  96.  
  97.     public static void main(String[] args) {
  98.         // Note: The matrices would have to be much larger to get better performance than
  99.         // doing the same without parallel execution. But then the output is hard to read.
  100.         int[][] A = { { 1, 2, 3 }, { 4, 5, 6 } };
  101.         int[][] B = { { 7, 8 }, { 9, 10 }, { 11, 12 } };
  102.         int[][] result = ForkJoinPool.commonPool().invoke(new MatrixProduct(A, B));
  103.         for (int[] row : result)
  104.             System.out.println(Arrays.toString(row));
  105.     }
  106. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement