Advertisement
Guest User

Untitled

a guest
Apr 21st, 2018
132
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.53 KB | None | 0 0
  1. #ifdef YES_AVX
  2. void matvec_YMM(double* a, double* x, double* y, int n, int lb)
  3. {
  4.     int i, j;
  5.     __m256d rx0, ra0, ra1, ra2, ra3, ry0, ry1, ry2, ry3;
  6.     double *ptr_x, *ptr_a;
  7.     __declspec(align(16)) double buf0[4], buf1[4], buf2[4], buf3[4];
  8.     memset((void *)y, 0, n * sizeof(double));
  9.     ptr_a = a;
  10.     for (i = 0; i < n; i += 4)
  11.     {
  12.         ry0 = ry1 = ry2 = ry3 = _mm256_setzero_pd();
  13.         ptr_x = x;
  14.  
  15.  
  16.         for (j = 0; j < n; j += 16)
  17.         {
  18.             _mm_prefetch((const char *)(ptr_x + 16), _MM_HINT_T0);
  19.             _mm_prefetch((const char *)(ptr_x + 24), _MM_HINT_T0);
  20.  
  21.             _mm_prefetch((const char *)(ptr_a + 16), _MM_HINT_NTA);
  22.             _mm_prefetch((const char *)(ptr_a + 24), _MM_HINT_NTA);
  23.  
  24.                     _mm_prefetch((const char *)(ptr_a +n+ 16), _MM_HINT_NTA);
  25.             _mm_prefetch((const char *)(ptr_a +n+ 24), _MM_HINT_NTA);  
  26.  
  27.                     _mm_prefetch((const char *)(ptr_a +2*n+ 16), _MM_HINT_NTA);
  28.             _mm_prefetch((const char *)(ptr_a +2*n+ 24), _MM_HINT_NTA);
  29.  
  30.                     _mm_prefetch((const char *)(ptr_a +3*n+ 16), _MM_HINT_NTA);
  31.             _mm_prefetch((const char *)(ptr_a +3*n+ 24), _MM_HINT_NTA);
  32.  
  33.         //--------------------------0
  34.             rx0 = _mm256_load_pd(ptr_x);
  35.             ra0 = _mm256_load_pd(ptr_a);
  36.             ra1 = _mm256_load_pd(ptr_a + n);
  37.             ra2 = _mm256_load_pd(ptr_a + 2 * n);
  38.             ra3 = _mm256_load_pd(ptr_a + 3 * n);
  39.  
  40.             ra0 = _mm256_mul_pd(ra0, rx0);
  41.             ra1 = _mm256_mul_pd(ra1, rx0);
  42.             ra2 = _mm256_mul_pd(ra2, rx0);
  43.             ra3 = _mm256_mul_pd(ra3, rx0);
  44.  
  45.             ry0 = _mm256_add_pd(ry0, ra0);
  46.             ry1 = _mm256_add_pd(ry1, ra1);
  47.             ry2 = _mm256_add_pd(ry2, ra2);
  48.             ry3 = _mm256_add_pd(ry3, ra3);
  49.  
  50.             //-------256----------------1
  51.             rx0 = _mm256_load_pd(ptr_x + 4);
  52.             ra0 = _mm256_load_pd(ptr_a + 4);
  53.             ra1 = _mm256_load_pd(ptr_a + n + 4);
  54.             ra2 = _mm256_load_pd(ptr_a + 2 * n + 4);
  55.             ra3 = _mm256_load_pd(ptr_a + 3 * n + 4);
  56.  
  57.             ra0 = _mm256_mul_pd(ra0, rx0);
  58.             ra1 = _mm256_mul_pd(ra1, rx0);
  59.             ra2 = _mm256_mul_pd(ra2, rx0);
  60.             ra3 = _mm256_mul_pd(ra3, rx0);
  61.  
  62.             ry0 = _mm256_add_pd(ry0, ra0);
  63.             ry1 = _mm256_add_pd(ry1, ra1);
  64.             ry2 = _mm256_add_pd(ry2, ra2);
  65.             ry3 = _mm256_add_pd(ry3, ra3);
  66.  
  67.             //-------256----------------2
  68.             rx0 = _mm256_load_pd(ptr_x + 8);
  69.             ra0 = _mm256_load_pd(ptr_a + 8);
  70.             ra1 = _mm256_load_pd(ptr_a + n + 8);
  71.             ra2 = _mm256_load_pd(ptr_a + 2 * n + 8);
  72.             ra3 = _mm256_load_pd(ptr_a + 3 * n + 8);
  73.  
  74.             ra0 = _mm256_mul_pd(ra0, rx0);
  75.             ra1 = _mm256_mul_pd(ra1, rx0);
  76.             ra2 = _mm256_mul_pd(ra2, rx0);
  77.             ra3 = _mm256_mul_pd(ra3, rx0);
  78.  
  79.             ry0 = _mm256_add_pd(ry0, ra0);
  80.             ry1 = _mm256_add_pd(ry1, ra1);
  81.             ry2 = _mm256_add_pd(ry2, ra2);
  82.             ry3 = _mm256_add_pd(ry3, ra3);
  83.  
  84.             //-------256----------------3
  85.             rx0 = _mm256_load_pd(ptr_x + 12);
  86.             ra0 = _mm256_load_pd(ptr_a + 12);
  87.             ra1 = _mm256_load_pd(ptr_a + n + 12);
  88.             ra2 = _mm256_load_pd(ptr_a + 2 * n + 12);
  89.             ra3 = _mm256_load_pd(ptr_a + 3 * n + 12);
  90.  
  91.             ra0 = _mm256_mul_pd(ra0, rx0);
  92.             ra1 = _mm256_mul_pd(ra1, rx0);
  93.             ra2 = _mm256_mul_pd(ra2, rx0);
  94.             ra3 = _mm256_mul_pd(ra3, rx0);
  95.  
  96.             ry0 = _mm256_add_pd(ry0, ra0);
  97.             ry1 = _mm256_add_pd(ry1, ra1);
  98.             ry2 = _mm256_add_pd(ry2, ra2);
  99.             ry3 = _mm256_add_pd(ry3, ra3);
  100.  
  101.             ptr_a += 16;
  102.             ptr_x += 16;
  103.         }
  104.  
  105.         ptr_a += 3 * n;
  106.  
  107.         _mm256_store_pd(buf0, ry0);
  108.         _mm256_store_pd(buf1, ry1);
  109.         _mm256_store_pd(buf2, ry2);
  110.         _mm256_store_pd(buf3, ry3);
  111.  
  112.         y[i] = buf0[0] + buf0[1] + buf0[2] + buf0[3];
  113.         y[i + 1] = buf1[0] + buf1[1] + buf1[2] + buf1[3];
  114.         y[i + 2] = buf2[0] + buf2[1] + buf2[2] + buf2[3];
  115.         y[i + 3] = buf3[0] + buf3[1] + buf3[2] + buf3[3];
  116.     }
  117. }
  118. #endif
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement