Advertisement
stgatilov

Vectorize getting all permutations of string

Jan 8th, 2017
579
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.09 KB | None | 0 0
  1. #include <tmmintrin.h>
  2. #include <assert.h>
  3. #include <stdio.h>
  4. #include <stdint.h>
  5. #include <time.h>
  6. #include <string.h>
  7. #include <algorithm>
  8.  
  9. //string length
  10. const int LEN = 12;
  11.  
  12. //=========================== original code =========================
  13.  
  14. void permutation_original(int k, char *s) {
  15.   for (size_t j = 1; j < LEN; ++j) {
  16.     std::swap(s[k % (j + 1)], s[j]);
  17.     k = k / (j + 1);
  18.   }
  19. }
  20.  
  21. void iterate_original(const char *s, int numK) {
  22.     assert(strlen(s) == LEN);
  23.  
  24.     uint8_t checksum[16] = {0};
  25.     char s_copy[LEN + 1] = {0};
  26.  
  27.     for (int k = 0; k < numK; k++) {
  28.         memcpy(s_copy, s, LEN);
  29.         permutation_original(k, s_copy);
  30.         //note: checksum is computed via 32-bit addition
  31.         for (int i = 0; i < 4; i++)
  32.             ((uint32_t*)(void*)checksum)[i] += ((uint32_t*)(void*)s_copy)[i];
  33.     }
  34.  
  35.     printf("k < %d: ", numK);
  36.     for (int i = 0; i < LEN; i++)
  37.         printf(" %02X", (uint32_t)checksum[i]);
  38.     printf("\n");
  39. }
  40.  
  41. //=========================== SSE solution =========================
  42.  
  43. //number of digits in mixed radix system
  44. const int DIGITS = 3;
  45. //the bases in mixed radix system
  46. const int BASE0 = 1*2*3*4*5*6;  //720
  47. const int BASE1 = 7*8*9;        //504
  48. const int BASE2 = 10*11*12;     //1320
  49. //how [1..12] range is broken into digits
  50. const int DIVISION[DIGITS+1] = {0, 6, 9, 12};
  51.  
  52. //lookup tables (padde with zeros)
  53. __m128i mask0[BASE0 + 4];
  54. __m128i mask1[BASE1 + 4];
  55. __m128i mask2[BASE2 + 4];
  56.  
  57. __m128i* get_mask_ptr(int idx) {
  58.     if (idx == 0) return mask0;
  59.     if (idx == 1) return mask1;
  60.     if (idx == 2) return mask2;
  61.     assert(0); return 0;
  62. }
  63. int get_mask_cnt(int idx) {
  64.     if (idx == 0) return BASE0;
  65.     if (idx == 1) return BASE1;
  66.     if (idx == 2) return BASE2;
  67.     assert(0); return 0;
  68. }
  69. void precompute_mask_tables() {
  70.     assert(DIVISION[0] == 0);
  71.     assert(DIVISION[DIGITS] == LEN);
  72.  
  73.     for (int d = 0; d < DIGITS; d++) {
  74.         int minJ = DIVISION[d] + 1;
  75.         int maxJ = DIVISION[d + 1];
  76.         assert(maxJ >= minJ);
  77.  
  78.         //check that BASEd constant is correct
  79.         int cnt = 1;
  80.         for (int j = minJ; j <= maxJ; j++)
  81.             cnt *= j;
  82.         assert(cnt == get_mask_cnt(d));
  83.  
  84.         //clear with -1 masks (which produces zero output)
  85.         __m128i *lut = get_mask_ptr(d);
  86.         memset(lut, -1, (cnt+4) * sizeof(lut[0]));
  87.  
  88.         //iterate over all possible values of digit
  89.         for (int k = 0; k < cnt; k++) {
  90.             //generate identity permutation
  91.             char order[16];
  92.             for (int i = 0; i < LEN; i++)
  93.                 order[i] = i;
  94.             //apply swaps encoded in 'k'
  95.             int kk = k;
  96.             for (int j = minJ; j <= maxJ; j++) {
  97.                 int pos = kk % j;
  98.                 kk /= j;
  99.                 std::swap(order[pos], order[j-1]);
  100.             }
  101.             assert(kk == 0);
  102.             //store the order into mask
  103.             lut[k] = _mm_loadu_si128((__m128i*)order);
  104.         }
  105.     }
  106. }
  107.  
  108. __m128i permutation_fast(int k, __m128i s) {
  109.     s = _mm_shuffle_epi8(s, mask0[k % BASE0]); k /= BASE0;
  110.     s = _mm_shuffle_epi8(s, mask1[k % BASE1]); k /= BASE1;
  111.     s = _mm_shuffle_epi8(s, mask2[k        ]);
  112.     return s;
  113. }
  114.  
  115. void iterate_fast(__m128i s, int numK) {
  116.     __m128i checksum = _mm_setzero_si128();
  117.     for (int k = 0; k < numK; k++) {
  118.         __m128i res = permutation_fast(k, s);
  119.         checksum = _mm_add_epi32(checksum, res);
  120.     }
  121.  
  122.     uint8_t tmp[16];
  123.     _mm_storeu_si128((__m128i*)tmp, checksum);
  124.     printf("k < %d: ", numK);
  125.     for (int i = 0; i < LEN; i++)
  126.         printf(" %02X", (uint32_t)tmp[i]);
  127.     printf("\n");
  128. }
  129.  
  130. //======================= many-at-once improvement =====================
  131.  
  132. void iterate_many(__m128i s, int numK) {
  133.     assert(DIGITS == 3);
  134.     assert(numK == BASE0 * BASE1 * BASE2);
  135.  
  136.     __m128i checksum = _mm_setzero_si128();
  137.     //note: permutations are iterated NOT in sequental order of 'k' !
  138.     //in fact, k = k0 + k1 * BASE0 + k2 * BASE0*BASE1
  139.     for (int k0 = 0; k0 < BASE0; k0++) {
  140.         __m128i s0 = _mm_shuffle_epi8(s, mask0[k0]);
  141.         for (int k1 = 0; k1 < BASE1; k1++) {
  142.             __m128i s1 = _mm_shuffle_epi8(s0, mask1[k1]);
  143.             for (int k2 = 0; k2 < BASE2; k2+=4) {
  144.                 //note: loop is manually unrolled by 4 times
  145.                 __m128i sx0 = _mm_shuffle_epi8(s1, mask2[k2+0]);
  146.                 __m128i sx1 = _mm_shuffle_epi8(s1, mask2[k2+1]);
  147.                 __m128i sx2 = _mm_shuffle_epi8(s1, mask2[k2+2]);
  148.                 __m128i sx3 = _mm_shuffle_epi8(s1, mask2[k2+3]);
  149.                 checksum = _mm_add_epi32(checksum, sx0);
  150.                 checksum = _mm_add_epi32(checksum, sx1);
  151.                 checksum = _mm_add_epi32(checksum, sx2);
  152.                 checksum = _mm_add_epi32(checksum, sx3);
  153.             }
  154.         }
  155.     }
  156.  
  157.     uint8_t tmp[16];
  158.     _mm_storeu_si128((__m128i*)tmp, checksum);
  159.     printf("k < %d: ", numK);
  160.     for (int i = 0; i < LEN; i++)
  161.         printf(" %02X", (uint32_t)tmp[i]);
  162.     printf("\n");
  163. }
  164.  
  165.  
  166. //======================================================================
  167.  
  168. #define START(name) \
  169.     int time_start_##name = clock(); \
  170.  
  171. #define END(name) \
  172.     int time_end_##name = clock(); \
  173.     printf("%6.3lf s taken by '%s'\n", double(time_end_##name - time_start_##name) / CLOCKS_PER_SEC, #name); \
  174.  
  175. int main() {
  176.     const char source[16] = "abacabadabac";
  177.     __m128i src_reg = _mm_loadu_si128((__m128i*)source);
  178.  
  179.     START(precompute);
  180.         precompute_mask_tables();
  181.     END(precompute);
  182.  
  183.     const int CHECK_CNT = 12739451;
  184.     START(original_check);
  185.         iterate_original (source , CHECK_CNT);
  186.     END(original_check);
  187.     START(fast_check);
  188.         iterate_fast     (src_reg, CHECK_CNT);
  189.     END(fast_check);
  190.  
  191.     const int ALL_CNT = BASE0 * BASE1 * BASE2;
  192. /*  START(original_full);   //note: long wait!
  193.         iterate_original (source, ALL_CNT);
  194.     END(original_full);*/
  195.     START(fast_all);
  196.         iterate_fast (src_reg, ALL_CNT);
  197.     END(fast_all);
  198.     START(many_all);
  199.         iterate_many (src_reg, ALL_CNT);
  200.     END(many_all);
  201.  
  202.  
  203.     return 0;
  204. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement