Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # CHANGELOG:
- # 5/13/2023 8:00 AM PDT (UTC-7): fixed bug that occurs when n is a prime power in solve_composite
- # Dedicated to public domain/licensed using [CC0](https://creativecommons.org/publicdomain/zero/1.0/).
- # in Python, result of % is always positive, but
- # this isn't true in some languages, so this code
- # is written without this assumption.
- # code could be greatly simplified if you only want one solution
- # copied from SO
- def egcd(a, b):
- # returns Bezout coefficients directly
- if a == 0:
- return (b, 0, 1)
- else:
- g, y, x = egcd(b % a, a)
- return (g, x - (b // a) * y, y)
- def modinv(n, mod):
- g, x, y = egcd(n, mod)
- if g != 1:
- raise Exception('not relatively prime')
- return (x % mod + mod) % mod
- def bezout(a, b):
- g, x, y = egcd(a, b)
- # given modinv:
- # x = modinv(a, b)
- # y = (1 - x * a) // b
- return x, y
- def solve_prime(A, b, mod):
- rows = len(A)
- cols = len(A[0])
- solved = 0
- freevars = []
- for pivot in range(cols):
- for r in range(solved, rows):
- if A[r][pivot] != 0:
- break
- else: # if all coeffs zero
- freevars.append(pivot)
- continue
- # swap row into solved
- for c in range(cols):
- temp = A[r][c]
- A[r][c] = A[solved][c]
- A[solved][c] = temp
- temp = b[r]
- b[r] = b[solved]
- b[solved] = temp
- r = solved
- solved += 1
- # invert row
- inv = modinv(A[r][pivot], mod)
- for c in range(cols):
- A[r][c] = (A[r][c] * inv) % mod
- b[r] = (b[r] * inv) % mod
- # subtract from other rows
- for other in range(rows):
- if other == r: continue
- mul = A[other][pivot]
- for c in range(cols):
- v = (A[other][c] - mul * A[r][c]) % mod
- v = (v + mod) % mod
- A[other][c] = v
- b[other] = ((b[other] - mul * b[r]) % mod + mod) % mod
- solns = []
- # loop over assignments to free vars
- free_assn = [0] * len(freevars)
- while True:
- # get assignment
- assn = [0] * cols
- # assign free vars
- for i in range(len(freevars)):
- assn[freevars[i]] = free_assn[i]
- # assign solved vars
- curvar = 0
- for r in range(solved):
- while A[r][curvar] == 0:
- curvar += 1
- assert curvar < cols
- assn[curvar] = b[r]
- for i in freevars:
- assn[curvar] = (assn[curvar] - A[r][i] * assn[i]) % mod
- assn[curvar] = (assn[curvar] + mod) % mod
- solns.append(assn)
- # increment free_assn
- for j in reversed(range(len(freevars))):
- free_assn[j] += 1
- if free_assn[j] == mod:
- free_assn[j] = 0
- else:
- break
- else:
- break
- # print('A =', A)
- # print('b =', b)
- # print('solved =', solved)
- # print('freevars =', freevars)
- return solns
- # print(solve_prime([
- # [0, 0, 0],
- # [0, 0, 0],
- # [0, 1, 1],
- # ], [0, 0, 1], 2))
- # -> [[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]
- # print(solve_prime([
- # [2, 0, 2],
- # [1, 1, 0],
- # [0, 2, 2],
- # ], [2, 2, 0], 3))
- # -> [[0, 2, 1]]
- def solve_primepow(A, b, p, k):
- # for finding multiple solutions, since the matrix itself
- # doesn't change for each solution, it's possible to cache
- # the row-operation matrix to avoid recomputing RREF
- rows = len(A)
- cols = len(A[0])
- assert k >= 1
- if k == 1:
- return solve_prime(A, b, p)
- # make a copy as solve_prime modifies A, b
- orig_A = [A[r][:] for r in range(rows)]
- orig_b = b[:]
- x1s = solve_prime(A, b, p)
- newmod = pow(p, k - 1)
- mod = newmod * p
- solns = []
- for x1 in x1s:
- # make a copy of A, b as we need to modify it
- A = [orig_A[r][:] for r in range(rows)]
- b = orig_b[:]
- # substitute p x' + x1
- # in terms of p x', the RHS is then less by (coeff) * x1
- for r in range(rows):
- for c in range(cols):
- b[r] = (b[r] - A[r][c] * x1[c] % mod) % mod
- assert b[r] % p == 0
- b[r] = b[r] // p
- for r in range(rows):
- for c in range(cols):
- A[r][c] %= newmod
- xps = solve_primepow(A, b, p, k - 1)
- for xp in xps:
- soln = [0] * cols
- for c in range(cols):
- soln[c] = p * xp[c] + x1[c]
- solns.append(soln)
- return solns
- # print(solve_primepow([[1, 3, 2], [1, 0, 1], [0, 2, 2]], [2, 2, 0], 2, 2))
- # -> [[2, 0, 0], [0, 2, 2], [1, 1, 1], [3, 3, 3]]
- def prime_factor(n):
- ret = []
- two_exp = 0
- while n % 2 == 0:
- two_exp += 1
- n //= 2
- if two_exp > 0:
- ret.append((2, two_exp))
- p = 3
- while p <= n:
- if n % p == 0:
- # note: p is always prime
- # as if p were composite there would be a smaller
- # prime factor which would have already been divided out
- exp = 0
- while n % p == 0:
- exp += 1
- n //= p
- ret.append((p, exp))
- p += 2
- return ret
- # print(prime_factor(588)) # [(2, 2), (3, 1), (7, 2)]
- def solve_composite(A, b, n):
- if n == 1:
- # everything = 0
- return [[0] * cols]
- # a lot of this can be pre-computed for a given n
- rows = len(A)
- cols = len(A[0])
- # each soln is a list of CRT constraints (mod, vals)
- soln_cons = [[]]
- for p, e in prime_factor(n):
- mod = pow(p, e)
- # make copies bc modify
- A_temp = [[A[r][c] % mod for c in range(cols)] for r in range(rows)]
- b_temp = b[:]
- old_solns = soln_cons
- cur_solns = solve_primepow(A_temp, b_temp, p, e)
- soln_cons = []
- for old in old_solns:
- for cur in cur_solns:
- soln = old[:]
- soln.append((mod, cur))
- soln_cons.append(soln)
- # compute solutions using CRT
- solns = []
- for cons in soln_cons:
- # there's probably more efficient ways to do CRT
- # than pairwise but this works
- while len(cons) >= 2:
- mod2, vals2 = cons.pop()
- mod1, vals1 = cons.pop()
- m1, m2 = bezout(mod1, mod2)
- assert m1 * mod1 + m2 * mod2 == 1
- mod = mod1 * mod2
- c1 = (m2 * mod2 % mod + mod) % mod
- c2 = (m1 * mod1 % mod + mod) % mod
- soln = [0] * cols
- for c in range(cols):
- soln[c] = (c1 * vals1[c] + c2 * vals2[c]) % mod
- cons.append((mod, soln))
- assert len(cons) == 1
- mod, vals = cons[0]
- solns.append(vals)
- return solns
- print(solve_composite(
- [
- [2, 0, 2],
- [4, 4, 0],
- [0, 5, 5]
- ], [2, 2, 3], 6))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement