4plus2equals42

Matrix multiplication with intrinsics

Oct 26th, 2016
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 2.58 KB | None | 0 0
  1. #include <complex.h>
  2. #include <xmmintrin.h>
  3. #include <pmmintrin.h>
  4.  
  5.  
  6. __m128 complex_multiplication(__m128 v1, __m128 v2) {
  7.   __m128 t1 = _mm_moveldup_ps(v1);
  8.   __m128 t2 = _mm_mul_ps(t1, v2);
  9.   v2 = _mm_shuffle_ps(v2, v2, 0xb1);
  10.   t1 = _mm_movehdup_ps(v1);
  11.   t1 = _mm_mul_ps(t1, v2);
  12.   return _mm_addsub_ps(t2, t1);
  13. }
  14.  
  15. void chemm(complex float* A,
  16.         complex float* B,
  17.         complex float* C,
  18.         int m,
  19.         int n,
  20.         complex float alpha,
  21.         complex float beta){
  22.  
  23.   float* alpha_array = (float*)&alpha;
  24.   float* beta_array = (float*)&beta;
  25.   float* c_complex_array;
  26.   float* y_array_1;
  27.   float* y_array_2;
  28.  
  29.  
  30.   // Create two vectors with the repeated values of the 'alpha' and 'beta' constants
  31.   __m128 alpha_vector = _mm_set_ps(alpha_array[1], alpha_array[0], alpha_array[1], alpha_array[0]);
  32.   __m128 beta_vector = _mm_set_ps(beta_array[1], beta_array[0], beta_array[1], beta_array[0]);
  33.  
  34.   // Defining vectors used later
  35.   __m128 x_vector, y_vector, c_vector, c_double_upper_vector, c_lower_sum_vector;
  36.  
  37.   for (int x = 0; x < n; x++) {
  38.     for (int y = 0; y < m; y++) {
  39.       // Load one complex number twice from C
  40.       c_complex_array = (float*)&C[y*n + x];
  41.       c_vector =  _mm_castpd_ps(_mm_loaddup_pd((double*)c_complex_array));
  42.  
  43.       // Multiply with beta
  44.       c_vector = complex_multiplication(c_vector, beta_vector);
  45.  
  46.       // Work with two and two complex multiplications in parallel
  47.       for (int z = 0; z < m; z+=2) {
  48.         // Load two and two complex numbers from A and B
  49.         x_vector = _mm_loadu_ps((float*)&A[y*m + z]);
  50.         y_array_1 = (float*)&B[z*n + x];
  51.         y_array_2 = (float*)&B[(z+1)*n + x];
  52.         y_vector = _mm_set_ps(y_array_2[1], y_array_2[0], y_array_1[1], y_array_1[0]);
  53.  
  54.         // Get multiplication results
  55.         x_vector = complex_multiplication(x_vector, y_vector);
  56.  
  57.         // Also multiply with alpha
  58.         x_vector = complex_multiplication(x_vector, alpha_vector);
  59.  
  60.         // Add result to C's vector
  61.         c_vector = _mm_add_ps(c_vector, x_vector);
  62.       }
  63.  
  64.       // The following stores the answers from the above loop
  65.       //  and stores them in C
  66.  
  67.       // Copy the upper 64 bits of C's result vector to both upper and lower 64 bits
  68.       c_double_upper_vector = _mm_movehl_ps(c_vector, c_vector);
  69.       // Add vectors. This makes the lower 64 the correct sum
  70.       c_lower_sum_vector = _mm_add_ps(c_vector, c_double_upper_vector);
  71.       // Store the lower 64 bits in memory (C matrix)
  72.       _mm_store_sd((double*)c_complex_array, _mm_castps_pd(c_lower_sum_vector));
  73.     }
  74.   }
  75. }
Advertisement
Add Comment
Please, Sign In to add comment