SHARE
TWEET

Untitled

a guest May 24th, 2019 133 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import sys
  2.  
  3. ASCII = set(
  4.     map(chr, range(0x20, 0x7f)) +
  5.     ['\t', '\n']
  6. )
  7.  
  8. def xor(*args):
  9.     if len(args) > 2:
  10.         xs, ys = args[0], xor(*args[1:])
  11.     else:
  12.         xs, ys = args[0], args[1]
  13.  
  14.     out = []
  15.     for x, y in zip(map(ord, xs), map(ord, ys)):
  16.         out.append(x ^ y)
  17.  
  18.     return ''.join(map(chr, out))
  19.  
  20. def repr_ascii(s):
  21.     o = ''
  22.     for c in s:
  23.         o += c if c in ASCII else '!'
  24.     return repr(o)[1:-1]
  25.  
  26.  
  27. def padding_PKCS5(n):
  28.     return chr(n) * n
  29.  
  30. class PaddingOracle:
  31.     def __init__(self, query, block_size=16, nested=1, padding=padding_PKCS5, output = sys.stdout):
  32.         self.err  = None
  33.         self.nested = nested
  34.         self.padding = padding
  35.         self.output = output
  36.         self.block_size = block_size
  37.  
  38.         # handle different query function types
  39.  
  40.         if query.func_code.co_argcount == 1:
  41.             self.query = lambda iv, ct: query(iv + ct)
  42.         elif query.func_code.co_argcount == 2:
  43.             self.query = query
  44.         else:
  45.             raise ValueError('Query function must take one/two arguments')
  46.  
  47.     def encrypt_block(self, bl, mid, pt):
  48.         assert len(bl) == self.block_size
  49.         assert len(pt) == self.block_size
  50.  
  51.         # ensure that dec(bl) -> pt
  52.  
  53.         iv  = 'A' * self.block_size
  54.         ptt = self.decrypt_block(iv = iv, ct = bl, mid = mid)
  55.  
  56.         return xor(iv, pt, ptt)
  57.  
  58.     def encrypt(self, pt):
  59.  
  60.         pad = self.block_size - (len(pt) % self.block_size)
  61.         pt  = pt + chr(pad) * pad
  62.         mid = 'B'*self.block_size * self.nested
  63.         ct  = ''
  64.  
  65.         assert len(pt) % self.block_size == 0
  66.  
  67.         bs  = [
  68.             pt[i:i+self.block_size] for i in range(0, len(pt), self.block_size)
  69.         ]
  70.  
  71.         assert len(bs) > 0
  72.  
  73.         for pblock in bs[::-1]:
  74.  
  75.             assert len(mid) == self.block_size * self.nested
  76.  
  77.             bl  = mid[-self.block_size:]
  78.             mid = mid[:-self.block_size]
  79.             iv  = self.encrypt_block(bl = bl, pt = pblock, mid = mid)
  80.             mid = iv + mid
  81.             ct  = bl + ct
  82.  
  83.         ct = mid + ct
  84.  
  85.         assert len(ct) == len(pt) + self.block_size * self.nested
  86.  
  87.         return ct[:self.block_size], ct[self.block_size:]
  88.  
  89.     def decrypt(self, ct, iv = None):
  90.  
  91.         if iv is not None:
  92.             ct = iv + ct
  93.  
  94.         assert iv is None or len(iv) == self.block_size
  95.         assert len(ct) > self.block_size*self.nested
  96.         assert len(ct) % self.block_size == 0
  97.  
  98.         blocks = [
  99.             ct[i:i+self.block_size] for i in range(0, len(ct), self.block_size)
  100.         ]
  101.  
  102.         pt = ''
  103.  
  104.         for i in range(0, len(blocks) - self.nested):
  105.  
  106.             bs = blocks[i:i+self.nested+1]
  107.  
  108.             assert len(bs) == self.nested + 1
  109.  
  110.             pt += self.decrypt_block(
  111.                 iv  = bs[0],
  112.                 ct  = bs[-1],
  113.                 mid = ''.join(bs[1:-1])
  114.             )
  115.  
  116.         return pt[:-ord(pt[-1])]
  117.  
  118.     def decrypt_block(self, iv, ct, mid = ''):
  119.         assert len(iv) == self.block_size
  120.         assert len(ct) == self.block_size
  121.         assert len(mid) == (self.nested - 1) * self.block_size
  122.  
  123.         def query(iv, b2):
  124.             return self.query(iv, mid + b2)
  125.  
  126.         # Case A: there is exactly one byte of padding
  127.         # byte 15 is \x01
  128.         # if byte 14 is \x02, then padding will be valid when we flip byte 15 with \x03
  129.  
  130.         pt = ''
  131.         for i in range(self.block_size-1, -1, -1):
  132.             for val in range(0x100):
  133.  
  134.                 pad = self.padding(self.block_size - i)
  135.  
  136.                 iv_flipped = \
  137.                         iv[:i]\
  138.                         + xor(iv[i], chr(val))\
  139.                         + xor(
  140.                             pad[:-1],
  141.                             pt,
  142.                             iv[i+1:]
  143.                         )
  144.  
  145.                 assert len(iv_flipped) == self.block_size
  146.  
  147.                 t = chr(val ^ ord(pad[-1])) + pt
  148.  
  149.                 # print progress
  150.  
  151.                 if self.output:
  152.  
  153.                     r = repr_ascii(t)
  154.                     p = ' ' * (2 * self.block_size - len(r))
  155.  
  156.                     self.output.write('byte %2d, pt %s : %s%s\r' %
  157.                         (
  158.                             i,
  159.                             t.encode('hex').rjust(self.block_size * 2, '?'),
  160.                             r,
  161.                             p
  162.                         )
  163.                     )
  164.                     self.output.flush()
  165.  
  166.                 # query the oracle
  167.  
  168.                 if query(iv_flipped, ct):
  169.  
  170.                     # check for edge-case false positive
  171.  
  172.                     if i == self.block_size - 1:
  173.                         q = query(
  174.                             xor(
  175.                                 iv_flipped,
  176.                                 '\x00' * (self.block_size-2) + '\x01\x00',
  177.                             ),
  178.                             ct
  179.                         )
  180.  
  181.                         if not q:
  182.                             continue
  183.  
  184.                     pt = chr(val ^ ord(pad[-1])) + pt
  185.  
  186.                     break
  187.             else:
  188.                 assert False, 'all 256 xored values failed'
  189.  
  190.         if self.output:
  191.             self.output.write('\n')
  192.             self.output.flush()
  193.  
  194.         return pt
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top