Advertisement
Guest User

Untitled

a guest
Apr 10th, 2020
233
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.36 KB | None | 0 0
  1.  
  2. from collections import defaultdict
  3. from fractions import Fraction
  4. from math import log2, ceil
  5.  
  6. END_OF_MESSAGE = 256
  7.  
  8.  
  9.  
  10. from functools import wraps
  11. from time import time
  12.  
  13. def timing(f):
  14.     @wraps(f)
  15.     def wrap(*args, **kw):
  16.         ts = time()
  17.         result = f(*args, **kw)
  18.         te = time()
  19.         print('func:%r args:[%r, %r] took: %2.4f sec' % \
  20.           (f.__name__, args, kw, te-ts))
  21.         return result
  22.     return wrap
  23.  
  24. def build_prob(occurences_list):
  25.  
  26.     # occurences_list = {
  27.     # "A":2,
  28.     # "B":1,
  29.     # "C":1,
  30.     # END_OF_MESSAGE:1
  31.     # }
  32.  
  33.     output_prob = dict()
  34.     occurences_all_count = sum(occurences_list.values())
  35.     cumulative_count = 0
  36.  
  37.     for symbol, occurences in occurences_list.items():
  38.         # prob_pair = Fraction(cumulative_count, occurences_all_count), Fraction(cumulative_count + occurences, occurences_all_count)
  39.         prob_pair =cumulative_count/ occurences_all_count, (cumulative_count + occurences)/occurences_all_count
  40.         output_prob[symbol] = prob_pair
  41.         cumulative_count += occurences
  42.  
  43.     return output_prob
  44.  
  45.    
  46. def returnBits(interval, bit_counter):
  47.     output = ""
  48.     while True:
  49.         l, r = interval
  50.         if r < 0.5:        #if interval is contained in [0, 0.5)
  51.             output += '0' + '1'*bit_counter
  52.             bit_counter = 0
  53.             interval = 2*l, 2*r
  54.         elif l >= 0.5:     #if interval is contained in [0.5, 1)
  55.             output += '1' + '0'*bit_counter
  56.             interval = 2*l - 1, 2*r - 1
  57.             bit_counter = 0
  58.         elif (l >= 0.25 and r < 0.75):  #if interval is contained in [0.25, 0.75)
  59.             bit_counter += 1
  60.             interval = 2*l - 0.5, 2*r - 0.5
  61.         else:
  62.             break
  63.     return interval, bit_counter, output
  64.  
  65.  
  66.  
  67. def end_arithmetic_encoding(interval, bit_counter):
  68.     bit_counter += 1
  69.     l, r = interval
  70.    
  71.     if l < 0.25:
  72.         output = '0' + '1'*bit_counter
  73.     else:
  74.         output = '1' + '0'*bit_counter
  75.     return output
  76.  
  77.  
  78. def arithmetic_encoding(input_codes, occurences):
  79.     start = 0.0
  80.     end = 1.0
  81.     output = ""
  82.     bit_counter = 0
  83.     occurences_list = occurences.copy()
  84.  
  85.     for code in input_codes:
  86.         probability_dict = build_prob(occurences_list)
  87.         l, r = probability_dict[code]
  88.  
  89.         start, end = start + (end - start)*l, start + (end - start)*r
  90.        
  91.         (start, end), bit_counter, temp_output = returnBits((start, end), bit_counter)
  92.        
  93.         output += temp_output
  94.         occurences_list[code] += 1
  95.  
  96.     temp_output = end_arithmetic_encoding((start, end), bit_counter)
  97.     output += temp_output
  98.  
  99.     return output
  100.  
  101. # def arithmetic_decoding(input_fraction, occurences_list):
  102. #     start = 0.0
  103. #     end = 1.0
  104. #     output_codes = []
  105. #     code = None
  106.    
  107. #     while code != END_OF_MESSAGE:
  108. #         probability_dict = build_prob(occurences_list)
  109. #         for code, (l, r) in probability_dict.items():
  110. #             if(start + (end-start)*l <= input_fraction < start + (end-start)*r):
  111. #                 start, end = start + (end-start)*l, start + (end-start)*r
  112.  
  113. #                 if(code != END_OF_MESSAGE):
  114. #                     output_codes.append(code)
  115. #                 print(input_fraction)
  116. #                 while True:
  117. #                     if end < 0.5:        #if interval is contained in [0, 0.5)
  118. #                         pass
  119. #                     elif start >= 0.5:     #if interval is contained in [0.5, 1)
  120. #                         input_fraction -= 0.5
  121. #                         start -= 0.5
  122. #                         end -= 0.5
  123. #                     elif (start >= 0.25 and end <= 0.75):  #if interval is contained in [0.25, 0.75)
  124. #                         input_fraction -= 0.25
  125. #                         start -= 0.25
  126. #                         end -= 0.25
  127. #                     else:
  128. #                         break
  129.  
  130. #                     start *= 2
  131. #                     end *= 2
  132. #                     input_fraction *= 2
  133.  
  134. #                 occurences_list[code] += 1
  135. #                 break
  136. #     return "".join(output_codes)
  137.  
  138. def arithmetic_decoding(input_bits, occurences):
  139.     start = 0.0
  140.     end = 1.0
  141.     output_codes = []
  142.     code = None
  143.     occurences_list = occurences.copy()
  144.  
  145.     interval_sizes = [r-l for (l,r) in build_prob(occurences_list).values()]
  146.     smallest_probability = min(interval_sizes)
  147.     # buffer_size = 2 + ceil(-log2(smallest_probability))
  148.     buffer_size = len(input_bits)
  149.     buffer = BitBuffer(buffer_size, input_bits)
  150.     buffer.load_buffer()
  151.  
  152.     while code != END_OF_MESSAGE:
  153.         probability_dict = build_prob(occurences_list)
  154.         for code, (l, r) in probability_dict.items():
  155.             buffer_value = float(buffer)
  156.  
  157.             if(start + (end-start)*l <= buffer_value < start + (end-start)*r):
  158.                 start, end = start + (end-start)*l, start + (end-start)*r
  159.  
  160.                 if(code != END_OF_MESSAGE):
  161.                     output_codes.append(code)
  162.                 while True:
  163.                     if end < 0.5:        #if interval is contained in [0, 0.5)
  164.                         pass
  165.                     elif start >= 0.5:     #if interval is contained in [0.5, 1)
  166.                         buffer.substract_bit(0) #substract 0.5
  167.                         start -= 0.5
  168.                         end -= 0.5
  169.                     elif (start >= 0.25 and end < 0.75):  #if interval is contained in [0.25, 0.75)
  170.                         buffer.substract_bit(1) #substract 0.25
  171.                         start -= 0.25
  172.                         end -= 0.25
  173.                     else:
  174.                         break
  175.  
  176.                     start *= 2
  177.                     end *= 2
  178.                     buffer <<= 1    # shift all bits in buffer to the left, and move one bit from input to the right
  179.  
  180.                 occurences_list[code] += 1
  181.                 break
  182.     return (output_codes)
  183.                    
  184.  
  185.  
  186. @timing
  187. def encode_from_file(in_filename, out_filename='out.bin'):
  188.     occurences = {x: 1 for x in range(256+1)}
  189.     data = list(open(in_filename, "rb").read())
  190.     encoded_binary = arithmetic_encoding(data+[END_OF_MESSAGE], occurences)
  191.     bits = bitarray(encoded_binary)
  192.     with open(out_filename, 'wb') as f:
  193.         bits.tofile(f)
  194.  
  195. @timing
  196. def decode_from_file(in_filename, out_filename='decoded.bin'):
  197.     occurences = {x: 1 for x in range(256+1)}
  198.  
  199.     encoded_binary = bitarray()
  200.     with open(in_filename, 'rb') as f:
  201.         encoded_binary.fromfile(f)
  202.     decoded = arithmetic_decoding(encoded_binary, occurences)
  203.     with open(out_filename, 'wb') as f:
  204.         f.write(bytearray(decoded))
  205.  
  206.  
  207. # occurences = {x: 1 for x in range(256+1)}
  208. # message = "ABCA"
  209. # message = [0, 1, 2, 3, 2, 2, 2, 255, 1, 2]
  210. # encoded_binary = arithmetic_encoding(message+[END_OF_MESSAGE], occurences)
  211. # print(encoded_binary)
  212. # decoded = arithmetic_decoding(encoded_binary, occurences) #
  213. # print(message)
  214. # print(decoded)
  215. # assert message == decoded
  216. # print("ok")
  217.  
  218. in_file = 'test0.bin'
  219. encoded_file = 'out.bin'
  220. decoded_file = 'decoded0.bin'
  221.  
  222. print('starting encoding')
  223.  
  224. encode_from_file(in_file, encoded_file)
  225. print('finished encoding')
  226. print('starting decoding')
  227.  
  228. decode_from_file(encoded_file, decoded_file)
  229. print('finished decoding')
  230.  
  231. import filecmp
  232. if filecmp.cmp(in_file, decoded_file):
  233.     print("DZIALA!")
  234. else:
  235.     print("jest problem")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement