Guest User

Full source

a guest
Nov 23rd, 2015
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 3.84 KB | None | 0 0
  1. #include <stdio.h>
  2. #include <emmintrin.h>
  3. #include <time.h>
  4. #include <math.h>
  5.  
  6. double now()
  7. {
  8.     return (double)clock() / (double)CLOCKS_PER_SEC;
  9. }
  10.  
  11. void transpose(float *M, int N)
  12. {
  13.     int i, j;
  14.     for(i = 0; i < N; i++)
  15.         for(j = i + 1; j < N; j++)
  16.         {
  17.             float tmp = M[i*N + j];
  18.             M[i*N + j] = M[j*N + i];
  19.             M[j*N + i] = tmp;
  20.         }
  21. }
  22.  
  23. int main(void)
  24. {
  25.     double t0, t1, t2;
  26.     float *A, *B, *C, *Correct;
  27.     float *ptrA, *ptrB;
  28.     int N = 1000;
  29.     int i, j, k, n, mode;
  30.  
  31.     t0 = now();
  32.  
  33.     A = (float*)_aligned_malloc(N*N*sizeof(float), 16);
  34.     B = (float*)_aligned_malloc(N*N*sizeof(float), 16);
  35.     C = (float*)_aligned_malloc(N*N*sizeof(float), 16);
  36.     Correct = NULL;
  37.  
  38.     for(i = 0; i < N; i++)
  39.         for(j = 0; j < N; j++)
  40.         {
  41.             A[i*N + j] = (float)(rand() % 1000) / 1000.0f;
  42.             B[i*N + j] = (float)(rand() % 1000) / 1000.0f;
  43.         }
  44.  
  45.     printf("%dx%d matricies; element size: %d byte(s)\n", N, N, sizeof(float));
  46.  
  47.     for(mode = 0; mode <= 2; mode++)
  48.     {
  49.         for(n = 0; n < 1; n++)
  50.         {
  51.             printf("Mode: %d; threads: %d\n", mode, n);
  52.             memset(C, 0, N*N*sizeof(float));
  53.             t1 = now() - t0;
  54.  
  55.             switch(mode)
  56.             {
  57.             case 0:
  58.                 for(i = 0; i < N; i++)
  59.                     for(j = 0; j < N; j++)
  60.                     {
  61.                         C[i*N + j] = 0.0f;
  62.                         for(k = 0; k < N; k++)
  63.                             C[i*N + j] += A[i*N + k]*B[k*N + j];
  64.                     }
  65.                 break;
  66.             case 1:
  67.                 transpose(B, N);
  68.                 for(i = 0; i < N; i++)
  69.                     for(j = 0; j < N; j++)
  70.                     {
  71.                         float tmp;
  72.                         tmp = 0.0f;
  73.                         ptrA = A + i*N;
  74.                         ptrB = B + j*N;
  75.                         for(k = 0; k < N; k++, ptrA++, ptrB++)
  76.                             tmp = tmp + (*ptrA) * (*ptrB);
  77.                         C[i*N + j] = tmp;
  78.                     }
  79.                 transpose(B, N);
  80.                 break;
  81.             case 2:
  82.                 transpose(B, N);
  83.                 for(i = 0; i < N; i++)
  84.                     for(j = 0; j < N; j++)
  85.                     {
  86.                         __m128 tmp;
  87.                         tmp = _mm_set1_ps(0.0f);
  88.                         ptrA = A + i*N;
  89.                         ptrB = B + j*N;
  90.                         for(k = 0; k < N/4; k++, ptrA += 4, ptrB += 4)
  91.                             tmp = _mm_add_ps(tmp,
  92.                                              _mm_mul_ps(
  93.                                                  _mm_load_ps(ptrA),
  94.                                                  _mm_load_ps(ptrB)));
  95.                         C[i*N + j] = 0.0f;
  96.                         for(k = 0; k < 4; k++)
  97.                             C[i*N + j] += tmp.m128_f32[k];
  98.                     }
  99.                 transpose(B, N);
  100.                 break;
  101.             }
  102.  
  103.             t2 = now() - t0;
  104.             printf("Time elapsed: %.3lf ms\n", (t2 - t1) * 1000.0);
  105.  
  106.             if(!Correct)
  107.             {
  108.                 Correct = (float*)_aligned_malloc(N*N*sizeof(float), 16);
  109.                 memcpy(Correct, C, N*N*sizeof(float));
  110.             }
  111.             else
  112.             {
  113.                 for(i = 0; i < N; i++)
  114.                     for(j = 0; j < N; j++)
  115.                         if(fabs((double)C[i*N + j] - (double)Correct[i*N + j]) > 1e-2f)
  116.                         {
  117.                             printf("Wrong result: C[%d,%d] = %lg =/= %lg\n",
  118.                                    i, j, (double)C[i*N + j], (double)Correct[i*N + j]);
  119.                             i = j = N;
  120.                         }
  121.             }
  122.         }
  123.     }
  124.  
  125.     return 0;
  126. }
Advertisement
Add Comment
Please, Sign In to add comment