foobar76

Bit-efficient uniform random number generators

Oct 7th, 2019
422
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 6.59 KB | None | 0 0
  1. #include <assert.h>
  2. #include <stdbool.h>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6.  
  7. #define TEST_ROUNDS 1000000000
  8.  
  9. #undef UNROLLED_LOOP    /* Use unrolled loop variants (for branch predictors) */
  10. #define LONG_MULTIPLICATION     /* Long multiplication vs. bitmask with
  11.                                  * rejection */
  12.  
  13. #undef BIT_ACCOUNTING   /* Keep track of used bits */
  14. #undef PRINT_RESULTS    /* Print individual get_uniform_rand() results */
  15.  
  16. #undef BOGUS_RANDOMNESS /* Rotate seed by one bit per request instead of
  17.                          * proper PRNG */
  18.  
  19. /* The state word must be initialized to non-zero */
  20. static uint64_t state;
  21.  
  22. static uint64_t buf[2];
  23. static uint64_t left[2] = {0, 0};
  24.  
  25. /* By Marsaglia(?) */
  26. static uint64_t
  27. xorshift64s(void)
  28. {
  29.   uint64_t x = state;
  30.  
  31.   x ^= x >> 12;
  32.   x ^= x << 25;
  33.   x ^= x >> 27;
  34.   state = x;
  35.  
  36.   return x * 2685821657736338717ULL;
  37. }
  38.  
  39. static void
  40. fill_primary(uint64_t n)
  41. {
  42.   if (64 - left[0] >= left[1])
  43.   {
  44.     if (__builtin_expect(left[0] != 0, 1))
  45.     {
  46.       buf[0] = buf[0] | (buf[1] >> left[0]);
  47.       buf[1] <<= left[0];
  48.     }
  49.     else
  50.     {
  51.       buf[0] = buf[1];
  52.       buf[1] = 0;
  53.     }
  54.  
  55.     left[0] += left[1];
  56.     buf[1] = xorshift64s();
  57.     left[1] = 64;
  58.   }
  59.  
  60.   if (__builtin_expect(left[0] >= n, 1))
  61.   {
  62.     return;
  63.   }
  64.  
  65.   buf[0] |= buf[1] >> left[0];
  66.   buf[1] <<= 64 - left[0];
  67.   left[1] -= 64 - left[0];
  68.   left[0] = 64;
  69. }
  70.  
  71. static void
  72. init_fill(uint64_t x)
  73. {
  74.   state = x;
  75.   buf[1] = xorshift64s();
  76.   left[1] = 64;
  77.  
  78.   fill_primary(64);
  79. }
  80.  
  81. /* Provides n random bits on most significant bits. */
  82. static uint64_t
  83. peek_bits(uint64_t n)
  84. {
  85. #ifndef BOGUS_RANDOMNESS
  86.   if (__builtin_expect(left[0] < n, 0))
  87.   {
  88.     fill_primary(n);
  89.   }
  90.  
  91.   return buf[0];
  92. #else
  93.   state = (state << 1) + (state >> 63);
  94.  
  95.   return state;
  96. #endif
  97. }
  98.  
  99. #ifdef BIT_ACCOUNTING
  100. static uint64_t used_bits = 0;
  101. #endif
  102.  
  103. /* Consumes bits peeked by earlier peek_bits() call. Multiple calls can be
  104.  * made for one peek_bits() call, but sum of arguments must be less or equal
  105.  * to earlier peek. */
  106. static void
  107. consume_bits(uint64_t n)
  108. {
  109.   buf[0] <<= n;
  110.   left[0] -= n;
  111.  
  112. #ifdef BIT_ACCOUNTING
  113.   used_bits += n;
  114. #endif
  115. }
  116.  
  117. #ifndef LONG_MULTIPLICATION
  118. /* Variable length bitmask with rejection method. */
  119.  
  120. #ifndef UNROLLED_LOOP
  121. static uint64_t
  122. get_uniform_rand(uint64_t range)
  123. {
  124.   uint64_t shift = __builtin_clzl(range);
  125.   uint64_t res;
  126.  
  127.   if (__builtin_expect(range == 0, 0))
  128.   {
  129.     return 0;
  130.   }
  131.  
  132.   do
  133.   {
  134.     res = peek_bits(64 - shift) >> shift;
  135.     consume_bits(64 - shift);
  136.   }
  137.   while (__builtin_expect(res >= range, 0));
  138.  
  139.   return res;
  140. }
  141.  
  142. #else   /* UNROLLED_LOOP */
  143. static uint64_t
  144. get_uniform_rand(uint64_t range)
  145. {
  146.   uint64_t rounds = 2;
  147.   uint64_t shift = __builtin_clzl(range);
  148.   uint64_t bitsperloop;
  149.   uint64_t res;
  150.   uint64_t nores = 1;
  151.   uint64_t neededrounds = 0;
  152.  
  153.   if (__builtin_expect(range == 0, 0))
  154.   {
  155.     return 0;
  156.   }
  157.  
  158.   if (rounds * (64 - shift) <= 64)
  159.   {
  160.     bitsperloop = rounds * (64 - shift);
  161.   }
  162.   else
  163.   {
  164.     bitsperloop = 64 - shift;
  165.     rounds = 1;
  166.   }
  167.  
  168.   do
  169.   {
  170.     uint64_t rbits = peek_bits(bitsperloop);
  171.     uint64_t neededrounds = 0;
  172.  
  173.     /* Unrolled loop */
  174.     for (uint64_t i = 0; i < rounds; i++)
  175.     {
  176.       res = nores ? rbits >> shift : res;
  177.       rbits <<= 64 - shift;
  178.       neededrounds += nores;
  179.       nores = res >= range;
  180.     }
  181.  
  182.     consume_bits(neededrounds * (64 - shift));
  183.   }
  184.   while (__builtin_expect(nores, 0));
  185.  
  186.   return res;
  187. }
  188.  
  189. #endif  /* UNROLLED_LOOP */
  190.  
  191. #else   /* LONG_MULTIPLICATION */
  192. /* These versions effectively perform a long multiplication until a fixed
  193.  * point integer portion of the computation can't change or a carry makes it
  194.  * increment. */
  195.  
  196. #ifndef UNROLLED_LOOP
  197. static uint64_t
  198. get_uniform_rand(uint64_t range)
  199. {
  200.   uint64_t shift = __builtin_clzl(range);
  201.   uint64_t add = range << shift;
  202.   unsigned __int128 mul;
  203.   uint64_t res, frac;
  204.   uint64_t rbits;
  205.   uint64_t v0 = 1, v1 = 0;
  206.  
  207.   if (__builtin_expect(range == 0, 0))
  208.   {
  209.     return 0;
  210.   }
  211.  
  212.   rbits = peek_bits(64 - shift);
  213.   consume_bits(64 - shift);
  214.  
  215.   frac = mul = (unsigned __int128)(rbits & (~0ULL << shift)) * range;
  216.   res = mul >> 64;
  217.  
  218.   /* This is essentially arbitrary-precision multiplication, tracking 65
  219.    * bits. */
  220.   while ((v0 ^ v1) & (frac > -add))
  221.   {
  222.     uint64_t tmp;
  223.  
  224.     /* old top bit */
  225.     v0 = frac >> 63;
  226.     frac <<= 1;
  227.  
  228.     tmp = frac;
  229.     frac += ((int64_t) peek_bits(1) >> 63) & add;
  230.     /* new top bit */
  231.     v1 = (tmp > frac);
  232.  
  233.     consume_bits(1);
  234.  
  235.     /* If both v0 and v1 are set a bit carries to the result. */
  236.     res += v0 & v1;
  237.   }
  238.  
  239.   return res;
  240. }
  241.  
  242. #else   /* UNROLLED_LOOP */
  243.  
  244. static uint64_t
  245. get_uniform_rand(uint64_t range)
  246. {
  247.   const uint64_t rounds = 2;
  248.   uint64_t shift = __builtin_clzl(range);
  249.   uint64_t add = range << shift;
  250.   unsigned __int128 mul;
  251.   uint64_t res, frac;
  252.   uint64_t rbits;
  253.   uint64_t v0 = 1, v1 = 0;
  254.   uint64_t ok = 1;
  255.  
  256.   if (__builtin_expect(range == 0, 0))
  257.   {
  258.     return 0;
  259.   }
  260.  
  261.   rbits = peek_bits(64 - shift);
  262.   consume_bits(64 - shift);
  263.  
  264.   frac = mul = (unsigned __int128)(rbits & (~0ULL << shift)) * range;
  265.   res = mul >> 64;
  266.  
  267.   /* Can the fractional part still cause res to increment? */
  268.   do
  269.   {
  270.     uint64_t neededbits = 0;
  271.  
  272.     rbits = peek_bits(rounds);
  273.  
  274.     /* Unrolled loop */
  275.     for (uint64_t i = 0; i < rounds; i++)
  276.     {
  277.       uint64_t tmp;
  278.  
  279.       ok &= (v0 ^ v1) & (frac > -add);
  280.       neededbits += ok;
  281.  
  282.       /* old top bit */
  283.       v0 = frac >> 63;
  284.       frac <<= 1;
  285.  
  286.       /* random bit */
  287.       v1 = rbits >> 63;
  288.       rbits <<= 1;
  289.  
  290.       tmp = frac;
  291.       frac += -(ok & v1) & add;
  292.       /* new top bit */
  293.       v1 = (tmp > frac);
  294.  
  295.       /* If both v0 and v1 are set a bit carries to the result. */
  296.       res += ok & v0 & v1;
  297.     }
  298.  
  299.     consume_bits(neededbits);
  300.   }
  301.   while (__builtin_expect(ok, 0));
  302.  
  303.   return res;
  304. }
  305.  
  306. #endif  /* UNROLLED_LOOP */
  307. #endif  /* LONG_MULTIPLICATION */
  308.  
  309. int
  310. main(int argc, char **argv)
  311. {
  312.   uint64_t mod;
  313.   uint64_t x = 0;
  314.  
  315.   if (argc != 3)
  316.   {
  317.     return 1;
  318.   }
  319.   /* Largest possible return value minus one. */
  320.   mod = atoll(argv[1]);
  321.  
  322.   /* Seed to xorshift64s */
  323.   init_fill(atoll(argv[2]));
  324.  
  325.   for (int i = 0; i < TEST_ROUNDS; i++)
  326.   {
  327.     uint64_t res = get_uniform_rand((volatile uint64_t)mod);
  328.  
  329. #ifdef PRINT_RESULTS
  330.     printf("res %llu\n", res);
  331. #endif
  332.  
  333.     /* assert(res < mod); */
  334.     x += res;
  335.   }
  336.  
  337. #ifdef BIT_ACCOUNTING
  338.   printf("used bits: %llu\n", used_bits);
  339. #endif
  340.  
  341.   printf("sum %llu\n", x);
  342.   return 0;
  343. }
Add Comment
Please, Sign In to add comment