Advertisement
Tvor0zhok

Умножение разреженных матриц

Jun 12th, 2022
949
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.19 KB | None | 0 0
  1. /// <summary>
  2. /// Умножение матрицы на матрицу
  3. /// </summary>
  4. /// <typeparam name="T"> тип элементов матриц </typeparam>
  5. /// <param name="A"> 1-ая матрица </param>
  6. /// <param name="B"> 2-ая матрица </param>
  7. /// <returns> матрица, являющаяся результатом произведения обеих матриц </returns>
  8. template <class T>
  9. CSR_Matrix<T> operator * (CSR_Matrix<T> A, CSR_Matrix<T> B)
  10. {
  11.     assert(A.m() == B.n());
  12.  
  13.     // Число строк, столбцов, ненулевых элементов матрицы-результата
  14.  
  15.     int _N = A.n(), _M = B.m(), _NNZ = 0;
  16.  
  17.     int p = GridThreadsNum;
  18.  
  19.     // Распределяем разреженные вектор-строки матрицы по потокам
  20.  
  21.     vector <int> thread_pos(p + 1);
  22.  
  23.     for (int i = p - 1; i >= 0; --i)
  24.         thread_pos[p - i] = thread_pos[p - i - 1] + (_N + i) / p;
  25.  
  26.     vector <T> values_A = A.values(), values_B = B.values();
  27.     vector <int> cols_A = A.cols(), cols_B = B.cols();
  28.     vector <int> rowptr_A = A.rowptr(), rowptr_B = B.rowptr();
  29.  
  30.     vector <int> rowptr(_N + 1);
  31.     vector <vector <T>> rows(p, vector <T>(_M));
  32.  
  33.     /*
  34.     // Сжатый массив строк для матрицы-результата
  35.     vector <T> values;
  36.     vector <int> cols, rowptr(1);
  37.  
  38.     // Вспомогательный массив, хранящий текущую рассматриваемую потоком
  39.     // строку (в плотном формате, то есть не в разреженном)
  40.     vector <vector <T>> rows(p, vector <T> (_M));
  41.  
  42.     // Преподсчет элементов массива rowptr
  43.     // rowptr[i + 2] = k <=> в i-ой строке
  44.     // матрицы-результата k ненулевых элементов
  45.  
  46.     omp_set_num_threads(p);
  47.     #pragma omp parallel
  48.     {
  49.         int ThreadID = omp_get_thread_num();
  50.  
  51.         for (int i = thread_pos[ThreadID]; i < thread_pos[ThreadID + 1]; ++i)
  52.         {
  53.             for (int j = rowptr_A[i]; j < rowptr_A[i + 1]; ++j)
  54.             {
  55.                 int col_A = cols_A[j]; T value_A = values_A[j];
  56.  
  57.                 for (int k = rowptr_B[col_A]; k < rowptr_B[col_A + 1]; ++k)
  58.                 {
  59.                     int col_B = cols_B[k]; T value_B = values_B[k];
  60.                     rows[ThreadID][col_B] += value_A * value_B;
  61.                 }
  62.             }
  63.  
  64.             for (int j = 0; j < _M; ++j)
  65.                 if (rows[ThreadID][j])
  66.                 {
  67.                     values.push_back(rows[ThreadID][j]);
  68.                     cols.push_back(j); ++_NNZ;
  69.  
  70.                     rows[ThreadID][j] = 0;
  71.                 }
  72.  
  73.             rowptr.push_back(_NNZ);
  74.         }
  75.     } */
  76.  
  77.     omp_set_num_threads(p);
  78.     #pragma omp parallel
  79.     {
  80.         int ThreadID = omp_get_thread_num();
  81.  
  82.         for (int i = thread_pos[ThreadID]; i < thread_pos[ThreadID + 1]; ++i)
  83.         {
  84.             for (int j = rowptr_A[i]; j < rowptr_A[i + 1]; ++j)
  85.             {
  86.                 int col_A = cols_A[j]; T value_A = values_A[j];
  87.  
  88.                 for (int k = rowptr_B[col_A]; k < rowptr_B[col_A + 1]; ++k)
  89.                 {
  90.                     int col_B = cols_B[k]; T value_B = values_B[k];
  91.                     rows[ThreadID][col_B] += value_A * value_B;
  92.                 }
  93.             }
  94.  
  95.             for (int j = 0; j < _M; ++j)
  96.                 if (rows[ThreadID][j])
  97.                 {
  98.                     if (i == _N - 1) ++_NNZ;
  99.                     else ++rowptr[i + 2];
  100.  
  101.                     rows[ThreadID][j] = 0;
  102.                 }
  103.         }
  104.     }
  105.  
  106.     for (int i = 3; i <= _N; ++i)
  107.         rowptr[i] += rowptr[i - 1];
  108.  
  109.     _NNZ += rowptr[_N];
  110.  
  111.     vector <T> values(_NNZ);
  112.     vector <int> cols(_NNZ);
  113.  
  114.     // Основной алгоритм: заполнение массивов values и
  115.     // cols соответствующими значениями
  116.  
  117.     #pragma omp parallel
  118.     {
  119.         int ThreadID = omp_get_thread_num();
  120.  
  121.         for (int i = thread_pos[ThreadID]; i < thread_pos[ThreadID + 1]; ++i)
  122.         {
  123.             for (int j = rowptr_A[i]; j < rowptr_A[i + 1]; ++j)
  124.             {
  125.                 int col_A = cols_A[j]; T value_A = values_A[j];
  126.  
  127.                 for (int k = rowptr_B[col_A]; k < rowptr_B[col_A + 1]; ++k)
  128.                 {
  129.                     int col_B = cols_B[k]; T value_B = values_B[k];
  130.                     rows[ThreadID][col_B] += value_A * value_B;
  131.                 }
  132.             }
  133.  
  134.             for (int j = 0; j < _M; ++j)
  135.                 if (rows[ThreadID][j])
  136.                 {
  137.                     int pos = rowptr[i + 1];
  138.  
  139.                     values[pos] = rows[ThreadID][j];
  140.                     cols[pos] = j;
  141.  
  142.                     rows[ThreadID][j] = 0;
  143.                     ++rowptr[i + 1];
  144.                 }
  145.         }
  146.     }
  147.        
  148.     CSR_Matrix<T>* res = new CSR_Matrix<T>(_N, _M, _NNZ, values, cols, rowptr);
  149.     return *res;
  150. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement