Guest User

Untitled

a guest
Jun 24th, 2018
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.71 KB | None | 0 0
  1. import heapq
  2.  
  3.  
  4. class HuffmanNode:
  5. def __init__(self, symbol=None, freq=None):
  6. self.symbol = symbol
  7. self.freq = freq
  8. self.parent = None
  9. self.left = None
  10. self.right = None
  11.  
  12. def __lt__(self, other):
  13. return self.freq < other.freq
  14.  
  15. def is_leaf(self):
  16. return not self.left and not self.right
  17.  
  18. def get_code(self):
  19. # 调试用
  20. if not self.is_laef():
  21. raise ValueError("Not a leaf node.")
  22.  
  23. code = ''
  24. node = self
  25. while node.parent:
  26. if node.parent.left == node:
  27. code = '0' + code
  28. else:
  29. code = '1' + code
  30. code = code.parent
  31.  
  32. return code
  33.  
  34.  
  35. class Huffman:
  36. BYTE_MAX_NUM = 255
  37.  
  38. def __init__(self):
  39. self.origin = None
  40. self.compressed = None
  41. self.huffman_tree = None
  42. self.freqs = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
  43. self.coding_table = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
  44. self.reverse_table = {}
  45. self.coding_str = ''
  46.  
  47. def _minimize_frequencies(self):
  48. # 缩小字频使其在一个字节范围以内
  49. max_freq = max(self.freqs)
  50.  
  51. for symbol, freq in enumerate(self.freqs):
  52. scale_freq = int(self.BYTE_MAX_NUM * (freq / max_freq))
  53. scale_freq = 1 if not scale_freq and freq else scale_freq
  54.  
  55. self.freqs[symbol] = scale_freq
  56.  
  57. def _get_symbol_frequencies(self):
  58. for symbol in self.origin:
  59. self.freqs[symbol] += 1
  60.  
  61. self._minimize_frequencies()
  62.  
  63. def _initial_node_heap(self):
  64. self._heap = []
  65. for symbol, freq in enumerate(self.freqs):
  66. node = HuffmanNode(symbol, freq)
  67. heapq.heappush(self._heap, node)
  68.  
  69. def _build_huffman_tree(self):
  70. self._initial_node_heap()
  71.  
  72. while len(self._heap) > 1:
  73. node1 = heapq.heappop(self._heap)
  74. node2 = heapq.heappop(self._heap)
  75.  
  76. new_node = HuffmanNode(symbol=None, freq=node1.freq + node2.freq)
  77. new_node.left, new_node.right = node1, node2
  78. node1.parent, node2.parent = new_node, new_node
  79. heapq.heappush(self._heap, new_node)
  80.  
  81. self.huffman_tree = heapq.heappop(self._heap)
  82. del self._heap
  83. return self.huffman_tree
  84.  
  85. def _build_coding_table(self, node, code_str=''):
  86. if node is None:
  87. return
  88.  
  89. if node.symbol is not None:
  90. self.coding_table[node.symbol] = code_str
  91. self.reverse_table[code_str] = node.symbol
  92.  
  93. self._build_coding_table(node.left, code_str + '0')
  94. self._build_coding_table(node.right, code_str + '1')
  95.  
  96. def _pading_coding_str(self):
  97. pading_count = 8 - len(self.coding_str) % 8
  98. self.coding_str += '0' * pading_count
  99. state_str = '{:08b}'.format(pading_count)
  100. self.coding_str = state_str + self.coding_str
  101.  
  102. def _prefix_coding_freqs(self):
  103. coding_freqs = []
  104. for freq in self.freqs:
  105. coding_freqs.append('{:08b}'.format(freq))
  106. coding_freqs = ''.join(coding_freqs)
  107. self.coding_str = coding_freqs + self.coding_str
  108.  
  109. def _build_codeing_str(self):
  110. temp = []
  111. for symbol in self.origin:
  112. temp.append(self.coding_table[symbol])
  113. self.coding_str = ''.join(temp)
  114.  
  115. self._pading_coding_str()
  116. self._prefix_coding_freqs()
  117.  
  118. return self.coding_str
  119.  
  120. def _get_compressed(self):
  121. assert(len(self.coding_str) % 8 == 0)
  122.  
  123. b = bytearray()
  124. for index in range(0, len(self.coding_str), 8):
  125. code_num = int(self.coding_str[index:index + 8], 2)
  126. b.append(code_num)
  127.  
  128. self.compressed = bytes(b)
  129. return self.compressed
  130.  
  131. def _read_frequencies_from_compressed(self):
  132. coding_freqs = self.compressed[:self.BYTE_MAX_NUM + 1]
  133. for index, freq in enumerate(coding_freqs):
  134. self.freqs[index] = freq
  135.  
  136. def _get_real_coding_from_compressed(self):
  137. pading_count = self.compressed[self.BYTE_MAX_NUM + 1]
  138. byte_coding_str = self.compressed[self.BYTE_MAX_NUM + 2:]
  139. coding_str = []
  140. for num in byte_coding_str:
  141. temp = bin(num)[2:]
  142. # 补足省略掉的前导零
  143. temp = '0' * (8 - len(temp)) + temp
  144. assert(len(temp) == 8)
  145. coding_str.append(temp)
  146. coding_str = ''.join(coding_str)
  147. assert(len(coding_str) % 8 == 0)
  148. real_coding_str = coding_str[:-pading_count]
  149. return real_coding_str
  150.  
  151. def _decode_compressed(self):
  152. real_coding_str = self._get_real_coding_from_compressed()
  153. decode_content = []
  154.  
  155. node = self.huffman_tree
  156. for state in real_coding_str:
  157. if state == '0':
  158. node = node.left
  159. elif state == '1':
  160. node = node.right
  161.  
  162. if node.symbol is not None:
  163. assert(0 <= node.symbol <= self.BYTE_MAX_NUM)
  164. hex_str = hex(node.symbol)[2:]
  165. # fromhex方法将两个字符识别为一个16进制数
  166. # 所以单个数需要补零
  167. hex_str = '0' + hex_str if len(hex_str) == 1 else hex_str
  168. decode_content.append(hex_str)
  169. node = self.huffman_tree
  170.  
  171. decode_content = ''.join(decode_content)
  172. return bytes.fromhex(decode_content)
  173.  
  174. def clear(self):
  175. self.__init__()
  176.  
  177. def encode(self, origin):
  178. self.clear()
  179. self.origin = origin
  180. self._get_symbol_frequencies()
  181. self._build_huffman_tree()
  182. self._build_coding_table(self.huffman_tree)
  183. self._build_codeing_str()
  184.  
  185. return self._get_compressed()
  186.  
  187. def compresse(self, filename, output_filename=None):
  188. with open(filename, 'rb') as file:
  189. origin = file.read()
  190.  
  191. compressed_content = self.encode(origin)
  192. if output_filename is None:
  193. output_filename = filename + '.hfm'
  194. with open(output_filename, 'wb') as file:
  195. file.write(compressed_content)
  196.  
  197. return True
  198.  
  199. def decode(self, compressed):
  200. self.clear()
  201. self.compressed = compressed
  202. self._read_frequencies_from_compressed()
  203. self._build_huffman_tree()
  204. return self._decode_compressed()
  205.  
  206. def uncompresse(self, filename, output_filename=None):
  207. with open(filename, 'rb') as file:
  208. compressed = file.read()
  209.  
  210. decode_content = self.decode(compressed)
  211. if output_filename is None:
  212. if filename.endswith('.hfm'):
  213. output_filename = filename[:-4]
  214. else:
  215. output_filename = filename + '.dhfm'
  216.  
  217. with open(output_filename, 'wb') as file:
  218. file.write(decode_content)
  219.  
  220. return True
Add Comment
Please, Sign In to add comment