Advertisement
Guest User

Untitled

a guest
Nov 9th, 2024
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.58 KB | None | 0 0
  1. import time
  2. from concurrent.futures import ThreadPoolExecutor
  3. from typing import List, Tuple, Optional
  4. import numpy as np
  5. from array import array
  6. import ctypes
  7. from line_profiler._line_profiler import byteorder
  8.  
  9. class Node:
  10.     __slots__ = ['char', 'freq', 'left', 'right']
  11.  
  12.     def __init__(self, char: str, freq: int, left=None, right=None):
  13.         self.char = char
  14.         self.freq = freq
  15.         self.left = left
  16.         self.right = right
  17.  
  18.  
  19.  
  20.  
  21. class HybridLookupTable:
  22.    
  23. """Hybrid approach combining direct lookup for short codes and binary search for long codes"""
  24.    
  25. __slots__ = ['short_table', 'long_codes', 'max_short_bits']
  26.  
  27.     def __init__(self, max_short_bits: int = 8):
  28.         self.max_short_bits = max_short_bits
  29.         self.short_table = [(None, 0)] * (1 << max_short_bits)  # Changed to tuple list for safety
  30.         self.long_codes = {}
  31.  
  32.     def add_code(self, code: str, char: str) -> None:
  33.         code_int = int(code, 2)
  34.         code_len = len(code)
  35.  
  36.         if code_len <= self.max_short_bits:
  37.             # For short codes, use lookup table with limited prefix expansion
  38.             prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
  39.             base_index = code_int << (self.max_short_bits - code_len)
  40.             for i in range(prefix_mask + 1):
  41.                 self.short_table[base_index | i] = (char, code_len)
  42.         else:
  43.             # For long codes, store in dictionary
  44.             self.long_codes[code_int] = (char, code_len)
  45.  
  46.     def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
  47.        
  48. """Look up a bit pattern and return (character, code length) if found"""
  49.        
  50. if length <= self.max_short_bits:
  51.             return self.short_table[bits & ((1 << self.max_short_bits) - 1)]
  52.  
  53.         # Try matching long codes
  54.         for code_bits, (char, code_len) in self.long_codes.items():
  55.             if code_len <= length:
  56.                 mask = (1 << code_len) - 1
  57.                 if (bits >> (length - code_len)) == (code_bits & mask):
  58.                     return (char, code_len)
  59.         return None
  60. class BitBuffer:
  61.    
  62. """Fast bit buffer implementation using ctypes"""
  63.    
  64. __slots__ = ['buffer', 'bits_in_buffer']
  65.  
  66.     def __init__(self):
  67.         self.buffer = ctypes.c_uint64(0)
  68.         self.bits_in_buffer = 0
  69.     def add_byte(self, byte: int) -> None:
  70.         self.buffer.value = (self.buffer.value << 8) | byte
  71.         self.bits_in_buffer += 8
  72.     def peek_bits(self, num_bits: int) -> int:
  73.         return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)
  74.  
  75.     def consume_bits(self, num_bits: int) -> None:
  76.         self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
  77.         self.bits_in_buffer -= num_bits
  78.  
  79.  
  80. class ChunkDecoder:
  81.    
  82. """Decoder for a chunk of compressed data"""
  83.    
  84. __slots__ = ['lookup_table', 'tree', 'chunk_size']
  85.  
  86.     def __init__(self, lookup_table, tree, chunk_size=1024):
  87.         self.lookup_table = lookup_table
  88.         self.tree = tree
  89.         self.chunk_size = chunk_size
  90.  
  91.     def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
  92.        
  93. """Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
  94.        
  95. result = []
  96.         pos = start_bit
  97.         buffer = BitBuffer()
  98.         bytes_processed = start_bit >> 3
  99.         bit_offset = start_bit & 7
  100.         # Pre-fill buffer
  101.         for _ in range(8):
  102.             if bytes_processed < len(data):
  103.                 buffer.add_byte(data[bytes_processed])
  104.                 bytes_processed += 1
  105.         # Skip initial bit offset
  106.         if bit_offset:
  107.             buffer.consume_bits(bit_offset)
  108.  
  109.         while pos < end_bit and buffer.bits_in_buffer >= 8:
  110.             # Try lookup table first (optimized for 8-bit codes)
  111.             lookup_bits = buffer.peek_bits(8)
  112.             char_info = self.lookup_table.lookup(lookup_bits, 8)
  113.  
  114.             if char_info:
  115.                 char, code_len = char_info
  116.                 buffer.consume_bits(code_len)
  117.                 result.append(char)
  118.                 pos += code_len
  119.             else:
  120.                 # Fall back to tree traversal
  121.                 node = self.tree
  122.                 while node.left and node.right and buffer.bits_in_buffer > 0:
  123.                     bit = buffer.peek_bits(1)
  124.                     buffer.consume_bits(1)
  125.                     node = node.right if bit else node.left
  126.                     pos += 1
  127.                 if not (node.left or node.right):
  128.                     result.append(node.char)
  129.  
  130.             # Refill buffer if needed
  131.             while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
  132.                 buffer.add_byte(data[bytes_processed])
  133.                 bytes_processed += 1
  134.         return result, pos - start_bit
  135.  
  136.  
  137. class OptimizedHuffmanDecoder:
  138.     def __init__(self, num_threads=4, chunk_size=1024):
  139.         self.tree = None
  140.         self.freqs = {}
  141.         self.lookup_table = HybridLookupTable()
  142.         self.num_threads = num_threads
  143.         self.chunk_size = chunk_size
  144.         self._setup_lookup_tables()
  145.  
  146.     def _setup_lookup_tables(self):
  147.         # Pre-calculate bit manipulation tables
  148.         self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
  149.         self.bit_shifts = array('B', [x & 7 for x in range(8)])
  150.  
  151.     def _build_efficient_tree(self) -> None:
  152.         # Use list-based heap instead of sorting
  153.         nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]
  154.  
  155.         # Convert to min-heap
  156.         nodes.sort(reverse=True)  # Sort once at the beginning
  157.         while len(nodes) > 1:
  158.             freq1, _, node1 = nodes.pop()
  159.             freq2, _, node2 = nodes.pop()
  160.  
  161.             # Create parent node
  162.             parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
  163.             nodes.append((freq1 + freq2, len(nodes), parent))
  164.             nodes.sort(reverse=True)
  165.  
  166.         self.tree = nodes[0][2] if nodes else None
  167.         self._build_codes(self.tree)
  168.  
  169.     def _build_codes(self, node: Node, code: str = '') -> None:
  170.        
  171. """Build lookup table using depth-first traversal"""
  172.        
  173. if not node:
  174.             return
  175.         if not node.left and not node.right:
  176.             if code:  # Never store empty codes
  177.                 self.lookup_table.add_code(code, node.char)
  178.             return
  179.         if node.left:
  180.             self._build_codes(node.left, code + '0')
  181.         if node.right:
  182.             self._build_codes(node.right, code + '1')
  183.  
  184.     def _parse_header_fast(self, data: memoryview) -> int:
  185.        
  186. """Optimized header parsing"""
  187.        
  188. pos = 12  # Skip first 12 bytes (file_len, always0, chars_count)
  189.         chars_count = int.from_bytes(data[8:12], byteorder)
  190.  
  191.         # Pre-allocate dictionary space
  192.         self.freqs = {}
  193.         self.freqs.clear()
  194.  
  195.         # Process all characters in a single loop
  196.         for _ in range(chars_count):
  197.             count = int.from_bytes(data[pos:pos + 4], byteorder)
  198.             char = chr(data[pos + 4])  # Faster than decode
  199.             self.freqs[char] = count
  200.             pos += 8
  201.         return pos
  202.  
  203.     def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
  204.        
  205. """Parallel decoding using multiple threads"""
  206.        
  207. chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
  208.         chunks = []
  209.  
  210.         # Create chunks ensuring they align with byte boundaries when possible
  211.         for i in range(0, total_bits, chunk_bits):
  212.             end_bit = min(i + chunk_bits, total_bits)
  213.             if i > 0:
  214.                 # Align to byte boundary when possible
  215.                 while (i & 7) != 0 and i > 0:
  216.                     i -= 1
  217.             chunks.append((i, end_bit))
  218.  
  219.         # Create decoders for each thread
  220.         decoders = [
  221.             ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
  222.             for _ in range(len(chunks))
  223.         ]
  224.  
  225.         # Process chunks in parallel
  226.         with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
  227.             futures = [
  228.                 executor.submit(decoder.decode_chunk, data, start, end)
  229.                 for decoder, (start, end) in zip(decoders, chunks)
  230.             ]
  231.  
  232.             # Collect results
  233.             results = []
  234.             for future in futures:
  235.                 chunk_result, _ = future.result()
  236.                 results.extend(chunk_result)
  237.  
  238.         return ''.join(results)
  239.  
  240.     def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
  241.        
  242. """Optimized single-threaded decoding for small inputs"""
  243.        
  244. if total_bits > self.chunk_size:
  245.             return self._decode_bits_parallel(data, total_bits)
  246.  
  247.         result = []
  248.         buffer = BitBuffer()
  249.         pos = 0
  250.         bytes_processed = 0
  251.         # Pre-fill buffer
  252.         while bytes_processed < min(8, len(data)):
  253.             buffer.add_byte(data[bytes_processed])
  254.             bytes_processed += 1
  255.         while pos < total_bits:
  256.             # Use lookup table for common patterns
  257.             if buffer.bits_in_buffer >= 8:
  258.                 lookup_bits = buffer.peek_bits(8)
  259.                 char_info = self.lookup_table.lookup(lookup_bits, 8)
  260.  
  261.                 if char_info:
  262.                     char, code_len = char_info
  263.                     buffer.consume_bits(code_len)
  264.                     result.append(char)
  265.                     pos += code_len
  266.                 else:
  267.                     # Tree traversal for uncommon patterns
  268.                     node = self.tree
  269.                     while node.left and node.right and buffer.bits_in_buffer > 0:
  270.                         bit = buffer.peek_bits(1)
  271.                         buffer.consume_bits(1)
  272.                         node = node.right if bit else node.left
  273.                         pos += 1
  274.                     if not (node.left or node.right):
  275.                         result.append(node.char)
  276.  
  277.             # Refill buffer
  278.             while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
  279.                 buffer.add_byte(data[bytes_processed])
  280.                 bytes_processed += 1
  281.             if buffer.bits_in_buffer == 0:
  282.                 break
  283.         return ''.join(result)
  284.  
  285.     def decode_hex(self, hex_string: str) -> str:
  286.         # Use numpy for faster hex decoding
  287.         clean_hex = hex_string.replace(' ', '')
  288.         data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
  289.         return self.decode_bytes(data.tobytes())
  290.  
  291.     def decode_bytes(self, data: bytes) -> str:
  292.         view = memoryview(data)
  293.         pos = self._parse_header_fast(view)
  294.  
  295.         self._build_efficient_tree()
  296.  
  297.         # Get packed data info using numpy for faster parsing
  298.         header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
  299.         packed_bits = int(header[0])
  300.         packed_bytes = int(header[1])
  301.         pos += 12
  302.         # Choose decoding method based on size
  303.         if packed_bits > self.chunk_size:
  304.             return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
  305.         else:
  306.             return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)
  307.  
  308.     def encode(self, text: str) -> bytes:
  309.        
  310. """Encode text using Huffman coding - for testing purposes"""
  311.        
  312. # Count frequencies
  313.         self.freqs = {}
  314.         for char in text:
  315.             self.freqs[char] = self.freqs.get(char, 0) + 1
  316.         # Build tree and codes
  317.         self._build_efficient_tree()
  318.  
  319.         # Convert text to bits
  320.         bits = []
  321.         for char in text:
  322.             code = self.lookup_table.get_code(char)
  323.             bits.extend(code)
  324.  
  325.         # Pack bits into bytes
  326.         packed_bytes = []
  327.         for i in range(0, len(bits), 8):
  328.             byte = 0
  329.             for j in range(min(8, len(bits) - i)):
  330.                 if bits[i + j]:
  331.                     byte |= 1 << (7 - j)
  332.             packed_bytes.append(byte)
  333.  
  334.         # Create header
  335.         header = bytearray()
  336.         header.extend(len(text).to_bytes(4, byteorder))
  337.         header.extend(b'\x00' * 4)  # always0
  338.         header.extend(len(self.freqs).to_bytes(4, byteorder))
  339.  
  340.         # Add frequency table
  341.         for char, freq in self.freqs.items():
  342.             header.extend(freq.to_bytes(4, byteorder))
  343.             header.extend(char.encode('ascii'))
  344.             header.extend(b'\x00' * 3)  # padding
  345.         # Add packed data info
  346.         header.extend(len(bits).to_bytes(4, byteorder))
  347.         header.extend(len(packed_bytes).to_bytes(4, byteorder))
  348.         header.extend(b'\x00' * 4)  # unpacked_bytes
  349.         # Combine header and packed data
  350.         return bytes(header + bytes(packed_bytes))
  351.  
  352. if __name__ == '__main__':
  353.     # Create decoder with custom settings
  354.     decoder = OptimizedHuffmanDecoder(
  355.         num_threads=4,  # Number of threads for parallel processing
  356.         chunk_size=1024  # Minimum size for parallel processing
  357.     )
  358.  
  359.     test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
  360.     start_time = time.perf_counter()
  361.     # Decode data
  362.     result = decoder.decode_hex(test_hex)
  363.     execution_time_ms = (time.perf_counter() - start_time) * 1000  # Convert to milliseconds
  364.     print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
  365.     print(result)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement