Advertisement
stgatilov

strlen-like vectorization (question by George)

Jul 31st, 2015
445
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 9.70 KB | None | 0 0
  1. #include <stdio.h>
  2. #include <time.h>
  3. #include <nmmintrin.h>
  4. #include <stdint.h>
  5. #include <stdlib.h>
  6. #include <string.h>
  7. #include <intrin.h>
  8. #ifdef _MSC_VER
  9.     #define ALIGN(n) __declspec(align(n))
  10.   int __builtin_ctz (unsigned int x) {
  11.     unsigned long res;
  12.     _BitScanForward(&res, x);
  13.     return res;
  14.   }
  15. #else
  16.     #define ALIGN(n) __attribute__((aligned(n)))
  17. #endif
  18. inline int min(int a, int b) { return a < b ? a : b; }
  19. inline int max(int a, int b) { return a > b ? a : b; }
  20.  
  21.  
  22. //OP's original code
  23. int CommonAsciiLength_original(int len, unsigned char *p) {
  24.   int i = 0;
  25.   while (i < len && p[i] >= 32 && p[i] <= 127)
  26.     i++;
  27.   return i;
  28. }
  29.  
  30. //OP's improved code with 64-bit integers
  31. int CommonAsciiLength_bitmask(int len, unsigned char *p) {
  32.   int i = 0;
  33.   while (i < len - 8) {
  34.     uint64_t bytes = *(uint64_t *)(p + i);
  35.     uint64_t middleBits = bytes & 0x6060606060606060;
  36.     uint64_t highBits = bytes & 0x8080808080808080;
  37.     middleBits |= (middleBits >> 1);
  38.     middleBits &= ~(highBits >> 2);
  39.     if ((middleBits & 0x2020202020202020) != 0x2020202020202020)
  40.         break;
  41.     i += 8;
  42.   }
  43.   while (i < len && p[i] >= 32 && p[i] <= 127)
  44.     i++;
  45.   return i;
  46. }
  47.  
  48. //SSE2 vectorized answer by Pete
  49. int CommonAsciiLength_Pete(int len, unsigned char *p) {
  50.     int i = 0;
  51.     __m128i A;
  52.     __m128i B;
  53.     __m128i C;
  54.     __m128i* pu = (__m128i*)p;    
  55.     int const len16 = len / 16;
  56.     while (i < len16) {
  57.         A = pu[i];
  58.         B = _mm_slli_epi32(A, 1);
  59.         C = _mm_slli_epi32(A, 2);
  60.         B = _mm_or_si128(B, C);
  61.         A = _mm_andnot_si128(A, B);
  62.  
  63.         int mask = _mm_movemask_epi8(A);
  64.         if (mask == 0xFFFF) {
  65.             ++i;
  66.         }
  67.         else {
  68.             if (mask == 0) {
  69.                 return i * 16;
  70.             }
  71.             break;
  72.         }
  73.     }
  74.     i *= 16;
  75.     while (i < len && p[i] >= 32 && p[i] <= 127) {
  76.         i++;
  77.     }
  78.     return i;
  79. }
  80.  
  81. //SSE2 vectorization by stgatilov: no unrolling, per-byte tail
  82. int CommonAsciiLength_sse2(int len, unsigned char *p) {
  83.   const __m128i *ptr = (const __m128i *)p;
  84.   int blocks = len >> 4;
  85.  
  86.   int cnt;
  87.   for (cnt = 0; cnt < blocks; cnt++) {
  88.     __m128i mask = _mm_cmplt_epi8(ptr[cnt], _mm_set1_epi8(32));
  89.     if (_mm_movemask_epi8(mask))
  90.       break;
  91.   }
  92.   int pos;
  93.   for (pos = cnt * 16; pos < len; pos++)
  94.     if (char(p[pos]) < 32)
  95.       return pos;
  96.   return len;
  97. }
  98.  
  99. //SSE2 vectorization by stgatilov: 4x unrolling, per-byte tail
  100. int CommonAsciiLength_sse2_x4(int len, unsigned char *p) {
  101.   const __m128i *ptr = (const __m128i *)p;
  102.   int blocks = (len >> 6) << 2;
  103.  
  104.   int cnt;
  105.   for (cnt = 0; cnt < blocks; cnt += 4) {
  106.     __m128i m0 = _mm_cmplt_epi8(ptr[cnt+0], _mm_set1_epi8(32));
  107.     __m128i m1 = _mm_cmplt_epi8(ptr[cnt+1], _mm_set1_epi8(32));
  108.     __m128i m2 = _mm_cmplt_epi8(ptr[cnt+2], _mm_set1_epi8(32));
  109.     __m128i m3 = _mm_cmplt_epi8(ptr[cnt+3], _mm_set1_epi8(32));
  110.     __m128i mask = _mm_or_si128(_mm_or_si128(m0, m1), _mm_or_si128(m2, m3));
  111.     if (_mm_movemask_epi8(mask))
  112.       break;
  113.   }
  114.   int pos;
  115.   for (pos = cnt * 16; pos < len; pos++)
  116.     if (char(p[pos]) < 32)
  117.       return pos;
  118.   return len;
  119. }
  120.  
  121. //SSE4.2 vectorization by stgatilov: using range string operations
  122. int CommonAsciiLength_sse42(int len, unsigned char *p) {
  123.   const __m128i *ptr = (const __m128i *)p;
  124.   int blocks = (len >> 4);
  125.  
  126.   __m128i range = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 32);
  127.  
  128.   int cnt;
  129.   for (cnt = 0; cnt < blocks; cnt++) {
  130.     int res = _mm_cmpestri(range, 2, ptr[cnt], 16, _SIDD_UBYTE_OPS | _SIDD_CMP_RANGES | _SIDD_MASKED_NEGATIVE_POLARITY | _SIDD_LEAST_SIGNIFICANT);
  131.     if (res < 16)
  132.       return cnt * 16 + res;
  133.   }
  134.   int pos = 16 * cnt;
  135.   int res = _mm_cmpestri(range, 2, ptr[cnt], len - pos, _SIDD_UBYTE_OPS | _SIDD_CMP_RANGES | _SIDD_MASKED_NEGATIVE_POLARITY | _SIDD_LEAST_SIGNIFICANT);
  136.   res = (res == 16 ? len - pos : res);
  137.   return pos + res;
  138. }
  139.  
  140. //SSE2 vectorization by stgatilov: no unrolling, fast BSF tail
  141. int CommonAsciiLength_sse2_end(int len, unsigned char *p) {
  142.   const __m128i *ptr = (const __m128i *)p;
  143.   int blocks = len >> 4;
  144.  
  145.   int cnt;
  146.   for (cnt = 0; cnt < blocks; cnt++) {
  147.     __m128i mask = _mm_cmplt_epi8(ptr[cnt], _mm_set1_epi8(32));
  148.     int val = _mm_movemask_epi8(mask);
  149.     if (val)
  150.       return 16 * cnt + __builtin_ctz(val);
  151.   }
  152.   __m128i mask = _mm_cmplt_epi8(ptr[cnt], _mm_set1_epi8(32));
  153.   int val = _mm_movemask_epi8(mask);
  154.   val |= -(1 << (len - 16 * cnt));
  155.   return 16 * cnt + __builtin_ctz(val);
  156. }
  157.  
  158. //SSE2 vectorization by stgatilov: x4 unrolling, fast SSE2 + BSF tail
  159. int CommonAsciiLength_sse2_x4_end(int len, unsigned char *p) {
  160.   const __m128i *ptr = (const __m128i *)p;
  161.   int blocks = len >> 4;
  162.   int fastBlocks = (blocks >> 2) << 2;
  163.  
  164.   int cnt;
  165.   for (cnt = 0; cnt < fastBlocks; cnt += 4) {
  166.     __m128i m0 = _mm_cmplt_epi8(ptr[cnt+0], _mm_set1_epi8(32));
  167.     __m128i m1 = _mm_cmplt_epi8(ptr[cnt+1], _mm_set1_epi8(32));
  168.     __m128i m2 = _mm_cmplt_epi8(ptr[cnt+2], _mm_set1_epi8(32));
  169.     __m128i m3 = _mm_cmplt_epi8(ptr[cnt+3], _mm_set1_epi8(32));
  170.     __m128i mask = _mm_or_si128(_mm_or_si128(m0, m1), _mm_or_si128(m2, m3));
  171.     if (_mm_movemask_epi8(mask))
  172.       break;
  173.   }
  174.   for (; cnt < blocks; cnt++) {
  175.     __m128i mask = _mm_cmplt_epi8(ptr[cnt], _mm_set1_epi8(32));
  176.     int val = _mm_movemask_epi8(mask);
  177.     if (val)
  178.       return 16 * cnt + __builtin_ctz(val);
  179.   }
  180.   __m128i mask = _mm_cmplt_epi8(ptr[cnt], _mm_set1_epi8(32));
  181.   int val = _mm_movemask_epi8(mask);
  182.   val |= -(1 << (len - 16 * cnt));
  183.   return 16 * cnt + __builtin_ctz(val);
  184. }
  185.  
  186. //========================================== TESTING =================================================
  187.  
  188. static const int SIZE = 1<<20;
  189. unsigned char data[SIZE];             //number of chars in data buffer
  190.  
  191. static const int CNT = 1<<16;         //number of queries generated (all in one buffer)
  192. unsigned char *queryPtr[CNT];
  193. int queryLen[CNT];
  194. static const int64_t WORK = 1LL<<32;  //total amount of work is proportional to this
  195. static const int AVGLEN = 100;        //with probability (1 - 1/A) ASCII character is generated
  196.  
  197. int main() {
  198.   for (int i = 0; i < SIZE; i++) {
  199.     if (rand() % AVGLEN == 0)
  200.       data[i] = rand() % 256;
  201.     else
  202.       data[i] = 32 + rand() % (128-32);
  203.   }
  204.  
  205.   double avgLen = 0.0;
  206.   for (int i = 0; i < CNT; i++) {
  207.     int start = rand() % (SIZE/16) * 16;
  208.     int len = rand() % min(SIZE - start + 1, 1<<12);
  209.     queryPtr[i] = data + start;
  210.     queryLen[i] = len;
  211.  
  212.     #define VERSIONS 8
  213.     int res[VERSIONS];
  214.     res[0] = CommonAsciiLength_original(queryLen[i], queryPtr[i]);
  215.     res[1] = CommonAsciiLength_bitmask(queryLen[i], queryPtr[i]);
  216.     res[2] = CommonAsciiLength_Pete(queryLen[i], queryPtr[i]);
  217.     res[3] = CommonAsciiLength_sse2(queryLen[i], queryPtr[i]);
  218.     res[4] = CommonAsciiLength_sse2_x4(queryLen[i], queryPtr[i]);
  219.     res[5] = CommonAsciiLength_sse42(queryLen[i], queryPtr[i]);
  220.     res[6] = CommonAsciiLength_sse2_end(queryLen[i], queryPtr[i]);
  221.     res[7] = CommonAsciiLength_sse2_x4_end(queryLen[i], queryPtr[i]);
  222.     avgLen += res[0];
  223.  
  224.     bool same = true;
  225.     for (int i = 0; i+1 < VERSIONS; i++)
  226.       if (res[i] != res[i+1])
  227.         same = false;
  228.     if (!same)
  229.       printf("Error: %d %d %d %d %d %d %d %d\n", res[0], res[1], res[2], res[3], res[4], res[5], res[6], res[7]);
  230.   }
  231.   avgLen /= CNT;
  232.  
  233.   printf("All checked.\n");
  234.   printf("Average answer = %0.1lf\n", avgLen);
  235.  
  236.   int CALLS = WORK / max(AVGLEN, 16);
  237.  
  238.   {
  239.     int start = clock();
  240.     int sum = 0;
  241.     for (int i = 0; i < CALLS; i++) {
  242.       int k = i & (CNT-1);
  243.       sum += CommonAsciiLength_original(queryLen[k], queryPtr[k]);
  244.     }
  245.     printf("Time = %0.3lf   (%d) original\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  246.   }
  247.   {
  248.     int start = clock();
  249.     int sum = 0;
  250.     for (int i = 0; i < CALLS; i++) {
  251.       int k = i & (CNT-1);
  252.       sum += CommonAsciiLength_bitmask(queryLen[k], queryPtr[k]);
  253.     }
  254.     printf("Time = %0.3lf   (%d) bitmask\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  255.   }
  256.   {
  257.     int start = clock();
  258.     int sum = 0;
  259.     for (int i = 0; i < CALLS; i++) {
  260.       int k = i & (CNT-1);
  261.       sum += CommonAsciiLength_Pete(queryLen[k], queryPtr[k]);
  262.     }
  263.     printf("Time = %0.3lf   (%d) Pete\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  264.   }
  265.   {
  266.     int start = clock();
  267.     int sum = 0;
  268.     for (int i = 0; i < CALLS; i++) {
  269.       int k = i & (CNT-1);
  270.       sum += CommonAsciiLength_sse2(queryLen[k], queryPtr[k]);
  271.     }
  272.     printf("Time = %0.3lf   (%d) sse2\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  273.   }
  274.   {
  275.     int start = clock();
  276.     int sum = 0;
  277.     for (int i = 0; i < CALLS; i++) {
  278.       int k = i & (CNT-1);
  279.       sum += CommonAsciiLength_sse2_x4(queryLen[k], queryPtr[k]);
  280.     }
  281.     printf("Time = %0.3lf   (%d) sse2_x4\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  282.   }
  283.   {
  284.     int start = clock();
  285.     int sum = 0;
  286.     for (int i = 0; i < CALLS; i++) {
  287.       int k = i & (CNT-1);
  288.       sum += CommonAsciiLength_sse42(queryLen[k], queryPtr[k]);
  289.     }
  290.     printf("Time = %0.3lf   (%d) sse42\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  291.   }
  292.   {
  293.     int start = clock();
  294.     int sum = 0;
  295.     for (int i = 0; i < CALLS; i++) {
  296.       int k = i & (CNT-1);
  297.       sum += CommonAsciiLength_sse2_end(queryLen[k], queryPtr[k]);
  298.     }
  299.     printf("Time = %0.3lf   (%d) sse2_end\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  300.   }
  301.   {
  302.     int start = clock();
  303.     int sum = 0;
  304.     for (int i = 0; i < CALLS; i++) {
  305.       int k = i & (CNT-1);
  306.       sum += CommonAsciiLength_sse2_x4_end(queryLen[k], queryPtr[k]);
  307.     }
  308.     printf("Time = %0.3lf   (%d) sse2_x4_end\n", double(clock() - start) / CLOCKS_PER_SEC, sum);
  309.   }
  310.  
  311.   return 0;
  312. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement