paranid5

matrix multiply

May 18th, 2021 (edited)
544
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import java.io.*;
  2. import java.security.SecureRandom;
  3. import java.time.LocalTime;
  4. import java.util.Arrays;
  5.  
  6. class Main {
  7.     private static final BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
  8.     private static final PrintWriter writer = new PrintWriter(System.out, true);
  9.     private static final SecureRandom randomizer = new SecureRandom();
  10.  
  11.     private static final int[] intArrayInput() throws IOException {
  12.         return Arrays
  13.                 .stream(reader.readLine().trim().split(" "))
  14.                 .mapToInt(Integer::parseInt)
  15.                 .toArray();
  16.     }
  17.  
  18.     private static final void exitOnLess1(final int a, final int b) {
  19.         if (a <= 0 || b <= 0) {
  20.             writer.println("Size must be bigger than zero");
  21.             System.exit(0);
  22.         }
  23.     }
  24.  
  25.     private static final void printMatrix(final int[][] matrix) {
  26.         Arrays.stream(matrix).forEach(row -> {
  27.             Arrays.stream(row).forEach(x -> writer.print(x + " "));
  28.             writer.print('\n');
  29.         });
  30.     }
  31.  
  32.     private static final int[][] generateMatrix(final int n, final int m) {
  33.         final var matrix = new int[n][m];
  34.  
  35.         for (int i = 0; i < n; i++)
  36.             for (int q = 0; q < m; q++)
  37.                 matrix[i][q] = randomizer.nextInt(1000);
  38.  
  39.         return matrix;
  40.     }
  41.  
  42.     private static final int[][] matrixMultiplyST(final int[][] first, final int[][] second) {
  43.         final int n = first.length;
  44.         final int m = second[0].length;
  45.         final int k = second.length;
  46.         final var ans = new int[n][m];
  47.  
  48.         for (int i = 0; i < n; i++)
  49.             for (int q = 0; q < m; q++)
  50.                 for (int r = 0; r < k; r++)
  51.                     ans[i][q] += first[i][r] * second[r][q];
  52.  
  53.         return ans;
  54.     }
  55.  
  56.     private static final int[][] matrixMultiplyMT(
  57.             final int[][] first,
  58.             final int[][] second,
  59.             final int threadsAmount
  60.     ) {
  61.         final int n = first.length;
  62.         final int m = second[0].length;
  63.         final int k = second.length;
  64.         final var ans = new int[n][m];
  65.  
  66.         final var threads = new Thread[threadsAmount];
  67.         final var cellsInThreads = new int[threadsAmount];
  68.         final var ost = n * m % threadsAmount;
  69.         Arrays.fill(cellsInThreads, n * m / threadsAmount);
  70.  
  71.         for (int i = 0; i < ost; i++)
  72.             cellsInThreads[i]++;
  73.  
  74.         int cur = 0;
  75.  
  76.         for (int i = 0; i < threadsAmount; cur += cellsInThreads[i], i++) {
  77.             final int finalI = i;
  78.             final int finalCur = cur;
  79.  
  80.             threads[i] = new Thread(() -> {
  81.                 for (int t = finalCur; t < finalCur + cellsInThreads[finalI]; t++)
  82.                     for (int r = 0; r < k; r++)
  83.                         ans[t / n][t % n] += first[t / n][r] * second[r][t % n];
  84.             });
  85.  
  86.             threads[i].start();
  87.         }
  88.  
  89.         Arrays.stream(threads).forEach(x -> {
  90.             try {
  91.                 x.join();
  92.             } catch (final InterruptedException e) {
  93.                 e.printStackTrace();
  94.             }
  95.         });
  96.  
  97.         return ans;
  98.     }
  99.  
  100.     public static final void main(final String[] args) throws IOException {
  101.         writer.println("First matrix's size: ");
  102.         final var inp1 = intArrayInput();
  103.         final var n1 = inp1[0];
  104.         final var m1 = inp1[1];
  105.  
  106.         exitOnLess1(n1, m1);
  107.  
  108.         writer.println("Second matrix's size: ");
  109.         final var inp2 = intArrayInput();
  110.         final var n2 = inp2[0];
  111.         final var m2 = inp2[1];
  112.  
  113.         exitOnLess1(n2, m2);
  114.  
  115.         if (m1 != n2 && m2 != n1) {
  116.             writer.println("Colums amount of 1-st matrix must be equal to rows amount of 2-nd matrix or vice versa");
  117.             return;
  118.         }
  119.  
  120.         writer.println("Amount of threads:");
  121.         final int threadsAmount = Integer.parseInt(reader.readLine().trim());
  122.  
  123.         if (threadsAmount < 1) {
  124.             writer.println("Threads amount must be bigger than zero");
  125.             return;
  126.         }
  127.  
  128.         final var matrix1 = generateMatrix(n1, m1);
  129.         final var matrix2 = generateMatrix(n2, m2);
  130.  
  131.         // writer.println("First matrix:");
  132.         // printMatrix(matrix1);
  133.  
  134.         // writer.println("Second matrix:");
  135.         // printMatrix(matrix2);
  136.  
  137.         final var start1 = LocalTime.now().getNano();
  138.         final var firstMul = matrixMultiplyST(matrix1, matrix2);
  139.         final var finish1 = LocalTime.now().getNano();
  140.         writer.println("Single Thread: " + (Math.max(start1, finish1) - Math.min(start1, finish1)) + " nanos");
  141.  
  142.         // writer.println("Single Thread matrix:");
  143.         // printMatrix(firstMul);
  144.  
  145.         final var start2 = LocalTime.now().getNano();
  146.         final var secondMul = matrixMultiplyMT(matrix1, matrix2, threadsAmount);
  147.         final var finish2 = LocalTime.now().getNano();
  148.         writer.println("Multi Thread:  " + (Math.max(start2, finish2) - Math.min(start2, finish2)) + " nanos");
  149.  
  150.         // writer.println("Multi Thread matrix:");
  151.         // printMatrix(secondMul);
  152.  
  153.         assert(Arrays.deepEquals(firstMul, secondMul));
  154.  
  155.         reader.close();
  156.         writer.close();
  157.     }
  158. }
RAW Paste Data