Advertisement
foryou97

RSA know some bits

Aug 16th, 2018
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.40 KB | None | 0 0
  1. # recover it as long as 27% of bits are known
  2. #!/usr/bin/env python
  3. import gmpy
  4. import math
  5. from itertools import product
  6.  
  7.  
  8. #get the base64 out, replace corrupted bytes with base64 value
  9. def fixup(b64, pad):
  10.   ret = ''.join(b64.split("\n")[1:-2]).replace(" ",pad)
  11.   ret = ret.ljust(((len(ret)-1)/4+1)*4,'=')
  12.   return ret
  13.  
  14. def s2bin(s):
  15.   return bin(int(s.encode("hex"),16))[2:]
  16.  
  17. '''
  18. #get known offsets from a real pem file
  19. # FIXME: this is dangerous, some places shift around if the leading byte is 0. Maybe make
  20. #  a few of these or something and try multiple times... since this one may be wrong...
  21. dat = s2bin(''.join(open("poop").read().split("\n")[1:-2]).decode("base64"))
  22. rsak = RSA.importKey(open("poop").read())
  23. d = {}
  24. #ugly, but meh, I'm lazy
  25. d['n'] = bin(rsak.n)[2:].strip("L")
  26. d['e'] = bin(rsak.e)[2:].strip("L")
  27. d['q'] = bin(rsak.q)[2:].strip("L")
  28. d['p'] = bin(rsak.p)[2:].strip("L")
  29. d['d'] = bin(rsak.d)[2:].strip("L")
  30. d['c'] = bin(libnum.invmod(rsak.q,rsak.p))[2:].strip("L")
  31. d['dq'] = bin(rsak.d % (rsak.q - 1))[2:].strip("L")
  32. d['dp'] = bin(rsak.d % (rsak.p - 1))[2:].strip("L")
  33.  
  34.  
  35. inp = open("corrupted2").read()
  36.  
  37. s1 = s2bin(fixup(inp,'A').decode("base64"))
  38. s2 = s2bin(fixup(inp,'/').decode("base64"))
  39. '''
  40.  
  41. inp = open("corrupted.pem").read().split("\n")
  42. cn = ''.join([x[4:].replace(":","") for x in inp[2:11]])
  43. cd = ''.join([x[4:].replace(":","") for x in inp[13:22]])
  44. cp = ''.join([x[4:].replace(":","") for x in inp[23:28]])
  45. cq = ''.join([x[4:].replace(":","") for x in inp[29:34]])
  46. cpd = ''.join([x[4:].replace(":","") for x in inp[35:40]])
  47. cqd = ''.join([x[4:].replace(":","") for x in inp[41:46]])
  48.  
  49.  
  50. '''
  51. known = [a if (a == b) else '2' for (a,b) in zip(s1,s2)]
  52.  
  53. #figure out where in our pem file things live
  54. mask = [" "]*(len(s1)+100)
  55. for v in ['n','e','p','q','d','c','dq','dp']:
  56.  for i in range(dat.index(d[v]),dat.index(d[v])+len(d[v])):
  57.    mask[i] = v
  58.  
  59. #representation of what we know about our missing key bits...
  60. print ''.join([(colored(v[-1],'green' if (b != '2') else 'red')) if (v != ' ') else ' ' for (v,b) in zip(mask,known)])
  61.  
  62.  
  63. pd = int(''.join([str(int(b)&1) if (v == 'd') else '' for (v,b) in zip(mask,known)]),2)|1 #these are odd...
  64. kpd = int(''.join([str(1-(int(b)>>1)) if (v == 'd') else '' for (v,b) in zip(mask,known)]),2)|1
  65.  
  66. pq = int(''.join([str(int(b)&1) if (v == 'q') else '' for (v,b) in zip(mask,known)]),2)|1
  67. kpq = int(''.join([str(1-(int(b)>>1)) if (v == 'q') else '' for (v,b) in zip(mask,known)]),2)|1
  68.  
  69. pp = int(''.join([str(int(b)&1) if (v == 'p') else '' for (v,b) in zip(mask,known)]),2)|1
  70. kpp = int(''.join([str(1-(int(b)>>1)) if (v == 'p') else '' for (v,b) in zip(mask,known)]),2)|1
  71.  
  72. pdp = int(''.join([str(int(b)&1) if (v == 'dp') else '' for (v,b) in zip(mask,known)]),2)
  73. kpdp = int(''.join([str(1-(int(b)>>1)) if (v == 'dp') else '' for (v,b) in zip(mask,known)]),2)|1
  74.  
  75. pdq = int(''.join([str(int(b)&1) if (v == 'dq') else '' for (v,b) in zip(mask,known)]),2)
  76. kpdq = int(''.join([str(1-(int(b)>>1)) if (v == 'dq') else '' for (v,b) in zip(mask,known)]),2)|1
  77. '''
  78. pd = int(cd.replace(" ","0"),16)|1
  79. kpd = int(cd.replace(" ","0"),16)^int(cd.replace(" ","F"),16)^int("F"*len(cd),16)|1
  80.  
  81. pp = int(cp.replace(" ","0"),16)|1
  82. kpp = int(cp.replace(" ","0"),16)^int(cp.replace(" ","F"),16)^int("F"*len(cp),16)|1
  83.  
  84. pq = int(cq.replace(" ","0"),16)|1
  85. kpq = int(cq.replace(" ","0"),16)^int(cq.replace(" ","F"),16)^int("F"*len(cq),16)|1
  86.  
  87. pdp = int(cpd.replace(" ","0"),16)|1
  88. kpdp = int(cpd.replace(" ","0"),16)^int(cpd.replace(" ","F"),16)^int("F"*len(cpd),16)|1
  89.  
  90. pdq = int(cqd.replace(" ","0"),16)|1
  91. kpdq = int(cqd.replace(" ","0"),16)^int(cqd.replace(" ","F"),16)^int("F"*len(cqd),16)|1
  92.  
  93.  
  94.  
  95. # we are given N and e..
  96. N = int("dbfabdb1495d3276e7626b84796e9fc20fa13c1744f10c8c3f3e3c2c6040c2e7f313dfa3d1fe10d1ae577cfeab7452aa53102eef7be0099c022560e57a5c30d50940642d1b097dd2109ae02f2dcff8198cd5a395fcac4266107848b9dd63c387d2538e50415343042033ea09c084155e652b0f062340d5d4717a402a9d806a6b",16)
  97. e = int("010001",16)
  98.  
  99.  
  100. # start adding in stuff from http://eprint.iacr.org/2008/510.pdf
  101. def hamming_weight(x):
  102.   return bin(x).count("1")
  103.  
  104. def dtwiddle(k):
  105.   return ((k*(N+1)+1)/e)
  106.  
  107.  
  108. kp_candidates = []
  109. for kp in xrange(e):
  110.   kp_candidates.append(hamming_weight((dtwiddle(kp)^pd)&pd))
  111.  
  112. k = kp_candidates.index(min(kp_candidates))
  113. del kp_candidates
  114.  
  115. #we now know the first half of d!
  116. dtwid = dtwiddle(k)
  117. pd = (pd & ((1<<512)-1)) | (dtwid & (((1<<512)-1)<<512) )
  118. kpd |= (((1<<512)-1)<<512)
  119.  
  120. def kptest(kp):
  121.   return ((kp*kp)%e - ((k*(N-1)+1)%e)*kp - k) %e == 0
  122.  
  123. def tau(t):
  124.   for i in xrange(t.bit_length()):
  125.     if (t&(1<<i)):
  126.       return i
  127.   return 0
  128.  
  129. kpps = []
  130. for kpv in xrange(e):
  131.   if kptest(kpv):
  132.     kpps.append(kpv)
  133.  
  134. #if these don't work, we need to try the other way!
  135. kq = kpps[0]
  136. kp = kpps[1]
  137.  
  138. dpps = [gmpy.invert(e,2<<(1+tau(kpps[j]))) for j in xrange(2)]
  139.  
  140. def gb(v,i):
  141.   return (v>>i)&1
  142.  
  143. def bp(x,i):
  144.   return x&((1<<i)-1)
  145.  
  146. def sb(x,i,v):
  147.   return (x & ~(1<<i))|(v << i)
  148.  
  149. #fix e*pdp[q] % (2**(1+tau(kp[q]))) == 1
  150. # FIXME we're lazy and e = 0x10001, so just treat e=1
  151. pdp = (pdp & (~(2**(1+tau(kp)) - 1))) | 1
  152. pdq = (pdq & (~(2**(1+tau(kq)) - 1))) | 1
  153.  
  154. def calc_slice(vals,i):
  155.   pp,pq,pd,pdp,pdq = vals
  156.   c1 = gb(N-bp(pp,i)*bp(pq,i),i)
  157.   c2 = gb(k*(N+1)+1 - k*(bp(pp,i)+bp(pq,i)) - e*bp(pd,i),i+tau(k))
  158.   c3 = gb(kp*(bp(pp,i)-1)+1 - e*bp(pdp,i),i+tau(kp))
  159.   c4 = gb(kq*(bp(pq,i)-1)+1 - e*bp(pdq,i),i+tau(kq))
  160.   feasible = []
  161.   for (pi,qi,di,dpi,dqi) in product([0,1],repeat=5):
  162.     if (c1 == (pi^qi)) and (c2 == (di^pi^qi)) and (c3 == (dpi^pi)) and (c4 == (dqi^qi)):
  163.       feasible.append((pi,qi,di,dpi,dqi))
  164.  
  165.   return feasible
  166.  
  167. def force_slice(bits,vals,i):
  168.   pp,pq,pd,pdp,pdq = vals
  169.   pp = sb(pp,i,bits[0])
  170.   pq = sb(pq,i,bits[1])
  171.   pd = sb(pd,tau(k)+i,bits[2])
  172.   pdp = sb(pdp,tau(kp)+i,bits[3])
  173.   pdq = sb(pdq,tau(kq)+i,bits[4])
  174.   return (pp,pq,pd,pdp,pdq)
  175.  
  176. def pick_slice(vals,i):
  177.   slices = calc_slice(vals,i)
  178.   pp,pq,pd,pdp,pdq = vals
  179.   best_slices = []
  180.   for s in slices:
  181.     score = 0
  182.     #fixme shouldn't count to score if it is an "unknown" bit
  183.     if gb(pp,i) != s[0]:
  184.       score += gb(kpp,i)
  185.     if gb(pq,i) != s[1]:
  186.       score += gb(kpq,i)
  187.     if gb(pd,i+tau(k)) != s[2]:
  188.       score += gb(kpd,i+tau(k))
  189.     if gb(pdp,i+tau(kp)) != s[3]:
  190.       score += gb(kpdp,i+tau(kp))
  191.     if gb(pdq,i+tau(kq)) != s[4]:
  192.       score += gb(kpdq,i+tau(kq))
  193.     if score == 0:
  194.       best_slices.append(s)
  195. #    else:
  196. #      print s, score
  197. #      best_slices.append(s)
  198.   return best_slices
  199.  
  200. def actual_slice(i):
  201.   return (gb(pp,i),gb(pq,i),gb(pd,tau(k)+i),gb(pdp,tau(kp)+i),gb(pdq,tau(kq)+i))
  202.  
  203. def solve(stop=512):
  204.   global pp,pq,pd,pdp,pdq
  205.   c = 1
  206.   backtrack = []
  207.   i = 1
  208.   while i < stop:
  209.     if (i%100 == 0):
  210.       print i, len(backtrack)
  211. #    print i
  212.     vals = (pp,pq,pd,pdp,pdq)
  213.     gg = pick_slice(vals,i)
  214.     choose = 0
  215.     if (len(gg) == 0):
  216. #      print "oops going back (%d)"%(len(backtrack))
  217.       saved,vals = backtrack.pop()
  218.       i = saved
  219.       gg = pick_slice(vals,i)
  220.       choose = 1
  221.  
  222.     if (choose == 0):
  223.       c*= len(gg)
  224.       if len(gg) > 1:
  225.         backtrack.append((i,vals))
  226. #        print i, len(backtrack)
  227.     pp,pq,pd,pdp,pdq = force_slice(gg[choose],vals,i)
  228.     i += 1
  229.   return c
  230.  
  231. solve(512)
  232. print pp*pq == N
  233.  
  234. msg = open("encrypted").read()
  235. msg = msg.encode('hex')
  236. msg = int(msg, 16)
  237. msg = pow(msg, pd, N)
  238. msg = hex(msg).strip('L').split('00')[-1]
  239. print msg.decode('hex')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement