Want more features on Pastebin? Sign Up, it's FREE!
Guest

Strassen's Matrix Multiplication

By: a guest on Nov 26th, 2012  |  syntax: C++  |  size: 5.96 KB  |  views: 263  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. #include <iostream>
  2. #include <ctime>
  3. using namespace std;
  4.  
  5. void strassen(int **a, int **b, int **c, int tam);
  6. void sum(int **a, int **b, int **result, int tam);
  7. void subtract(int **a, int **b, int **result, int tam);
  8. int  **allocate_matrix(int size);
  9. void dealloc(int **matrix, int size);
  10.  
  11. int main()
  12. {
  13.         for(int i = 0 ; i < 3; i++)
  14.         {
  15.                 clock_t start = clock();
  16.  
  17.                 int m = 1024;
  18.                 int **A = new int*[m];
  19.                 int **B = new int*[m];
  20.                 int **product = new int*[m];
  21.  
  22.                 A = allocate_matrix(m);
  23.                 B = allocate_matrix(m);
  24.                 product = allocate_matrix(m);
  25.  
  26.                 for(int row = 0; row < m; row++) {
  27.                         for(int inner = 0; inner < m; inner++){
  28.                                 A[row][inner] = 5;
  29.                                 //cout << A[row][inner] << " ";
  30.                         }
  31.                         //cout << "\n";
  32.                 }
  33.  
  34.                 for(int row = 0; row < m; row++) {
  35.                         for(int inner = 0; inner < m; inner++){
  36.                                 B[row][inner] = 3;
  37.                                 //cout << B[row][inner] << " ";
  38.                         }
  39.                         //cout << "\n";
  40.                 }
  41.  
  42.                 /*for(int row = 0; row < m; row++) {
  43.                 for(int inner = 0; inner < m; inner++){
  44.                 product[row][inner] = 0;
  45.                 cout << product[row][inner] << " ";
  46.                 }
  47.                 cout << "\n";
  48.                 }*/
  49.  
  50.                 strassen (A,B,product,m);
  51.  
  52.                 //for(int row = 0; row < m; row++) {
  53.                 //      for(int inner = 0; inner < m; inner++){
  54.                 //              cout << product[row][inner] << " ";
  55.                 //      }
  56.                 //      cout << "\n";
  57.                 //}
  58.  
  59.  
  60.                 //deallocation
  61.                 dealloc(A, m);
  62.                 dealloc(B, m);
  63.                 dealloc(product, m);
  64.                 clock_t end = clock();
  65.                 double cpu_time = static_cast<double>(end - start)/CLOCKS_PER_SEC;
  66.  
  67.  
  68.                 cout << cpu_time << endl;
  69.         }
  70.  
  71.  
  72.         cout << "done" << endl;
  73.  
  74.         return 0;
  75. }
  76. void strassen(int **a, int **b, int **c, int tam) {
  77.  
  78.         // trivial case: when the matrix is 1 X 1:
  79.         if (tam == 1) {
  80.                 c[0][0] = a[0][0] * b[0][0];
  81.                 return;
  82.         }
  83.  
  84.         // other cases are treated here:
  85.         int newTam = tam/2;
  86.         int **a11, **a12, **a21, **a22;
  87.         int **b11, **b12, **b21, **b22;
  88.         int **c11, **c12, **c21, **c22;
  89.         int **p1, **p2, **p3, **p4, **p5, **p6, **p7;
  90.  
  91.         // memory allocation:
  92.         a11 = allocate_matrix(newTam);
  93.         a12 = allocate_matrix(newTam);
  94.         a21 = allocate_matrix(newTam);
  95.         a22 = allocate_matrix(newTam);
  96.  
  97.         b11 = allocate_matrix(newTam);
  98.         b12 = allocate_matrix(newTam);
  99.         b21 = allocate_matrix(newTam);
  100.         b22 = allocate_matrix(newTam);
  101.  
  102.         c11 = allocate_matrix(newTam);
  103.         c12 = allocate_matrix(newTam);
  104.         c21 = allocate_matrix(newTam);
  105.         c22 = allocate_matrix(newTam);
  106.  
  107.         p1 = allocate_matrix(newTam);
  108.         p2 = allocate_matrix(newTam);
  109.         p3 = allocate_matrix(newTam);
  110.         p4 = allocate_matrix(newTam);
  111.         p5 = allocate_matrix(newTam);
  112.         p6 = allocate_matrix(newTam);
  113.         p7 = allocate_matrix(newTam);
  114.  
  115.         int **aResult = allocate_matrix(newTam);
  116.         int **bResult = allocate_matrix(newTam);
  117.  
  118.         int i, j;
  119.  
  120.         //dividing the matrices in 4 sub-matrices:
  121.         for (i = 0; i < newTam; i++) {
  122.                 for (j = 0; j < newTam; j++) {
  123.                         a11[i][j] = a[i][j];
  124.                         a12[i][j] = a[i][j + newTam];
  125.                         a21[i][j] = a[i + newTam][j];
  126.                         a22[i][j] = a[i + newTam][j + newTam];
  127.  
  128.                         b11[i][j] = b[i][j];
  129.                         b12[i][j] = b[i][j + newTam];
  130.                         b21[i][j] = b[i + newTam][j];
  131.                         b22[i][j] = b[i + newTam][j + newTam];
  132.                 }
  133.         }
  134.  
  135.         // Calculating p1 to p7:
  136.  
  137.         sum(a11, a22, aResult, newTam); // a11 + a22
  138.         sum(b11, b22, bResult, newTam); // b11 + b22
  139.         strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22)
  140.  
  141.         sum(a21, a22, aResult, newTam); // a21 + a22
  142.         strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11)
  143.  
  144.         subtract(b12, b22, bResult, newTam); // b12 - b22
  145.         strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22)
  146.  
  147.         subtract(b21, b11, bResult, newTam); // b21 - b11
  148.         strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11)
  149.  
  150.         sum(a11, a12, aResult, newTam); // a11 + a12
  151.         strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22)  
  152.  
  153.         subtract(a21, a11, aResult, newTam); // a21 - a11
  154.         sum(b11, b12, bResult, newTam); // b11 + b12
  155.         strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12)
  156.  
  157.         subtract(a12, a22, aResult, newTam); // a12 - a22
  158.         sum(b21, b22, bResult, newTam); // b21 + b22
  159.         strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22)
  160.  
  161.         // calculating c21, c21, c11 e c22:
  162.  
  163.         sum(p3, p5, c12, newTam); // c12 = p3 + p5
  164.         sum(p2, p4, c21, newTam); // c21 = p2 + p4
  165.  
  166.         sum(p1, p4, aResult, newTam); // p1 + p4
  167.         sum(aResult, p7, bResult, newTam); // p1 + p4 + p7
  168.         subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7
  169.  
  170.         sum(p1, p3, aResult, newTam); // p1 + p3
  171.         sum(aResult, p6, bResult, newTam); // p1 + p3 + p6
  172.         subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6
  173.  
  174.         // Grouping the results obtained in a single matrix:
  175.         for (i = 0; i < newTam ; i++) {
  176.                 for (j = 0 ; j < newTam ; j++) {
  177.                         c[i][j] = c11[i][j];
  178.                         c[i][j + newTam] = c12[i][j];
  179.                         c[i + newTam][j] = c21[i][j];
  180.                         c[i + newTam][j + newTam] = c22[i][j];
  181.                 }
  182.         }
  183.  
  184.         // deallocating memory (free):
  185.         dealloc(a11, newTam);
  186.         dealloc(a12, newTam);
  187.         dealloc(a21, newTam);
  188.         dealloc(a22, newTam);
  189.  
  190.         dealloc(b11, newTam);
  191.         dealloc(b12, newTam);
  192.         dealloc(b21, newTam);
  193.         dealloc(b22, newTam);
  194.  
  195.         dealloc(c11, newTam);
  196.         dealloc(c12, newTam);
  197.         dealloc(c21, newTam);
  198.         dealloc(c22, newTam);
  199.  
  200.         dealloc(p1, newTam);
  201.         dealloc(p2, newTam);
  202.         dealloc(p3, newTam);
  203.         dealloc(p4, newTam);
  204.         dealloc(p5, newTam);
  205.         dealloc(p6, newTam);
  206.         dealloc(p7, newTam);
  207.         dealloc(aResult, newTam);
  208.         dealloc(bResult, newTam);
  209.  
  210. } // end of Strassen function
  211.  
  212. void sum(int **a, int **b, int **result, int tam) {
  213.         int i, j;
  214.         for (i = 0; i < tam; i++) {
  215.                 for (j = 0; j < tam; j++) {
  216.                         result[i][j] = a[i][j] + b[i][j];
  217.                 }
  218.         }
  219. }
  220.  
  221. void subtract(int **a, int **b, int **result, int tam) {
  222.         int i, j;
  223.         for (i = 0; i < tam; i++) {
  224.                 for (j = 0; j < tam; j++) {
  225.                         result[i][j] = a[i][j] - b[i][j];
  226.                 }
  227.         }  
  228. }
  229. int **allocate_matrix(int size) {
  230.         int **temp = new int*[size];
  231.         for ( int i = 0 ; i < size ; i++)
  232.         {
  233.                 temp[i] = new int[size];
  234.         }
  235.         return (temp);
  236. }
  237. void dealloc (int **matrix,int size) {
  238.         if (matrix == NULL) {
  239.                 return;
  240.         }
  241.         for ( int i = 0 ; i < size ; i++)
  242.         {
  243.                 delete[] matrix[i];
  244.         }
  245.         delete[] matrix;
  246. }
clone this paste RAW Paste Data