Advertisement
Guest User

LLM AVX

a guest
May 7th, 2025
36
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.82 KB | None | 0 0
  1. // My Version
  2. void clCompute::IMPL_BASE_UnPackBits_Scatter(i64 num, const u32 *src, u8 *dst)
  3. {
  4.   i64 bitIndex = 0;
  5.   for (i32 b = 31; b >= 0; b--)
  6.   {
  7.     u32 mask = 1 << b;
  8.     i64 localBitIndex = bitIndex;
  9.     for (i64 i = 0; i < num; i++, localBitIndex++)
  10.       dst[localBitIndex >> 3] |= ((src[i] & mask) >> b) << (7 - (localBitIndex & 7));
  11.     bitIndex += num;
  12.   }
  13. }
  14.  
  15. // AI Version
  16. void clCompute::IMPL_AVX2_UnPackBits_Scatter(i64 num, const u32 * src, u8 * dst)
  17. {
  18.   memset(dst, 0, num * sizeof(u32));
  19.  
  20.   // Process 8 elements at a time
  21.   const int64_t numChunks = num >> 3;
  22.   const int64_t remainder = num & 3;
  23.   for (int bit = 31; bit >= 0; bit--)
  24.   {
  25.     __m256i bitMask = _mm256_set1_epi32(1 << bit);
  26.     int64_t bitOffset = (31 - bit) * num;
  27.     for (int64_t i = 0; i < numChunks; i++)
  28.     {
  29.       __m256i data = _mm256_loadu_si256((__m256i *) & src[i * 8]);
  30.       __m256i masked = _mm256_and_si256(data, bitMask);
  31.       __m256i shifted = _mm256_srli_epi32(masked, bit);
  32.       uint32_t bitResults =
  33.         ((_mm256_extract_epi32(shifted, 0) & 1) << 7) |
  34.         ((_mm256_extract_epi32(shifted, 1) & 1) << 6) |
  35.         ((_mm256_extract_epi32(shifted, 2) & 1) << 5) |
  36.         ((_mm256_extract_epi32(shifted, 3) & 1) << 4) |
  37.         ((_mm256_extract_epi32(shifted, 4) & 1) << 3) |
  38.         ((_mm256_extract_epi32(shifted, 5) & 1) << 2) |
  39.         ((_mm256_extract_epi32(shifted, 6) & 1) << 1) |
  40.         ((_mm256_extract_epi32(shifted, 7) & 1) << 0);
  41.       dst[(bitOffset + i * 8) >> 3] = (uint8_t)bitResults;
  42.     }
  43.  
  44.     // Process remaining elements
  45.     int64_t offset = numChunks * 8;
  46.     for (int64_t i = 0; i < remainder; i++)
  47.     {
  48.       uint32_t bitValue = (src[offset + i] & (1 << bit)) >> bit;
  49.       dst[(bitOffset + offset + i) >> 3] |= (bitValue << (7 - ((bitOffset + offset + i) & 7)));
  50.     }
  51.   }
  52. }
  53.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement