Pastebin launched a little side project called VERYVIRAL.com, check it out ;-) Want more features on Pastebin? Sign Up, it's FREE!
Guest

main.cpp

By: a guest on Nov 29th, 2010  |  syntax: C++  |  size: 2.67 KB  |  views: 1,090  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. #include <iostream>
  2. #include <stdlib.h>
  3. #include <time.h>
  4. #include <sys/time.h>
  5. #include "matrix.h"
  6.  
  7. typedef pair<matrix, long> result;
  8.  
  9. int cut = 64;
  10.  
  11. matrix mult_std(matrix a, matrix b) {
  12.         matrix c(a.dim(), false, false);
  13.         for (int i = 0; i < a.dim(); i++)
  14.                 for (int k = 0; k < a.dim(); k++)
  15.                         for (int j = 0; j < a.dim(); j++)
  16.                                 c(i,j) += a(i,k) * b(k,j);
  17.        
  18.         return c;
  19. }
  20.  
  21. matrix get_part(int pi, int pj, matrix m) {
  22.         matrix p(m.dim() / 2, false, true);
  23.         pi = pi * p.dim();
  24.         pj = pj * p.dim();
  25.        
  26.         for (int i = 0; i < p.dim(); i++)
  27.                 for (int j = 0; j < p.dim(); j++)
  28.                         p(i,j) = m(i + pi,j + pj);
  29.                        
  30.         return p;
  31. }
  32.  
  33. void set_part(int pi, int pj, matrix* m, matrix p) {
  34.         pi = pi * p.dim();
  35.         pj = pj * p.dim();
  36.        
  37.         for (int i = 0; i < p.dim(); i++)
  38.                 for (int j = 0; j < p.dim(); j++)
  39.                         (*m)(i + pi,j + pj) = p(i,j);
  40. }
  41.  
  42. matrix mult_strassen(matrix a, matrix b) {
  43.         if (a.dim() <= cut)
  44.                 return mult_std(a, b);
  45.  
  46.         matrix a11 = get_part(0, 0, a);
  47.         matrix a12 = get_part(0, 1, a);
  48.         matrix a21 = get_part(1, 0, a);
  49.         matrix a22 = get_part(1, 1, a);
  50.        
  51.         matrix b11 = get_part(0, 0, b);
  52.         matrix b12 = get_part(0, 1, b);
  53.         matrix b21 = get_part(1, 0, b);
  54.         matrix b22 = get_part(1, 1, b);
  55.        
  56.         matrix m1 = mult_strassen(a11 + a22, b11 + b22);
  57.         matrix m2 = mult_strassen(a21 + a22, b11);
  58.         matrix m3 = mult_strassen(a11, b12 - b22);
  59.         matrix m4 = mult_strassen(a22, b21 - b11);
  60.         matrix m5 = mult_strassen(a11 + a12, b22);
  61.         matrix m6 = mult_strassen(a21 - a11, b11 + b12);
  62.         matrix m7 = mult_strassen(a12 - a22, b21 + b22);
  63.        
  64.         matrix c(a.dim(), false, true);
  65.         set_part(0, 0, &c, m1 + m4 - m5 + m7);
  66.         set_part(0, 1, &c, m3 + m5);
  67.         set_part(1, 0, &c, m2 + m4);
  68.         set_part(1, 1, &c, m1 - m2 + m3 + m6);
  69.        
  70.         return c;
  71. }
  72.  
  73. pair<matrix, long> run(matrix (*f)(matrix, matrix), matrix a, matrix b) {
  74.         struct timeval start, end;
  75.        
  76.         gettimeofday(&start, NULL);
  77.         matrix c = f(a, b);
  78.         gettimeofday(&end, NULL);
  79.         long e = (end.tv_sec * 1000 + end.tv_usec / 1000);
  80.         long s =(start.tv_sec * 1000 + start.tv_usec / 1000);
  81.        
  82.         return pair<matrix, long> (c, e - s);
  83. }
  84.  
  85. int main() {
  86.         /* test cut of for strassen
  87.         /*
  88.         for (cut = 2; cut <= 512; cut++) {
  89.                 matrix a(512, true, true);
  90.                 matrix b(512, true, true);
  91.                 result r = run(mult_strassen, a, b);
  92.                 cout << cut << " " << r.second << "\n";
  93.         }
  94.         */
  95.        
  96.         /* performance test: standard and strassen */
  97.         for (int dim = 0; dim <= 1024; dim += 64) {
  98.                 matrix a(dim, true, false);
  99.                 matrix b(dim, true, false);
  100.                 result std = run(mult_std, a, b);
  101.                 matrix c(dim, true, true);
  102.                 matrix d(dim, true, true);
  103.                 result strassen = run(mult_strassen, c, d);
  104.                 cout << dim << " " << std.second << " " << strassen.second << "\n";
  105.         }
  106. }