Advertisement
kubpica

Untitled

Mar 23rd, 2020
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.27 KB | None | 0 0
  1. memset((void*)y, 0, n * sizeof(double));
  2.     double* ptr_a = a;
  3.     double* ptr_x = NULL;
  4.     const int vectorSize = 4;
  5.     const int outerStep = 4;
  6.     const int innerStep = 16;
  7.     __m256d ra0, ra1, ra2, ra3;
  8.     __m256d rx0, rx1, rx2, rx3;
  9.     __m256d ry0, ry1, ry2, ry3;
  10.     __declspec(align(32)) double buf0[vectorSize];
  11.     __declspec(align(32)) double buf1[vectorSize];
  12.     __declspec(align(32)) double buf2[vectorSize];
  13.     __declspec(align(32)) double buf3[vectorSize];
  14.  
  15.     for (int i = 0; i < n; i += outerStep, ptr_a += (outerStep - 1) * n) {
  16.         ry0 = ry1 = ry2 = ry3 = _mm256_setzero_pd();
  17.         ptr_x = x;
  18.         for (int j = 0; j < n; j += innerStep, ptr_a += innerStep, ptr_x += innerStep) {
  19.             _mm_prefetch((char*)(ptr_x + innerStep), _MM_HINT_T0);
  20.             _mm_prefetch((char*)(ptr_x + innerStep + 8), _MM_HINT_T0);
  21.             _mm_prefetch((char*)(ptr_a + innerStep), _MM_HINT_NTA);
  22.             _mm_prefetch((char*)(ptr_a + innerStep + 8), _MM_HINT_NTA);
  23.             _mm_prefetch((char*)(ptr_a + n + innerStep), _MM_HINT_NTA);
  24.             _mm_prefetch((char*)(ptr_a + n + innerStep + 8), _MM_HINT_NTA);
  25.             _mm_prefetch((char*)(ptr_a + 2 * n + innerStep), _MM_HINT_NTA);
  26.             _mm_prefetch((char*)(ptr_a + 2 * n + innerStep + 8), _MM_HINT_NTA);
  27.             _mm_prefetch((char*)(ptr_a + 3 * n + innerStep), _MM_HINT_NTA);
  28.             _mm_prefetch((char*)(ptr_a + 3 * n + innerStep + 8), _MM_HINT_NTA);
  29.             rx0 = _mm256_load_pd(ptr_x);
  30.             rx1 = _mm256_load_pd(ptr_x + vectorSize);
  31.             rx2 = _mm256_load_pd(ptr_x + 2 * vectorSize);
  32.             rx3 = _mm256_load_pd(ptr_x + 3 * vectorSize);
  33.             ra0 = _mm256_load_pd(ptr_a);
  34.             ra1 = _mm256_load_pd(ptr_a + n);
  35.             ra2 = _mm256_load_pd(ptr_a + 2 * n);
  36.             ra3 = _mm256_load_pd(ptr_a + 3 * n);
  37.             ry0 = _mm256_fmadd_pd(ra0, rx0, ry0);
  38.             ry1 = _mm256_fmadd_pd(ra1, rx0, ry1);
  39.             ry2 = _mm256_fmadd_pd(ra2, rx0, ry2);
  40.             ry3 = _mm256_fmadd_pd(ra3, rx0, ry3);
  41.             ra0 = _mm256_load_pd(ptr_a + vectorSize);
  42.             ra1 = _mm256_load_pd(ptr_a + n + vectorSize);
  43.             ra2 = _mm256_load_pd(ptr_a + 2 * n + vectorSize);
  44.             ra3 = _mm256_load_pd(ptr_a + 3 * n + vectorSize);
  45.             ry0 = _mm256_fmadd_pd(ra0, rx1, ry0);
  46.             ry1 = _mm256_fmadd_pd(ra1, rx1, ry1);
  47.             ry2 = _mm256_fmadd_pd(ra2, rx1, ry2);
  48.             ry3 = _mm256_fmadd_pd(ra3, rx1, ry3);
  49.             ra0 = _mm256_load_pd(ptr_a + 2 * vectorSize);
  50.             ra1 = _mm256_load_pd(ptr_a + n + 2 * vectorSize);
  51.             ra2 = _mm256_load_pd(ptr_a + 2 * n + 2 * vectorSize);
  52.             ra3 = _mm256_load_pd(ptr_a + 3 * n + 2 * vectorSize);
  53.             ry0 = _mm256_fmadd_pd(ra0, rx2, ry0);
  54.             ry1 = _mm256_fmadd_pd(ra1, rx2, ry1);
  55.             ry2 = _mm256_fmadd_pd(ra2, rx2, ry2);
  56.             ry3 = _mm256_fmadd_pd(ra3, rx2, ry3);
  57.             ra0 = _mm256_load_pd(ptr_a + 3 * vectorSize);
  58.             ra1 = _mm256_load_pd(ptr_a + n + 3 * vectorSize);
  59.             ra2 = _mm256_load_pd(ptr_a + 2 * n + 3 * vectorSize);
  60.             ra3 = _mm256_load_pd(ptr_a + 3 * n + 3 * vectorSize);
  61.             ry0 = _mm256_fmadd_pd(ra0, rx3, ry0);
  62.             ry1 = _mm256_fmadd_pd(ra1, rx3, ry1);
  63.             ry2 = _mm256_fmadd_pd(ra2, rx3, ry2);
  64.             ry3 = _mm256_fmadd_pd(ra3, rx3, ry3);
  65.         }
  66.         _mm256_store_pd(buf0, ry0);
  67.         _mm256_store_pd(buf1, ry1);
  68.         _mm256_store_pd(buf2, ry2);
  69.         _mm256_store_pd(buf3, ry3);
  70.         y[i] = buf0[0] + buf0[1] + buf0[2] + buf0[3];
  71.         y[i + 1] = buf1[0] + buf1[1] + buf1[2] + buf1[3];
  72.         y[i + 2] = buf2[0] + buf2[1] + buf2[2] + buf2[3];
  73.         y[i + 3] = buf3[0] + buf3[1] + buf3[2] + buf3[3];
  74.     }
  75.     ptr_a = ptr_x = NULL;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement