Advertisement
Guest User

main.cpp

a guest
Nov 29th, 2010
4,486
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.67 KB | None | 0 0
  1. #include <iostream>
  2. #include <stdlib.h>
  3. #include <time.h>
  4. #include <sys/time.h>
  5. #include "matrix.h"
  6.  
  7. typedef pair<matrix, long> result;
  8.  
  9. int cut = 64;
  10.  
  11. matrix mult_std(matrix a, matrix b) {
  12.     matrix c(a.dim(), false, false);
  13.     for (int i = 0; i < a.dim(); i++)
  14.         for (int k = 0; k < a.dim(); k++)
  15.             for (int j = 0; j < a.dim(); j++)
  16.                 c(i,j) += a(i,k) * b(k,j);
  17.    
  18.     return c;
  19. }
  20.  
  21. matrix get_part(int pi, int pj, matrix m) {
  22.     matrix p(m.dim() / 2, false, true);
  23.     pi = pi * p.dim();
  24.     pj = pj * p.dim();
  25.    
  26.     for (int i = 0; i < p.dim(); i++)
  27.         for (int j = 0; j < p.dim(); j++)
  28.             p(i,j) = m(i + pi,j + pj);
  29.            
  30.     return p;
  31. }
  32.  
  33. void set_part(int pi, int pj, matrix* m, matrix p) {
  34.     pi = pi * p.dim();
  35.     pj = pj * p.dim();
  36.    
  37.     for (int i = 0; i < p.dim(); i++)
  38.         for (int j = 0; j < p.dim(); j++)
  39.             (*m)(i + pi,j + pj) = p(i,j);
  40. }
  41.  
  42. matrix mult_strassen(matrix a, matrix b) {
  43.     if (a.dim() <= cut)
  44.         return mult_std(a, b);
  45.  
  46.     matrix a11 = get_part(0, 0, a);
  47.     matrix a12 = get_part(0, 1, a);
  48.     matrix a21 = get_part(1, 0, a);
  49.     matrix a22 = get_part(1, 1, a);
  50.    
  51.     matrix b11 = get_part(0, 0, b);
  52.     matrix b12 = get_part(0, 1, b);
  53.     matrix b21 = get_part(1, 0, b);
  54.     matrix b22 = get_part(1, 1, b);
  55.    
  56.     matrix m1 = mult_strassen(a11 + a22, b11 + b22);
  57.     matrix m2 = mult_strassen(a21 + a22, b11);
  58.     matrix m3 = mult_strassen(a11, b12 - b22);
  59.     matrix m4 = mult_strassen(a22, b21 - b11);
  60.     matrix m5 = mult_strassen(a11 + a12, b22);
  61.     matrix m6 = mult_strassen(a21 - a11, b11 + b12);
  62.     matrix m7 = mult_strassen(a12 - a22, b21 + b22);
  63.    
  64.     matrix c(a.dim(), false, true);
  65.     set_part(0, 0, &c, m1 + m4 - m5 + m7);
  66.     set_part(0, 1, &c, m3 + m5);
  67.     set_part(1, 0, &c, m2 + m4);
  68.     set_part(1, 1, &c, m1 - m2 + m3 + m6);
  69.    
  70.     return c;
  71. }
  72.  
  73. pair<matrix, long> run(matrix (*f)(matrix, matrix), matrix a, matrix b) {
  74.     struct timeval start, end;
  75.    
  76.     gettimeofday(&start, NULL);
  77.     matrix c = f(a, b);
  78.     gettimeofday(&end, NULL);
  79.     long e = (end.tv_sec * 1000 + end.tv_usec / 1000);
  80.     long s =(start.tv_sec * 1000 + start.tv_usec / 1000);
  81.    
  82.     return pair<matrix, long> (c, e - s);
  83. }
  84.  
  85. int main() {
  86.     /* test cut of for strassen
  87.     /*
  88.     for (cut = 2; cut <= 512; cut++) {
  89.         matrix a(512, true, true);
  90.         matrix b(512, true, true);
  91.         result r = run(mult_strassen, a, b);
  92.         cout << cut << " " << r.second << "\n";
  93.     }
  94.     */
  95.    
  96.     /* performance test: standard and strassen */
  97.     for (int dim = 0; dim <= 1024; dim += 64) {
  98.         matrix a(dim, true, false);
  99.         matrix b(dim, true, false);
  100.         result std = run(mult_std, a, b);
  101.         matrix c(dim, true, true);
  102.         matrix d(dim, true, true);
  103.         result strassen = run(mult_strassen, c, d);
  104.         cout << dim << " " << std.second << " " << strassen.second << "\n";
  105.     }
  106. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement