Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- alignas(CacheLineSize) std::uint16_t nnzInputIndices[InputDimensions+16];
- IndexType numNnzInputIndices = 0;
- non_zero_indices(input, nnzInputIndices, numNnzInputIndices);
- constexpr IndexType ChunkSize = 4;
- constexpr IndexType NumChunks = 8;
- constexpr IndexType TileSize = NumChunks * ChunkSize;
- static_assert(PaddedOutputDimensions % TileSize == 0);
- constexpr IndexType NumTiles = PaddedOutputDimensions / TileSize;
- const __m128i ones = _mm_set1_epi16(1);
- while (numNnzInputIndices % 4 != 0)
- nnzInputIndices[numNnzInputIndices++] = InputDimensions;
- __m128i acc[NumChunks];
- for (IndexType i = 0; i < NumTiles; ++i)
- {
- auto biasesTile = reinterpret_cast<const __m128i*>(&biases[i * TileSize]);
- auto outputTile = reinterpret_cast< __m128i*>(&output[i * TileSize]);
- for (IndexType k = 0; k < NumChunks; ++k)
- acc[k] = biasesTile[k];
- for (IndexType j = 0; j < numNnzInputIndices; j += 4)
- {
- const auto mul0 = _mm_set1_epi16(input[nnzInputIndices[j+0]] | (input[nnzInputIndices[j+1]] << 8));
- const auto mul2 = _mm_set1_epi16(input[nnzInputIndices[j+2]] | (input[nnzInputIndices[j+3]] << 8));
- const auto col0 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+0] * PaddedOutputDimensions + i * TileSize]);
- const auto col1 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+1] * PaddedOutputDimensions + i * TileSize]);
- const auto col2 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+2] * PaddedOutputDimensions + i * TileSize]);
- const auto col3 = reinterpret_cast<const __m128i*>(&weights[nnzInputIndices[j+3] * PaddedOutputDimensions + i * TileSize]);
- for (IndexType k = 0; k < NumChunks / 4; ++k)
- {
- __m128i prod0 = _mm_maddubs_epi16(mul0, _mm_unpacklo_epi8(col0[k], col1[k]));
- __m128i prod1 = _mm_maddubs_epi16(mul0, _mm_unpackhi_epi8(col0[k], col1[k]));
- __m128i prod2 = _mm_maddubs_epi16(mul2, _mm_unpacklo_epi8(col2[k], col3[k]));
- __m128i prod3 = _mm_maddubs_epi16(mul2, _mm_unpackhi_epi8(col2[k], col3[k]));
- acc[k*4 + 0] = _mm_add_epi32(acc[k*4 + 0], _mm_madd_epi16(ones, _mm_unpacklo_epi16(prod0, prod2)));
- acc[k*4 + 1] = _mm_add_epi32(acc[k*4 + 1], _mm_madd_epi16(ones, _mm_unpackhi_epi16(prod0, prod2)));
- acc[k*4 + 2] = _mm_add_epi32(acc[k*4 + 2], _mm_madd_epi16(ones, _mm_unpacklo_epi16(prod1, prod3)));
- acc[k*4 + 3] = _mm_add_epi32(acc[k*4 + 3], _mm_madd_epi16(ones, _mm_unpackhi_epi16(prod1, prod3)));
- }
- }
- for (IndexType k = 0; k < NumChunks; ++k)
- outputTile[k] = acc[k];
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement