Advertisement
Guest User

Strassen's Matrix Multiplication

a guest
Nov 26th, 2012
1,606
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.96 KB | None | 0 0
  1. #include <iostream>
  2. #include <ctime>
  3. using namespace std;
  4.  
  5. void strassen(int **a, int **b, int **c, int tam);
  6. void sum(int **a, int **b, int **result, int tam);
  7. void subtract(int **a, int **b, int **result, int tam);
  8. int  **allocate_matrix(int size);
  9. void dealloc(int **matrix, int size);
  10.  
  11. int main()
  12. {
  13.     for(int i = 0 ; i < 3; i++)
  14.     {
  15.         clock_t start = clock();
  16.  
  17.         int m = 1024;
  18.         int **A = new int*[m];
  19.         int **B = new int*[m];
  20.         int **product = new int*[m];
  21.  
  22.         A = allocate_matrix(m);
  23.         B = allocate_matrix(m);
  24.         product = allocate_matrix(m);
  25.  
  26.         for(int row = 0; row < m; row++) {
  27.             for(int inner = 0; inner < m; inner++){
  28.                 A[row][inner] = 5;
  29.                 //cout << A[row][inner] << " ";
  30.             }
  31.             //cout << "\n";
  32.         }
  33.  
  34.         for(int row = 0; row < m; row++) {
  35.             for(int inner = 0; inner < m; inner++){
  36.                 B[row][inner] = 3;
  37.                 //cout << B[row][inner] << " ";
  38.             }
  39.             //cout << "\n";
  40.         }
  41.  
  42.         /*for(int row = 0; row < m; row++) {
  43.         for(int inner = 0; inner < m; inner++){
  44.         product[row][inner] = 0;
  45.         cout << product[row][inner] << " ";
  46.         }
  47.         cout << "\n";
  48.         }*/
  49.  
  50.         strassen (A,B,product,m);
  51.  
  52.         //for(int row = 0; row < m; row++) {
  53.         //  for(int inner = 0; inner < m; inner++){
  54.         //      cout << product[row][inner] << " ";
  55.         //  }
  56.         //  cout << "\n";
  57.         //}
  58.  
  59.  
  60.         //deallocation
  61.         dealloc(A, m);
  62.         dealloc(B, m);
  63.         dealloc(product, m);
  64.         clock_t end = clock();
  65.         double cpu_time = static_cast<double>(end - start)/CLOCKS_PER_SEC;
  66.  
  67.  
  68.         cout << cpu_time << endl;
  69.     }
  70.  
  71.  
  72.     cout << "done" << endl;
  73.  
  74.     return 0;
  75. }
  76. void strassen(int **a, int **b, int **c, int tam) {
  77.  
  78.     // trivial case: when the matrix is 1 X 1:
  79.     if (tam == 1) {
  80.         c[0][0] = a[0][0] * b[0][0];
  81.         return;
  82.     }
  83.  
  84.     // other cases are treated here:
  85.     int newTam = tam/2;
  86.     int **a11, **a12, **a21, **a22;
  87.     int **b11, **b12, **b21, **b22;
  88.     int **c11, **c12, **c21, **c22;
  89.     int **p1, **p2, **p3, **p4, **p5, **p6, **p7;
  90.  
  91.     // memory allocation:
  92.     a11 = allocate_matrix(newTam);
  93.     a12 = allocate_matrix(newTam);
  94.     a21 = allocate_matrix(newTam);
  95.     a22 = allocate_matrix(newTam);
  96.  
  97.     b11 = allocate_matrix(newTam);
  98.     b12 = allocate_matrix(newTam);
  99.     b21 = allocate_matrix(newTam);
  100.     b22 = allocate_matrix(newTam);
  101.  
  102.     c11 = allocate_matrix(newTam);
  103.     c12 = allocate_matrix(newTam);
  104.     c21 = allocate_matrix(newTam);
  105.     c22 = allocate_matrix(newTam);
  106.  
  107.     p1 = allocate_matrix(newTam);
  108.     p2 = allocate_matrix(newTam);
  109.     p3 = allocate_matrix(newTam);
  110.     p4 = allocate_matrix(newTam);
  111.     p5 = allocate_matrix(newTam);
  112.     p6 = allocate_matrix(newTam);
  113.     p7 = allocate_matrix(newTam);
  114.  
  115.     int **aResult = allocate_matrix(newTam);
  116.     int **bResult = allocate_matrix(newTam);
  117.  
  118.     int i, j;
  119.  
  120.     //dividing the matrices in 4 sub-matrices:
  121.     for (i = 0; i < newTam; i++) {
  122.         for (j = 0; j < newTam; j++) {
  123.             a11[i][j] = a[i][j];
  124.             a12[i][j] = a[i][j + newTam];
  125.             a21[i][j] = a[i + newTam][j];
  126.             a22[i][j] = a[i + newTam][j + newTam];
  127.  
  128.             b11[i][j] = b[i][j];
  129.             b12[i][j] = b[i][j + newTam];
  130.             b21[i][j] = b[i + newTam][j];
  131.             b22[i][j] = b[i + newTam][j + newTam];
  132.         }
  133.     }
  134.  
  135.     // Calculating p1 to p7:
  136.  
  137.     sum(a11, a22, aResult, newTam); // a11 + a22
  138.     sum(b11, b22, bResult, newTam); // b11 + b22
  139.     strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22)
  140.  
  141.     sum(a21, a22, aResult, newTam); // a21 + a22
  142.     strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11)
  143.  
  144.     subtract(b12, b22, bResult, newTam); // b12 - b22
  145.     strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22)
  146.  
  147.     subtract(b21, b11, bResult, newTam); // b21 - b11
  148.     strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11)
  149.  
  150.     sum(a11, a12, aResult, newTam); // a11 + a12
  151.     strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22)  
  152.  
  153.     subtract(a21, a11, aResult, newTam); // a21 - a11
  154.     sum(b11, b12, bResult, newTam); // b11 + b12
  155.     strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12)
  156.  
  157.     subtract(a12, a22, aResult, newTam); // a12 - a22
  158.     sum(b21, b22, bResult, newTam); // b21 + b22
  159.     strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22)
  160.  
  161.     // calculating c21, c21, c11 e c22:
  162.  
  163.     sum(p3, p5, c12, newTam); // c12 = p3 + p5
  164.     sum(p2, p4, c21, newTam); // c21 = p2 + p4
  165.  
  166.     sum(p1, p4, aResult, newTam); // p1 + p4
  167.     sum(aResult, p7, bResult, newTam); // p1 + p4 + p7
  168.     subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7
  169.  
  170.     sum(p1, p3, aResult, newTam); // p1 + p3
  171.     sum(aResult, p6, bResult, newTam); // p1 + p3 + p6
  172.     subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6
  173.  
  174.     // Grouping the results obtained in a single matrix:
  175.     for (i = 0; i < newTam ; i++) {
  176.         for (j = 0 ; j < newTam ; j++) {
  177.             c[i][j] = c11[i][j];
  178.             c[i][j + newTam] = c12[i][j];
  179.             c[i + newTam][j] = c21[i][j];
  180.             c[i + newTam][j + newTam] = c22[i][j];
  181.         }
  182.     }
  183.  
  184.     // deallocating memory (free):
  185.     dealloc(a11, newTam);
  186.     dealloc(a12, newTam);
  187.     dealloc(a21, newTam);
  188.     dealloc(a22, newTam);
  189.  
  190.     dealloc(b11, newTam);
  191.     dealloc(b12, newTam);
  192.     dealloc(b21, newTam);
  193.     dealloc(b22, newTam);
  194.  
  195.     dealloc(c11, newTam);
  196.     dealloc(c12, newTam);
  197.     dealloc(c21, newTam);
  198.     dealloc(c22, newTam);
  199.  
  200.     dealloc(p1, newTam);
  201.     dealloc(p2, newTam);
  202.     dealloc(p3, newTam);
  203.     dealloc(p4, newTam);
  204.     dealloc(p5, newTam);
  205.     dealloc(p6, newTam);
  206.     dealloc(p7, newTam);
  207.     dealloc(aResult, newTam);
  208.     dealloc(bResult, newTam);
  209.  
  210. } // end of Strassen function
  211.  
  212. void sum(int **a, int **b, int **result, int tam) {
  213.     int i, j;
  214.     for (i = 0; i < tam; i++) {
  215.         for (j = 0; j < tam; j++) {
  216.             result[i][j] = a[i][j] + b[i][j];
  217.         }
  218.     }
  219. }
  220.  
  221. void subtract(int **a, int **b, int **result, int tam) {
  222.     int i, j;
  223.     for (i = 0; i < tam; i++) {
  224.         for (j = 0; j < tam; j++) {
  225.             result[i][j] = a[i][j] - b[i][j];
  226.         }
  227.     }  
  228. }
  229. int **allocate_matrix(int size) {
  230.     int **temp = new int*[size];
  231.     for ( int i = 0 ; i < size ; i++)
  232.     {
  233.         temp[i] = new int[size];
  234.     }
  235.     return (temp);
  236. }
  237. void dealloc (int **matrix,int size) {
  238.     if (matrix == NULL) {
  239.         return;
  240.     }
  241.     for ( int i = 0 ; i < size ; i++)
  242.     {
  243.         delete[] matrix[i];
  244.     }
  245.     delete[] matrix;
  246. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement