Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import time
- from concurrent.futures import ThreadPoolExecutor
- from typing import List, Tuple, Optional
- import numpy as np
- from array import array
- import ctypes
- from line_profiler._line_profiler import byteorder
- class Node:
- __slots__ = ['char', 'freq', 'left', 'right']
- def __init__(self, char: str, freq: int, left=None, right=None):
- self.char = char
- self.freq = freq
- self.left = left
- self.right = right
- class HybridLookupTable:
- """Hybrid approach combining direct lookup for short codes and binary search for long codes"""
- __slots__ = ['short_table', 'long_codes', 'max_short_bits']
- def __init__(self, max_short_bits: int = 8):
- self.max_short_bits = max_short_bits
- self.short_table = [(None, 0)] * (1 << max_short_bits) # Changed to tuple list for safety
- self.long_codes = {}
- def add_code(self, code: str, char: str) -> None:
- code_int = int(code, 2)
- code_len = len(code)
- if code_len <= self.max_short_bits:
- # For short codes, use lookup table with limited prefix expansion
- prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
- base_index = code_int << (self.max_short_bits - code_len)
- for i in range(prefix_mask + 1):
- self.short_table[base_index | i] = (char, code_len)
- else:
- # For long codes, store in dictionary
- self.long_codes[code_int] = (char, code_len)
- def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
- """Look up a bit pattern and return (character, code length) if found"""
- if length <= self.max_short_bits:
- return self.short_table[bits & ((1 << self.max_short_bits) - 1)]
- # Try matching long codes
- for code_bits, (char, code_len) in self.long_codes.items():
- if code_len <= length:
- mask = (1 << code_len) - 1
- if (bits >> (length - code_len)) == (code_bits & mask):
- return (char, code_len)
- return None
- class BitBuffer:
- """Fast bit buffer implementation using ctypes"""
- __slots__ = ['buffer', 'bits_in_buffer']
- def __init__(self):
- self.buffer = ctypes.c_uint64(0)
- self.bits_in_buffer = 0
- def add_byte(self, byte: int) -> None:
- self.buffer.value = (self.buffer.value << 8) | byte
- self.bits_in_buffer += 8
- def peek_bits(self, num_bits: int) -> int:
- return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)
- def consume_bits(self, num_bits: int) -> None:
- self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
- self.bits_in_buffer -= num_bits
- class ChunkDecoder:
- """Decoder for a chunk of compressed data"""
- __slots__ = ['lookup_table', 'tree', 'chunk_size']
- def __init__(self, lookup_table, tree, chunk_size=1024):
- self.lookup_table = lookup_table
- self.tree = tree
- self.chunk_size = chunk_size
- def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
- """Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
- result = []
- pos = start_bit
- buffer = BitBuffer()
- bytes_processed = start_bit >> 3
- bit_offset = start_bit & 7
- # Pre-fill buffer
- for _ in range(8):
- if bytes_processed < len(data):
- buffer.add_byte(data[bytes_processed])
- bytes_processed += 1
- # Skip initial bit offset
- if bit_offset:
- buffer.consume_bits(bit_offset)
- while pos < end_bit and buffer.bits_in_buffer >= 8:
- # Try lookup table first (optimized for 8-bit codes)
- lookup_bits = buffer.peek_bits(8)
- char_info = self.lookup_table.lookup(lookup_bits, 8)
- if char_info:
- char, code_len = char_info
- buffer.consume_bits(code_len)
- result.append(char)
- pos += code_len
- else:
- # Fall back to tree traversal
- node = self.tree
- while node.left and node.right and buffer.bits_in_buffer > 0:
- bit = buffer.peek_bits(1)
- buffer.consume_bits(1)
- node = node.right if bit else node.left
- pos += 1
- if not (node.left or node.right):
- result.append(node.char)
- # Refill buffer if needed
- while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
- buffer.add_byte(data[bytes_processed])
- bytes_processed += 1
- return result, pos - start_bit
- class OptimizedHuffmanDecoder:
- def __init__(self, num_threads=4, chunk_size=1024):
- self.tree = None
- self.freqs = {}
- self.lookup_table = HybridLookupTable()
- self.num_threads = num_threads
- self.chunk_size = chunk_size
- self._setup_lookup_tables()
- def _setup_lookup_tables(self):
- # Pre-calculate bit manipulation tables
- self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
- self.bit_shifts = array('B', [x & 7 for x in range(8)])
- def _build_efficient_tree(self) -> None:
- # Use list-based heap instead of sorting
- nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]
- # Convert to min-heap
- nodes.sort(reverse=True) # Sort once at the beginning
- while len(nodes) > 1:
- freq1, _, node1 = nodes.pop()
- freq2, _, node2 = nodes.pop()
- # Create parent node
- parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
- nodes.append((freq1 + freq2, len(nodes), parent))
- nodes.sort(reverse=True)
- self.tree = nodes[0][2] if nodes else None
- self._build_codes(self.tree)
- def _build_codes(self, node: Node, code: str = '') -> None:
- """Build lookup table using depth-first traversal"""
- if not node:
- return
- if not node.left and not node.right:
- if code: # Never store empty codes
- self.lookup_table.add_code(code, node.char)
- return
- if node.left:
- self._build_codes(node.left, code + '0')
- if node.right:
- self._build_codes(node.right, code + '1')
- def _parse_header_fast(self, data: memoryview) -> int:
- """Optimized header parsing"""
- pos = 12 # Skip first 12 bytes (file_len, always0, chars_count)
- chars_count = int.from_bytes(data[8:12], byteorder)
- # Pre-allocate dictionary space
- self.freqs = {}
- self.freqs.clear()
- # Process all characters in a single loop
- for _ in range(chars_count):
- count = int.from_bytes(data[pos:pos + 4], byteorder)
- char = chr(data[pos + 4]) # Faster than decode
- self.freqs[char] = count
- pos += 8
- return pos
- def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
- """Parallel decoding using multiple threads"""
- chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
- chunks = []
- # Create chunks ensuring they align with byte boundaries when possible
- for i in range(0, total_bits, chunk_bits):
- end_bit = min(i + chunk_bits, total_bits)
- if i > 0:
- # Align to byte boundary when possible
- while (i & 7) != 0 and i > 0:
- i -= 1
- chunks.append((i, end_bit))
- # Create decoders for each thread
- decoders = [
- ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
- for _ in range(len(chunks))
- ]
- # Process chunks in parallel
- with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
- futures = [
- executor.submit(decoder.decode_chunk, data, start, end)
- for decoder, (start, end) in zip(decoders, chunks)
- ]
- # Collect results
- results = []
- for future in futures:
- chunk_result, _ = future.result()
- results.extend(chunk_result)
- return ''.join(results)
- def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
- """Optimized single-threaded decoding for small inputs"""
- if total_bits > self.chunk_size:
- return self._decode_bits_parallel(data, total_bits)
- result = []
- buffer = BitBuffer()
- pos = 0
- bytes_processed = 0
- # Pre-fill buffer
- while bytes_processed < min(8, len(data)):
- buffer.add_byte(data[bytes_processed])
- bytes_processed += 1
- while pos < total_bits:
- # Use lookup table for common patterns
- if buffer.bits_in_buffer >= 8:
- lookup_bits = buffer.peek_bits(8)
- char_info = self.lookup_table.lookup(lookup_bits, 8)
- if char_info:
- char, code_len = char_info
- buffer.consume_bits(code_len)
- result.append(char)
- pos += code_len
- else:
- # Tree traversal for uncommon patterns
- node = self.tree
- while node.left and node.right and buffer.bits_in_buffer > 0:
- bit = buffer.peek_bits(1)
- buffer.consume_bits(1)
- node = node.right if bit else node.left
- pos += 1
- if not (node.left or node.right):
- result.append(node.char)
- # Refill buffer
- while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
- buffer.add_byte(data[bytes_processed])
- bytes_processed += 1
- if buffer.bits_in_buffer == 0:
- break
- return ''.join(result)
- def decode_hex(self, hex_string: str) -> str:
- # Use numpy for faster hex decoding
- clean_hex = hex_string.replace(' ', '')
- data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
- return self.decode_bytes(data.tobytes())
- def decode_bytes(self, data: bytes) -> str:
- view = memoryview(data)
- pos = self._parse_header_fast(view)
- self._build_efficient_tree()
- # Get packed data info using numpy for faster parsing
- header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
- packed_bits = int(header[0])
- packed_bytes = int(header[1])
- pos += 12
- # Choose decoding method based on size
- if packed_bits > self.chunk_size:
- return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
- else:
- return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)
- def encode(self, text: str) -> bytes:
- """Encode text using Huffman coding - for testing purposes"""
- # Count frequencies
- self.freqs = {}
- for char in text:
- self.freqs[char] = self.freqs.get(char, 0) + 1
- # Build tree and codes
- self._build_efficient_tree()
- # Convert text to bits
- bits = []
- for char in text:
- code = self.lookup_table.get_code(char)
- bits.extend(code)
- # Pack bits into bytes
- packed_bytes = []
- for i in range(0, len(bits), 8):
- byte = 0
- for j in range(min(8, len(bits) - i)):
- if bits[i + j]:
- byte |= 1 << (7 - j)
- packed_bytes.append(byte)
- # Create header
- header = bytearray()
- header.extend(len(text).to_bytes(4, byteorder))
- header.extend(b'\x00' * 4) # always0
- header.extend(len(self.freqs).to_bytes(4, byteorder))
- # Add frequency table
- for char, freq in self.freqs.items():
- header.extend(freq.to_bytes(4, byteorder))
- header.extend(char.encode('ascii'))
- header.extend(b'\x00' * 3) # padding
- # Add packed data info
- header.extend(len(bits).to_bytes(4, byteorder))
- header.extend(len(packed_bytes).to_bytes(4, byteorder))
- header.extend(b'\x00' * 4) # unpacked_bytes
- # Combine header and packed data
- return bytes(header + bytes(packed_bytes))
- if __name__ == '__main__':
- # Create decoder with custom settings
- decoder = OptimizedHuffmanDecoder(
- num_threads=4, # Number of threads for parallel processing
- chunk_size=1024 # Minimum size for parallel processing
- )
- test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
- start_time = time.perf_counter()
- # Decode data
- result = decoder.decode_hex(test_hex)
- execution_time_ms = (time.perf_counter() - start_time) * 1000 # Convert to milliseconds
- print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
- print(result)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement