Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- void linear(int m, int n, int k, const float* A, int lda, const float* B, int ldb, float* scores, int lds, float* bias)
- {
- int num_threads = n/32;
- omp_set_num_threads(num_threads);
- #pragma omp parallel
- {
- __m256 b0 = _mm256_load_ps(&bias[32*t + 0]);
- __m256 b1 = _mm256_load_ps(&bias[32 * t + 8]);
- __m256 b2 = _mm256_load_ps(&bias[32 * t + 16]);
- __m256 b3 = _mm256_load_ps(&bias[32 * t + 24]);
- for (int i = 0; i < m; i++) {
- int storeOffset = i * lds + t * slice_rows;
- __m256 intSum[4] = { b0, b1, b2, b3 };
- for (int j = 0; j < k-16; j += 16) {
- //standard FMA operations for multiplication of matrix
- }
- _mm256_stream_ps(scores + storeOffset, intSum[0]);
- _mm256_stream_ps(scores + storeOffset + 8, intSum[1]);
- _mm256_stream_ps(scores + storeOffset + 16, intSum[2]);
- _mm256_stream_ps(scores + storeOffset + 24, intSum[3]);
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment