Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- from datetime import datetime
- import itertools
- import json
- import functools
- import sys
- def show_binary(byte):
- return '{:08b}'.format(byte)
- class StreamReader:
- def __init__(self, data):
- assert type(data) == bytes
- self._data = data
- self._byte_index = 0
- self._byte_total = len(data)
- self._bitwise = False
- self._bit_buffer = 0
- self._bit_pending = 0
- def has_more(self):
- if self._byte_index < self._byte_total:
- return True
- if self._bitwise and self._bit_pending:
- return True
- return False
- def get_byte(self, for_bits=False):
- assert not self._bitwise or for_bits
- result = self._data[self._byte_index]
- self._byte_index += 1
- # print(">>> next byte", result)
- return result
- def get_bytes(self, size):
- self._byte_index += size
- return self._data[self._byte_index - size: self._byte_index]
- def get_int(self, size):
- assert not self._bitwise
- return functools.reduce(
- lambda result, item: result * 256 + item,
- [self.get_byte() for _ in range(size)][::-1],
- 0
- )
- def get_string(self):
- assert not self._bitwise
- result = ""
- while True:
- char = self.get_byte()
- if char == 0:
- return result
- result += chr(char)
- def set_bitwise(self, value):
- if self._bitwise is True and value is False:
- assert self._bit_pending == 0
- self._bitwise = value
- def get_bit(self):
- assert self._bitwise
- if not self._bit_pending:
- self._bit_buffer = self.get_byte(for_bits=True)
- self._bit_pending = 8
- result = self._bit_buffer & 1
- self._bit_buffer >>= 1
- self._bit_pending -= 1
- return result
- def get_bits(self, size, reverse=False):
- return functools.reduce(
- lambda result, item: item | result << 1,
- [self.get_bit() for _ in range(size)][::-1 if reverse else 1],
- 0,
- )
- def discard_remaining_bits(self):
- self._bit_pending = 0
- def get_remaining(self):
- return self._data[self._byte_index:]
- class HuffmanDecoder:
- def __init__(self, bit_lengths):
- bit_length_to_count = {0: 0}
- max_length = 0
- for bit_length in bit_lengths:
- max_length = max(max_length, bit_length)
- if bit_length not in bit_length_to_count:
- bit_length_to_count[bit_length] = 0
- bit_length_to_count[bit_length] += 1
- code = 0
- bit_length_to_next_code = {}
- for bit_length in range(1, max_length + 1):
- # TODO: Figure out what this does.
- code = (code + bit_length_to_count.get(bit_length - 1, 0)) << 1;
- bit_length_to_next_code[bit_length] = code;
- codes = []
- for bit_length in bit_lengths:
- if bit_length == 0:
- codes.append(None)
- else:
- codes.append(bit_length_to_next_code[bit_length])
- bit_length_to_next_code[bit_length] += 1
- self.tree = {}
- for index, code in enumerate(codes):
- if code is None:
- continue
- code_str = '{:0{width}b}'.format(code, width=bit_lengths[index])
- current = self.tree
- previous = None
- for char in code_str:
- char = int(char)
- if char not in current:
- current[char] = {}
- previous = current
- current = current[char]
- previous[char] = index
- def extract(self, stream):
- current = self.tree
- chars = []
- while type(current) == dict:
- char = stream.get_bit()
- chars.append(char)
- current = current[char]
- print(chars)
- return current
- class GzipDecompress:
- def __init__(self, data):
- self._stream = StreamReader(data)
- def run(self):
- self.result = []
- while self._stream.has_more():
- self._read_gzip_block()
- break
- def _read_gzip_block(self):
- # Reference = https://tools.ietf.org/html/rfc1952
- assert self._stream.get_byte() == 31
- assert self._stream.get_byte() == 139
- compression_method = self._stream.get_byte()
- flags = self._stream.get_byte()
- print("flags", show_binary(flags))
- has_filename = bool(flags & (1 << 3))
- has_comment = bool(flags & (1 << 4))
- modified_time = self._stream.get_int(size=4)
- print("mtime", datetime.fromtimestamp(modified_time))
- extra_flags = self._stream.get_byte()
- print("extra_flags", show_binary(extra_flags))
- operating_system = self._stream.get_byte()
- filename = self._stream.get_string() if has_filename else None
- print("filename", filename)
- comment = self._stream.get_string() if has_comment else None
- print("comment", comment)
- self._read_deflate_block()
- def _read_deflate_block(self):
- # Reference = https://tools.ietf.org/html/rfc1951
- print("!!!", self._stream.get_remaining())
- self._stream.set_bitwise(True)
- is_final_block = self._stream.get_bit()
- print("is_final_block", is_final_block)
- block_type = self._stream.get_bits(2)
- print("block_type", block_type)
- if block_type == 0:
- self._stream.discard_remaining_bits()
- length = self._stream.get_int(size=2)
- self._stream.get_int(size=2) # one's compliment of length
- self.result.extend(self._stream.get_types(length))
- elif block_type == 1:
- self._read_fixed_huffman_codes()
- elif block_type == 2:
- self._read_dynamic_huffman_codes()
- else:
- raise NotImplementedError()
- def _read_fixed_huffman_codes(self):
- bit_lengths = []
- symbol = 0
- while symbol <= 143:
- bit_lengths.append(8)
- symbol += 1
- while symbol <= 255:
- bit_lengths.append(9)
- symbol += 1
- while symbol <= 279:
- bit_lengths.append(7)
- symbol += 1
- while symbol <= 287:
- bit_lengths.append(8)
- symbol += 1
- # bit_lengths = [3, 3, 3, 3, 3, 2, 4, 4] # example for testing
- # for this example, the generated tree matches what is shown in rfc1951
- decoder = HuffmanDecoder(bit_lengths)
- while True:
- value = decoder.extract(self._stream)
- print(">>>", value)
- if value == 256:
- break
- def _read_dynamic_huffman_codes(self):
- literal_count = self._stream.get_bits(5, reverse=True) + 257
- distance_count = self._stream.get_bits(5, reverse=True) + 1
- code_count = self._stream.get_bits(4, reverse=True) + 4
- print(literal_count, distance_count, code_count)
- bit_lengths = []
- for _ in range(code_count):
- bit_lengths.append(self._stream.get_bits(3, reverse=True))
- ordering = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]
- print(bit_lengths)
- print(ordering)
- bit_lengths = [
- pair[0]
- for pair in sorted(
- itertools.zip_longest(bit_lengths, ordering, fillvalue=0),
- key=lambda pair: pair[1],
- )
- ]
- print(bit_lengths)
- decoder = HuffmanDecoder(bit_lengths)
- print(json.dumps(decoder.tree, indent=4))
- while True:
- value = decoder.extract(self._stream)
- print(">>>", value)
- if value == 256:
- break
- try:
- while True:
- print(self._stream.get_bit(), end="")
- except IndexError:
- print("\n")
- raise NotImplementedError()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("filename", help="Name of input file.")
- options = parser.parse_args()
- with open(options.filename, "rb") as file_handle:
- data = file_handle.read()
- print(GzipDecompress(data).run())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement