Advertisement
Guest User

StackOverflow:user3572032

a guest
Nov 6th, 2014
250
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 30.17 KB | None | 0 0
  1. // Standard outputs
  2. #include <iostream>
  3.  
  4. // get AVX intrinsics
  5. #include <immintrin.h>
  6.  
  7. // get CPUID capability
  8. //#include <intrin.h>
  9.  
  10. #ifdef __MACH__                             // OS X (32-bit and 64-bit)
  11. #include <mach/mach.h>
  12. #include <mach/mach_time.h>
  13. #include <unistd.h>
  14. #elif defined(WIN32)                        // WIN32
  15. extern __declspec(dllimport) BOOL __stdcall QueryPerformanceFrequency(LONGLONG * f);
  16. extern __declspec(dllimport) BOOL __stdcall QueryPerformanceCounter(LONGLONG * t);
  17. #else                                       // Linux et al
  18. #include <time.h>
  19. #endif
  20. #include <math.h>                           // floor()
  21.  
  22. typedef struct
  23. {
  24.     uint64_t timestamp[2];
  25. } TimeStamp_t;
  26.  
  27. typedef struct
  28. {
  29.     uint32_t secs;
  30.     uint32_t ns;
  31. } TimeStampDiff_t;
  32.  
  33. static void TimeStamp(TimeStamp_t *timeStamp)
  34. {
  35. #ifdef __MACH__
  36.     *(uint64_t *)timeStamp = mach_absolute_time();
  37. #elif defined(WIN32)
  38.     (void)QueryPerformanceCounter((LONGLONG *)timeStamp);
  39. #else
  40.     clock_gettime(CLOCK_PROCESS_CPUTIME_ID, (struct timespec *)timeStamp);
  41. #endif
  42. }
  43.  
  44. static void TimeStampDiff(
  45.     const TimeStamp_t *timeStamp1,
  46.     const TimeStamp_t *timeStamp2,
  47.     TimeStampDiff_t *timeStampDiff)
  48. {
  49.     double ns, s;
  50.  
  51. #ifdef __MACH__
  52.     static mach_timebase_info_data_t sTimebaseInfo;
  53.     uint64_t t1, t2;
  54.  
  55.     if (sTimebaseInfo.denom == 0)
  56.     {
  57.         (void)mach_timebase_info(&sTimebaseInfo);
  58.     }
  59.     t1 = *(uint64_t *)timeStamp1;
  60.     t2 = *(uint64_t *)timeStamp2;
  61.     ns = (double)(t2 - t1) * (double)sTimebaseInfo.numer / (double)sTimebaseInfo.denom;
  62.                                             // get difference in ns
  63. #elif defined(WIN32)
  64.     static int64_t sTimeBaseFreq = 0;
  65.     int64_t t1, t2;
  66.  
  67.     if (sTimeBaseFreq == 0)
  68.     {
  69.         (void)QueryPerformanceFrequency((LONGLONG *)&sTimeBaseFreq);
  70.     }
  71.     t1 = *(int64_t *)timeStamp1;
  72.     t2 = *(int64_t *)timeStamp2;
  73.     ns = (double)(t2 - t1) / (double)sTimeBaseFreq * 1.0e9;
  74. #else
  75.     struct timespec t1 = *(struct timespec *)timeStamp1;
  76.     struct timespec t2 = *(struct timespec *)timeStamp2;
  77.     ns = (double)(t2.tv_sec - t1.tv_sec) * 1.0e9 + (double)(t2.tv_nsec - t1.tv_nsec);
  78.                                             // get difference in ns
  79. #endif
  80.     s = ns * 1.0e-9;                        // get difference in s
  81.     timeStampDiff->secs = (uint32_t)floor(s);   // get integer seconds
  82.     timeStampDiff->ns = (uint32_t)((s - (double)timeStampDiff->secs) * 1.0e9);
  83.                                             // get integer ns
  84. }
  85.  
  86. using namespace std;
  87.  
  88. #define ONE 0x80000000
  89. #define ZERO 0x00000009
  90.  
  91. static const __m256d SIGNMASK = _mm256_castsi256_pd(_mm256_set_epi32(ONE,ZERO,ONE,ZERO,ONE,ZERO,ONE,ZERO));
  92.  
  93. const int dim = 4;
  94.  
  95. typedef void (*Test_Function)(
  96.     double lam_11[4][4],
  97.     double lam_12[4][4],
  98.     double lam_13[4][4],
  99.     double lam_22[4][4],
  100.     double lam_23[4][4],
  101.     double lam_33[4][4],
  102.     const double rjk[4],
  103.     const double a[4],
  104.     const double b[4],
  105.     const double sqrt_gamma,
  106.     const double SPab,
  107.     const double d1_phi[16],
  108.     double d2_phi[192]);
  109.  
  110. static void foo_ref(
  111.     double lam_11[4][4],
  112.     double lam_12[4][4],
  113.     double lam_13[4][4],
  114.     double lam_22[4][4],
  115.     double lam_23[4][4],
  116.     double lam_33[4][4],
  117.     const double rjk[4],
  118.     const double a[4],
  119.     const double b[4],
  120.     const double sqrt_gamma,
  121.     const double SPab,
  122.     const double d1_phi[16],
  123.     double d2_phi[192])
  124. {
  125.     const double re_sqrt_gamma = 1.0 / sqrt_gamma;
  126.  
  127.     memset(d2_phi, 0.0, 192*sizeof(double));
  128.  
  129.     for (int alpha=0; alpha < 4; alpha++) {
  130.         for (int k=0; k < 3; k++) {
  131.                 for (int beta=0; beta < 4; beta++) {
  132.                     for (int l=0; l < 4 ; l++) {
  133.                         d2_phi[(alpha*3+k)*16 + beta*dim+l] =
  134.                          SPab * d1_phi[alpha*dim+k] * d1_phi[beta*dim+l]
  135.                                 +   a[k] * (  lam_11[alpha][ beta]*  a[l]
  136.                                                 + lam_12[alpha][ beta]*  b[l]
  137.                                                 + lam_13[alpha][ beta]*rjk[l]  );
  138.  
  139.                     }
  140.                 }
  141.             }
  142.     }
  143.  
  144.     for (int alpha=0; alpha < 4; alpha++) {
  145.         for (int k=0; k < 3; k++) {
  146.             for (int beta=0; beta < 4; beta++) {
  147.                 for (int l=0; l < 4 ; l++) {
  148.                     d2_phi[(alpha*3+k)*16 + beta*4+l] = - (     d2_phi[(alpha*3+k)*16 + beta*dim+l]
  149.  
  150.                             +   b[k] * (  lam_12[ beta][alpha] *  a[l]
  151.                                         + lam_22[alpha][ beta] *  b[l]
  152.                                         + lam_23[alpha][ beta] * rjk[l]  )
  153.  
  154.                             + rjk[k] * (  lam_13[ beta][alpha] *  a[l]
  155.                                         + lam_23[ beta][alpha] *  b[l]
  156.                                         + lam_33[alpha][ beta] *rjk[l]  )
  157.                             ) * re_sqrt_gamma;
  158.                 }
  159.             }
  160.         }
  161.     }
  162. }
  163.  
  164. static void foo_ref_1(
  165.     double lam_11[4][4],
  166.     double lam_12[4][4],
  167.     double lam_13[4][4],
  168.     double lam_22[4][4],
  169.     double lam_23[4][4],
  170.     double lam_33[4][4],
  171.     const double rjk[4],
  172.     const double a[4],
  173.     const double b[4],
  174.     const double sqrt_gamma,
  175.     const double SPab,
  176.     const double d1_phi[16],
  177.     double d2_phi[192])
  178. {
  179.     const double re_sqrt_gamma = 1.0 / sqrt_gamma;
  180.  
  181.     memset(d2_phi, 0.0, 192*sizeof(double));
  182.  
  183.     for (int alpha=0; alpha < 4; alpha++) {
  184.         for (int beta=0; beta < 4; beta++) {
  185.             for (int k=0; k < 3; k++) {
  186.                     for (int l=0; l < 4 ; l++) {
  187.                         d2_phi[(alpha*3+k)*16 + beta*dim+l] =
  188.                          SPab * d1_phi[alpha*dim+k] * d1_phi[beta*dim+l]
  189.                                 +   a[k] * (  lam_11[alpha][ beta]*  a[l]
  190.                                                 + lam_12[alpha][ beta]*  b[l]
  191.                                                 + lam_13[alpha][ beta]*rjk[l]  );
  192.  
  193.                     }
  194.                 }
  195.             }
  196.     }
  197.  
  198.     for (int alpha=0; alpha < 4; alpha++) {
  199.         for (int beta=0; beta < 4; beta++) {
  200.             for (int k=0; k < 3; k++) {
  201.                 for (int l=0; l < 4 ; l++) {
  202.                     d2_phi[(alpha*3+k)*16 + beta*4+l] = - (     d2_phi[(alpha*3+k)*16 + beta*dim+l]
  203.  
  204.                             +   b[k] * (  lam_12[ beta][alpha] *  a[l]
  205.                                         + lam_22[alpha][ beta] *  b[l]
  206.                                         + lam_23[alpha][ beta] * rjk[l]  )
  207.  
  208.                             + rjk[k] * (  lam_13[ beta][alpha] *  a[l]
  209.                                         + lam_23[ beta][alpha] *  b[l]
  210.                                         + lam_33[alpha][ beta] *rjk[l]  )
  211.                             ) * re_sqrt_gamma;
  212.                 }
  213.             }
  214.         }
  215.     }
  216. }
  217.  
  218. static void foo_test(
  219.     double lam_11[4][4],
  220.     double lam_12[4][4],
  221.     double lam_13[4][4],
  222.     double lam_22[4][4],
  223.     double lam_23[4][4],
  224.     double lam_33[4][4],
  225.     const double rjk[4],
  226.     const double a[4],
  227.     const double b[4],
  228.     const double sqrt_gamma,
  229.     const double SPab,
  230.     const double d1_phi[16],
  231.     double d2_phi[192])
  232. {
  233.     const double re_sqrt_gamma = 1.0 / sqrt_gamma;
  234.  
  235.     const double * const lam_11_p = &lam_11[0][0];
  236.     const double * const lam_12_p = &lam_12[0][0];
  237.     const double * const lam_13_p = &lam_13[0][0];
  238.     const double * const lam_22_p = &lam_22[0][0];
  239.     const double * const lam_23_p = &lam_23[0][0];
  240.     const double * const lam_33_p = &lam_33[0][0];
  241.  
  242.     memset(d2_phi, 0.0, 192*sizeof(double));
  243.  
  244.     double* addr = d2_phi;
  245.     __m256d ymm0; __m256d ymm1; __m256d ymm2; __m256d ymm3;
  246.     __m256d ymm4; __m256d ymm5; __m256d ymm6; __m256d ymm7;
  247.     __m256d ymm8; __m256d ymm9; __m256d ymm10; __m256d ymm11;
  248.     __m256d ymm12; __m256d ymm13; __m256d ymm14; __m256d ymm15;
  249.  
  250.     // load SPab, because it is constant
  251.     ymm0 = _mm256_broadcast_sd(&SPab);
  252.  
  253.     for (int alpha=0; alpha < 4; alpha++) {
  254.         for (int k=0; k < 3; k++) {
  255.  
  256.             ymm1 = _mm256_broadcast_sd(d1_phi + alpha*4 + k); // load d1_phi[alpha*dim+k] to all
  257.             ymm2 = _mm256_broadcast_sd(a + k); // load a[k] to all
  258.             // Precalculate a part here, because it is reusable
  259.             ymm10 = _mm256_mul_pd(ymm0, ymm1); // SPab * d1_phi[alpha*dim+k] = PRE
  260.  
  261.             for (int beta=0; beta < 4; beta++) {
  262.                 // Load the three lambdas to all
  263.                 ymm3 = _mm256_broadcast_sd(lam_11_p + alpha*4 + beta);
  264.                 ymm4 = _mm256_broadcast_sd(lam_12_p + alpha*4 + beta);
  265.                 ymm5 = _mm256_broadcast_sd(lam_13_p + alpha*4 + beta);
  266.  
  267.                 ymm6 = _mm256_load_pd(a); // load the whole 4-vector 'a' into register
  268.                 ymm7 = _mm256_load_pd(b); // load the whole 4-vector 'b' into register
  269.                 ymm8 = _mm256_load_pd(rjk); // load the whole 4-vector 'b' into register
  270.                 ymm9 = _mm256_load_pd(d1_phi + beta*4);
  271.  
  272.                 ymm11 = _mm256_mul_pd(ymm10, ymm9); // PRE * d1_phi[beta*dim+l] = SUM1
  273.  
  274.                 // Do the three Multiplications
  275.                 ymm12 = _mm256_mul_pd(ymm3,ymm6); // lam_11[alpha][ beta] *  a[l] = PROD1
  276.                 ymm13 = _mm256_mul_pd(ymm4,ymm7); // lam_12[alpha][ beta] *  b[l] = PROD2
  277.                 ymm14 = _mm256_mul_pd(ymm5,ymm8); // lam_13[alpha][ beta] * rjk[l] = PROD3
  278.  
  279.                 ymm12 = _mm256_add_pd(ymm12, ymm13); // PROD1 + PROD2 = PROD12
  280.                 ymm12 = _mm256_add_pd(ymm12, ymm14); // PROD12 + PROD3 = PROD123
  281.  
  282.                 ymm12 = _mm256_mul_pd(ymm12, ymm2); // a[k] * PROD123 = SUM2
  283.                 ymm12 = _mm256_add_pd(ymm11, ymm12); // SUM1 + SUM2
  284.  
  285.                 _mm256_stream_pd(addr, ymm12);
  286.                 addr+=4;
  287.             }
  288.         }
  289.     }
  290.  
  291.     addr = d2_phi;
  292.     // load sqrt_gamma, because it is constant
  293.     ymm7 = _mm256_broadcast_sd(&re_sqrt_gamma);
  294.  
  295.     for (int alpha=0; alpha < 4; alpha++) {
  296.         for (int k=0; k < 3; k++) {
  297.             // Load values that are only dependent on k
  298.             ymm9 = _mm256_broadcast_sd(b+k); // all b[k]
  299.             ymm8 = _mm256_broadcast_sd(rjk+k); // all rjk[k]
  300.  
  301.             for (int beta=0; beta < 4; beta++) {
  302.                 // Load the lambdas, because they will stay the same for nine iterations
  303.                 ymm15 = _mm256_broadcast_sd(lam_12_p + 4*beta + alpha);   // all lam_12[ beta][alpha]
  304.                 ymm14 = _mm256_broadcast_sd(lam_22_p + 4*alpha + beta);   // all lam_22[alpha][ beta]
  305.                 ymm13 = _mm256_broadcast_sd(lam_23_p + 4*alpha + beta);   // all lam_23[alpha][ beta]
  306.                 ymm12 = _mm256_broadcast_sd(lam_13_p + 4*beta + alpha);   // all lam_13[ beta][alpha]
  307.                 ymm11 = _mm256_broadcast_sd(lam_23_p + 4*beta + alpha); // all lam_23[ beta][alpha]
  308.                 ymm10 = _mm256_broadcast_sd(lam_33_p + 4*alpha + beta); // lam_33[alpha][ beta]
  309.  
  310.                 // Load the values that depend on the innermost loop, which is removed do to AVX
  311.                 ymm6 =_mm256_load_pd(a); // a[i] until a[l+3]
  312.                 ymm5 =_mm256_load_pd(b); // b[i] until b[l+3]
  313.                 ymm4 =_mm256_load_pd(rjk); // rjk[i] until rjk[l+3]
  314.                 //__m256d ymm3 =_mm256_load_pd(d2_phi +  (alpha*3+k)*16  + beta*dim); // d2_phi[(alpha*3+k)*12 + beta*dim] until d2_phi[(alpha*3+k)*12 + beta*dim +3]
  315.                 ymm3 =_mm256_load_pd(addr);
  316.                 // Block that is later on multiplied with b[k]
  317.                 ymm2 = _mm256_mul_pd(ymm15 , ymm6); // lam_12[ beta][alpha] *  a[l]
  318.                 ymm1 = _mm256_mul_pd(ymm14 , ymm5); // lam_22[alpha][ beta] * b[l];
  319.  
  320.                 ymm0 = _mm256_add_pd(ymm2, ymm1); // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l];
  321.  
  322.                 ymm2 = _mm256_mul_pd(ymm13 , ymm4); // lam_23[alpha][ beta] * rjk[l]
  323.                 ymm0 = _mm256_add_pd(ymm2, ymm0); // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l] + lam_23[alpha][ beta] * b[i];
  324.  
  325.                 ymm0 = _mm256_mul_pd(ymm9 , ymm0); // b[k] * (first sum of three)
  326.  
  327.                 // Block that is later on multiplied with rjk[k]
  328.                 ymm2 = _mm256_mul_pd(ymm12 , ymm6); // lam_13[ beta][alpha] *  a[l]
  329.                 ymm1 = _mm256_mul_pd(ymm11 , ymm5); // lam_23[ beta][alpha] *  b[l]
  330.  
  331.                 ymm2 = _mm256_add_pd(ymm2, ymm1); // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l];
  332.  
  333.                 ymm1 = _mm256_mul_pd(ymm10 , ymm4); // lam_33[alpha][ beta] * rjk[l]
  334.                 ymm2 = _mm256_add_pd(ymm2 , ymm1); // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l] + lam_33[alpha][ beta] *rjk[l]
  335.  
  336.                 ymm2 = _mm256_mul_pd(ymm2 , ymm8); // rjk[k] * (second sum of three)
  337.                 ymm0 = _mm256_add_pd(ymm0, ymm2); // add to temporal result in ymm0
  338.                 ymm0 = _mm256_add_pd(ymm3, ymm0); // Old value of d2 Phi;
  339.  
  340.                 ymm0 = _mm256_mul_pd(ymm0, ymm7); // all divided by sqrt_gamma
  341.                 ymm0 = _mm256_xor_pd(ymm0, SIGNMASK);
  342.  
  343.                 _mm256_store_pd(addr, ymm0);
  344.                 //_mm256_stream_pd(d2_phi + (alpha*3+k)*16  + beta*dim, ymm0);
  345.                 addr += 4;
  346.             }
  347.         }
  348.     }
  349. }
  350.  
  351. static void foo_test_1(
  352.     double lam_11[4][4],
  353.     double lam_12[4][4],
  354.     double lam_13[4][4],
  355.     double lam_22[4][4],
  356.     double lam_23[4][4],
  357.     double lam_33[4][4],
  358.     const double rjk[4],
  359.     const double a[4],
  360.     const double b[4],
  361.     const double sqrt_gamma,
  362.     const double SPab,
  363.     const double d1_phi[16],
  364.     double d2_phi[192])
  365. {
  366.     const double re_sqrt_gamma = 1.0 / sqrt_gamma;
  367.  
  368.     memset(d2_phi, 0.0, 192*sizeof(double));
  369.  
  370.     const __m256d ymm6 = _mm256_load_pd(a); // load the whole 4-vector 'a' into register
  371.  
  372.     {
  373.         // load SPab, because it is constant
  374.         const __m256d ymm0 = _mm256_broadcast_sd(&SPab);
  375.         const __m256d ymm7 = _mm256_load_pd(b); // load the whole 4-vector 'b' into register
  376.         const __m256d ymm8 = _mm256_load_pd(rjk); // load the whole 4-vector 'rjk' into register
  377.  
  378.         double* addr = d2_phi;
  379.  
  380.         for (int alpha=0; alpha < 4; alpha++)
  381.         {
  382.             for (int k=0; k < 3; k++)
  383.             {
  384.                 const __m256d ymm1 = _mm256_broadcast_sd(&d1_phi[alpha*dim + k]); // load d1_phi[alpha*dim+k] to all
  385.                 const __m256d ymm2 = _mm256_broadcast_sd(&a[k]); // load a[k] to all
  386.                 const __m256d ymm10 = _mm256_mul_pd(ymm0, ymm1); // SPab * d1_phi[alpha*dim+k] = PRE
  387.  
  388.                 for (int beta=0; beta < 4; beta++) {
  389.                     // Load the three lambdas to all
  390.                     const __m256d ymm3 = _mm256_broadcast_sd(&lam_11[alpha][beta]);
  391.                     const __m256d ymm4 = _mm256_broadcast_sd(&lam_12[alpha][beta]);
  392.                     const __m256d ymm5 = _mm256_broadcast_sd(&lam_13[alpha][beta]);
  393.  
  394.                     const __m256d ymm9 = _mm256_load_pd(d1_phi + beta*4);
  395.  
  396.                     const __m256d ymm11 = _mm256_mul_pd(ymm10, ymm9); // PRE * d1_phi[beta*dim+l] = SUM1
  397.  
  398.                     // Do the three Multiplications
  399.                     __m256d ymm12 = _mm256_mul_pd(ymm3,ymm6); // lam_11[alpha][ beta] *  a[l] = PROD1
  400.                     const __m256d ymm13 = _mm256_mul_pd(ymm4,ymm7); // lam_12[alpha][ beta] *  b[l] = PROD2
  401.                     const __m256d ymm14 = _mm256_mul_pd(ymm5,ymm8); // lam_13[alpha][ beta] * rjk[l] = PROD3
  402.  
  403.                     __m256d ymm15 = _mm256_add_pd(ymm12, ymm13); // PROD1 + PROD2 = PROD12
  404.                     ymm12 = _mm256_add_pd(ymm15, ymm14); // PROD12 + PROD3 = PROD123
  405.  
  406.                     ymm15 = _mm256_mul_pd(ymm12, ymm2); // a[k] * PROD123 = SUM2
  407.                     ymm12 = _mm256_add_pd(ymm11, ymm15); // SUM1 + SUM2
  408.  
  409.                     _mm256_stream_pd(addr, ymm12);
  410.                     addr+=4;
  411.                 }
  412.             }
  413.         }
  414.     }
  415.  
  416.     {
  417.         const __m256d ymm4 =_mm256_load_pd(rjk); // rjk[i] until rjk[l+3]
  418.         const __m256d ymm5 =_mm256_load_pd(b); // b[l] until b[l+3]
  419.  
  420.         // load sqrt_gamma, because it is constant
  421.         const __m256d ymm7 = _mm256_broadcast_sd(&re_sqrt_gamma);
  422.  
  423.         double* addr = d2_phi;
  424.  
  425.         for (int alpha=0; alpha < 4; alpha++)
  426.         {
  427.             for (int k=0; k < 3; k++)
  428.             {
  429.                 // Load values that are only dependent on k
  430.                 const __m256d ymm9 = _mm256_broadcast_sd(b+k); // all b[k]
  431.                 const __m256d ymm8 = _mm256_broadcast_sd(rjk+k); // all rjk[k]
  432.  
  433.                 for (int beta=0; beta < 4; beta++)
  434.                 {
  435.                     __m256d ymm0, ymm1, ymm2;
  436.  
  437.                     // Load the lambdas, because they will stay the same for nine iterations
  438.                     const __m256d ymm15 = _mm256_broadcast_sd(&lam_12[beta][alpha]);   // all lam_12[ beta][alpha]
  439.                     const __m256d ymm14 = _mm256_broadcast_sd(&lam_22[alpha][beta]);   // all lam_22[alpha][ beta]
  440.                     const __m256d ymm13 = _mm256_broadcast_sd(&lam_23[alpha][beta]);   // all lam_23[alpha][ beta]
  441.                     const __m256d ymm12 = _mm256_broadcast_sd(&lam_13[beta][alpha]);   // all lam_13[ beta][alpha]
  442.                     const __m256d ymm11 = _mm256_broadcast_sd(&lam_23[beta][alpha]); // all lam_23[ beta][alpha]
  443.                     const __m256d ymm10 = _mm256_broadcast_sd(&lam_33[alpha][beta]); // lam_33[alpha][ beta]
  444.  
  445.                     // Load the values that depend on the innermost loop, which is removed do to AVX
  446.  
  447.                     const __m256d ymm3 =_mm256_load_pd(addr);
  448.                     // Block that is later on multiplied with b[k]
  449.                     ymm2 = _mm256_mul_pd(ymm15 , ymm6); // lam_12[ beta][alpha] *  a[l]
  450.                     ymm1 = _mm256_mul_pd(ymm14 , ymm5); // lam_22[alpha][ beta] * b[l];
  451.  
  452.                     ymm0 = _mm256_add_pd(ymm2, ymm1); // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l];
  453.  
  454.                     ymm2 = _mm256_mul_pd(ymm13 , ymm4); // lam_23[alpha][ beta] * rjk[l]
  455.                     ymm0 = _mm256_add_pd(ymm2, ymm0); // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l] + lam_23[alpha][ beta] * b[i];
  456.  
  457.                     ymm0 = _mm256_mul_pd(ymm9 , ymm0); // b[k] * (first sum of three)
  458.  
  459.                     // Block that is later on multiplied with rjk[k]
  460.                     ymm2 = _mm256_mul_pd(ymm12 , ymm6); // lam_13[ beta][alpha] *  a[l]
  461.                     ymm1 = _mm256_mul_pd(ymm11 , ymm5); // lam_23[ beta][alpha] *  b[l]
  462.  
  463.                     ymm2 = _mm256_add_pd(ymm2, ymm1); // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l];
  464.  
  465.                     ymm1 = _mm256_mul_pd(ymm10 , ymm4); // lam_33[alpha][ beta] * rjk[l]
  466.                     ymm2 = _mm256_add_pd(ymm2 , ymm1); // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l] + lam_33[alpha][ beta] *rjk[l]
  467.  
  468.                     ymm1 = _mm256_mul_pd(ymm2 , ymm8); // rjk[k] * (second sum of three)
  469.                     ymm2 = _mm256_add_pd(ymm0, ymm1); // add to temporal result in ymm0
  470.                     ymm1 = _mm256_add_pd(ymm3, ymm2); // Old value of d2 Phi;
  471.  
  472.                     ymm2 = _mm256_mul_pd(ymm1, ymm7); // all divided by sqrt_gamma
  473.                     ymm1 = _mm256_xor_pd(ymm2, SIGNMASK);
  474.  
  475.                     _mm256_store_pd(addr, ymm1);
  476.                     //_mm256_stream_pd(d2_phi + (alpha*3+k)*16  + beta*dim, ymm0);
  477.                     addr += 4;
  478.                 }
  479.             }
  480.         }
  481.     }
  482. }
  483.  
  484. static void foo_test_2(
  485.     double lam_11[4][4],
  486.     double lam_12[4][4],
  487.     double lam_13[4][4],
  488.     double lam_22[4][4],
  489.     double lam_23[4][4],
  490.     double lam_33[4][4],
  491.     const double rjk[4],
  492.     const double a[4],
  493.     const double b[4],
  494.     const double sqrt_gamma,
  495.     const double SPab,
  496.     const double d1_phi[16],
  497.     double d2_phi[192])
  498. {
  499.     const double re_sqrt_gamma = 1.0 / sqrt_gamma;
  500.  
  501.     memset(d2_phi, 0.0, 192*sizeof(double));
  502.  
  503.     const __m256d ymm6 = _mm256_load_pd(a); // load the whole 4-vector 'a' into register
  504.  
  505.     {
  506.         // load SPab, because it is constant
  507.         const __m256d ymm0 = _mm256_broadcast_sd(&SPab);
  508.         const __m256d ymm7 = _mm256_load_pd(b); // load the whole 4-vector 'b' into register
  509.         const __m256d ymm8 = _mm256_load_pd(rjk); // load the whole 4-vector 'rjk' into register
  510.  
  511.         for (int alpha=0; alpha < 4; alpha++)
  512.         {
  513.             for (int beta=0; beta < 4; beta++)
  514.             {
  515.                 // Load the three lambdas to all
  516.                 const __m256d ymm3 = _mm256_broadcast_sd(&lam_11[alpha][beta]);
  517.                 const __m256d ymm4 = _mm256_broadcast_sd(&lam_12[alpha][beta]);
  518.                 const __m256d ymm5 = _mm256_broadcast_sd(&lam_13[alpha][beta]);
  519.  
  520.                 const __m256d ymm9 = _mm256_load_pd(d1_phi + beta*4);
  521.  
  522.                 // Do the three Multiplications
  523.                 const __m256d ymm13 = _mm256_mul_pd(ymm4,ymm7); // lam_12[alpha][ beta] *  b[l] = PROD2
  524.                 const __m256d ymm14 = _mm256_mul_pd(ymm5,ymm8); // lam_13[alpha][ beta] * rjk[l] = PROD3
  525.                 const __m256d ymm15 = _mm256_mul_pd(ymm3,ymm6); // lam_11[alpha][ beta] *  a[l] = PROD1
  526.                 __m256d ymm12 = _mm256_add_pd(ymm15, ymm13); // PROD1 + PROD2 = PROD12
  527.                 ymm12 = _mm256_add_pd(ymm12, ymm14); // PROD12 + PROD3 = PROD123
  528.  
  529.                 double* addr = d2_phi + alpha*3*16  + beta*dim;
  530.  
  531.                 for (int k=0; k < 3; k++)
  532.                 {
  533.                     const __m256d ymm1 = _mm256_broadcast_sd(&d1_phi[alpha*dim + k]); // load d1_phi[alpha*dim+k] to all
  534.                     const __m256d ymm2 = _mm256_broadcast_sd(&a[k]); // load a[k] to all
  535.                     const __m256d ymm10 = _mm256_mul_pd(ymm0, ymm1); // SPab * d1_phi[alpha*dim+k] = PRE
  536.                     const __m256d ymm11 = _mm256_mul_pd(ymm10, ymm9); // PRE * d1_phi[beta*dim+l] = SUM1
  537.  
  538.                     __m256d ymm12t = _mm256_mul_pd(ymm12, ymm2); // a[k] * PROD123 = SUM2
  539.                     ymm12t = _mm256_add_pd(ymm11, ymm12t); // SUM1 + SUM2
  540.  
  541.                     _mm256_store_pd(addr, ymm12t);
  542.  
  543.                     addr+=16;
  544.                 }
  545.             }
  546.         }
  547.     }
  548.  
  549.     {
  550.         const __m256d ymm4 =_mm256_load_pd(rjk); // rjk[i] until rjk[l+3]
  551.         const __m256d ymm5 =_mm256_load_pd(b); // b[l] until b[l+3]
  552.  
  553.         // load sqrt_gamma, because it is constant
  554.         const __m256d ymm7 = _mm256_broadcast_sd(&re_sqrt_gamma);
  555.  
  556.         for (int alpha=0; alpha < 4; alpha++)
  557.         {
  558.             for (int beta=0; beta < 4; beta++)
  559.             {
  560.                 // Load the lambdas, because they will stay the same for nine iterations
  561.                 const __m256d ymm15 = _mm256_broadcast_sd(&lam_12[beta][alpha]);   // all lam_12[ beta][alpha]
  562.                 const __m256d ymm14 = _mm256_broadcast_sd(&lam_22[alpha][beta]);   // all lam_22[alpha][ beta]
  563.                 const __m256d ymm13 = _mm256_broadcast_sd(&lam_23[alpha][beta]);   // all lam_23[alpha][ beta]
  564.                 const __m256d ymm12 = _mm256_broadcast_sd(&lam_13[beta][alpha]);   // all lam_13[ beta][alpha]
  565.                 const __m256d ymm11 = _mm256_broadcast_sd(&lam_23[beta][alpha]); // all lam_23[ beta][alpha]
  566.                 const __m256d ymm10 = _mm256_broadcast_sd(&lam_33[alpha][beta]); // lam_33[alpha][ beta]
  567.  
  568.                 __m256d ymm0, ymm1, ymm2;
  569.  
  570.                 // Block that is later on multiplied with b[k]
  571.                 ymm2 = _mm256_mul_pd(ymm15 , ymm6); // lam_12[ beta][alpha] *  a[l]
  572.                 ymm1 = _mm256_mul_pd(ymm14 , ymm5); // lam_22[alpha][ beta] * b[l];
  573.                 ymm0 = _mm256_add_pd(ymm2, ymm1);   // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l];
  574.                 ymm2 = _mm256_mul_pd(ymm13 , ymm4); // lam_23[alpha][ beta] * rjk[l]
  575.                 ymm0 = _mm256_add_pd(ymm2, ymm0);   // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l] + lam_23[alpha][ beta] * b[i];
  576.  
  577.                 // Block that is later on multiplied with rjk[k]
  578.                 ymm2 = _mm256_mul_pd(ymm12 , ymm6); // lam_13[ beta][alpha] *  a[l]
  579.                 ymm1 = _mm256_mul_pd(ymm11 , ymm5); // lam_23[ beta][alpha] *  b[l]
  580.                 ymm2 = _mm256_add_pd(ymm2, ymm1);   // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l];
  581.                 ymm1 = _mm256_mul_pd(ymm10 , ymm4); // lam_33[alpha][ beta] * rjk[l]
  582.                 ymm2 = _mm256_add_pd(ymm2 , ymm1);  // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l] + lam_33[alpha][ beta] *rjk[l]
  583.  
  584.                 double* addr = d2_phi + alpha*3*16  + beta*dim;
  585.  
  586.                 for (int k=0; k < 3; k++)
  587.                 {
  588.                     // Load values that are only dependent on k
  589.                     const __m256d ymm9 = _mm256_broadcast_sd(b+k); // all b[k]
  590.                     const __m256d ymm8 = _mm256_broadcast_sd(rjk+k); // all rjk[k]
  591.  
  592.                     // Load the values that depend on the innermost loop, which is removed do to AVX
  593.  
  594.                     const __m256d ymm3 =_mm256_load_pd(addr);
  595.  
  596.                     __m256d ymm0t, ymm1t, ymm2t;
  597.  
  598.                     // Block that is later on multiplied with b[k]
  599.  
  600.                     ymm0t = _mm256_mul_pd(ymm9 , ymm0); // b[k] * (first sum of three)
  601.  
  602.                     // Block that is later on multiplied with rjk[k]
  603.  
  604.                     ymm1t = _mm256_mul_pd(ymm2 , ymm8); // rjk[k] * (second sum of three)
  605.                     ymm2t = _mm256_add_pd(ymm0t, ymm1t); // add to temporal result in ymm0
  606.                     ymm1t = _mm256_add_pd(ymm3, ymm2t);  // Old value of d2 Phi;
  607.  
  608.                     ymm2t = _mm256_mul_pd(ymm1t, ymm7); // all divided by sqrt_gamma
  609.                     ymm1t = _mm256_xor_pd(ymm2t, SIGNMASK);
  610.  
  611.                     _mm256_store_pd(addr, ymm1t);
  612.  
  613.                     addr += 16;
  614.                 }
  615.             }
  616.         }
  617.     }
  618. }
  619.  
  620. static void test(const char *foo_name, Test_Function foo)
  621. {
  622.     TimeStamp_t t0, t1;
  623.     TimeStampDiff_t t_diff;
  624.     double dtime;
  625.     double tmp;
  626.  
  627.     __declspec(align(64)) double lam_11[4][4] = {
  628.         {217.91920572807675,-363.19867621346157,217.91920572807777,-72.639735242692979},
  629.         {-363.19867621346157,651.03497117894915,-396.79589782952706,108.95960286403944},
  630.         {217.91920572807777,-396.79589782952706,215.19655972279568,-36.319867621346461},
  631.         {-72.639735242692979,108.95960286403944,-36.319867621346461,0.00000000000000000}};
  632.  
  633.     __declspec(align(64)) double lam_12[4][4] = {{-72.639735242692979,145.27947048538556,-145.27947048538485,72.639735242692254},
  634.         {108.95960286403944,-287.83629496548929,324.15616258683468,-145.27947048538482},
  635.         {-36.319867621346461,178.87669210145020,-287.83629496548929,145.27947048538559},
  636.         {0.00000000000000000,-36.319867621346489,108.95960286403948,-72.639735242692979}};
  637.  
  638.     __declspec(align(64)) double lam_13[4][4] = {{0.00000000000000000,0.00000000000000000,0.00000000000000000,0.00000000000000000},
  639.         {-441.12078924256735,882.24157848513676,-882.24157848514119,441.12078924257179},
  640.         {441.12078924256735,-882.24157848513676,882.24157848514119,-441.12078924257179},
  641.         {0.00000000000000000,0.00000000000000000,0.00000000000000000,0.00000000000000000}};
  642.  
  643.     __declspec(align(64)) double lam_22[4][4] = {{0.00000000000000000,-36.319867621346489,108.95960286403948,-72.639735242692979},
  644.         {-36.319867621346489,215.19655972279583,-396.79589782952723,217.91920572807783},
  645.         {108.95960286403948,-396.79589782952723,651.03497117894938,-363.19867621346168},
  646.         {-72.639735242692979,217.91920572807783,-363.19867621346168,217.91920572807675}};
  647.  
  648.     __declspec(align(64)) double lam_23[4][4] = {{-0.00000000000000000,-0.00000000000000000,-0.00000000000000000,-0.00000000000000000},
  649.         {441.12078924257179,-882.24157848514119,882.24157848513676,-441.12078924256735},
  650.         {-441.12078924257179,882.24157848514119,-882.24157848513676,441.12078924256735},
  651.         {-0.00000000000000000,-0.00000000000000000,-0.00000000000000000,-0.00000000000000000}};
  652.  
  653.     __declspec(align(64)) double lam_33[4][4] = {{3759.6260671131017,-7519.2521342262207,7519.2521342262580,-3759.6260671131390},
  654.         {-7519.2521342262207,15038.504268452458,-15038.504268452496,7519.2521342262580},
  655.         {7519.2521342262589,-15038.504268452496,15038.504268452458,-7519.2521342262207},
  656.         {-3759.6260671131390,7519.2521342262580,-7519.2521342262207,3759.6260671131017}};
  657.  
  658.     __declspec(align(64)) double rjk[4] = {0.13900000000000001,0.00000000000000000,0.00000000000000000,0.00000000000000000};
  659.     __declspec(align(64)) double a  [4] = {0.00000000000000000,-0.92388179295480133,0.38267797512611235,0.00000000000000000};
  660.     __declspec(align(64)) double b  [4] = {0.00000000000000000,-0.92388179295480133,0.38267797512611235,0.00000000000000000};
  661.  
  662.     const double sqrt_gamma=1.4136482746161726e-007;
  663.     const double SPab=0.99999999999999001;
  664.  
  665.     __declspec(align(64)) const double d1_phi[16] = {-0.00000000000000000, 5.5917642862595271e-007, -2.2932516454884578e-007, 0.00000000000000000,
  666.         -0.00000000000000000, -5.5289354740543637e-007,2.2618372393858761e-007, 0.00000000000000000,
  667.         -0.00000000000000000,-5.5289354740543637e-007, 2.2618372393858761e-007, 0.00000000000000000,
  668.         -0.00000000000000000, 5.5917642862595271e-007, -2.2932516454884578e-007, 0.00000000000000000};
  669.  
  670.     __declspec(align(64)) double d2_phi[192] = { 0 };
  671.  
  672.     TimeStamp(&t0);
  673.  
  674.     for(int j=0; j<10; j++)
  675.     {
  676.         for(int i=0; i<100000; i++)
  677.         {
  678.             memset(d2_phi, 0.0, 192*sizeof(double));
  679.             foo(lam_11, lam_12, lam_13,
  680.                 lam_22, lam_23, lam_33,
  681.                 rjk, a, b,
  682.                 sqrt_gamma, SPab,
  683.                 d1_phi,
  684.                 d2_phi);
  685.         }
  686.     }
  687.  
  688.     TimeStamp(&t1);
  689.     TimeStampDiff(&t0, &t1, &t_diff);
  690.     dtime = (double)t_diff.secs + (double)t_diff.ns * 1.0e-9;
  691.     cout << foo_name << ": needed " << dtime * 1000.0 << " milliseconds to complete.\n";
  692.  
  693.     tmp = 0.0;
  694.     for(int i=0; i<192; i++) {
  695.         tmp += d2_phi[i];
  696.         tmp /= 2.0;
  697.     }
  698.  
  699.     cout << foo_name << ": Lala  " << tmp  << "\n";
  700. }
  701.  
  702. int main()
  703. {
  704.     test("foo_ref   ", &foo_ref);
  705.     test("foo_ref_1 ", &foo_ref_1);
  706.     test("foo_test  ", &foo_test);
  707.     test("foo_test_1", &foo_test_1);
  708.     test("foo_test_2", &foo_test_2);
  709.     return 0;
  710. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement