SHARE
TWEET

Vectorizing linear search

stgatilov Jul 19th, 2015 (edited) 261 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <stdio.h>
  2. #include <assert.h>
  3. #include <time.h>
  4. #include <algorithm>
  5. #include <emmintrin.h>
  6. #include <tmmintrin.h>  //_mm_hadd_epi32
  7. #include <smmintrin.h>  //_mm_extract_epi32
  8. #include <intrin.h>
  9. #include <stdint.h>
  10.  
  11. //implementation by OP
  12. static int linear_OP (const int *arr, int n, int key) {
  13.     int i = 0;
  14.     while (i < n) {
  15.         if (arr [i] >= key)
  16.             break;
  17.         ++i;
  18.     }
  19.     return i;
  20. }
  21.  
  22. //scalar implementation by stgatilov (that's me=))
  23. static int linear_stgatilov_scalar (const int *arr, int n, int key) {
  24.     int cnt = 0;
  25.     for (int i = 0; i < n; i++)
  26.         cnt += (arr[i] < key);
  27.     return cnt;
  28. }
  29.  
  30. //vectorized implementation by stgatilov (that's me again=))
  31. static int linear_stgatilov_vec (const int *arr, int n, int key) {
  32.     assert(size_t(arr) % 16 == 0);
  33.     __m128i vkey = _mm_set1_epi32(key);
  34.         __m128i cnt = _mm_setzero_si128();
  35.         for (int i = 0; i < n; i += 16) {
  36.             __m128i mask0 = _mm_cmplt_epi32(_mm_load_si128((__m128i *)&arr[i+0]), vkey);
  37.             __m128i mask1 = _mm_cmplt_epi32(_mm_load_si128((__m128i *)&arr[i+4]), vkey);
  38.             __m128i mask2 = _mm_cmplt_epi32(_mm_load_si128((__m128i *)&arr[i+8]), vkey);
  39.             __m128i mask3 = _mm_cmplt_epi32(_mm_load_si128((__m128i *)&arr[i+12]), vkey);
  40.             __m128i sum = _mm_add_epi32(_mm_add_epi32(mask0, mask1), _mm_add_epi32(mask2, mask3));
  41.             cnt = _mm_sub_epi32(cnt, sum);
  42.         }
  43.         cnt = _mm_hadd_epi32(cnt, cnt);
  44.         cnt = _mm_hadd_epi32(cnt, cnt);
  45. //      int ans = _mm_extract_epi32(cnt, 0);    //SSE4.1
  46.         int ans = _mm_extract_epi16(cnt, 0);    //correct only for n < 32K
  47.         return ans;
  48. }
  49.  
  50. //implementation by Paul R
  51. static int linear_PaulR(const int32_t *A, int n, int32_t key)
  52. {
  53. #define VEC_INT_ELEMS 4
  54. #define BLOCK_SIZE (VEC_INT_ELEMS * 32)
  55.     const __m128i vkey = _mm_set1_epi32(key);
  56.     int vresult = -1;
  57.     int result = -1;
  58.     int i, j;
  59.  
  60.     for (i = 0; i <= n - BLOCK_SIZE; i += BLOCK_SIZE)
  61.     {
  62.         __m128i vmask0 = _mm_set1_epi32(-1);
  63.         __m128i vmask1 = _mm_set1_epi32(-1);
  64.         int mask0, mask1;
  65.  
  66.         for (j = 0; j < BLOCK_SIZE; j += VEC_INT_ELEMS * 2)
  67.         {
  68.             __m128i vA0 = _mm_load_si128((__m128i *)&A[i + j]);
  69.             __m128i vA1 = _mm_load_si128((__m128i *)&A[i + j + VEC_INT_ELEMS]);
  70.             __m128i vcmp0 = _mm_cmpgt_epi32(vkey, vA0);
  71.             __m128i vcmp1 = _mm_cmpgt_epi32(vkey, vA1);
  72.             vmask0 = _mm_and_si128(vmask0, vcmp0);
  73.             vmask1 = _mm_and_si128(vmask1, vcmp1);
  74.         }
  75.         mask0 = _mm_movemask_epi8(vmask0);
  76.         mask1 = _mm_movemask_epi8(vmask1);
  77.         if ((mask0 & mask1) != 0xffff)
  78.         {
  79.             vresult = i;
  80.             break;
  81.         }
  82.     }
  83.     if (vresult > -1)
  84.     {
  85.         result = vresult + linear_OP(&A[vresult], BLOCK_SIZE, key);
  86.     }
  87.     else if (i < n)
  88.     {
  89.         result = i + linear_OP(&A[i], n - i, key);
  90.     }
  91.     return result;
  92. #undef BLOCK_SIZE
  93. #undef VEC_INT_ELEMS
  94. }
  95.  
  96. //Testing code below
  97.  
  98. static const int SIZE = 197;
  99. int n = SIZE;
  100. static union {
  101.     int arr[SIZE + 16];
  102.     __m128i align;
  103. };
  104. static const int TRIES = 1<<23;
  105. static const int SAMPLES = 1<<10;
  106. int tmp[SAMPLES];
  107.  
  108. int main() {
  109.     for (int i = 0; i < n; i++)
  110.         arr[i] = rand();
  111.     std::sort(arr, arr+n);
  112.     //Note: padding by maximal values is required!
  113.     for (int i = n; i < (i+15)/16*16; i++)
  114.         arr[i] = INT_MAX;
  115.  
  116.     for (int t = 0; t < TRIES/10; t++) {
  117.         int q = rand();
  118.         int ans = linear_OP(arr, n, q);
  119.         int res1 = linear_PaulR(arr, n, q);
  120.         int res2 = linear_stgatilov_scalar(arr, n, q);
  121.         int res3 = linear_stgatilov_vec(arr, n, q);
  122.         if (ans != res1 || ans != res2 || ans != res3)
  123.             printf("error (%d): %d, %d, %d, %d\n", q, ans, res1, res2, res3);
  124.     }
  125.  
  126.     for (int i = 0; i < SAMPLES; i++)
  127.         tmp[i] = rand();
  128.  
  129.     {
  130.         int start = clock();
  131.         int check = 0;
  132.         for (int t = 0; t < TRIES; t++) {
  133.             int q = tmp[t & (SAMPLES-1)];
  134.             int res = linear_OP(arr, n, q);
  135.             check += res;
  136.         }
  137.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  138.         printf("[OP]\n");
  139.         printf("Time = %0.3lf (%d)\n", elapsed, check);
  140.     }
  141.  
  142.     {
  143.         int start = clock();
  144.         int check = 0;
  145.         for (int t = 0; t < TRIES; t++) {
  146.             int q = tmp[t & (SAMPLES-1)];
  147.             int res = linear_PaulR(arr, n, q);
  148.             check += res;
  149.         }
  150.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  151.         printf("[Paul R]\n");
  152.         printf("Time = %0.3lf (%d)\n", elapsed, check);
  153.     }
  154.  
  155.     {
  156.         int start = clock();
  157.         int check = 0;
  158.         for (int t = 0; t < TRIES; t++) {
  159.             int q = tmp[t & (SAMPLES-1)];
  160.             int res = linear_stgatilov_vec(arr, n, q);
  161.             check += res;
  162.         }
  163.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  164.         printf("[stgatilov]\n");
  165.         printf("Time = %0.3lf (%d)\n", elapsed, check);
  166.     }
  167.  
  168.     return 0;
  169. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top