Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from abc import ABC, abstractmethod
- from typing import List
- from functools import reduce
- import operator
- class Packet(ABC):
- def __init__(self, bits: str):
- self.bits = bits
- self.header = Packet.parse_header(self.bits)
- self.value = None
- self.parse()
- def parse_header(bits):
- version_bits = bits[:3]
- type_bits = bits[3:6]
- header = dict(
- version=int(version_bits, 2),
- type_code=int(type_bits, 2),
- )
- header["type"] = "literal" if header["type_code"] == 4 else "operator"
- header["length_type_id"] = None if header["type"] == "literal" else int(bits[6])
- return header
- def consume_packet_bits(bits):
- header = Packet.parse_header(bits)
- end_idx = 6
- if header["type"] == "literal":
- return LiteralPacket.get_packet(bits)
- elif header["type"] == "operator":
- return OperatorPacket.get_packet(bits)
- return None
- @abstractmethod
- def parse(self):
- pass
- class LiteralPacket(Packet):
- def __repr__(self) -> str:
- return f"LiteralPacketV{self.header['version']}[value={self.value}]"
- def get_packet(bits):
- end_idx = 6
- while int(bits[end_idx]) == 1:
- end_idx += 5
- end_idx += 5
- return LiteralPacket(bits[:end_idx]), bits[end_idx:]
- def parse(self):
- packet_body = self.bits[6:]
- binary_value = ""
- for bit_idx in range(0, len(packet_body), 5):
- last_nibble = packet_body[bit_idx] == "0"
- binary_value += packet_body[bit_idx + 1:bit_idx + 5]
- if last_nibble:
- break
- self.value = int(binary_value, 2)
- return self.value
- class OperatorPacket(Packet):
- def __init__(self, bits: str, subpackets: List):
- self.subpackets = subpackets
- super().__init__(bits)
- def __repr__(self) -> str:
- operator_packet_type = "Length" if int(self.header["length_type_id"]) == 0 else "Count"
- return f"{operator_packet_type}OperatorPacketV{self.header['version']}[value={self.value} | {', '.join([str(packet) for packet in self.subpackets])}]"
- def get_packet(bits):
- header = Packet.parse_header(bits)
- packet = None
- end_idx = 7
- subpackets = []
- if header["length_type_id"] == 0:
- # 15 bit packet length descriptor
- payload_length = int(bits[7:7+15], 2)
- end_idx += 15 + payload_length
- remainder = bits[end_idx - payload_length:end_idx]
- payload_bits_consumed = 0
- while payload_bits_consumed < payload_length and remainder != "":
- next_packet, remainder = Packet.consume_packet_bits(remainder)
- subpackets.append(next_packet)
- payload_bits_consumed += len(next_packet.bits)
- elif header["length_type_id"] == 1:
- # 11 bit packet count descriptor
- payload_count = int(bits[7:7+11], 2)
- end_idx += 11
- remainder = bits[end_idx:]
- for subpacket_idx in range(payload_count):
- if remainder == "":
- break
- next_packet, remainder = Packet.consume_packet_bits(remainder)
- subpackets.append(next_packet)
- end_idx += len(next_packet.bits)
- packet = OperatorPacket(bits[:end_idx], subpackets=subpackets)
- return packet, bits[end_idx:]
- def parse(self):
- type_code = self.header["type_code"]
- subpacket_values = (packet.value for packet in self.subpackets)
- if type_code == 0:
- # Sum
- self.value = sum(subpacket_values)
- elif type_code == 1:
- # Product
- self.value = reduce(operator.mul, subpacket_values)
- elif type_code == 2:
- # Min
- self.value = min(subpacket_values)
- elif type_code == 3:
- # Max
- self.value = max(subpacket_values)
- elif type_code == 5:
- # Greater than
- self.value = int(self.subpackets[0].value > self.subpackets[1].value)
- elif type_code == 6:
- # Less than
- self.value = int(self.subpackets[0].value < self.subpackets[1].value)
- elif type_code == 7:
- # Equal to
- self.value = int(self.subpackets[0].value == self.subpackets[1].value)
- def compute_version_sum(packet):
- if isinstance(packet, LiteralPacket):
- return packet.header["version"]
- elif isinstance(packet, OperatorPacket):
- return packet.header["version"] + sum(compute_version_sum(subpacket) for subpacket in packet.subpackets)
- else:
- return -1
- if __name__ == "__main__":
- file = "inputs/2021/16/data.txt"
- with open(file, "r") as f:
- transmission = f.read()
- binary_transmission = ""
- for hex_digit in transmission:
- binary_transmission += bin(int(hex_digit, 16))[2:].zfill(4)
- packet, remainder = Packet.consume_packet_bits(binary_transmission)
- print(f"Part 1: {compute_version_sum(packet)}")
- print(f"Part 2: {packet.value}")
Add Comment
Please, Sign In to add comment