AgentE382

Lua 5.1 BigInt

Jun 10th, 2014
319
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 12.12 KB | None | 0 0
  1. --
  2. -- Copyright (c) 2010 Ted Unangst <[email protected]>
  3. --
  4. -- Permission to use, copy, modify, and distribute this software for any
  5. -- purpose with or without fee is hereby granted, provided that the above
  6. -- copyright notice and this permission notice appear in all copies.
  7. --
  8. -- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. -- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. -- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. -- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. -- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. -- ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. -- OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. --
  16.  
  17. --
  18. -- Lua version ported/copied from the C version, copyright as follows.
  19. --
  20. -- Copyright (c) 2000 by Jef Poskanzer <[email protected]>.
  21. -- All rights reserved.
  22. --
  23. -- Redistribution and use in source and binary forms, with or without
  24. -- modification, are permitted provided that the following conditions
  25. -- are met:
  26. -- 1. Redistributions of source code must retain the above copyright
  27. --    notice, this list of conditions and the following disclaimer.
  28. -- 2. Redistributions in binary form must reproduce the above copyright
  29. --    notice, this list of conditions and the following disclaimer in the
  30. --    documentation and/or other materials provided with the distribution.
  31. --
  32. -- THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  33. -- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  34. -- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  35. -- ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  36. -- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  37. -- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  38. -- OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  39. -- HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  40. -- LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  41. -- OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  42. -- SUCH DAMAGE.
  43. --
  44.  
  45. --
  46. -- Usage is pretty obvious.  bigint(n) constructs a bigint.  n can be
  47. -- a number or a string.  The regular operators are all overridden via
  48. -- metatable, and also generally work with regular numbers too.
  49. -- Note that comparisons between bigints and numbers *don't* work.
  50. -- bigint() uses a small cache, so bigint(constant) should be fast.
  51. --
  52. -- Only the basic operations are here.  No gcd, factorial, prime test, ...
  53. --
  54. -- Technical notes:
  55. -- Primary difference from the C version is the obvious 1 indexing, this
  56. -- shows up in a variety of ways and places, along with slightly different
  57. -- handling of pre-allocated comps arrays.
  58. -- I made a brief effort to sync some of the comments, but you're better
  59. -- off just reading the C code.
  60. -- C version homepage: http://www.acme.com/software/bigint/
  61. --
  62.  
  63. -- two functions to help make Lua act more like C
  64. local function fl(x)
  65.     if x < 0 then
  66.         return math.ceil(x) + 0 -- make -0 go away
  67.     else
  68.         return math.floor(x)
  69.     end
  70. end
  71.  
  72. local function cmod(a, b)
  73.     local x = a % b
  74.     if a < 0 and x > 0 then
  75.         x = x - b
  76.     end
  77.     return x
  78. end
  79.  
  80.  
  81. local radix = 2^24 -- maybe up to 2^26 is safe?
  82. local radix_sqrt = fl(math.sqrt(radix))
  83.  
  84. local bigintmt -- forward decl
  85.  
  86. local function alloc()
  87.     local bi = {}
  88.     setmetatable(bi, bigintmt)
  89.     bi.comps = {}
  90.     bi.sign = 1;
  91.     return bi
  92. end
  93.  
  94. local function clone(a)
  95.     local bi = alloc()
  96.     bi.sign = a.sign
  97.     local c = bi.comps
  98.     local ac = a.comps
  99.     for i = 1, #ac do
  100.         c[i] = ac[i]
  101.     end
  102.     return bi
  103. end
  104.  
  105. local function normalize(bi, notrunc)
  106.     local c = bi.comps
  107.     local v
  108.     -- borrow for negative components
  109.     for i = 1, #c - 1 do
  110.         v = c[i]
  111.         if v < 0 then
  112.             c[i+1] = c[i+1] + fl(v / radix) - 1
  113.             v = cmod(v, radix)
  114.             if v ~= 0 then
  115.                 c[i] = v + radix
  116.             else
  117.                 c[i] = v
  118.                 c[i+1] = c[i+1] + 1
  119.             end
  120.         end
  121.     end
  122.     -- is top component negative?
  123.     if c[#c] < 0 then
  124.         -- switch the sign and fix components
  125.         bi.sign = -bi.sign
  126.         for i = 1, #c - 1 do
  127.             v = c[i]
  128.             c[i] = radix - v
  129.             c[i+1] = c[i+1] + 1
  130.         end
  131.         c[#c] = -c[#c]
  132.     end
  133.     -- carry for components larger than radix
  134.     for i = 1, #c do
  135.         v = c[i]
  136.         if v > radix then
  137.             c[i+1] = (c[i+1] or 0) + fl(v / radix)
  138.             c[i] = cmod(v, radix)
  139.         end
  140.     end
  141.     -- trim off leading zeros
  142.     if not notrunc then
  143.         for i = #c, 2, -1 do
  144.             if c[i] == 0 then
  145.                 c[i] = nil
  146.             else
  147.                 break
  148.             end
  149.         end
  150.     end
  151.     -- check for -0
  152.     if #c == 1 and c[1] == 0 and bi.sign == -1 then
  153.         bi.sign = 1
  154.     end
  155. end
  156.  
  157. local function negate(a)
  158.     local bi = clone(a)
  159.     bi.sign = -bi.sign
  160.     return bi
  161. end
  162.  
  163. local function compare(a, b)
  164.     local ac, bc = a.comps, b.comps
  165.     local as, bs = a.sign, b.sign
  166.     if ac == bc then
  167.         return 0
  168.     elseif as > bs then
  169.         return 1
  170.     elseif as < bs then
  171.         return -1
  172.     elseif #ac > #bc then
  173.         return as
  174.     elseif #ac < #bc then
  175.         return -as
  176.     end
  177.     for i = #ac, 1, -1 do
  178.         if ac[i] > bc[i] then
  179.             return as
  180.         elseif ac[i] < bc[i] then
  181.             return -as
  182.         end
  183.     end
  184.     return 0
  185. end
  186.  
  187. local function lt(a, b)
  188.     return compare(a, b) < 0
  189. end
  190.  
  191. local function eq(a, b)
  192.     return compare(a, b) == 0
  193. end
  194.  
  195. local function le(a, b)
  196.     return compare(a, b) <= 0
  197. end
  198.  
  199. local function addint(a, n)
  200.     local bi = clone(a)
  201.     if bi.sign == 1 then
  202.         bi.comps[1] = bi.comps[1] + n
  203.     else
  204.         bi.comps[1] = bi.comps[1] - n
  205.     end
  206.     normalize(bi)
  207.     return bi
  208. end
  209.  
  210. local function add(a, b)
  211.     if type(a) == "number" then
  212.         return addint(b, a)
  213.     elseif type(b) == "number" then
  214.         return addint(a, b)
  215.     end
  216.     local bi = clone(a)
  217.     local sign = bi.sign == b.sign
  218.     local c = bi.comps
  219.     for i = #c + 1, #b.comps do
  220.         c[i] = 0
  221.     end
  222.     local bc = b.comps
  223.     for i = 1, #bc do
  224.         local v = bc[i]
  225.         if sign then
  226.             c[i] = c[i] + v
  227.         else
  228.             c[i] = c[i] - v
  229.         end
  230.     end
  231.     normalize(bi)
  232.     return bi
  233. end
  234.  
  235. local function sub(a, b)
  236.     if type(b) == "number" then
  237.         return addint(a, -b)
  238.     elseif type(a) == "number" then
  239.         a = bigint(a)
  240.     end
  241.     return add(a, negate(b))
  242. end
  243.  
  244. local function mulint(a, b)
  245.     local bi = clone(a)
  246.     if b < 0 then
  247.         b = -b
  248.         bi.sign = -bi.sign
  249.     end
  250.     local bc = bi.comps
  251.     for i = 1, #bc do
  252.         bc[i] = bc[i] * b
  253.     end
  254.     normalize(bi)
  255.     return bi
  256. end
  257.  
  258. local function multiply(a, b)
  259.     local bi = alloc()
  260.     local c = bi.comps
  261.     local ac, bc = a.comps, b.comps
  262.     for i = 1, #ac + #bc do
  263.         c[i] = 0
  264.     end
  265.     for i = 1, #ac do
  266.         for j = 1, #bc do
  267.             c[i+j-1] = c[i+j-1] + ac[i] * bc[j]
  268.         end
  269.         -- keep the zeroes
  270.         normalize(bi, true)
  271.     end
  272.     normalize(bi)
  273.     if bi ~= bigint(0) then
  274.         bi.sign = a.sign * b.sign
  275.     end
  276.     return bi
  277. end
  278.  
  279. local function kmul(a, b)
  280.     local ac, bc = a.comps, b.comps
  281.     local an, bn = #a.comps, #b.comps
  282.     local bi, bj, bk, bl = alloc(), alloc(), alloc(), alloc()
  283.     local ic, jc, kc, lc = bi.comps, bj.comps, bk.comps, bl.comps
  284.  
  285.     local n = fl((math.max(an, bn) + 1) / 2)
  286.     for i = 1, n do
  287.         ic[i] = (i + n <= an) and ac[i+n] or 0
  288.         jc[i] = (i <= an) and ac[i] or 0
  289.         kc[i] = (i + n <= bn) and bc[i+n] or 0
  290.         lc[i] = (i <= bn) and bc[i] or 0
  291.     end
  292.     normalize(bi)
  293.     normalize(bj)
  294.     normalize(bk)
  295.     normalize(bl)
  296.     local ik = bi * bk
  297.     local jl = bj * bl
  298.     local mid = (bi + bj) * (bk + bl) - ik - jl
  299.     local mc = mid.comps
  300.     local ikc = ik.comps
  301.     local jlc = jl.comps
  302.     for i = 1, #mc do
  303.         jlc[i+n] = (jlc[i+n] or 0) + mc[i]
  304.     end
  305.     for i = 1, #ikc do
  306.         jlc[i+n*2] = (jlc[i+n*2] or 0) + ikc[i]
  307.     end
  308.     jl.sign = a.sign * b.sign
  309.     normalize(jl)
  310.     return jl
  311. end
  312.  
  313. local kthresh = 12
  314.  
  315. local function mul(a, b)
  316.     if type(a) == "number" then
  317.         return mulint(b, a)
  318.     elseif type(b) == "number" then
  319.         return mulint(a, b)
  320.     end
  321.     if #a.comps < kthresh or #b.comps < kthresh then
  322.         return multiply(a, b)
  323.     end
  324.     return kmul(a, b)
  325. end
  326.  
  327. local function divint(numer, denom)
  328.     local bi = clone(numer)
  329.     if denom < 0 then
  330.         denom = -denom
  331.         bi.sign = -bi.sign
  332.     end
  333.     local r = 0
  334.     local c = bi.comps
  335.     for i = #c, 1, -1 do
  336.         r = r * radix + c[i]
  337.         c[i] = fl(r / denom)
  338.         r = cmod(r, denom)
  339.     end
  340.     normalize(bi)
  341.     return bi
  342. end
  343.  
  344. local function multi_divide(numer, denom)
  345.     local n = #denom.comps
  346.     local approx = divint(numer, denom.comps[n])
  347.     for i = n, #approx.comps do
  348.         approx.comps[i - n + 1] = approx.comps[i]
  349.     end
  350.     for i = #approx.comps, #approx.comps - n + 2, -1 do
  351.         approx.comps[i] = nil
  352.     end
  353.     local rem = approx * denom - numer
  354.     if rem < denom then
  355.         quotient = approx
  356.     else
  357.         quotient = approx - multi_divide(rem, denom)
  358.     end
  359.     return quotient
  360. end
  361.  
  362. local function multi_divide_wrap(numer, denom)
  363.     -- we use a successive approximation method, but it doesn't work
  364.     -- if the high order component is too small.  adjust if needed.
  365.     if denom.comps[#denom.comps] < radix_sqrt then
  366.         numer = mulint(numer, radix_sqrt)
  367.         denom = mulint(denom, radix_sqrt)
  368.     end
  369.     return multi_divide(numer, denom)
  370. end
  371.  
  372. local function div(numer, denom)
  373.     if type(denom) == "number" then
  374.         if denom == 0 then
  375.             error("divide by 0", 2)
  376.         end
  377.         return divint(numer, denom)
  378.     elseif type(numer) == "number" then
  379.         numer = bigint(numer)
  380.     end
  381.     -- check signs and trivial cases
  382.     local sign = 1
  383.     local cmp = compare(denom, bigint(0))
  384.     if cmp == 0 then
  385.         error("divide by 0", 2)
  386.     elseif cmp == -1 then
  387.         sign = -sign
  388.         denom = negate(denom)
  389.     end
  390.     cmp = compare(numer, bigint(0))
  391.     if cmp == 0 then
  392.         return bigint(0)
  393.     elseif cmp == -1 then
  394.         sign = -sign
  395.         numer = negate(numer)
  396.     end
  397.     cmp = compare(numer, denom)
  398.     if cmp == -1 then
  399.         return bigint(0)
  400.     elseif cmp == 0 then
  401.         return bigint(sign)
  402.     end
  403.     local bi
  404.     -- if small enough, do it the easy way
  405.     if #denom.comps == 1 then
  406.         bi = divint(numer, denom.comps[1])
  407.     else
  408.         bi = multi_divide_wrap(numer, denom)
  409.     end
  410.     if sign == -1 then
  411.         bi = negate(bi)
  412.     end
  413.     return bi
  414. end
  415.  
  416. local function intrem(bi, m)
  417.     if m < 0 then
  418.         m = -m
  419.     end
  420.     local rad_r = 1
  421.     local r = 0
  422.     local bc = bi.comps
  423.     for i = 1, #bc do
  424.         local v = bc[i]
  425.         r = cmod(r + v * rad_r, m)
  426.         rad_r = cmod(rad_r * radix, m)
  427.     end
  428.     if bi.sign < 1 then
  429.         r = -r
  430.     end
  431.     return r
  432. end
  433.  
  434. local function intmod(bi, m)
  435.     local r = intrem(bi, m)
  436.     if r < 0 then
  437.         r = r + m
  438.     end
  439.     return r
  440. end
  441.  
  442. local function rem(bi, m)
  443.     if type(m) == "number" then
  444.         return bigint(intrem(bi, m))
  445.     elseif type(bi) == "number" then
  446.         bi = bigint(bi)
  447.     end
  448.     return bi - ((bi / m) * m)
  449. end
  450.  
  451. local function mod(a, m)
  452.     local bi = rem(a, m)
  453.     if bi.sign == -1 then
  454.         bi = bi + m
  455.     end
  456.     return bi
  457. end
  458.  
  459. local printscale = 10000000
  460. local printscalefmt = string.format("%%.%dd", math.log10(printscale))
  461. local function makestr(bi, s)
  462.     if bi >= bigint(printscale) then
  463.         makestr(divint(bi, printscale), s)
  464.     end
  465.     table.insert(s, string.format(printscalefmt, intmod(bi, printscale)))
  466. end
  467.  
  468. local function biginttostring(bi)
  469.     local s = {}
  470.     if bi < bigint(0) then
  471.         bi = negate(bi)
  472.         table.insert(s, "-")
  473.     end
  474.     makestr(bi, s)
  475.     s = table.concat(s):gsub("^0*", "")
  476.     if s == "" then s = "0" end
  477.     return s
  478. end
  479.  
  480. bigintmt = {
  481.     __add = add,
  482.     __sub = sub,
  483.     __mul = mul,
  484.     __div = div,
  485.     __mod = mod,
  486.     __unm = negate,
  487.     __eq = eq,
  488.     __lt = lt,
  489.     __le = le,
  490.     __tostring = biginttostring
  491. }
  492.  
  493. local cache = {}
  494. local ncache = 0
  495.  
  496. function bigint(n)
  497.     if cache[n] then
  498.         return cache[n]
  499.     end
  500.     local bi
  501.     if type(n) == "string" then
  502.         local digits = { n:byte(1, -1) }
  503.         for i = 1, #digits do
  504.             digits[i] = string.char(digits[i])
  505.         end
  506.         local start = 1
  507.         local sign = 1
  508.         if digits[i] == '-' then
  509.             sign = -1
  510.             start = 2
  511.         end
  512.         bi = bigint(0)
  513.         for i = start, #digits do
  514.             bi = addint(mulint(bi, 10), tonumber(digits[i]))
  515.         end
  516.         bi = mulint(bi, sign)
  517.     else
  518.         bi = alloc()
  519.         bi.comps[1] = n
  520.         normalize(bi)
  521.     end
  522.     if ncache > 100 then
  523.         cache = {}
  524.         ncache = 0
  525.     end
  526.     cache[n] = bi
  527.     ncache = ncache + 1
  528.     return bi
  529. end
  530.  
  531. function mp(b,e,m)  -- This function by AgentE382, based on Martin Huesser's function, KillaVanilla's function, and the C API function.
  532.     local o = bigint(1)
  533.     local r = clone(o)
  534.     local z = bigint(0)
  535.     while e > z do
  536.         if e % 2 == o then
  537.             r = (r * b) % m
  538.         end
  539.         e = e / 2
  540.         b = (b * b) % m
  541.     end
  542.     return r
  543. end
Advertisement
Add Comment
Please, Sign In to add comment