Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <stdio.h>
- #include <assert.h>
- #include <time.h>
- #include <string.h>
- #include <algorithm>
- #include <smmintrin.h>
- #include <stdint.h>
- #ifdef _MSC_VER
- #define FORCEINLINE __forceinline
- #define NOINLINE __declspec(noinline)
- #define ASSUME(cond) __assume(cond)
- #else
- #define FORCEINLINE __attribute__((always_inline)) inline
- #define NOINLINE __attribute__((noinline))
- #define ASSUME(cond) if (!(cond)) __builtin_unreachable()
- #endif
- //from http://stackoverflow.com/q/2786899/556899
- #define MIN(x, y) (x < y ? x : y)
- #define MAX(x, y) (x < y ? y : x)
- #define CMP_SWAP(x, y) {\
- auto a = MIN(x, y);\
- auto b = MAX(x, y);\
- x = a;\
- y = b;\
- }
- //implementation by STL (note: destructive!)
- template<class T> T get_kth_stl (T *arr, int n, int idx) {
- std::nth_element(arr, arr+idx, arr+n);
- return arr[idx];
- }
- //partial sort by bubble sort or something similar
- template<class T> T get_kth_bubble_sort (T *arr, size_t n, size_t idx) {
- for (size_t i = 0; i <= idx; i++)
- for (size_t j = i+1; j < n; j++)
- CMP_SWAP(arr[i], arr[j]);
- return arr[idx];
- }
- //partial sort by selection sort
- template<class T> T get_kth_selection_sort (T *arr, size_t n, size_t idx) {
- for (size_t i = 0; i <= idx; i++) {
- T minVal = arr[i];
- size_t minIdx = i;
- for (size_t j = i+1; j < n; j++)
- if (minVal > arr[j]) {
- minVal = arr[j];
- minIdx = j;
- }
- std::swap(arr[i], arr[minIdx]);
- }
- return arr[idx];
- }
- //counting selection
- template<class T> T get_kth_count (const T *arr, int n, int idx) {
- for (size_t i = 0; i < n; i++) {
- auto x = arr[i];
- //count number of "less" and "equal" elements
- int cntLess = 0, cntEq = 0;
- for (size_t j = 0; j < n; j++) {
- cntLess += arr[j] < x;
- cntEq += arr[j] == x;
- }
- //fast range checking from here: http://stackoverflow.com/a/17095534/556899
- if ((unsigned int)(idx - cntLess) < cntEq)
- return x;
- }
- assert(0);
- return -1;
- }
- static FORCEINLINE int reduce_sum_v16(__m128i vals) {
- vals = _mm_hadd_epi16(vals, vals);
- vals = _mm_hadd_epi16(vals, vals);
- vals = _mm_hadd_epi16(vals, vals);
- return _mm_extract_epi16(vals, 0);
- }
- //vectorized counting selection for 16-bit numbers
- int16_t get_kth_count_v16 (const int16_t *arr, int n, int idx) {
- for (size_t i = 0; i < n; i++) {
- auto x = arr[i];
- __m128i xx = _mm_set1_epi16(x);
- __m128i cntLess = _mm_setzero_si128();
- __m128i cntEq = _mm_setzero_si128();
- for (size_t j = 0; j < n; j+=8) {
- __m128i vals = _mm_loadu_si128((__m128i*)&arr[j]);
- cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vals, xx));
- cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vals, xx));
- }
- int ncLess = reduce_sum_v16(cntLess);
- int ncEq = reduce_sum_v16(cntEq);
- if ((unsigned int)(idx - ncLess) < ncEq)
- return x;
- }
- assert(0);
- return -1;
- }
- //vectorized counting selection for 24 (or less) 16-bit numbers
- FORCEINLINE int16_t get_kth_count_v16_n24 (const int16_t *arr, int n, int idx) {
- assert(n <= 24);
- __m128i a = _mm_loadu_si128((__m128i*)&arr[0]);
- __m128i b = _mm_loadu_si128((__m128i*)&arr[8]);
- __m128i c = _mm_loadu_si128((__m128i*)&arr[16]);
- for (size_t i = 0; i < n; i++) {
- auto x = arr[i];
- __m128i xx = _mm_set1_epi16(x);
- __m128i cntLess = _mm_cmplt_epi16(a, xx);
- __m128i cntEq = _mm_cmpeq_epi16(a, xx);
- cntLess = _mm_add_epi16(cntLess, _mm_cmplt_epi16(b, xx));
- cntEq = _mm_add_epi16(cntEq, _mm_cmpeq_epi16(b, xx));
- cntLess = _mm_add_epi16(cntLess, _mm_cmplt_epi16(c, xx));
- cntEq = _mm_add_epi16(cntEq, _mm_cmpeq_epi16(c, xx));
- int ncLess = reduce_sum_v16(cntLess);
- int ncEq = reduce_sum_v16(cntEq);
- if ((unsigned short)(idx + ncLess) < -(short)ncEq)
- return x;
- }
- ASSUME(false);
- }
- //vectorized counting selection, transposed
- int16_t get_kth_count_v16t (const int16_t *arr, int n, int idx) {
- __m128i idxV = _mm_set1_epi16(idx);
- for (size_t i = 0; i < n; i += 8) {
- auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
- __m128i cntLess = _mm_setzero_si128();
- __m128i cntEq = _mm_setzero_si128();
- for (size_t j = 0; j < n; j++) {
- __m128i vAll = _mm_set1_epi16(arr[j]);
- cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
- cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
- }
- __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
- if (int bm = _mm_movemask_epi8(mask)) {
- for (int t = 0; t < 8; t++)
- if (bm & (1 << (2*t)))
- return arr[i + t];
- }
- }
- assert(0);
- return -1;
- }
- //vectorized counting selection, vectorization along both directions
- int16_t get_kth_count_v16both (const int16_t *arr, int n, int idx) {
- __m128i idxV = _mm_set1_epi16(idx);
- for (size_t i = 0; i < n; i += 8) {
- auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
- __m128i cntLess = _mm_setzero_si128();
- __m128i cntEq = _mm_setzero_si128();
- for (size_t j = 0; j < n; j += 8) {
- __m128i vj = _mm_loadu_si128((__m128i*)&arr[j]);
- __m128i a0 = _mm_unpacklo_epi16(vj, vj);
- __m128i a1 = _mm_unpackhi_epi16(vj, vj);
- __m128i b0 = _mm_unpacklo_epi32(a0, a0);
- __m128i b1 = _mm_unpackhi_epi32(a0, a0);
- __m128i b2 = _mm_unpacklo_epi32(a1, a1);
- __m128i b3 = _mm_unpackhi_epi32(a1, a1);
- __m128i vAll0 = _mm_unpacklo_epi32(b0, b0);
- __m128i vAll1 = _mm_unpackhi_epi32(b0, b0);
- __m128i vAll2 = _mm_unpacklo_epi32(b1, b1);
- __m128i vAll3 = _mm_unpackhi_epi32(b1, b1);
- __m128i vAll4 = _mm_unpacklo_epi32(b2, b2);
- __m128i vAll5 = _mm_unpackhi_epi32(b2, b2);
- __m128i vAll6 = _mm_unpacklo_epi32(b3, b3);
- __m128i vAll7 = _mm_unpackhi_epi32(b3, b3);
- #define DOIT(k) \
- cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll##k, xx)); \
- cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll##k, xx));
- DOIT(0);
- DOIT(1);
- DOIT(2);
- DOIT(3);
- DOIT(4);
- DOIT(5);
- DOIT(6);
- DOIT(7);
- #undef DOIT
- }
- __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
- if (int bm = _mm_movemask_epi8(mask)) {
- for (int t = 0; t < 8; t++)
- if (bm & (1 << (2*t)))
- return arr[i + t];
- }
- }
- assert(0);
- return -1;
- }
- //answer by Morwenn, see http://stackoverflow.com/a/33319374/556899
- //updated, now it does not fully sort the array
- template<class T> T get_kth_network_sort (T *first, int n, int idx) {
- assert(n == 23);
- #ifdef _WIN32
- //that's much better on MSVC
- #define swap_if(x, y, c) CMP_SWAP(x, y)
- #else
- //original proposal, better on GCC
- #define swap_if(x, y, c) {\
- auto dx = x;\
- auto dy = y;\
- auto tmp = x = std::min(dx, dy);\
- y ^= dx ^ tmp;\
- }
- #endif
- swap_if(first[0u], first[1u], compare);
- swap_if(first[2u], first[3u], compare);
- swap_if(first[4u], first[5u], compare);
- swap_if(first[6u], first[7u], compare);
- swap_if(first[8u], first[9u], compare);
- swap_if(first[10u], first[11u], compare);
- swap_if(first[1u], first[3u], compare);
- swap_if(first[5u], first[7u], compare);
- swap_if(first[9u], first[11u], compare);
- swap_if(first[0u], first[2u], compare);
- swap_if(first[4u], first[6u], compare);
- swap_if(first[8u], first[10u], compare);
- swap_if(first[1u], first[2u], compare);
- swap_if(first[5u], first[6u], compare);
- swap_if(first[9u], first[10u], compare);
- swap_if(first[1u], first[5u], compare);
- swap_if(first[6u], first[10u], compare);
- swap_if(first[5u], first[9u], compare);
- swap_if(first[2u], first[6u], compare);
- swap_if(first[1u], first[5u], compare);
- swap_if(first[6u], first[10u], compare);
- swap_if(first[0u], first[4u], compare);
- swap_if(first[7u], first[11u], compare);
- swap_if(first[3u], first[7u], compare);
- swap_if(first[4u], first[8u], compare);
- swap_if(first[0u], first[4u], compare);
- swap_if(first[7u], first[11u], compare);
- swap_if(first[1u], first[4u], compare);
- swap_if(first[7u], first[10u], compare);
- swap_if(first[3u], first[8u], compare);
- swap_if(first[2u], first[3u], compare);
- swap_if(first[8u], first[9u], compare);
- swap_if(first[2u], first[4u], compare);
- swap_if(first[7u], first[9u], compare);
- swap_if(first[3u], first[5u], compare);
- swap_if(first[6u], first[8u], compare);
- swap_if(first[3u], first[4u], compare);
- swap_if(first[5u], first[6u], compare);
- swap_if(first[7u], first[8u], compare);
- swap_if(first[12u], first[13u], compare);
- swap_if(first[14u], first[15u], compare);
- swap_if(first[16u], first[17u], compare);
- swap_if(first[18u], first[19u], compare);
- swap_if(first[20u], first[21u], compare);
- swap_if(first[13u], first[15u], compare);
- swap_if(first[17u], first[19u], compare);
- swap_if(first[12u], first[14u], compare);
- swap_if(first[16u], first[18u], compare);
- swap_if(first[20u], first[22u], compare);
- swap_if(first[13u], first[14u], compare);
- swap_if(first[17u], first[18u], compare);
- swap_if(first[21u], first[22u], compare);
- swap_if(first[13u], first[17u], compare);
- swap_if(first[18u], first[22u], compare);
- swap_if(first[17u], first[21u], compare);
- swap_if(first[14u], first[18u], compare);
- swap_if(first[13u], first[17u], compare);
- swap_if(first[18u], first[22u], compare);
- swap_if(first[12u], first[16u], compare);
- swap_if(first[15u], first[19u], compare);
- swap_if(first[16u], first[20u], compare);
- swap_if(first[12u], first[16u], compare);
- swap_if(first[13u], first[16u], compare);
- swap_if(first[19u], first[22u], compare);
- swap_if(first[15u], first[20u], compare);
- swap_if(first[14u], first[15u], compare);
- swap_if(first[20u], first[21u], compare);
- swap_if(first[14u], first[16u], compare);
- swap_if(first[19u], first[21u], compare);
- swap_if(first[15u], first[17u], compare);
- swap_if(first[18u], first[20u], compare);
- swap_if(first[15u], first[16u], compare);
- swap_if(first[17u], first[18u], compare);
- swap_if(first[19u], first[20u], compare);
- swap_if(first[0u], first[12u], compare);
- swap_if(first[2u], first[14u], compare);
- swap_if(first[4u], first[16u], compare);
- swap_if(first[6u], first[18u], compare);
- swap_if(first[8u], first[20u], compare);
- swap_if(first[10u], first[22u], compare);
- swap_if(first[2u], first[12u], compare);
- swap_if(first[10u], first[20u], compare);
- swap_if(first[4u], first[12u], compare);
- swap_if(first[6u], first[14u], compare);
- swap_if(first[8u], first[16u], compare);
- swap_if(first[10u], first[18u], compare);
- swap_if(first[8u], first[12u], compare);
- swap_if(first[10u], first[14u], compare);
- swap_if(first[10u], first[12u], compare);
- swap_if(first[1u], first[13u], compare);
- swap_if(first[3u], first[15u], compare);
- swap_if(first[5u], first[17u], compare);
- swap_if(first[7u], first[19u], compare);
- swap_if(first[9u], first[21u], compare);
- swap_if(first[3u], first[13u], compare);
- swap_if(first[11u], first[21u], compare);
- swap_if(first[5u], first[13u], compare);
- swap_if(first[7u], first[15u], compare);
- swap_if(first[9u], first[17u], compare);
- swap_if(first[11u], first[19u], compare);
- swap_if(first[9u], first[13u], compare);
- swap_if(first[11u], first[15u], compare);
- swap_if(first[11u], first[13u], compare);
- swap_if(first[11u], first[12u], compare);
- #undef swap_if
- return first[idx];
- }
- //Testing code below
- typedef short Elem; //type of elements
- static const int SIZE = 32; //padded size of array
- static const int COUNT = 23; //number of actual input elements
- static const int IDX = COUNT/2; //sorted index of element being searched
- static const int CYCLES = 1<<13;
- static const int SAMPLES = 1<<10;
- Elem input[SAMPLES][SIZE];
- Elem work[SAMPLES][SIZE];
- int main() {
- for (int i = 0; i < SAMPLES; i++) {
- for (int j = 0; j < COUNT; j++)
- input[i][j] = Elem(rand() & 0x7FFF);
- //note: input arrays must be padded with max-values
- for (int j = COUNT; j < SIZE; j++)
- input[i][j] = std::numeric_limits<Elem>::max();
- }
- /* for (int i = 0; i < SAMPLES; i++) {
- Elem tmp[SIZE];
- memcpy(tmp, input[i], sizeof(tmp));
- int ans = get_kth_stl(tmp, COUNT, IDX);
- memcpy(tmp, input[i], sizeof(tmp));
- int res1 = get_kth_count_v16(tmp, COUNT, IDX);
- if (ans != res1) {
- printf("error: %d, %d\n", ans, res1);
- for (int j = 0; j < COUNT; j++)
- printf("%d ", int(input[i][j]));
- printf("\n");
- }
- }*/
- double memcpyTime;
- {
- int start = clock();
- for (int t = 0; t < CYCLES; t++)
- memcpy(work, input, sizeof(work));
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("memcpy only: %0.3lf\n", elapsed);
- memcpyTime = elapsed;
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++) {
- memcpy(work, input, sizeof(work));
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_stl(work[i], COUNT, IDX);
- }
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("std::nth_element: %0.3lf (%d)\n", elapsed - memcpyTime, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++) {
- memcpy(work, input, sizeof(work));
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_bubble_sort(work[i], COUNT, IDX);
- }
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("bubble sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++) {
- memcpy(work, input, sizeof(work));
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_selection_sort(work[i], COUNT, IDX);
- }
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("selection sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++) {
- memcpy(work, input, sizeof(work));
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_network_sort(work[i], COUNT, IDX);
- }
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("network sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++)
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_count(input[i], COUNT, IDX);
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("trivial count: %0.3lf (%d)\n", elapsed, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++)
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_count_v16(input[i], COUNT, IDX);
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("vectorized count: %0.3lf (%d)\n", elapsed, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++)
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_count_v16_n24(input[i], COUNT, IDX);
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("vectorized count (n<=24): %0.3lf (%d)\n", elapsed, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++)
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_count_v16t(input[i], COUNT, IDX);
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("vectorized count (T): %0.3lf (%d)\n", elapsed, check);
- }
- {
- int start = clock();
- int check = 0;
- for (int t = 0; t < CYCLES; t++)
- for (int i = 0; i < SAMPLES; i++)
- check += get_kth_count_v16both(input[i], COUNT, IDX);
- double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
- printf("vectorized count (both): %0.3lf (%d)\n", elapsed, check);
- }
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement