SirSheepe

bigint

Jan 8th, 2021 (edited)
95
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. local bigint = {}
  2.  
  3. -- Create a new bigint or convert a number or string into a big
  4. -- Returns an empty, positive bigint if no number or string is given
  5. function bigint.new(num)
  6.     local self = {
  7.         sign = "+",
  8.         digits = {}
  9.     }
  10.  
  11.     -- Return a new bigint with the same sign and digits
  12.     function self:clone()
  13.         local newint = bigint.new()
  14.         newint.sign = self.sign
  15.         for _, digit in pairs(self.digits) do
  16.             newint.digits[#newint.digits + 1] = digit
  17.         end
  18.         return newint
  19.     end
  20.  
  21.     setmetatable(self, {
  22.         __add = function(lhs, rhs)
  23.             return bigint.add(lhs, rhs)
  24.         end,
  25.         __unm = function()
  26.             if (self.sign == "+") then
  27.                 self.sign = "-"
  28.             else
  29.                 self.sign = "+"
  30.             end
  31.             return self
  32.         end,
  33.         __sub = function(lhs, rhs)
  34.             return bigint.subtract(lhs, rhs)
  35.         end,
  36.         __mul = function(lhs, rhs)
  37.             return bigint.multiply(lhs, rhs)
  38.         end,
  39.         __div = function(lhs, rhs)
  40.             return bigint.divide(lhs, rhs)
  41.         end,
  42.         __mod = function(lhs, rhs)
  43.             local result, remainder = bigint.divide(lhs, rhs)
  44.             return result
  45.         end,
  46.         __pow = function(lhs, rhs)
  47.             return bigint.exponentiate(lhs, rhs)
  48.         end
  49.     })
  50.  
  51.     if (num) then
  52.         local num_string = tostring(num)
  53.         for digit in string.gmatch(num_string, "[0-9]") do
  54.             table.insert(self.digits, tonumber(digit))
  55.         end
  56.         if string.sub(num_string, 1, 1) == "-" then
  57.             self.sign = "-"
  58.         end
  59.     end
  60.  
  61.     return self
  62. end
  63.  
  64. -- Check the type of a big
  65. -- Normally only runs when global variable "strict" == true, but checking can be
  66. -- forced by supplying "true" as the second argument.
  67. function bigint.check(big, force)
  68.     if (strict or force) then
  69.         assert(#big.digits > 0, "bigint is empty")
  70.         assert(type(big.sign) == "string", "bigint is unsigned")
  71.         for _, digit in pairs(big.digits) do
  72.             assert(type(digit) == "number", digit .. " is not a number")
  73.             assert(digit < 10, digit .. " is greater than or equal to 10")
  74.         end
  75.     end
  76.     return true
  77. end
  78.  
  79. -- Return a new big with the same digits but with a positive sign (absolute
  80. -- value)
  81. function bigint.abs(big)
  82.     bigint.check(big)
  83.     local result = big:clone()
  84.     result.sign = "+"
  85.     return result
  86. end
  87.  
  88. -- Convert a big to a number or string
  89. function bigint.unserialize(big, output_type, precision)
  90.     bigint.check(big)
  91.  
  92.     local num = ""
  93.     if big.sign == "-" then
  94.         num = "-"
  95.     end
  96.  
  97.  
  98.     if ((output_type == nil)
  99.     or (output_type == "number")
  100.     or (output_type == "n")
  101.     or (output_type == "string")
  102.     or (output_type == "s")) then
  103.         -- Unserialization to a string or number requires reconstructing the
  104.         -- entire number
  105.  
  106.         for _, digit in pairs(big.digits) do
  107.             num = num .. math.floor(digit) -- lazy way of getting rid of .0$
  108.         end
  109.  
  110.         if ((output_type == nil)
  111.         or (output_type == "number")
  112.         or (output_type == "n")) then
  113.             return tonumber(num)
  114.         else
  115.             return num
  116.         end
  117.  
  118.     else
  119.         -- Unserialization to human-readable form or scientific notation only
  120.         -- requires reading the first few digits
  121.         if (precision == nil) then
  122.             precision = 3
  123.         else
  124.             assert(precision > 0, "Precision cannot be less than 1")
  125.             assert(math.floor(precision) == precision,
  126.                    "Precision must be a positive integer")
  127.         end
  128.  
  129.         -- num is the first (precision + 1) digits, the first being separated by
  130.         -- a decimal point from the others
  131.         num = num .. big.digits[1]
  132.         if (precision > 1) then
  133.             num = num .. "."
  134.             for i = 1, (precision - 1) do
  135.                 num = num .. big.digits[i + 1]
  136.             end
  137.         end
  138.  
  139.         return num .. "*10^" .. (#big.digits - 1)
  140.     end
  141. end
  142.  
  143. -- Basic comparisons
  144. -- Accepts symbols (<, >=, ~=) and Unix shell-like options (lt, ge, ne)
  145. function bigint.compare(big1, big2, comparison)
  146.     bigint.check(big1)
  147.     bigint.check(big2)
  148.  
  149.     local greater = false -- If big1.digits > big2.digits
  150.     local equal = false
  151.  
  152.     if (big1.sign == "-") and (big2.sign == "+") then
  153.         greater = false
  154.     elseif (#big1.digits > #big2.digits)
  155.     or ((big1.sign == "+") and (big2.sign == "-")) then
  156.         greater = true
  157.     elseif (#big1.digits == #big2.digits) then
  158.         -- Walk left to right, comparing digits
  159.         for digit = 1, #big1.digits do
  160.             if (big1.digits[digit] > big2.digits[digit]) then
  161.                 greater = true
  162.                 break
  163.             elseif (big2.digits[digit] > big1.digits[digit]) then
  164.                 break
  165.             elseif (digit == #big1.digits)
  166.                    and (big1.digits[digit] == big2.digits[digit]) then
  167.                 equal = true
  168.             end
  169.         end
  170.  
  171.     end
  172.  
  173.     -- If both numbers are negative, then the requirements for greater are
  174.     -- reversed
  175.     if (not equal) and (big1.sign == "-") and (big2.sign == "-") then
  176.         greater = not greater
  177.     end
  178.  
  179.     return (((comparison == "<") or (comparison == "lt"))
  180.             and ((not greater) and (not equal)) and true)
  181.         or (((comparison == ">") or (comparison == "gt"))
  182.             and ((greater) and (not equal)) and true)
  183.         or (((comparison == "==") or (comparison == "eq"))
  184.             and (equal) and true)
  185.         or (((comparison == ">=") or (comparison == "ge"))
  186.             and (equal or greater) and true)
  187.         or (((comparison == "<=") or (comparison == "le"))
  188.             and (equal or not greater) and true)
  189.         or (((comparison == "~=") or (comparison == "!=") or (comparison == "ne"))
  190.             and (not equal) and true)
  191.         or false
  192. end
  193.  
  194. -- BACKEND: Add big1 and big2, ignoring signs
  195. function bigint.add_raw(big1, big2)
  196.     bigint.check(big1)
  197.     bigint.check(big2)
  198.  
  199.     local result = bigint.new()
  200.     local max_digits = 0
  201.     local carry = 0
  202.  
  203.     if (#big1.digits >= #big2.digits) then
  204.         max_digits = #big1.digits
  205.     else
  206.         max_digits = #big2.digits
  207.     end
  208.  
  209.     -- Walk backwards right to left, like in long addition
  210.     for digit = 0, max_digits - 1 do
  211.         local sum = (big1.digits[#big1.digits - digit] or 0)
  212.                   + (big2.digits[#big2.digits - digit] or 0)
  213.                   + carry
  214.  
  215.         if (sum >= 10) then
  216.             carry = 1
  217.             sum = sum - 10
  218.         else
  219.             carry = 0
  220.         end
  221.  
  222.         result.digits[max_digits - digit] = sum
  223.     end
  224.  
  225.     -- Leftover carry in cases when #big1.digits == #big2.digits and sum > 10, ex. 7 + 9
  226.     if (carry == 1) then
  227.         table.insert(result.digits, 1, 1)
  228.     end
  229.  
  230.     return result
  231.  
  232. end
  233.  
  234. -- BACKEND: Subtract big2 from big1, ignoring signs
  235. function bigint.subtract_raw(big1, big2)
  236.     -- Type checking is done by bigint.compare
  237.     assert(bigint.compare(bigint.abs(big1), bigint.abs(big2), ">="),
  238.            "Size of " .. bigint.unserialize(big1, "string") .. " is less than "
  239.            .. bigint.unserialize(big2, "string"))
  240.  
  241.     local result = big1:clone()
  242.     local max_digits = #big1.digits
  243.     local borrow = 0
  244.  
  245.     -- Logic mostly copied from bigint.add_raw ---------------------------------
  246.     -- Walk backwards right to left, like in long subtraction
  247.     for digit = 0, max_digits - 1 do
  248.         local diff = (big1.digits[#big1.digits - digit] or 0)
  249.                    - (big2.digits[#big2.digits - digit] or 0)
  250.                    - borrow
  251.  
  252.         if (diff < 0) then
  253.             borrow = 1
  254.             diff = diff + 10
  255.         else
  256.             borrow = 0
  257.         end
  258.  
  259.         result.digits[max_digits - digit] = diff
  260.     end
  261.     ----------------------------------------------------------------------------
  262.  
  263.  
  264.     -- Strip leading zeroes if any, but not if 0 is the only digit
  265.     while (#result.digits > 1) and (result.digits[1] == 0) do
  266.         table.remove(result.digits, 1)
  267.     end
  268.  
  269.     return result
  270. end
  271.  
  272. -- FRONTEND: Addition and subtraction operations, accounting for signs
  273. function bigint.add(big1, big2)
  274.     -- Type checking is done by bigint.compare
  275.  
  276.     local result
  277.  
  278.     -- If adding numbers of different sign, subtract the smaller sized one from
  279.     -- the bigger sized one and take the sign of the bigger sized one
  280.     if (big1.sign ~= big2.sign) then
  281.         if (bigint.compare(bigint.abs(big1), bigint.abs(big2), ">")) then
  282.             result = bigint.subtract_raw(big1, big2)
  283.             result.sign = big1.sign
  284.         else
  285.             result = bigint.subtract_raw(big2, big1)
  286.             result.sign = big2.sign
  287.         end
  288.  
  289.     elseif (big1.sign == "+") and (big2.sign == "+") then
  290.         result = bigint.add_raw(big1, big2)
  291.  
  292.     elseif (big1.sign == "-") and (big2.sign == "-") then
  293.         result = bigint.add_raw(big1, big2)
  294.         result.sign = "-"
  295.     end
  296.  
  297.     return result
  298. end
  299. function bigint.subtract(big1, big2)
  300.     -- Type checking is done by bigint.compare in bigint.add
  301.     -- Subtracting is like adding a negative
  302.     local big2_local = big2:clone()
  303.     if (big2.sign == "+") then
  304.         big2_local.sign = "-"
  305.     else
  306.         big2_local.sign = "+"
  307.     end
  308.     return bigint.add(big1, big2_local)
  309. end
  310.  
  311. -- BACKEND: Multiply a big by a single digit big, ignoring signs
  312. function bigint.multiply_single(big1, big2)
  313.     bigint.check(big1)
  314.     bigint.check(big2)
  315.     assert(#big2.digits == 1, bigint.unserialize(big2, "string")
  316.                               .. " has more than one digit")
  317.  
  318.     local result = bigint.new()
  319.     local carry = 0
  320.  
  321.     -- Logic mostly copied from bigint.add_raw ---------------------------------
  322.     -- Walk backwards right to left, like in long multiplication
  323.     for digit = 0, #big1.digits - 1 do
  324.         local this_digit = big1.digits[#big1.digits - digit]
  325.                          * big2.digits[1]
  326.                          + carry
  327.  
  328.         if (this_digit >= 10) then
  329.             carry = math.floor(this_digit / 10)
  330.             this_digit = this_digit - (carry * 10)
  331.         else
  332.             carry = 0
  333.         end
  334.  
  335.         result.digits[#big1.digits - digit] = this_digit
  336.     end
  337.  
  338.     -- Leftover carry in cases when big1.digits[1] * big2.digits[1] > 0
  339.     if (carry > 0) then
  340.         table.insert(result.digits, 1, carry)
  341.     end
  342.     ----------------------------------------------------------------------------
  343.  
  344.     return result
  345. end
  346.  
  347. -- FRONTEND: Multiply two bigs, accounting for signs
  348. function bigint.multiply(big1, big2)
  349.     -- Type checking done by bigint.multiply_single
  350.  
  351.     local result = bigint.new(0)
  352.     local larger, smaller -- Larger and smaller in terms of digits, not size
  353.  
  354.     if (bigint.unserialize(big1) == 0) or (bigint.unserialize(big2) == 0) then
  355.         return result
  356.     end
  357.  
  358.     if (#big1.digits >= #big2.digits) then
  359.         larger = big1
  360.         smaller = big2
  361.     else
  362.         larger = big2
  363.         smaller = big1
  364.     end
  365.  
  366.     -- Walk backwards right to left, like in long multiplication
  367.     for digit = 0, #smaller.digits - 1 do
  368.         -- Sorry for going over column 80! There's lots of big names here
  369.         local this_digit_product = bigint.multiply_single(larger,
  370.                                                           bigint.new(smaller.digits[#smaller.digits - digit]))
  371.  
  372.         -- "Placeholding zeroes"
  373.         if (digit > 0) then
  374.             for placeholder = 1, digit do
  375.                 table.insert(this_digit_product.digits, 0)
  376.             end
  377.         end
  378.  
  379.         result = bigint.add(result, this_digit_product)
  380.     end
  381.  
  382.     if (larger.sign == smaller.sign) then
  383.         result.sign = "+"
  384.     else
  385.         result.sign = "-"
  386.     end
  387.  
  388.     return result
  389. end
  390.  
  391.  
  392. -- Raise a big to a positive integer or big power (TODO: negative integer power)
  393. function bigint.exponentiate(big, power)
  394.     -- Type checking for big done by bigint.multiply
  395.     assert(bigint.compare(power, bigint.new(0), ">"),
  396.            " negative powers are not supported")
  397.     local exp = power:clone()
  398.  
  399.     if (bigint.compare(exp, bigint.new(0), "==")) then
  400.         return bigint.new(1)
  401.     elseif (bigint.compare(exp, bigint.new(1), "==")) then
  402.         return big
  403.     else
  404.         local result = big:clone()
  405.  
  406.         while (bigint.compare(exp, bigint.new(1), ">")) do
  407.             result = bigint.multiply(result, big)
  408.             exp = bigint.subtract(exp, bigint.new(1))
  409.         end
  410.  
  411.         return result
  412.     end
  413.  
  414. end
  415.  
  416. -- BACKEND: Divide two bigs (decimals not supported), returning big result and
  417. -- big remainder
  418. -- WARNING: Only supports positive integers
  419. function bigint.divide_raw(big1, big2)
  420.     -- Type checking done by bigint.compare
  421.     if (bigint.compare(big1, big2, "==")) then
  422.         return bigint.new(1), bigint.new(0)
  423.     elseif (bigint.compare(big1, big2, "<")) then
  424.         return bigint.new(0), bigint.new(0)
  425.     else
  426.         assert(bigint.compare(big2, bigint.new(0), "!="), "error: divide by zero")
  427.         assert(big1.sign == "+", "error: big1 is not positive")
  428.         assert(big2.sign == "+", "error: big2 is not positive")
  429.  
  430.         local result = bigint.new()
  431.  
  432.         local dividend = bigint.new() -- Dividend of a single operation, not the
  433.                                       -- dividend of the overall function
  434.         local divisor = big2:clone()
  435.         local factor = 1
  436.  
  437.         -- Walk left to right among digits in the dividend, like in long
  438.         -- division
  439.         for _, digit in pairs(big1.digits) do
  440.             dividend.digits[#dividend.digits + 1] = digit
  441.  
  442.             -- The dividend is smaller than the divisor, so a zero is appended
  443.             -- to the result and the loop ends
  444.             if (bigint.compare(dividend, divisor, "<")) then
  445.                 if (#result.digits > 0) then -- Don't add leading zeroes
  446.                     result.digits[#result.digits + 1] = 0
  447.                 end
  448.             else
  449.                 -- Find the maximum number of divisors that fit into the
  450.                 -- dividend
  451.                 factor = 0
  452.                 while (bigint.compare(divisor, dividend, "<=")) do
  453.                     divisor = bigint.add(divisor, big2)
  454.                     factor = factor + 1
  455.                 end
  456.  
  457.                 -- Append the factor to the result
  458.                 if (factor == 10) then
  459.                     -- Fixes a weird bug that introduces a new bug if fixed by
  460.                     -- changing the comparison in the while loop to "<="
  461.                     result.digits[#result.digits] = 1
  462.                     result.digits[#result.digits + 1] = 0
  463.                 else
  464.                     result.digits[#result.digits + 1] = factor
  465.                 end
  466.  
  467.                 -- Subtract the divisor from the dividend to obtain the
  468.                 -- remainder, which is the new dividend for the next loop
  469.                 dividend = bigint.subtract(dividend,
  470.                                            bigint.subtract(divisor, big2))
  471.  
  472.                 -- Reset the divisor
  473.                 divisor = big2:clone()
  474.             end
  475.  
  476.         end
  477.  
  478.         -- The remainder of the final loop is returned as the function's
  479.         -- overall remainder
  480.         return result, dividend
  481.     end
  482. end
  483.  
  484. -- FRONTEND: Divide two bigs (decimals not supported), returning big result and
  485. -- big remainder, accounting for signs
  486. function bigint.divide(big1, big2)
  487.     local result, remainder = bigint.divide_raw(bigint.abs(big1),
  488.                                                 bigint.abs(big2))
  489.     if (big1.sign == big2.sign) then
  490.         result.sign = "+"
  491.     else
  492.         result.sign = "-"
  493.     end
  494.  
  495.     return result, remainder
  496. end
  497.  
  498. -- FRONTEND: Return only the remainder from bigint.divide
  499. function bigint.modulus(big1, big2)
  500.     local result, remainder = bigint.divide(big1, big2)
  501.  
  502.     -- Remainder will always have the same sign as the dividend per C standard
  503.     -- https://en.wikipedia.org/wiki/Modulo_operation#Remainder_calculation_for_the_modulo_operation
  504.     remainder.sign = big1.sign
  505.     return remainder
  506. end
  507.  
  508. return bigint
RAW Paste Data