Advertisement
Guest User

solver.py

a guest
May 13th, 2023
192
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.98 KB | None | 0 0
  1. # CHANGELOG:
  2. # 5/13/2023 8:00 AM PDT (UTC-7): fixed bug that occurs when n is a prime power in solve_composite
  3.  
  4. # Dedicated to public domain/licensed using [CC0](https://creativecommons.org/publicdomain/zero/1.0/).
  5.  
  6. # in Python, result of % is always positive, but
  7. # this isn't true in some languages, so this code
  8. # is written without this assumption.
  9.  
  10. # code could be greatly simplified if you only want one solution
  11.  
  12. # copied from SO
  13. def egcd(a, b):
  14.     # returns Bezout coefficients directly
  15.     if a == 0:
  16.         return (b, 0, 1)
  17.     else:
  18.         g, y, x = egcd(b % a, a)
  19.         return (g, x - (b // a) * y, y)
  20. def modinv(n, mod):
  21.     g, x, y = egcd(n, mod)
  22.     if g != 1:
  23.         raise Exception('not relatively prime')
  24.     return (x % mod + mod) % mod
  25.  
  26. def bezout(a, b):
  27.     g, x, y = egcd(a, b)
  28.     # given modinv:
  29.     # x = modinv(a, b)
  30.     # y = (1 - x * a) // b
  31.     return x, y
  32.  
  33. def solve_prime(A, b, mod):
  34.     rows = len(A)
  35.     cols = len(A[0])
  36.     solved = 0
  37.     freevars = []
  38.     for pivot in range(cols):
  39.         for r in range(solved, rows):
  40.             if A[r][pivot] != 0:
  41.                 break
  42.         else: # if all coeffs zero
  43.             freevars.append(pivot)
  44.             continue
  45.         # swap row into solved
  46.         for c in range(cols):
  47.             temp = A[r][c]
  48.             A[r][c] = A[solved][c]
  49.             A[solved][c] = temp
  50.         temp = b[r]
  51.         b[r] = b[solved]
  52.         b[solved] = temp
  53.         r = solved
  54.         solved += 1
  55.         # invert row
  56.         inv = modinv(A[r][pivot], mod)
  57.         for c in range(cols):
  58.             A[r][c] = (A[r][c] * inv) % mod
  59.         b[r] = (b[r] * inv) % mod
  60.         # subtract from other rows
  61.         for other in range(rows):
  62.             if other == r: continue
  63.             mul = A[other][pivot]
  64.             for c in range(cols):
  65.                 v = (A[other][c] - mul * A[r][c]) % mod
  66.                 v = (v + mod) % mod
  67.                 A[other][c] = v
  68.             b[other] = ((b[other] - mul * b[r]) % mod + mod) % mod
  69.     solns = []
  70.     # loop over assignments to free vars
  71.     free_assn = [0] * len(freevars)
  72.     while True:
  73.         # get assignment
  74.         assn = [0] * cols
  75.         # assign free vars
  76.         for i in range(len(freevars)):
  77.             assn[freevars[i]] = free_assn[i]
  78.         # assign solved vars
  79.         curvar = 0
  80.         for r in range(solved):
  81.             while A[r][curvar] == 0:
  82.                 curvar += 1
  83.                 assert curvar < cols
  84.             assn[curvar] = b[r]
  85.             for i in freevars:
  86.                 assn[curvar] = (assn[curvar] - A[r][i] * assn[i]) % mod
  87.                 assn[curvar] = (assn[curvar] + mod) % mod
  88.         solns.append(assn)
  89.         # increment free_assn
  90.         for j in reversed(range(len(freevars))):
  91.             free_assn[j] += 1
  92.             if free_assn[j] == mod:
  93.                 free_assn[j] = 0
  94.             else:
  95.                 break
  96.         else:
  97.             break
  98.     # print('A =', A)
  99.     # print('b =', b)
  100.     # print('solved =', solved)
  101.     # print('freevars =', freevars)
  102.     return solns
  103.  
  104. # print(solve_prime([
  105. #     [0, 0, 0],
  106. #     [0, 0, 0],
  107. #     [0, 1, 1],
  108. #     ], [0, 0, 1], 2))
  109. # -> [[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]
  110.  
  111. # print(solve_prime([
  112. #     [2, 0, 2],
  113. #     [1, 1, 0],
  114. #     [0, 2, 2],
  115. #     ], [2, 2, 0], 3))
  116. # -> [[0, 2, 1]]
  117.  
  118. def solve_primepow(A, b, p, k):
  119.     # for finding multiple solutions, since the matrix itself
  120.     # doesn't change for each solution, it's possible to cache
  121.     # the row-operation matrix to avoid recomputing RREF
  122.     rows = len(A)
  123.     cols = len(A[0])
  124.     assert k >= 1
  125.     if k == 1:
  126.         return solve_prime(A, b, p)
  127.     # make a copy as solve_prime modifies A, b
  128.     orig_A = [A[r][:] for r in range(rows)]
  129.     orig_b = b[:]
  130.     x1s = solve_prime(A, b, p)
  131.     newmod = pow(p, k - 1)
  132.     mod = newmod * p
  133.     solns = []
  134.     for x1 in x1s:
  135.         # make a copy of A, b as we need to modify it
  136.         A = [orig_A[r][:] for r in range(rows)]
  137.         b = orig_b[:]
  138.         # substitute p x' + x1
  139.         # in terms of p x', the RHS is then less by (coeff) * x1
  140.         for r in range(rows):
  141.             for c in range(cols):
  142.                 b[r] = (b[r] - A[r][c] * x1[c] % mod) % mod
  143.             assert b[r] % p == 0
  144.             b[r] = b[r] // p
  145.         for r in range(rows):
  146.             for c in range(cols):
  147.                 A[r][c] %= newmod
  148.         xps = solve_primepow(A, b, p, k - 1)
  149.         for xp in xps:
  150.             soln = [0] * cols
  151.             for c in range(cols):
  152.                 soln[c] = p * xp[c] + x1[c]
  153.             solns.append(soln)
  154.     return solns
  155.  
  156. # print(solve_primepow([[1, 3, 2], [1, 0, 1], [0, 2, 2]], [2, 2, 0], 2, 2))
  157. # -> [[2, 0, 0], [0, 2, 2], [1, 1, 1], [3, 3, 3]]
  158.  
  159. def prime_factor(n):
  160.     ret = []
  161.     two_exp = 0
  162.     while n % 2 == 0:
  163.         two_exp += 1
  164.         n //= 2
  165.     if two_exp > 0:
  166.         ret.append((2, two_exp))
  167.     p = 3
  168.     while p <= n:
  169.         if n % p == 0:
  170.             # note: p is always prime
  171.             # as if p were composite there would be a smaller
  172.             # prime factor which would have already been divided out
  173.             exp = 0
  174.             while n % p == 0:
  175.                 exp += 1
  176.                 n //= p
  177.             ret.append((p, exp))
  178.         p += 2
  179.     return ret
  180.  
  181. # print(prime_factor(588)) # [(2, 2), (3, 1), (7, 2)]
  182.  
  183. def solve_composite(A, b, n):
  184.     if n == 1:
  185.         # everything = 0
  186.         return [[0] * cols]
  187.     # a lot of this can be pre-computed for a given n
  188.     rows = len(A)
  189.     cols = len(A[0])
  190.     # each soln is a list of CRT constraints (mod, vals)
  191.     soln_cons = [[]]
  192.     for p, e in prime_factor(n):
  193.         mod = pow(p, e)
  194.         # make copies bc modify
  195.         A_temp = [[A[r][c] % mod for c in range(cols)] for r in range(rows)]
  196.         b_temp = b[:]
  197.         old_solns = soln_cons
  198.         cur_solns = solve_primepow(A_temp, b_temp, p, e)
  199.         soln_cons = []
  200.         for old in old_solns:
  201.             for cur in cur_solns:
  202.                 soln = old[:]
  203.                 soln.append((mod, cur))
  204.                 soln_cons.append(soln)
  205.     # compute solutions using CRT
  206.     solns = []
  207.     for cons in soln_cons:
  208.         # there's probably more efficient ways to do CRT
  209.         # than pairwise but this works
  210.         while len(cons) >= 2:
  211.             mod2, vals2 = cons.pop()
  212.             mod1, vals1 = cons.pop()
  213.             m1, m2 = bezout(mod1, mod2)
  214.             assert m1 * mod1 + m2 * mod2 == 1
  215.             mod = mod1 * mod2
  216.             c1 = (m2 * mod2 % mod + mod) % mod
  217.             c2 = (m1 * mod1 % mod + mod) % mod
  218.             soln = [0] * cols
  219.             for c in range(cols):
  220.                 soln[c] = (c1 * vals1[c] + c2 * vals2[c]) % mod
  221.             cons.append((mod, soln))
  222.         assert len(cons) == 1
  223.         mod, vals = cons[0]
  224.         solns.append(vals)
  225.     return solns
  226.  
  227. print(solve_composite(
  228.     [
  229.         [2, 0, 2],
  230.         [4, 4, 0],
  231.         [0, 5, 5]
  232.     ], [2, 2, 3], 6))
  233.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement