Advertisement
zhangsongcui

char_count

Feb 3rd, 2019
270
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.13 KB | None | 0 0
  1. #include <immintrin.h>
  2. #include <cstdint>
  3. #include <string_view>
  4. #include <cstdio>
  5.  
  6. using namespace std::literals;
  7.  
  8. #define NOINLINE __attribute__((__noinline__))
  9.  
  10. inline static __m256i _mm256_shift_left(__m256i A, const uint32_t N) {
  11.     switch (N) {
  12. #define CASE_LT16(N) case N:\
  13. return _mm256_alignr_epi8(A, _mm256_permute2x128_si256(A, A, _MM_SHUFFLE(0, 0, 2, 0)), 16 - N);
  14.     CASE_LT16(1)
  15.     CASE_LT16(2)
  16.     CASE_LT16(3)
  17.     CASE_LT16(4)
  18.     CASE_LT16(5)
  19.     CASE_LT16(6)
  20.     CASE_LT16(7)
  21.     CASE_LT16(8)
  22.     CASE_LT16(9)
  23.     CASE_LT16(10)
  24.     CASE_LT16(11)
  25.     CASE_LT16(12)
  26.     CASE_LT16(13)
  27.     CASE_LT16(14)
  28.     CASE_LT16(15)
  29. #undef CASE_LT16
  30.  
  31.     case 16:
  32.         return _mm256_permute2x128_si256(A, A, _MM_SHUFFLE(0, 0, 2, 0));
  33.  
  34. #define CASE_GT16(N) case N:\
  35. return _mm256_slli_si256(_mm256_permute2x128_si256(A, A, _MM_SHUFFLE(0, 0, 2, 0)), N - 16);
  36.     CASE_GT16(17)
  37.     CASE_GT16(18)
  38.     CASE_GT16(19)
  39.     CASE_GT16(20)
  40.     CASE_GT16(21)
  41.     CASE_GT16(22)
  42.     CASE_GT16(23)
  43.     CASE_GT16(24)
  44.     CASE_GT16(25)
  45.     CASE_GT16(26)
  46.     CASE_GT16(27)
  47.     CASE_GT16(28)
  48.     CASE_GT16(29)
  49.     CASE_GT16(30)
  50.     CASE_GT16(31)
  51. #undef CASE_GT16
  52.  
  53.     default:
  54.         __builtin_unreachable();
  55.     }
  56. }
  57.  
  58. inline static __m256i _mm256_shift_right(__m256i A, const uint32_t N) {
  59.     switch (N) {
  60. #define CASE_LT16(N) case N:\
  61. return _mm256_alignr_epi8(_mm256_permute2x128_si256(A, A, _MM_SHUFFLE(2, 0, 0, 1)), A, N);
  62.     CASE_LT16(1)
  63.     CASE_LT16(2)
  64.     CASE_LT16(3)
  65.     CASE_LT16(4)
  66.     CASE_LT16(5)
  67.     CASE_LT16(6)
  68.     CASE_LT16(7)
  69.     CASE_LT16(8)
  70.     CASE_LT16(9)
  71.     CASE_LT16(10)
  72.     CASE_LT16(11)
  73.     CASE_LT16(12)
  74.     CASE_LT16(13)
  75.     CASE_LT16(14)
  76.     CASE_LT16(15)
  77. #undef CASE_LT16
  78.  
  79.     case 16:
  80.         return _mm256_permute2x128_si256(A, A, _MM_SHUFFLE(2, 0, 0, 1));
  81.  
  82. #define CASE_GT16(N) case N:\
  83. return _mm256_srli_si256(_mm256_permute2x128_si256(A, A, _MM_SHUFFLE(2, 0, 0, 1)), N - 16);
  84.     CASE_GT16(17)
  85.     CASE_GT16(18)
  86.     CASE_GT16(19)
  87.     CASE_GT16(20)
  88.     CASE_GT16(21)
  89.     CASE_GT16(22)
  90.     CASE_GT16(23)
  91.     CASE_GT16(24)
  92.     CASE_GT16(25)
  93.     CASE_GT16(26)
  94.     CASE_GT16(27)
  95.     CASE_GT16(28)
  96.     CASE_GT16(29)
  97.     CASE_GT16(30)
  98.     CASE_GT16(31)
  99. #undef CASE_GT16
  100.  
  101.     default:
  102.         __builtin_unreachable();
  103.     }
  104. }
  105.  
  106. NOINLINE uint64_t char_count(std::string_view sv, char c = ' ') {
  107.     uint64_t len = (uint32_t)sv.length(), count = 0;
  108.     const char *p = sv.data();
  109.    
  110.     if (len < sizeof(__m256i)) {
  111.         while (len--) {
  112.             count += *p++ == c;
  113.         }
  114.         return count;
  115.     }
  116.     auto vc = _mm256_set1_epi8(c);
  117.     if (auto align = (size_t)(p) % sizeof(__m256i)) {
  118.         p -= align;
  119.         auto buf = _mm256_load_si256((__m256i *)p);
  120.         buf = _mm256_shift_right(buf, (uint32_t)align);
  121.         auto result = _mm256_cmpeq_epi8(buf, vc);
  122.         auto mask = _mm256_movemask_epi8(result);
  123.         count += _mm_popcnt_u32(mask);
  124.         p += sizeof(__m256i);
  125.         len -= sizeof(__m256i) - align;
  126.     }
  127.     while (len >= sizeof(__m256i) * 2) {
  128.         auto buf = _mm256_load_si256((__m256i *)p);
  129.         auto buf1 = _mm256_load_si256(((__m256i *)p) + 1);
  130.         auto result = _mm256_cmpeq_epi8(buf, vc);
  131.         auto result1 = _mm256_cmpeq_epi8(buf1, vc);
  132.         uint64_t mask = (uint32_t)_mm256_movemask_epi8(result);
  133.         uint64_t mask1 = (uint32_t)_mm256_movemask_epi8(result1);
  134.         count += _mm_popcnt_u64((mask << 32) | mask1);
  135.         p += sizeof(__m256i) * 2;
  136.         len -= sizeof(__m256i) * 2;
  137.     }
  138.     if (len >= sizeof(__m256i)) {
  139.         auto buf = _mm256_load_si256((__m256i *)p);
  140.         auto result = _mm256_cmpeq_epi8(buf, vc);
  141.         auto mask = _mm256_movemask_epi8(result);
  142.         count += _mm_popcnt_u32(mask);
  143.         p += sizeof(__m256i);
  144.         len -= sizeof(__m256i);
  145.     }
  146.     if (len) {
  147.         auto buf = _mm256_load_si256((__m256i *)p);
  148.         buf = _mm256_shift_left(buf, (uint32_t)(sizeof(__m256i) - len));
  149.         auto result = _mm256_cmpeq_epi8(buf, vc);
  150.         auto mask = _mm256_movemask_epi8(result);
  151.         count += _mm_popcnt_u32(mask);
  152.     }
  153.     return count;
  154. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement