Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <tmmintrin.h>
- #include <assert.h>
- #include <stdio.h>
- #include <stdint.h>
- #include <time.h>
- #include <string.h>
- #include <algorithm>
- //string length
- const int LEN = 12;
- //=========================== original code =========================
- void permutation_original(int k, char *s) {
- for (size_t j = 1; j < LEN; ++j) {
- std::swap(s[k % (j + 1)], s[j]);
- k = k / (j + 1);
- }
- }
- void iterate_original(const char *s, int numK) {
- assert(strlen(s) == LEN);
- uint8_t checksum[16] = {0};
- char s_copy[LEN + 1] = {0};
- for (int k = 0; k < numK; k++) {
- memcpy(s_copy, s, LEN);
- permutation_original(k, s_copy);
- //note: checksum is computed via 32-bit addition
- for (int i = 0; i < 4; i++)
- ((uint32_t*)(void*)checksum)[i] += ((uint32_t*)(void*)s_copy)[i];
- }
- printf("k < %d: ", numK);
- for (int i = 0; i < LEN; i++)
- printf(" %02X", (uint32_t)checksum[i]);
- printf("\n");
- }
- //=========================== SSE solution =========================
- //number of digits in mixed radix system
- const int DIGITS = 3;
- //the bases in mixed radix system
- const int BASE0 = 1*2*3*4*5*6; //720
- const int BASE1 = 7*8*9; //504
- const int BASE2 = 10*11*12; //1320
- //how [1..12] range is broken into digits
- const int DIVISION[DIGITS+1] = {0, 6, 9, 12};
- //lookup tables (padde with zeros)
- __m128i mask0[BASE0 + 4];
- __m128i mask1[BASE1 + 4];
- __m128i mask2[BASE2 + 4];
- __m128i* get_mask_ptr(int idx) {
- if (idx == 0) return mask0;
- if (idx == 1) return mask1;
- if (idx == 2) return mask2;
- assert(0); return 0;
- }
- int get_mask_cnt(int idx) {
- if (idx == 0) return BASE0;
- if (idx == 1) return BASE1;
- if (idx == 2) return BASE2;
- assert(0); return 0;
- }
- void precompute_mask_tables() {
- assert(DIVISION[0] == 0);
- assert(DIVISION[DIGITS] == LEN);
- for (int d = 0; d < DIGITS; d++) {
- int minJ = DIVISION[d] + 1;
- int maxJ = DIVISION[d + 1];
- assert(maxJ >= minJ);
- //check that BASEd constant is correct
- int cnt = 1;
- for (int j = minJ; j <= maxJ; j++)
- cnt *= j;
- assert(cnt == get_mask_cnt(d));
- //clear with -1 masks (which produces zero output)
- __m128i *lut = get_mask_ptr(d);
- memset(lut, -1, (cnt+4) * sizeof(lut[0]));
- //iterate over all possible values of digit
- for (int k = 0; k < cnt; k++) {
- //generate identity permutation
- char order[16];
- for (int i = 0; i < LEN; i++)
- order[i] = i;
- //apply swaps encoded in 'k'
- int kk = k;
- for (int j = minJ; j <= maxJ; j++) {
- int pos = kk % j;
- kk /= j;
- std::swap(order[pos], order[j-1]);
- }
- assert(kk == 0);
- //store the order into mask
- lut[k] = _mm_loadu_si128((__m128i*)order);
- }
- }
- }
- __m128i permutation_fast(int k, __m128i s) {
- s = _mm_shuffle_epi8(s, mask0[k % BASE0]); k /= BASE0;
- s = _mm_shuffle_epi8(s, mask1[k % BASE1]); k /= BASE1;
- s = _mm_shuffle_epi8(s, mask2[k ]);
- return s;
- }
- void iterate_fast(__m128i s, int numK) {
- __m128i checksum = _mm_setzero_si128();
- for (int k = 0; k < numK; k++) {
- __m128i res = permutation_fast(k, s);
- checksum = _mm_add_epi32(checksum, res);
- }
- uint8_t tmp[16];
- _mm_storeu_si128((__m128i*)tmp, checksum);
- printf("k < %d: ", numK);
- for (int i = 0; i < LEN; i++)
- printf(" %02X", (uint32_t)tmp[i]);
- printf("\n");
- }
- //======================= many-at-once improvement =====================
- void iterate_many(__m128i s, int numK) {
- assert(DIGITS == 3);
- assert(numK == BASE0 * BASE1 * BASE2);
- __m128i checksum = _mm_setzero_si128();
- //note: permutations are iterated NOT in sequental order of 'k' !
- //in fact, k = k0 + k1 * BASE0 + k2 * BASE0*BASE1
- for (int k0 = 0; k0 < BASE0; k0++) {
- __m128i s0 = _mm_shuffle_epi8(s, mask0[k0]);
- for (int k1 = 0; k1 < BASE1; k1++) {
- __m128i s1 = _mm_shuffle_epi8(s0, mask1[k1]);
- for (int k2 = 0; k2 < BASE2; k2+=4) {
- //note: loop is manually unrolled by 4 times
- __m128i sx0 = _mm_shuffle_epi8(s1, mask2[k2+0]);
- __m128i sx1 = _mm_shuffle_epi8(s1, mask2[k2+1]);
- __m128i sx2 = _mm_shuffle_epi8(s1, mask2[k2+2]);
- __m128i sx3 = _mm_shuffle_epi8(s1, mask2[k2+3]);
- checksum = _mm_add_epi32(checksum, sx0);
- checksum = _mm_add_epi32(checksum, sx1);
- checksum = _mm_add_epi32(checksum, sx2);
- checksum = _mm_add_epi32(checksum, sx3);
- }
- }
- }
- uint8_t tmp[16];
- _mm_storeu_si128((__m128i*)tmp, checksum);
- printf("k < %d: ", numK);
- for (int i = 0; i < LEN; i++)
- printf(" %02X", (uint32_t)tmp[i]);
- printf("\n");
- }
- //======================================================================
- #define START(name) \
- int time_start_##name = clock(); \
- #define END(name) \
- int time_end_##name = clock(); \
- printf("%6.3lf s taken by '%s'\n", double(time_end_##name - time_start_##name) / CLOCKS_PER_SEC, #name); \
- int main() {
- const char source[16] = "abacabadabac";
- __m128i src_reg = _mm_loadu_si128((__m128i*)source);
- START(precompute);
- precompute_mask_tables();
- END(precompute);
- const int CHECK_CNT = 12739451;
- START(original_check);
- iterate_original (source , CHECK_CNT);
- END(original_check);
- START(fast_check);
- iterate_fast (src_reg, CHECK_CNT);
- END(fast_check);
- const int ALL_CNT = BASE0 * BASE1 * BASE2;
- /* START(original_full); //note: long wait!
- iterate_original (source, ALL_CNT);
- END(original_full);*/
- START(fast_all);
- iterate_fast (src_reg, ALL_CNT);
- END(fast_all);
- START(many_all);
- iterate_many (src_reg, ALL_CNT);
- END(many_all);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement