View difference between Paste ID: zaTwZkTY and PSrEUrPw
SHOW: | | - or go back to the newest paste.
1
package com.example.foo;
2
3
import java.util.Arrays;
4
import java.util.concurrent.ForkJoinPool;
5
import java.util.concurrent.ForkJoinTask;
6
7
/**
8
 * https://humanoidreadable.wordpress.com/2014/12/31/forkjoin-nonrecursive-task/
9
 * 
10
 * @author Claude Martin
11
 *
12-
public class SomeClass {
12+
13
public final class SomeClass {
14-
  private static class Task1 extends ForkJoinTask<Integer> {
14+
	/** This is used to fork/join once per entry of the resulting matrix. */
15-
    private static final long serialVersionUID = 4577085552193799470L;
15+
	private static class MatrixProduct extends ForkJoinTask<int[][]> {
16
		private static final long serialVersionUID = 1527639909439152978L;
17-
    private final Integer input;
17+
		private final int[][] A;
18-
    private Integer result = null;
18+
		private final int[][] B;
19
		private int[][] result = null;
20-
    public Task1(Integer input) {
20+
21-
      this.input = input;
21+
		public MatrixProduct(final int[][] A, final int[][] B) {
22-
    }
22+
			this.A = A;
23
			this.B = B;
24-
    @Override
24+
		}
25-
    public Integer getRawResult() {
25+
26-
      return this.result;
26+
		@Override
27-
    }
27+
		public int[][] getRawResult() {
28
			return result;
29-
    @Override
29+
		}
30-
    protected void setRawResult(Integer value) {
30+
31-
      this.result = value;
31+
		@Override
32-
    }
32+
		protected void setRawResult(final int[][] value) {
33
			this.result = value;
34-
    @Override
34+
		}
35-
    protected boolean exec() {
35+
36-
      ForkJoinTask<Integer> a = new Task2(this.input).fork();
36+
		// see https://en.wikipedia.org/wiki/Matrix_multiplication#Definition
37-
      ForkJoinTask<Integer> b = new Task2(this.input + 1).fork();
37+
		@Override
38-
      setRawResult(a.join() + b.join());
38+
		protected boolean exec() {
39-
      return true;
39+
			final int n = A.length;
40-
    }
40+
			final int p = B[0].length;
41-
  }
41+
			final int m = B.length; // == A[0].length
42
			this.result = new int[n][p];
43-
  private static class Task2 extends ForkJoinTask<Integer> {
43+
			@SuppressWarnings("unchecked")
44-
    private static final long serialVersionUID = -7037328571331917872L;
44+
			final ForkJoinTask<Integer>[][] tasks = new ForkJoinTask[n][p];
45-
    private final Integer input;
45+
			if (m != A[0].length) {
46-
    private Integer result = null;
46+
				throw new IllegalArgumentException("Can't calculate matrix product for given input");
47
			}
48-
    public Task2(Integer input) {
48+
			if (n == 0 || p == 0)
49-
      this.input = input;
49+
				return true;
50-
    }
50+
			for (int i = 0; i < n; i++)
51
				for (int j = 0; j < p; j++)
52-
    @Override
52+
					tasks[i][j] = new CalculateEntry(A, B, i, j).fork();
53-
    public Integer getRawResult() {
53+
			for (int i = 0; i < n; i++)
54-
      return this.result;
54+
				for (int j = 0; j < p; j++)
55-
    }
55+
					result[i][j] = tasks[i][j].join();
56
			return true;
57-
    @Override
57+
		}
58-
    protected void setRawResult(Integer value) {
58+
	}
59-
      this.result = value;
59+
60-
    }
60+
	/** This is used to calculate each entry. It won't create more subtasks. */
61
	private static class CalculateEntry extends ForkJoinTask<Integer> {
62-
    @Override
62+
		private static final long serialVersionUID = 2781862671220715328L;
63-
    protected boolean exec() {
63+
		private final int[][] A;
64-
      System.out.format("Task2(%d) started%n", this.input);
64+
		private final int[][] B;
65
		private final int i;
66-
      // This is the actual work:
66+
		private final int j;
67-
      setRawResult(this.input * this.input);
67+
		private Integer result = null;
68
69-
      // Just to show that they really run in parallel:
69+
		public CalculateEntry(final int[][] A, final int[][] B, final int i, final int j) {
70-
      try {
70+
			this.A = A;
71-
        Thread.sleep(5000);
71+
			this.B = B;
72-
      } catch (InterruptedException e) {
72+
			this.i = i;
73-
        e.printStackTrace();
73+
			this.j = j;
74-
      }
74+
		}
75
76-
      System.out.format("Task2(%d) finished%n", this.input);
76+
		@Override
77-
      return true;
77+
		public Integer getRawResult() {
78-
    }
78+
			return this.result;
79-
  }
79+
		}
80
81-
  public static void main(String[] args) {
81+
		@Override
82
		protected void setRawResult(final Integer value) {
83-
    Integer result = ForkJoinPool.commonPool().invoke(new Task1(5));
83+
			this.result = value;
84-
    System.out.println(result);
84+
		}
85
86-
    // 5*5 = 25
86+
		@Override
87-
    // 6*6 = 36
87+
		protected boolean exec() {
88-
    // 25+36= 61
88+
			int sum = 0;
89
			for (int k = 0; k < B.length; k++) {
90-
  }
90+
				sum += A[i][k] * B[k][j];
91
			}
92
			this.setRawResult(sum);
93
			return true;
94
		}
95
	}
96
97
	public static void main(String[] args) {
98
		// Note: The matrices would have to be much larger to get better performance than 
99
		// doing the same without parallel execution. But then the output is hard to read. 
100
		int[][] A = { { 1, 2, 3 }, { 4, 5, 6 } };
101
		int[][] B = { { 7, 8 }, { 9, 10 }, { 11, 12 } };
102
		int[][] result = ForkJoinPool.commonPool().invoke(new MatrixProduct(A, B));
103
		for (int[] row : result)
104
			System.out.println(Arrays.toString(row));
105
	}
106
}