Advertisement
Guest User

Untitled

a guest
Jul 22nd, 2021
27
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.75 KB | None | 0 0
  1. alignas(CacheLineSize) std::uint16_t nnzInputIndices[InputDimensions+16];
  2. IndexType numNnzInputIndices = 0;
  3. non_zero_indices(input, nnzInputIndices, numNnzInputIndices);
  4.  
  5. constexpr IndexType ChunkSize = 4;
  6. constexpr IndexType NumChunks = 8;
  7. constexpr IndexType TileSize = NumChunks * ChunkSize;
  8. static_assert(PaddedOutputDimensions % TileSize == 0);
  9. constexpr IndexType NumTiles = PaddedOutputDimensions / TileSize;
  10.  
  11. const __m128i ones = _mm_set1_epi16(1);
  12.  
  13. while (numNnzInputIndices % 4 != 0)
  14. nnzInputIndices[numNnzInputIndices++] = InputDimensions;
  15.  
  16. __m128i acc[NumChunks];
  17.  
  18. for (IndexType i = 0; i < NumTiles; ++i)
  19. {
  20. auto biasesTile = reinterpret_cast<const __m128i*>(&biases[i * TileSize]);
  21. auto outputTile = reinterpret_cast< __m128i*>(&output[i * TileSize]);
  22.  
  23. for (IndexType k = 0; k < NumChunks; ++k)
  24. acc[k] = biasesTile[k];
  25.  
  26. for (IndexType j = 0; j < numNnzInputIndices; j += 4)
  27. {
  28. const auto mul0 = _mm_set1_epi16(input[nnzInputIndices[j+0]] | (input[nnzInputIndices[j+1]] << 8));
  29. const auto mul2 = _mm_set1_epi16(input[nnzInputIndices[j+2]] | (input[nnzInputIndices[j+3]] << 8));
  30. const auto col0 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+0] * PaddedOutputDimensions + i * TileSize]);
  31. const auto col1 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+1] * PaddedOutputDimensions + i * TileSize]);
  32. const auto col2 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+2] * PaddedOutputDimensions + i * TileSize]);
  33. const auto col3 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+3] * PaddedOutputDimensions + i * TileSize]);
  34. for (IndexType k = 0; k < NumChunks / 4; ++k)
  35. {
  36. __m128i prod0 = _mm_maddubs_epi16(mul0, _mm_unpacklo_epi8(col0[k], col1[k]));
  37. __m128i prod1 = _mm_maddubs_epi16(mul0, _mm_unpackhi_epi8(col0[k], col1[k]));
  38. __m128i prod2 = _mm_maddubs_epi16(mul2, _mm_unpacklo_epi8(col2[k], col3[k]));
  39. __m128i prod3 = _mm_maddubs_epi16(mul2, _mm_unpackhi_epi8(col2[k], col3[k]));
  40. acc[k*4 + 0] = _mm_add_epi32(acc[k*4 + 0], _mm_madd_epi16(ones, _mm_unpacklo_epi16(prod0, prod2)));
  41. acc[k*4 + 1] = _mm_add_epi32(acc[k*4 + 1], _mm_madd_epi16(ones, _mm_unpackhi_epi16(prod0, prod2)));
  42. acc[k*4 + 2] = _mm_add_epi32(acc[k*4 + 2], _mm_madd_epi16(ones, _mm_unpacklo_epi16(prod1, prod3)));
  43. acc[k*4 + 3] = _mm_add_epi32(acc[k*4 + 3], _mm_madd_epi16(ones, _mm_unpackhi_epi16(prod1, prod3)));
  44. }
  45. }
  46.  
  47. for (IndexType k = 0; k < NumChunks; ++k)
  48. outputTile[k] = acc[k];
  49. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement