G2A Many GEOs
SHARE
TWEET

Untitled

a guest Apr 10th, 2020 190 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1.  
  2. from collections import defaultdict
  3. from fractions import Fraction
  4. from math import log2, ceil
  5.  
  6. END_OF_MESSAGE = 256
  7.  
  8.  
  9.  
  10. from functools import wraps
  11. from time import time
  12.  
  13. def timing(f):
  14.     @wraps(f)
  15.     def wrap(*args, **kw):
  16.         ts = time()
  17.         result = f(*args, **kw)
  18.         te = time()
  19.         print('func:%r args:[%r, %r] took: %2.4f sec' % \
  20.           (f.__name__, args, kw, te-ts))
  21.         return result
  22.     return wrap
  23.  
  24. def build_prob(occurences_list):
  25.  
  26.     # occurences_list = {
  27.     # "A":2,
  28.     # "B":1,
  29.     # "C":1,
  30.     # END_OF_MESSAGE:1
  31.     # }
  32.  
  33.     output_prob = dict()
  34.     occurences_all_count = sum(occurences_list.values())
  35.     cumulative_count = 0
  36.  
  37.     for symbol, occurences in occurences_list.items():
  38.         # prob_pair = Fraction(cumulative_count, occurences_all_count), Fraction(cumulative_count + occurences, occurences_all_count)
  39.         prob_pair =cumulative_count/ occurences_all_count, (cumulative_count + occurences)/occurences_all_count
  40.         output_prob[symbol] = prob_pair
  41.         cumulative_count += occurences
  42.  
  43.     return output_prob
  44.  
  45.    
  46. def returnBits(interval, bit_counter):
  47.     output = ""
  48.     while True:
  49.         l, r = interval
  50.         if r < 0.5:        #if interval is contained in [0, 0.5)
  51.             output += '0' + '1'*bit_counter
  52.             bit_counter = 0
  53.             interval = 2*l, 2*r
  54.         elif l >= 0.5:     #if interval is contained in [0.5, 1)
  55.             output += '1' + '0'*bit_counter
  56.             interval = 2*l - 1, 2*r - 1
  57.             bit_counter = 0
  58.         elif (l >= 0.25 and r < 0.75):  #if interval is contained in [0.25, 0.75)
  59.             bit_counter += 1
  60.             interval = 2*l - 0.5, 2*r - 0.5
  61.         else:
  62.             break
  63.     return interval, bit_counter, output
  64.  
  65.  
  66.  
  67. def end_arithmetic_encoding(interval, bit_counter):
  68.     bit_counter += 1
  69.     l, r = interval
  70.    
  71.     if l < 0.25:
  72.         output = '0' + '1'*bit_counter
  73.     else:
  74.         output = '1' + '0'*bit_counter
  75.     return output
  76.  
  77.  
  78. def arithmetic_encoding(input_codes, occurences):
  79.     start = 0.0
  80.     end = 1.0
  81.     output = ""
  82.     bit_counter = 0
  83.     occurences_list = occurences.copy()
  84.  
  85.     for code in input_codes:
  86.         probability_dict = build_prob(occurences_list)
  87.         l, r = probability_dict[code]
  88.  
  89.         start, end = start + (end - start)*l, start + (end - start)*r
  90.        
  91.         (start, end), bit_counter, temp_output = returnBits((start, end), bit_counter)
  92.        
  93.         output += temp_output
  94.         occurences_list[code] += 1
  95.  
  96.     temp_output = end_arithmetic_encoding((start, end), bit_counter)
  97.     output += temp_output
  98.  
  99.     return output
  100.  
  101. # def arithmetic_decoding(input_fraction, occurences_list):
  102. #     start = 0.0
  103. #     end = 1.0
  104. #     output_codes = []
  105. #     code = None
  106.    
  107. #     while code != END_OF_MESSAGE:
  108. #         probability_dict = build_prob(occurences_list)
  109. #         for code, (l, r) in probability_dict.items():
  110. #             if(start + (end-start)*l <= input_fraction < start + (end-start)*r):
  111. #                 start, end = start + (end-start)*l, start + (end-start)*r
  112.  
  113. #                 if(code != END_OF_MESSAGE):
  114. #                     output_codes.append(code)
  115. #                 print(input_fraction)
  116. #                 while True:
  117. #                     if end < 0.5:        #if interval is contained in [0, 0.5)
  118. #                         pass
  119. #                     elif start >= 0.5:     #if interval is contained in [0.5, 1)
  120. #                         input_fraction -= 0.5
  121. #                         start -= 0.5
  122. #                         end -= 0.5
  123. #                     elif (start >= 0.25 and end <= 0.75):  #if interval is contained in [0.25, 0.75)
  124. #                         input_fraction -= 0.25
  125. #                         start -= 0.25
  126. #                         end -= 0.25
  127. #                     else:
  128. #                         break
  129.  
  130. #                     start *= 2
  131. #                     end *= 2
  132. #                     input_fraction *= 2
  133.  
  134. #                 occurences_list[code] += 1
  135. #                 break
  136. #     return "".join(output_codes)
  137.  
  138. def arithmetic_decoding(input_bits, occurences):
  139.     start = 0.0
  140.     end = 1.0
  141.     output_codes = []
  142.     code = None
  143.     occurences_list = occurences.copy()
  144.  
  145.     interval_sizes = [r-l for (l,r) in build_prob(occurences_list).values()]
  146.     smallest_probability = min(interval_sizes)
  147.     # buffer_size = 2 + ceil(-log2(smallest_probability))
  148.     buffer_size = len(input_bits)
  149.     buffer = BitBuffer(buffer_size, input_bits)
  150.     buffer.load_buffer()
  151.  
  152.     while code != END_OF_MESSAGE:
  153.         probability_dict = build_prob(occurences_list)
  154.         for code, (l, r) in probability_dict.items():
  155.             buffer_value = float(buffer)
  156.  
  157.             if(start + (end-start)*l <= buffer_value < start + (end-start)*r):
  158.                 start, end = start + (end-start)*l, start + (end-start)*r
  159.  
  160.                 if(code != END_OF_MESSAGE):
  161.                     output_codes.append(code)
  162.                 while True:
  163.                     if end < 0.5:        #if interval is contained in [0, 0.5)
  164.                         pass
  165.                     elif start >= 0.5:     #if interval is contained in [0.5, 1)
  166.                         buffer.substract_bit(0) #substract 0.5
  167.                         start -= 0.5
  168.                         end -= 0.5
  169.                     elif (start >= 0.25 and end < 0.75):  #if interval is contained in [0.25, 0.75)
  170.                         buffer.substract_bit(1) #substract 0.25
  171.                         start -= 0.25
  172.                         end -= 0.25
  173.                     else:
  174.                         break
  175.  
  176.                     start *= 2
  177.                     end *= 2
  178.                     buffer <<= 1    # shift all bits in buffer to the left, and move one bit from input to the right
  179.  
  180.                 occurences_list[code] += 1
  181.                 break
  182.     return (output_codes)
  183.                    
  184.  
  185.  
  186. @timing
  187. def encode_from_file(in_filename, out_filename='out.bin'):
  188.     occurences = {x: 1 for x in range(256+1)}
  189.     data = list(open(in_filename, "rb").read())
  190.     encoded_binary = arithmetic_encoding(data+[END_OF_MESSAGE], occurences)
  191.     bits = bitarray(encoded_binary)
  192.     with open(out_filename, 'wb') as f:
  193.         bits.tofile(f)
  194.  
  195. @timing
  196. def decode_from_file(in_filename, out_filename='decoded.bin'):
  197.     occurences = {x: 1 for x in range(256+1)}
  198.  
  199.     encoded_binary = bitarray()
  200.     with open(in_filename, 'rb') as f:
  201.         encoded_binary.fromfile(f)
  202.     decoded = arithmetic_decoding(encoded_binary, occurences)
  203.     with open(out_filename, 'wb') as f:
  204.         f.write(bytearray(decoded))
  205.  
  206.  
  207. # occurences = {x: 1 for x in range(256+1)}
  208. # message = "ABCA"
  209. # message = [0, 1, 2, 3, 2, 2, 2, 255, 1, 2]
  210. # encoded_binary = arithmetic_encoding(message+[END_OF_MESSAGE], occurences)
  211. # print(encoded_binary)
  212. # decoded = arithmetic_decoding(encoded_binary, occurences) #
  213. # print(message)
  214. # print(decoded)
  215. # assert message == decoded
  216. # print("ok")
  217.  
  218. in_file = 'test0.bin'
  219. encoded_file = 'out.bin'
  220. decoded_file = 'decoded0.bin'
  221.  
  222. print('starting encoding')
  223.  
  224. encode_from_file(in_file, encoded_file)
  225. print('finished encoding')
  226. print('starting decoding')
  227.  
  228. decode_from_file(encoded_file, decoded_file)
  229. print('finished decoding')
  230.  
  231. import filecmp
  232. if filecmp.cmp(in_file, decoded_file):
  233.     print("DZIALA!")
  234. else:
  235.     print("jest problem")
RAW Paste Data
Ledger Nano X - The secure hardware wallet
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top