Advertisement
Guest User

B

a guest
Jan 23rd, 2020
251
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.65 KB | None | 0 0
  1. #include <thread>
  2. #include <tuple>
  3. #include <iostream>
  4. #include <chrono>
  5. #include <cassert>
  6. #include <vector>
  7. #include <cstdint>
  8. #include <future>
  9. #include <random>
  10. #include <fstream>
  11.  
  12. class DenseMat {
  13.  public:
  14.     DenseMat(int32_t rows = 0, int32_t cols = 0) : rows_(rows), cols_(cols), data_(rows * cols) {}
  15.  
  16.     DenseMat(int32_t rows, int32_t cols, const std::vector<int32_t> &data) : rows_(rows), cols_(cols), data_(data) {
  17.         assert((int32_t) data.size() == rows * cols);
  18.     }
  19.  
  20.     int32_t Rows() const {
  21.         return rows_;
  22.     }
  23.  
  24.     int32_t Cols() const {
  25.         return cols_;
  26.     }
  27.  
  28.     const int32_t &operator()(int row, int col) const {
  29.         return data_[row * cols_ + col];
  30.     }
  31.  
  32.     int32_t &operator()(int row, int col) {
  33.         return data_[row * cols_ + col];
  34.     }
  35.  
  36.     bool operator==(const DenseMat &other) const {
  37.         if (other.Rows() != Rows() || other.Cols() != Cols()) {
  38.             return false;
  39.         }
  40.  
  41.         for (int i = 0; i < Rows(); i++) {
  42.             for (int j = 0; j < Cols(); j++) {
  43.                 if (this->operator()(i, j) != other(i, j)) {
  44.                     std::cout<<this->operator()(i, j)<< " "<<other(i, j) << std::flush;
  45.                     return false;
  46.                 }
  47.             }
  48.         }
  49.  
  50.         return true;
  51.     }
  52.  
  53.     bool operator!=(const DenseMat &other) const {
  54.         return !(*this == other);
  55.     }
  56.  
  57.  private:
  58.     int32_t rows_;
  59.     int32_t cols_;
  60.     std::vector<int32_t> data_;
  61. };
  62.  
  63. DenseMat ReadMat(std::ifstream &input) {
  64.     int32_t rows, cols;
  65.     input >> rows >> cols;
  66.  
  67.     std::vector<int32_t> data;
  68.     data.reserve(rows * cols);
  69.  
  70.     for (size_t i = 0; i < data.size(); i++) {
  71.         int v;
  72.         input >> v;
  73.  
  74.         data.push_back(v);
  75.     }
  76.  
  77.     return DenseMat(rows, cols, data);
  78. }
  79.  
  80. #include <vector>
  81. #include <thread>
  82. #include <functional>
  83.  
  84. void find_cell(DenseMat &res, int i, int j, int &e,int &e_,int count) {
  85.     e = (i * res.Cols() + j + count)/res.Cols();
  86.     //e_ = (i * res.Cols() + j + count)/res.Cols();
  87.     e_ = e;
  88. }
  89.  
  90. void Multiply(DenseMat *res, const DenseMat &a, const DenseMat &b,
  91.         int b_i, int b_j, int e_i, int e_j) {
  92.     for (int i = b_i; i <= e_i; i++) {
  93.         for (int j = b_j; j < res->Cols(); j++) {
  94.             if ((i == e_i && j == e_j) || i == res->Rows()){
  95.                 return;
  96.             }
  97.  
  98.             for (int k = 0; k < a.Cols(); k++) {
  99.                 res->operator()(i, j) += a(i, k) * b(k, j);
  100.             }
  101. //            if (i== 99 && j == 99) {
  102. //                std::cout<< res->operator()(i, j);
  103. //            }
  104.         }
  105.         b_j = 0;
  106.     }
  107. }
  108.  
  109. DenseMat MatMulParal(const DenseMat &a, const DenseMat &b, int thread_count) {
  110.     DenseMat result(a.Rows(), b.Cols());
  111.     std::vector<std::thread> tasks;
  112.     int count = a.Rows() * b.Cols() / thread_count;
  113.     int b_i = 0;
  114.     int b_j = 0;
  115.     int e_j = 0;
  116.     int e_i = 0;
  117.     for (int i = 0; i < thread_count; ++i) {
  118.         if (i+1 == thread_count) {
  119.             count = result.Rows()*result.Cols() - count * (i);
  120.         }
  121.         find_cell(result, b_i, b_j, e_i, e_j, count);
  122.         Multiply(&result, a, b, b_i, b_j, e_i, e_j);
  123.         //tasks.push_back(std::thread(Multiply, &result, a, b, b_i, b_j, e_i, e_j));
  124.         b_i = e_i;
  125.         b_j = e_j;
  126.  
  127.     }
  128.     for (int kI = 0; kI < tasks.size(); ++kI) {
  129.         tasks[kI].join();
  130.     }
  131.         std::cout<< result.operator()(99, 99);
  132.  
  133.     return result;
  134. }
  135.  
  136. DenseMat SimpleMul(const DenseMat &a, const DenseMat &b) {
  137.     DenseMat res(a.Rows(), b.Cols());
  138.  
  139.     for (int i = 0; i < a.Rows(); i++) {
  140.         for (int j = 0; j < b.Cols(); j++) {
  141.             for (int k = 0; k < a.Cols(); k++) {
  142.                 res(i, j) += a(i, k) * b(k, j);
  143.             }
  144.         }
  145.     }
  146.  
  147.     return res;
  148. }
  149.  
  150. template<class Generator>
  151. DenseMat RandomMat(int32_t rows, int32_t cols, Generator &gen) {
  152.     std::uniform_int_distribution<> dis(-1000, 1000);
  153.     std::vector<int32_t> data(rows * cols);
  154.     for (size_t i = 0; i < data.size(); i++) {
  155.         data[i] = dis(gen);
  156.     }
  157.  
  158.     return DenseMat(rows, cols, data);
  159. }
  160.  
  161. template<class Func, class ... Args>
  162. std::tuple<int32_t, DenseMat> BenchmarkFunc(Func f, Args &&... args) {
  163.     std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
  164.     auto res = f(std::forward<Args>(args)...);
  165.     std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
  166.  
  167.     return {std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(), res};
  168. }
  169.  
  170. int main(int argc, char **argv) {
  171.     std::mt19937 gen(12312);
  172.  
  173.     std::cerr << "Generating random matrices..." << std::endl;
  174.     auto a = RandomMat(100, 1000, gen);
  175.     auto b = RandomMat(1000, 100, gen);
  176.  
  177.     std::cerr << "Start benchmarking..." << std::endl;
  178.  
  179.     int32_t trivial_ms;
  180.     DenseMat ans;
  181.     std::tie(trivial_ms, ans) = BenchmarkFunc(SimpleMul, a, b);
  182.  
  183.     const int32_t best_thread_count = 4; // ThreadSanitizer??? // std::thread::hardware_concurrency();
  184.  
  185.     int32_t last_time = 1e9;
  186.     for (int tcount = 1; tcount <= 10; tcount++) {
  187.         int32_t mul_ms;
  188.         DenseMat mat_res;
  189.         std::tie(mul_ms, mat_res) = BenchmarkFunc(MatMulParal, a, b, tcount);
  190.  
  191.         assert(mat_res == ans);
  192.  
  193.         std::cerr << "Thread count: " << tcount << std::endl;
  194.         std::cerr << "Time, ms: " << mul_ms << std::endl;
  195.  
  196.         if (tcount <= best_thread_count) {
  197.             //assert(mul_ms < last_time);
  198.             last_time = mul_ms;
  199.         }
  200.     }
  201.  
  202.     assert(last_time < 140);
  203.  
  204.     std::cerr << "Trivial ms: " << trivial_ms << std::endl;
  205.     std::cout << 1 << std::endl;
  206. }B
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement