globmont

AoC 21-16

Dec 16th, 2021 (edited)
183
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.45 KB | None | 0 0
  1. from abc import ABC, abstractmethod
  2. from typing import List
  3. from functools import reduce
  4. import operator
  5.  
  6. class Packet(ABC):
  7.     def __init__(self, bits: str):
  8.         self.bits = bits        
  9.         self.header = Packet.parse_header(self.bits)
  10.         self.value = None
  11.        
  12.         self.parse()
  13.        
  14.     def parse_header(bits):
  15.         version_bits = bits[:3]
  16.         type_bits = bits[3:6]
  17.        
  18.         header = dict(
  19.             version=int(version_bits, 2),
  20.             type_code=int(type_bits, 2),
  21.         )
  22.        
  23.         header["type"] = "literal" if header["type_code"] == 4 else "operator"
  24.         header["length_type_id"] = None if header["type"] == "literal" else int(bits[6])
  25.        
  26.         return header
  27.        
  28.     def consume_packet_bits(bits):
  29.         header = Packet.parse_header(bits)
  30.         end_idx = 6
  31.         if header["type"] == "literal":
  32.             return LiteralPacket.get_packet(bits)
  33.        
  34.         elif header["type"] == "operator":
  35.             return OperatorPacket.get_packet(bits)
  36.                    
  37.         return None
  38.                    
  39.     @abstractmethod
  40.     def parse(self):
  41.         pass
  42.        
  43.    
  44. class LiteralPacket(Packet):  
  45.     def __repr__(self) -> str:
  46.         return f"LiteralPacketV{self.header['version']}[value={self.value}]"
  47.    
  48.     def get_packet(bits):
  49.         end_idx = 6
  50.         while int(bits[end_idx]) == 1:
  51.             end_idx += 5
  52.        
  53.         end_idx += 5
  54.         return LiteralPacket(bits[:end_idx]), bits[end_idx:]
  55.    
  56.     def parse(self):
  57.         packet_body = self.bits[6:]
  58.         binary_value = ""
  59.         for bit_idx in range(0, len(packet_body), 5):
  60.             last_nibble = packet_body[bit_idx] == "0"
  61.             binary_value += packet_body[bit_idx + 1:bit_idx + 5]
  62.                        
  63.             if last_nibble:
  64.                 break
  65.        
  66.         self.value = int(binary_value, 2)
  67.         return self.value
  68.        
  69.  
  70. class OperatorPacket(Packet):
  71.     def __init__(self, bits: str, subpackets: List):
  72.         self.subpackets = subpackets
  73.         super().__init__(bits)
  74.    
  75.     def __repr__(self) -> str:
  76.         operator_packet_type = "Length" if int(self.header["length_type_id"]) == 0 else "Count"
  77.         return f"{operator_packet_type}OperatorPacketV{self.header['version']}[value={self.value} | {', '.join([str(packet) for packet in self.subpackets])}]"
  78.    
  79.     def get_packet(bits):
  80.         header = Packet.parse_header(bits)
  81.         packet = None
  82.         end_idx = 7
  83.         subpackets = []
  84.         if header["length_type_id"] == 0:
  85.             # 15 bit packet length descriptor
  86.             payload_length = int(bits[7:7+15], 2)
  87.             end_idx += 15 + payload_length
  88.            
  89.             remainder = bits[end_idx - payload_length:end_idx]
  90.             payload_bits_consumed = 0
  91.             while payload_bits_consumed < payload_length and remainder != "":
  92.                 next_packet, remainder = Packet.consume_packet_bits(remainder)
  93.                 subpackets.append(next_packet)
  94.                 payload_bits_consumed += len(next_packet.bits)
  95.            
  96.         elif header["length_type_id"] == 1:
  97.             # 11 bit packet count descriptor
  98.             payload_count = int(bits[7:7+11], 2)
  99.             end_idx += 11
  100.             remainder = bits[end_idx:]
  101.             for subpacket_idx in range(payload_count):
  102.                 if remainder == "":
  103.                     break
  104.                
  105.                 next_packet, remainder = Packet.consume_packet_bits(remainder)
  106.                 subpackets.append(next_packet)
  107.                 end_idx += len(next_packet.bits)
  108.                
  109.         packet = OperatorPacket(bits[:end_idx], subpackets=subpackets)    
  110.         return packet, bits[end_idx:]
  111.  
  112.    
  113.     def parse(self):
  114.         type_code = self.header["type_code"]
  115.         subpacket_values = (packet.value for packet in self.subpackets)
  116.         if type_code == 0:
  117.             # Sum
  118.             self.value = sum(subpacket_values)
  119.         elif type_code == 1:
  120.             # Product
  121.             self.value = reduce(operator.mul, subpacket_values)
  122.         elif type_code == 2:
  123.             # Min
  124.             self.value = min(subpacket_values)
  125.         elif type_code == 3:
  126.             # Max
  127.             self.value = max(subpacket_values)
  128.         elif type_code == 5:
  129.             # Greater than
  130.             self.value = int(self.subpackets[0].value > self.subpackets[1].value)
  131.         elif type_code == 6:
  132.             # Less than
  133.             self.value = int(self.subpackets[0].value < self.subpackets[1].value)
  134.         elif type_code == 7:
  135.             # Equal to
  136.             self.value = int(self.subpackets[0].value == self.subpackets[1].value)
  137.        
  138. def compute_version_sum(packet):
  139.     if isinstance(packet, LiteralPacket):
  140.         return packet.header["version"]
  141.     elif isinstance(packet, OperatorPacket):
  142.         return packet.header["version"] + sum(compute_version_sum(subpacket) for subpacket in packet.subpackets)
  143.     else:
  144.         return -1
  145.  
  146. if __name__ == "__main__":
  147.     file = "inputs/2021/16/data.txt"
  148.     with open(file, "r") as f:
  149.         transmission = f.read()
  150.    
  151.     binary_transmission = ""
  152.     for hex_digit in transmission:    
  153.         binary_transmission += bin(int(hex_digit, 16))[2:].zfill(4)
  154.    
  155.     packet, remainder = Packet.consume_packet_bits(binary_transmission)
  156.     print(f"Part 1: {compute_version_sum(packet)}")
  157.     print(f"Part 2: {packet.value}")
  158.  
Add Comment
Please, Sign In to add comment