SHOW:
|
|
- or go back to the newest paste.
1 | package com.example.foo; | |
2 | ||
3 | import java.util.Arrays; | |
4 | import java.util.concurrent.ForkJoinPool; | |
5 | import java.util.concurrent.ForkJoinTask; | |
6 | ||
7 | /** | |
8 | * https://humanoidreadable.wordpress.com/2014/12/31/forkjoin-nonrecursive-task/ | |
9 | * | |
10 | * @author Claude Martin | |
11 | * | |
12 | - | public class SomeClass { |
12 | + | |
13 | public final class SomeClass { | |
14 | - | private static class Task1 extends ForkJoinTask<Integer> { |
14 | + | /** This is used to fork/join once per entry of the resulting matrix. */ |
15 | - | private static final long serialVersionUID = 4577085552193799470L; |
15 | + | private static class MatrixProduct extends ForkJoinTask<int[][]> { |
16 | private static final long serialVersionUID = 1527639909439152978L; | |
17 | - | private final Integer input; |
17 | + | private final int[][] A; |
18 | - | private Integer result = null; |
18 | + | private final int[][] B; |
19 | private int[][] result = null; | |
20 | - | public Task1(Integer input) { |
20 | + | |
21 | - | this.input = input; |
21 | + | public MatrixProduct(final int[][] A, final int[][] B) { |
22 | - | } |
22 | + | this.A = A; |
23 | this.B = B; | |
24 | - | @Override |
24 | + | } |
25 | - | public Integer getRawResult() { |
25 | + | |
26 | - | return this.result; |
26 | + | @Override |
27 | - | } |
27 | + | public int[][] getRawResult() { |
28 | return result; | |
29 | - | @Override |
29 | + | } |
30 | - | protected void setRawResult(Integer value) { |
30 | + | |
31 | - | this.result = value; |
31 | + | @Override |
32 | - | } |
32 | + | protected void setRawResult(final int[][] value) { |
33 | this.result = value; | |
34 | - | @Override |
34 | + | } |
35 | - | protected boolean exec() { |
35 | + | |
36 | - | ForkJoinTask<Integer> a = new Task2(this.input).fork(); |
36 | + | // see https://en.wikipedia.org/wiki/Matrix_multiplication#Definition |
37 | - | ForkJoinTask<Integer> b = new Task2(this.input + 1).fork(); |
37 | + | @Override |
38 | - | setRawResult(a.join() + b.join()); |
38 | + | protected boolean exec() { |
39 | - | return true; |
39 | + | final int n = A.length; |
40 | - | } |
40 | + | final int p = B[0].length; |
41 | - | } |
41 | + | final int m = B.length; // == A[0].length |
42 | this.result = new int[n][p]; | |
43 | - | private static class Task2 extends ForkJoinTask<Integer> { |
43 | + | @SuppressWarnings("unchecked") |
44 | - | private static final long serialVersionUID = -7037328571331917872L; |
44 | + | final ForkJoinTask<Integer>[][] tasks = new ForkJoinTask[n][p]; |
45 | - | private final Integer input; |
45 | + | if (m != A[0].length) { |
46 | - | private Integer result = null; |
46 | + | throw new IllegalArgumentException("Can't calculate matrix product for given input"); |
47 | } | |
48 | - | public Task2(Integer input) { |
48 | + | if (n == 0 || p == 0) |
49 | - | this.input = input; |
49 | + | return true; |
50 | - | } |
50 | + | for (int i = 0; i < n; i++) |
51 | for (int j = 0; j < p; j++) | |
52 | - | @Override |
52 | + | tasks[i][j] = new CalculateEntry(A, B, i, j).fork(); |
53 | - | public Integer getRawResult() { |
53 | + | for (int i = 0; i < n; i++) |
54 | - | return this.result; |
54 | + | for (int j = 0; j < p; j++) |
55 | - | } |
55 | + | result[i][j] = tasks[i][j].join(); |
56 | return true; | |
57 | - | @Override |
57 | + | } |
58 | - | protected void setRawResult(Integer value) { |
58 | + | } |
59 | - | this.result = value; |
59 | + | |
60 | - | } |
60 | + | /** This is used to calculate each entry. It won't create more subtasks. */ |
61 | private static class CalculateEntry extends ForkJoinTask<Integer> { | |
62 | - | @Override |
62 | + | private static final long serialVersionUID = 2781862671220715328L; |
63 | - | protected boolean exec() { |
63 | + | private final int[][] A; |
64 | - | System.out.format("Task2(%d) started%n", this.input); |
64 | + | private final int[][] B; |
65 | private final int i; | |
66 | - | // This is the actual work: |
66 | + | private final int j; |
67 | - | setRawResult(this.input * this.input); |
67 | + | private Integer result = null; |
68 | ||
69 | - | // Just to show that they really run in parallel: |
69 | + | public CalculateEntry(final int[][] A, final int[][] B, final int i, final int j) { |
70 | - | try { |
70 | + | this.A = A; |
71 | - | Thread.sleep(5000); |
71 | + | this.B = B; |
72 | - | } catch (InterruptedException e) { |
72 | + | this.i = i; |
73 | - | e.printStackTrace(); |
73 | + | this.j = j; |
74 | - | } |
74 | + | } |
75 | ||
76 | - | System.out.format("Task2(%d) finished%n", this.input); |
76 | + | @Override |
77 | - | return true; |
77 | + | public Integer getRawResult() { |
78 | - | } |
78 | + | return this.result; |
79 | - | } |
79 | + | } |
80 | ||
81 | - | public static void main(String[] args) { |
81 | + | @Override |
82 | protected void setRawResult(final Integer value) { | |
83 | - | Integer result = ForkJoinPool.commonPool().invoke(new Task1(5)); |
83 | + | this.result = value; |
84 | - | System.out.println(result); |
84 | + | } |
85 | ||
86 | - | // 5*5 = 25 |
86 | + | @Override |
87 | - | // 6*6 = 36 |
87 | + | protected boolean exec() { |
88 | - | // 25+36= 61 |
88 | + | int sum = 0; |
89 | for (int k = 0; k < B.length; k++) { | |
90 | - | } |
90 | + | sum += A[i][k] * B[k][j]; |
91 | } | |
92 | this.setRawResult(sum); | |
93 | return true; | |
94 | } | |
95 | } | |
96 | ||
97 | public static void main(String[] args) { | |
98 | // Note: The matrices would have to be much larger to get better performance than | |
99 | // doing the same without parallel execution. But then the output is hard to read. | |
100 | int[][] A = { { 1, 2, 3 }, { 4, 5, 6 } }; | |
101 | int[][] B = { { 7, 8 }, { 9, 10 }, { 11, 12 } }; | |
102 | int[][] result = ForkJoinPool.commonPool().invoke(new MatrixProduct(A, B)); | |
103 | for (int[] row : result) | |
104 | System.out.println(Arrays.toString(row)); | |
105 | } | |
106 | } |