Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import heapq
- class HuffmanNode:
- def __init__(self, symbol=None, freq=None):
- self.symbol = symbol
- self.freq = freq
- self.parent = None
- self.left = None
- self.right = None
- def __lt__(self, other):
- return self.freq < other.freq
- def is_leaf(self):
- return not self.left and not self.right
- def get_code(self):
- # 调试用
- if not self.is_laef():
- raise ValueError("Not a leaf node.")
- code = ''
- node = self
- while node.parent:
- if node.parent.left == node:
- code = '0' + code
- else:
- code = '1' + code
- code = code.parent
- return code
- class Huffman:
- BYTE_MAX_NUM = 255
- def __init__(self):
- self.origin = None
- self.compressed = None
- self.huffman_tree = None
- self.freqs = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
- self.coding_table = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
- self.reverse_table = {}
- self.coding_str = ''
- def _minimize_frequencies(self):
- # 缩小字频使其在一个字节范围以内
- max_freq = max(self.freqs)
- for symbol, freq in enumerate(self.freqs):
- scale_freq = int(self.BYTE_MAX_NUM * (freq / max_freq))
- scale_freq = 1 if not scale_freq and freq else scale_freq
- self.freqs[symbol] = scale_freq
- def _get_symbol_frequencies(self):
- for symbol in self.origin:
- self.freqs[symbol] += 1
- self._minimize_frequencies()
- def _initial_node_heap(self):
- self._heap = []
- for symbol, freq in enumerate(self.freqs):
- node = HuffmanNode(symbol, freq)
- heapq.heappush(self._heap, node)
- def _build_huffman_tree(self):
- self._initial_node_heap()
- while len(self._heap) > 1:
- node1 = heapq.heappop(self._heap)
- node2 = heapq.heappop(self._heap)
- new_node = HuffmanNode(symbol=None, freq=node1.freq + node2.freq)
- new_node.left, new_node.right = node1, node2
- node1.parent, node2.parent = new_node, new_node
- heapq.heappush(self._heap, new_node)
- self.huffman_tree = heapq.heappop(self._heap)
- del self._heap
- return self.huffman_tree
- def _build_coding_table(self, node, code_str=''):
- if node is None:
- return
- if node.symbol is not None:
- self.coding_table[node.symbol] = code_str
- self.reverse_table[code_str] = node.symbol
- self._build_coding_table(node.left, code_str + '0')
- self._build_coding_table(node.right, code_str + '1')
- def _pading_coding_str(self):
- pading_count = 8 - len(self.coding_str) % 8
- self.coding_str += '0' * pading_count
- state_str = '{:08b}'.format(pading_count)
- self.coding_str = state_str + self.coding_str
- def _prefix_coding_freqs(self):
- coding_freqs = []
- for freq in self.freqs:
- coding_freqs.append('{:08b}'.format(freq))
- coding_freqs = ''.join(coding_freqs)
- self.coding_str = coding_freqs + self.coding_str
- def _build_codeing_str(self):
- temp = []
- for symbol in self.origin:
- temp.append(self.coding_table[symbol])
- self.coding_str = ''.join(temp)
- self._pading_coding_str()
- self._prefix_coding_freqs()
- return self.coding_str
- def _get_compressed(self):
- assert(len(self.coding_str) % 8 == 0)
- b = bytearray()
- for index in range(0, len(self.coding_str), 8):
- code_num = int(self.coding_str[index:index + 8], 2)
- b.append(code_num)
- self.compressed = bytes(b)
- return self.compressed
- def _read_frequencies_from_compressed(self):
- coding_freqs = self.compressed[:self.BYTE_MAX_NUM + 1]
- for index, freq in enumerate(coding_freqs):
- self.freqs[index] = freq
- def _get_real_coding_from_compressed(self):
- pading_count = self.compressed[self.BYTE_MAX_NUM + 1]
- byte_coding_str = self.compressed[self.BYTE_MAX_NUM + 2:]
- coding_str = []
- for num in byte_coding_str:
- temp = bin(num)[2:]
- # 补足省略掉的前导零
- temp = '0' * (8 - len(temp)) + temp
- assert(len(temp) == 8)
- coding_str.append(temp)
- coding_str = ''.join(coding_str)
- assert(len(coding_str) % 8 == 0)
- real_coding_str = coding_str[:-pading_count]
- return real_coding_str
- def _decode_compressed(self):
- real_coding_str = self._get_real_coding_from_compressed()
- decode_content = []
- node = self.huffman_tree
- for state in real_coding_str:
- if state == '0':
- node = node.left
- elif state == '1':
- node = node.right
- if node.symbol is not None:
- assert(0 <= node.symbol <= self.BYTE_MAX_NUM)
- hex_str = hex(node.symbol)[2:]
- # fromhex方法将两个字符识别为一个16进制数
- # 所以单个数需要补零
- hex_str = '0' + hex_str if len(hex_str) == 1 else hex_str
- decode_content.append(hex_str)
- node = self.huffman_tree
- decode_content = ''.join(decode_content)
- return bytes.fromhex(decode_content)
- def clear(self):
- self.__init__()
- def encode(self, origin):
- self.clear()
- self.origin = origin
- self._get_symbol_frequencies()
- self._build_huffman_tree()
- self._build_coding_table(self.huffman_tree)
- self._build_codeing_str()
- return self._get_compressed()
- def compresse(self, filename, output_filename=None):
- with open(filename, 'rb') as file:
- origin = file.read()
- compressed_content = self.encode(origin)
- if output_filename is None:
- output_filename = filename + '.hfm'
- with open(output_filename, 'wb') as file:
- file.write(compressed_content)
- return True
- def decode(self, compressed):
- self.clear()
- self.compressed = compressed
- self._read_frequencies_from_compressed()
- self._build_huffman_tree()
- return self._decode_compressed()
- def uncompresse(self, filename, output_filename=None):
- with open(filename, 'rb') as file:
- compressed = file.read()
- decode_content = self.decode(compressed)
- if output_filename is None:
- if filename.endswith('.hfm'):
- output_filename = filename[:-4]
- else:
- output_filename = filename + '.dhfm'
- with open(output_filename, 'wb') as file:
- file.write(decode_content)
- return True
Add Comment
Please, Sign In to add comment