Advertisement
stgatilov

Vectorized k-th element (median) for small array

Oct 24th, 2015
521
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 17.24 KB | None | 0 0
  1. #include <stdio.h>
  2. #include <assert.h>
  3. #include <time.h>
  4. #include <string.h>
  5. #include <algorithm>
  6. #include <smmintrin.h>
  7. #include <stdint.h>
  8.  
  9. #ifdef _MSC_VER
  10.     #define FORCEINLINE __forceinline
  11.     #define NOINLINE __declspec(noinline)
  12.     #define ASSUME(cond) __assume(cond)
  13. #else
  14.     #define FORCEINLINE __attribute__((always_inline)) inline
  15.     #define NOINLINE __attribute__((noinline))
  16.     #define ASSUME(cond) if (!(cond)) __builtin_unreachable()
  17. #endif
  18.  
  19. //from http://stackoverflow.com/q/2786899/556899
  20. #define MIN(x, y) (x < y ? x : y)
  21. #define MAX(x, y) (x < y ? y : x)
  22. #define CMP_SWAP(x, y) {\
  23.     auto a = MIN(x, y);\
  24.     auto b = MAX(x, y);\
  25.     x = a;\
  26.     y = b;\
  27. }
  28.  
  29. //implementation by STL (note: destructive!)
  30. template<class T> T get_kth_stl (T *arr, int n, int idx) {
  31.     std::nth_element(arr, arr+idx, arr+n);
  32.     return arr[idx];
  33. }
  34.  
  35. //partial sort by bubble sort or something similar
  36. template<class T> T get_kth_bubble_sort (T *arr, size_t n, size_t idx) {
  37.     for (size_t i = 0; i <= idx; i++)
  38.         for (size_t j = i+1; j < n; j++)
  39.             CMP_SWAP(arr[i], arr[j]);
  40.     return arr[idx];
  41. }
  42.  
  43. //partial sort by selection sort
  44. template<class T> T get_kth_selection_sort (T *arr, size_t n, size_t idx) {
  45.     for (size_t i = 0; i <= idx; i++) {
  46.         T minVal = arr[i];
  47.         size_t minIdx = i;
  48.         for (size_t j = i+1; j < n; j++)
  49.             if (minVal > arr[j]) {
  50.                 minVal = arr[j];
  51.                 minIdx = j;
  52.             }
  53.         std::swap(arr[i], arr[minIdx]);
  54.     }
  55.     return arr[idx];
  56. }
  57.  
  58. //counting selection
  59. template<class T> T get_kth_count (const T *arr, int n, int idx) {
  60.     for (size_t i = 0; i < n; i++) {
  61.         auto x = arr[i];
  62.         //count number of "less" and "equal" elements
  63.         int cntLess = 0, cntEq = 0;
  64.         for (size_t j = 0; j < n; j++) {
  65.             cntLess += arr[j] < x;
  66.             cntEq += arr[j] == x;
  67.         }
  68.         //fast range checking from here: http://stackoverflow.com/a/17095534/556899
  69.         if ((unsigned int)(idx - cntLess) < cntEq)
  70.             return x;
  71.     }
  72.     assert(0);
  73.     return -1;
  74. }
  75.  
  76. static FORCEINLINE int reduce_sum_v16(__m128i vals) {
  77.     vals = _mm_hadd_epi16(vals, vals);
  78.     vals = _mm_hadd_epi16(vals, vals);
  79.     vals = _mm_hadd_epi16(vals, vals);
  80.     return _mm_extract_epi16(vals, 0);
  81. }
  82.  
  83. //vectorized counting selection for 16-bit numbers
  84. int16_t get_kth_count_v16 (const int16_t *arr, int n, int idx) {
  85.     for (size_t i = 0; i < n; i++) {
  86.         auto x = arr[i];
  87.         __m128i xx = _mm_set1_epi16(x);
  88.         __m128i cntLess = _mm_setzero_si128();
  89.         __m128i cntEq = _mm_setzero_si128();
  90.         for (size_t j = 0; j < n; j+=8) {
  91.             __m128i vals = _mm_loadu_si128((__m128i*)&arr[j]);
  92.             cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vals, xx));
  93.             cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vals, xx));
  94.         }
  95.         int ncLess = reduce_sum_v16(cntLess);
  96.         int ncEq = reduce_sum_v16(cntEq);
  97.         if ((unsigned int)(idx - ncLess) < ncEq)
  98.             return x;
  99.     }
  100.     assert(0);
  101.     return -1;
  102. }
  103.  
  104. //vectorized counting selection for 24 (or less) 16-bit numbers
  105. FORCEINLINE int16_t get_kth_count_v16_n24 (const int16_t *arr, int n, int idx) {
  106.     assert(n <= 24);
  107.     __m128i a = _mm_loadu_si128((__m128i*)&arr[0]);
  108.     __m128i b = _mm_loadu_si128((__m128i*)&arr[8]);
  109.     __m128i c = _mm_loadu_si128((__m128i*)&arr[16]);
  110.  
  111.     for (size_t i = 0; i < n; i++) {
  112.         auto x = arr[i];
  113.         __m128i xx = _mm_set1_epi16(x);
  114.  
  115.         __m128i cntLess = _mm_cmplt_epi16(a, xx);
  116.         __m128i cntEq = _mm_cmpeq_epi16(a, xx);
  117.         cntLess = _mm_add_epi16(cntLess, _mm_cmplt_epi16(b, xx));
  118.         cntEq = _mm_add_epi16(cntEq, _mm_cmpeq_epi16(b, xx));
  119.         cntLess = _mm_add_epi16(cntLess, _mm_cmplt_epi16(c, xx));
  120.         cntEq = _mm_add_epi16(cntEq, _mm_cmpeq_epi16(c, xx));
  121.  
  122.         int ncLess = reduce_sum_v16(cntLess);
  123.         int ncEq = reduce_sum_v16(cntEq);
  124.         if ((unsigned short)(idx + ncLess) < -(short)ncEq)
  125.             return x;
  126.     }
  127.     ASSUME(false);
  128. }
  129.  
  130. //vectorized counting selection, transposed
  131. int16_t get_kth_count_v16t (const int16_t *arr, int n, int idx) {
  132.     __m128i idxV = _mm_set1_epi16(idx);
  133.     for (size_t i = 0; i < n; i += 8) {
  134.         auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
  135.  
  136.         __m128i cntLess = _mm_setzero_si128();
  137.         __m128i cntEq = _mm_setzero_si128();
  138.         for (size_t j = 0; j < n; j++) {
  139.             __m128i vAll = _mm_set1_epi16(arr[j]);
  140.             cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
  141.             cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
  142.         }
  143.  
  144.         __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
  145.         if (int bm = _mm_movemask_epi8(mask)) {
  146.             for (int t = 0; t < 8; t++)
  147.                 if (bm & (1 << (2*t)))
  148.                     return arr[i + t];
  149.         }
  150.     }
  151.     assert(0);
  152.     return -1;
  153. }
  154.  
  155. //vectorized counting selection, vectorization along both directions
  156. int16_t get_kth_count_v16both (const int16_t *arr, int n, int idx) {
  157.     __m128i idxV = _mm_set1_epi16(idx);
  158.     for (size_t i = 0; i < n; i += 8) {
  159.         auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
  160.  
  161.         __m128i cntLess = _mm_setzero_si128();
  162.         __m128i cntEq = _mm_setzero_si128();
  163.         for (size_t j = 0; j < n; j += 8) {
  164.             __m128i vj = _mm_loadu_si128((__m128i*)&arr[j]);
  165.  
  166.             __m128i a0 = _mm_unpacklo_epi16(vj, vj);
  167.             __m128i a1 = _mm_unpackhi_epi16(vj, vj);
  168.             __m128i b0 = _mm_unpacklo_epi32(a0, a0);
  169.             __m128i b1 = _mm_unpackhi_epi32(a0, a0);
  170.             __m128i b2 = _mm_unpacklo_epi32(a1, a1);
  171.             __m128i b3 = _mm_unpackhi_epi32(a1, a1);
  172.             __m128i vAll0 = _mm_unpacklo_epi32(b0, b0);
  173.             __m128i vAll1 = _mm_unpackhi_epi32(b0, b0);
  174.             __m128i vAll2 = _mm_unpacklo_epi32(b1, b1);
  175.             __m128i vAll3 = _mm_unpackhi_epi32(b1, b1);
  176.             __m128i vAll4 = _mm_unpacklo_epi32(b2, b2);
  177.             __m128i vAll5 = _mm_unpackhi_epi32(b2, b2);
  178.             __m128i vAll6 = _mm_unpacklo_epi32(b3, b3);
  179.             __m128i vAll7 = _mm_unpackhi_epi32(b3, b3);
  180.  
  181.             #define DOIT(k) \
  182.                 cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll##k, xx)); \
  183.                 cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll##k, xx));
  184.             DOIT(0);
  185.             DOIT(1);
  186.             DOIT(2);
  187.             DOIT(3);
  188.             DOIT(4);
  189.             DOIT(5);
  190.             DOIT(6);
  191.             DOIT(7);
  192.             #undef DOIT
  193.         }
  194.  
  195.         __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
  196.         if (int bm = _mm_movemask_epi8(mask)) {
  197.             for (int t = 0; t < 8; t++)
  198.                 if (bm & (1 << (2*t)))
  199.                     return arr[i + t];
  200.         }
  201.     }
  202.     assert(0);
  203.     return -1;
  204. }
  205.  
  206.  
  207. //answer by Morwenn, see http://stackoverflow.com/a/33319374/556899
  208. //updated, now it does not fully sort the array
  209. template<class T> T get_kth_network_sort (T *first, int n, int idx) {
  210.     assert(n == 23);
  211. #ifdef _WIN32
  212.   //that's much better on MSVC
  213.     #define swap_if(x, y, c) CMP_SWAP(x, y)
  214. #else
  215.   //original proposal, better on GCC
  216.     #define swap_if(x, y, c) {\
  217.         auto dx = x;\
  218.         auto dy = y;\
  219.         auto tmp = x = std::min(dx, dy);\
  220.         y ^= dx ^ tmp;\
  221.     }
  222. #endif
  223.     swap_if(first[0u], first[1u], compare);
  224.     swap_if(first[2u], first[3u], compare);
  225.     swap_if(first[4u], first[5u], compare);
  226.     swap_if(first[6u], first[7u], compare);
  227.     swap_if(first[8u], first[9u], compare);
  228.     swap_if(first[10u], first[11u], compare);
  229.     swap_if(first[1u], first[3u], compare);
  230.     swap_if(first[5u], first[7u], compare);
  231.     swap_if(first[9u], first[11u], compare);
  232.     swap_if(first[0u], first[2u], compare);
  233.     swap_if(first[4u], first[6u], compare);
  234.     swap_if(first[8u], first[10u], compare);
  235.     swap_if(first[1u], first[2u], compare);
  236.     swap_if(first[5u], first[6u], compare);
  237.     swap_if(first[9u], first[10u], compare);
  238.     swap_if(first[1u], first[5u], compare);
  239.     swap_if(first[6u], first[10u], compare);
  240.     swap_if(first[5u], first[9u], compare);
  241.     swap_if(first[2u], first[6u], compare);
  242.     swap_if(first[1u], first[5u], compare);
  243.     swap_if(first[6u], first[10u], compare);
  244.     swap_if(first[0u], first[4u], compare);
  245.     swap_if(first[7u], first[11u], compare);
  246.     swap_if(first[3u], first[7u], compare);
  247.     swap_if(first[4u], first[8u], compare);
  248.     swap_if(first[0u], first[4u], compare);
  249.     swap_if(first[7u], first[11u], compare);
  250.     swap_if(first[1u], first[4u], compare);
  251.     swap_if(first[7u], first[10u], compare);
  252.     swap_if(first[3u], first[8u], compare);
  253.     swap_if(first[2u], first[3u], compare);
  254.     swap_if(first[8u], first[9u], compare);
  255.     swap_if(first[2u], first[4u], compare);
  256.     swap_if(first[7u], first[9u], compare);
  257.     swap_if(first[3u], first[5u], compare);
  258.     swap_if(first[6u], first[8u], compare);
  259.     swap_if(first[3u], first[4u], compare);
  260.     swap_if(first[5u], first[6u], compare);
  261.     swap_if(first[7u], first[8u], compare);
  262.     swap_if(first[12u], first[13u], compare);
  263.     swap_if(first[14u], first[15u], compare);
  264.     swap_if(first[16u], first[17u], compare);
  265.     swap_if(first[18u], first[19u], compare);
  266.     swap_if(first[20u], first[21u], compare);
  267.     swap_if(first[13u], first[15u], compare);
  268.     swap_if(first[17u], first[19u], compare);
  269.     swap_if(first[12u], first[14u], compare);
  270.     swap_if(first[16u], first[18u], compare);
  271.     swap_if(first[20u], first[22u], compare);
  272.     swap_if(first[13u], first[14u], compare);
  273.     swap_if(first[17u], first[18u], compare);
  274.     swap_if(first[21u], first[22u], compare);
  275.     swap_if(first[13u], first[17u], compare);
  276.     swap_if(first[18u], first[22u], compare);
  277.     swap_if(first[17u], first[21u], compare);
  278.     swap_if(first[14u], first[18u], compare);
  279.     swap_if(first[13u], first[17u], compare);
  280.     swap_if(first[18u], first[22u], compare);
  281.     swap_if(first[12u], first[16u], compare);
  282.     swap_if(first[15u], first[19u], compare);
  283.     swap_if(first[16u], first[20u], compare);
  284.     swap_if(first[12u], first[16u], compare);
  285.     swap_if(first[13u], first[16u], compare);
  286.     swap_if(first[19u], first[22u], compare);
  287.     swap_if(first[15u], first[20u], compare);
  288.     swap_if(first[14u], first[15u], compare);
  289.     swap_if(first[20u], first[21u], compare);
  290.     swap_if(first[14u], first[16u], compare);
  291.     swap_if(first[19u], first[21u], compare);
  292.     swap_if(first[15u], first[17u], compare);
  293.     swap_if(first[18u], first[20u], compare);
  294.     swap_if(first[15u], first[16u], compare);
  295.     swap_if(first[17u], first[18u], compare);
  296.     swap_if(first[19u], first[20u], compare);
  297.     swap_if(first[0u], first[12u], compare);
  298.     swap_if(first[2u], first[14u], compare);
  299.     swap_if(first[4u], first[16u], compare);
  300.     swap_if(first[6u], first[18u], compare);
  301.     swap_if(first[8u], first[20u], compare);
  302.     swap_if(first[10u], first[22u], compare);
  303.     swap_if(first[2u], first[12u], compare);
  304.     swap_if(first[10u], first[20u], compare);
  305.     swap_if(first[4u], first[12u], compare);
  306.     swap_if(first[6u], first[14u], compare);
  307.     swap_if(first[8u], first[16u], compare);
  308.     swap_if(first[10u], first[18u], compare);
  309.     swap_if(first[8u], first[12u], compare);
  310.     swap_if(first[10u], first[14u], compare);
  311.     swap_if(first[10u], first[12u], compare);
  312.     swap_if(first[1u], first[13u], compare);
  313.     swap_if(first[3u], first[15u], compare);
  314.     swap_if(first[5u], first[17u], compare);
  315.     swap_if(first[7u], first[19u], compare);
  316.     swap_if(first[9u], first[21u], compare);
  317.     swap_if(first[3u], first[13u], compare);
  318.     swap_if(first[11u], first[21u], compare);
  319.     swap_if(first[5u], first[13u], compare);
  320.     swap_if(first[7u], first[15u], compare);
  321.     swap_if(first[9u], first[17u], compare);
  322.     swap_if(first[11u], first[19u], compare);
  323.     swap_if(first[9u], first[13u], compare);
  324.     swap_if(first[11u], first[15u], compare);
  325.     swap_if(first[11u], first[13u], compare);
  326.     swap_if(first[11u], first[12u], compare);
  327.     #undef swap_if
  328.     return first[idx];
  329. }
  330.  
  331.  
  332. //Testing code below
  333.  
  334. typedef short Elem;                 //type of elements
  335. static const int SIZE = 32;         //padded size of array
  336. static const int COUNT = 23;        //number of actual input elements
  337. static const int IDX = COUNT/2;     //sorted index of element being searched
  338.  
  339. static const int CYCLES = 1<<13;
  340. static const int SAMPLES = 1<<10;
  341. Elem input[SAMPLES][SIZE];
  342. Elem work[SAMPLES][SIZE];
  343.  
  344. int main() {
  345.     for (int i = 0; i < SAMPLES; i++) {
  346.         for (int j = 0; j < COUNT; j++)
  347.             input[i][j] = Elem(rand() & 0x7FFF);
  348.         //note: input arrays must be padded with max-values
  349.         for (int j = COUNT; j < SIZE; j++)
  350.             input[i][j] = std::numeric_limits<Elem>::max();
  351.     }
  352.  
  353. /*  for (int i = 0; i < SAMPLES; i++) {
  354.         Elem tmp[SIZE];
  355.  
  356.         memcpy(tmp, input[i], sizeof(tmp));
  357.         int ans = get_kth_stl(tmp, COUNT, IDX);
  358.  
  359.         memcpy(tmp, input[i], sizeof(tmp));
  360.         int res1 = get_kth_count_v16(tmp, COUNT, IDX);
  361.        
  362.         if (ans != res1) {
  363.             printf("error: %d, %d\n", ans, res1);
  364.             for (int j = 0; j < COUNT; j++)
  365.                 printf("%d ", int(input[i][j]));
  366.             printf("\n");
  367.         }
  368.     }*/
  369.  
  370.  
  371.     double memcpyTime;
  372.     {
  373.         int start = clock();
  374.         for (int t = 0; t < CYCLES; t++)
  375.             memcpy(work, input, sizeof(work));
  376.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  377.         printf("memcpy only: %0.3lf\n", elapsed);
  378.         memcpyTime = elapsed;
  379.     }
  380.  
  381.  
  382.     {
  383.         int start = clock();
  384.         int check = 0;
  385.         for (int t = 0; t < CYCLES; t++) {
  386.             memcpy(work, input, sizeof(work));
  387.             for (int i = 0; i < SAMPLES; i++)
  388.                 check += get_kth_stl(work[i], COUNT, IDX);
  389.         }
  390.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  391.         printf("std::nth_element: %0.3lf (%d)\n", elapsed - memcpyTime, check);
  392.     }
  393.  
  394.     {
  395.         int start = clock();
  396.         int check = 0;
  397.         for (int t = 0; t < CYCLES; t++) {
  398.             memcpy(work, input, sizeof(work));
  399.             for (int i = 0; i < SAMPLES; i++)
  400.                 check += get_kth_bubble_sort(work[i], COUNT, IDX);
  401.         }
  402.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  403.         printf("bubble sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
  404.     }
  405.  
  406.     {
  407.         int start = clock();
  408.         int check = 0;
  409.         for (int t = 0; t < CYCLES; t++) {
  410.             memcpy(work, input, sizeof(work));
  411.             for (int i = 0; i < SAMPLES; i++)
  412.                 check += get_kth_selection_sort(work[i], COUNT, IDX);
  413.         }
  414.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  415.         printf("selection sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
  416.     }
  417.  
  418.     {
  419.         int start = clock();
  420.         int check = 0;
  421.         for (int t = 0; t < CYCLES; t++) {
  422.             memcpy(work, input, sizeof(work));
  423.             for (int i = 0; i < SAMPLES; i++)
  424.                 check += get_kth_network_sort(work[i], COUNT, IDX);
  425.         }
  426.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  427.         printf("network sort: %0.3lf (%d)\n", elapsed - memcpyTime, check);
  428.     }
  429.  
  430.     {
  431.         int start = clock();
  432.         int check = 0;
  433.         for (int t = 0; t < CYCLES; t++)
  434.             for (int i = 0; i < SAMPLES; i++)
  435.                 check += get_kth_count(input[i], COUNT, IDX);
  436.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  437.         printf("trivial count: %0.3lf (%d)\n", elapsed, check);
  438.     }
  439.  
  440.     {
  441.         int start = clock();
  442.         int check = 0;
  443.         for (int t = 0; t < CYCLES; t++)
  444.             for (int i = 0; i < SAMPLES; i++)
  445.                 check += get_kth_count_v16(input[i], COUNT, IDX);
  446.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  447.         printf("vectorized count: %0.3lf (%d)\n", elapsed, check);
  448.     }
  449.  
  450.     {
  451.         int start = clock();
  452.         int check = 0;
  453.         for (int t = 0; t < CYCLES; t++)
  454.             for (int i = 0; i < SAMPLES; i++)
  455.                 check += get_kth_count_v16_n24(input[i], COUNT, IDX);
  456.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  457.         printf("vectorized count (n<=24): %0.3lf (%d)\n", elapsed, check);
  458.     }
  459.  
  460.     {
  461.         int start = clock();
  462.         int check = 0;
  463.         for (int t = 0; t < CYCLES; t++)
  464.             for (int i = 0; i < SAMPLES; i++)
  465.                 check += get_kth_count_v16t(input[i], COUNT, IDX);
  466.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  467.         printf("vectorized count (T): %0.3lf (%d)\n", elapsed, check);
  468.     }
  469.  
  470.     {
  471.         int start = clock();
  472.         int check = 0;
  473.         for (int t = 0; t < CYCLES; t++)
  474.             for (int i = 0; i < SAMPLES; i++)
  475.                 check += get_kth_count_v16both(input[i], COUNT, IDX);
  476.         double elapsed = double(clock() - start) / CLOCKS_PER_SEC;
  477.         printf("vectorized count (both): %0.3lf (%d)\n", elapsed, check);
  478.     }
  479.  
  480.     return 0;
  481. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement