zhangsongcui

memeq

Jun 3rd, 2022 (edited)
679
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.53 KB | None | 0 0
  1. #include <immintrin.h>
  2. #include <stdint.h>
  3. #include <stddef.h>
  4.  
  5. bool memeq(const uint8_t* pu1, const uint8_t* pu2, size_t n) {
  6.     if (n >= sizeof(__m512i)) {
  7.         size_t off;
  8.         for (off = 0; off <= n - sizeof(__m512i); off += sizeof(__m512i)) {
  9.             auto ymm1 = _mm512_loadu_si512((const __m512i*)(pu1 + off));
  10.             auto ymm2 = _mm512_loadu_si512((const __m512i*)(pu2 + off));
  11.             auto res = _mm512_cmpeq_epi8_mask(ymm1, ymm2);
  12.             if (_kortestz_mask64_u8(res, res)) return false;
  13.         }
  14.         if (off < n) {
  15.             auto ymm1 = _mm512_loadu_si512((const __m512i*)(pu1 + n - sizeof(__m512i)));
  16.             auto ymm2 = _mm512_loadu_si512((const __m512i*)(pu2 + n - sizeof(__m512i)));
  17.             auto res = _mm512_cmpeq_epi8_mask(ymm1, ymm2);
  18.             if (_kortestz_mask64_u8(res, res)) return false;
  19.         }
  20.     }
  21.     else if (n >= sizeof(__m256i)) {
  22.         size_t off;
  23.         for (off = 0; off <= n - sizeof(__m256i); off += sizeof(__m256i)) {
  24.             auto ymm1 = _mm256_loadu_si256((const __m256i*)(pu1 + off));
  25.             auto ymm2 = _mm256_loadu_si256((const __m256i*)(pu2 + off));
  26.             auto res = _mm256_cmpeq_epi8(ymm1, ymm2);
  27.             if (!_mm256_testc_si256((res), _mm256_set1_epi8(~0))) return false;
  28.         }
  29.         if (off < n) {
  30.             auto ymm1 = _mm256_loadu_si256((const __m256i*)(pu1 + n - sizeof(__m256i)));
  31.             auto ymm2 = _mm256_loadu_si256((const __m256i*)(pu2 + n - sizeof(__m256i)));
  32.             auto res = _mm256_cmpeq_epi8(ymm1, ymm2);
  33.             if (!_mm256_testc_si256((res), _mm256_set1_epi8(~0))) return false;
  34.         }
  35.     }
  36.     else if (n >= sizeof(__m128i)) {
  37.         {
  38.             auto xmm1 = _mm_loadu_si128((const __m128i*)pu1);
  39.             auto xmm2 = _mm_loadu_si128((const __m128i*)pu2);
  40.             auto res = _mm_cmpeq_epi8(xmm1, xmm2);
  41.             if (!_mm_test_all_ones(res)) return false;
  42.         }
  43.         if (n > sizeof(__m128i)) {
  44.             auto xmm1 = _mm_loadu_si128((const __m128i*)(pu1 + n - sizeof(__m128i)));
  45.             auto xmm2 = _mm_loadu_si128((const __m128i*)(pu2 + n - sizeof(__m128i)));
  46.             auto res = _mm_cmpeq_epi8(xmm1, xmm2);
  47.             if (!_mm_test_all_ones(res)) return false;
  48.         }
  49.     }
  50.     else if (n >= sizeof(uint64_t)) {
  51.         {
  52.             auto v1 = *(const uint64_t*)pu1;
  53.             auto v2 = *(const uint64_t*)pu2;
  54.             if (v1 != v2) return false;
  55.         }
  56.         if (n > sizeof(uint64_t)) {
  57.             auto v1 = *(const uint64_t*)(pu1 + n - sizeof(uint64_t));
  58.             auto v2 = *(const uint64_t*)(pu2 + n - sizeof(uint64_t));
  59.             if (v1 != v2) return false;
  60.         }
  61.     }
  62.     else if (n >= sizeof(uint32_t)) {
  63.         {
  64.             auto v1 = *(const uint32_t*)pu1;
  65.             auto v2 = *(const uint32_t*)pu2;
  66.             if (v1 != v2) return false;
  67.         }
  68.         if (n > sizeof(uint32_t)) {
  69.             auto v1 = *(const uint32_t*)(pu1 + n - sizeof(uint32_t));
  70.             auto v2 = *(const uint32_t*)(pu2 + n - sizeof(uint32_t));
  71.             if (v1 != v2) return false;
  72.         }
  73.     }
  74.     else if (n >= sizeof(uint16_t)) {
  75.         {
  76.             auto v1 = *(const uint16_t*)pu1;
  77.             auto v2 = *(const uint16_t*)pu2;
  78.             if (v1 != v2) return false;
  79.         }
  80.         if (n > sizeof(uint16_t)) {
  81.             auto v1 = *(pu1 + 2);
  82.             auto v2 = *(pu2 + 2);
  83.             if (v1 != v2) return false;
  84.         }
  85.     }
  86.     else {
  87.         if (*pu1 != *pu2) return false;
  88.     }
  89.  
  90.     return true;
  91. }
Add Comment
Please, Sign In to add comment