Advertisement
Guest User

Untitled

a guest
Dec 12th, 2022
241
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.95 KB | None | 0 0
  1. import heapq
  2. import itertools
  3. import math
  4. from dataclasses import dataclass, field
  5. from typing import List, Optional, Dict, Set, Tuple
  6. import pathlib
  7.  
  8.  
  9. def read_data():
  10.     with open(f"data/{pathlib.Path(__file__).stem}.txt") as raw_data:
  11.         return [line.strip() for line in raw_data.readlines()]
  12.  
  13.  
  14. @dataclass(frozen=True)
  15. class Vec:
  16.     x: int
  17.     y: int
  18.  
  19.     def __add__(self, delta: 'Vec') -> 'Vec':
  20.         return Vec(self.x + delta.x, self.y + delta.y)
  21.  
  22.  
  23. def deltas() -> List[Vec]:
  24.     return [Vec(-1, 0), Vec(1, 0), Vec(0, -1), Vec(0, 1)]
  25.  
  26.  
  27. @dataclass
  28. class Node:
  29.     pos: Vec
  30.     height: int
  31.     adj: Dict[Vec, Optional['Node']] = field(default_factory=lambda: {d: None for d in deltas()})
  32.     path_value: int = math.inf
  33.  
  34.     def neighbors(self):
  35.         return [node for node in self.adj.values() if node]
  36.  
  37.     def __hash__(self):
  38.         return hash(self.pos)
  39.  
  40.     def __eq__(self, other):
  41.         return isinstance(other, Node) and self.pos == other.pos
  42.  
  43.     def __lt__(self, other):
  44.         return isinstance(other, Node) and self.path_value < other.path_value
  45.  
  46.  
  47. @dataclass
  48. class NodeMap:
  49.     nodes: List[List['Node']]
  50.  
  51.     def __getitem__(self, item: Vec) -> Optional[Node]:
  52.         if (0 <= item.y < len(self.nodes)) and (0 <= item.x < len(self.nodes[0])):
  53.             return self.nodes[item.y][item.x]
  54.         return None
  55.  
  56.     def solve_adj(self):
  57.         """ figures out what neighbors each node can access """
  58.         for node in self.all_nodes():
  59.             for delta, adj_node in [(delta, self[node.pos + delta]) for delta in deltas()]:
  60.                 if adj_node and adj_node.height <= node.height + 1:
  61.                     node.adj[delta] = adj_node
  62.  
  63.     def dijkstra(self, start: Vec, end: Vec):
  64.         """ Solves the shortest path between start and end. If path is possible, nodes forming the shortest path will have their path values updated with distance to start """
  65.         start = self[start]
  66.         end = self[end]
  67.  
  68.         unvisited: Set[Node] = {node for node in self.all_nodes()}
  69.         for node in unvisited:
  70.             node.path_value = math.inf
  71.         start.path_value = 0
  72.  
  73.         next_node = [start]
  74.  
  75.         while next_node:
  76.             current = heapq.heappop(next_node)
  77.             if current == end:
  78.                 break
  79.             if current not in unvisited:
  80.                 continue
  81.  
  82.             new_path_value = current.path_value + 1
  83.  
  84.             for node in current.neighbors():
  85.                 if node in unvisited:
  86.                     node.path_value = min(node.path_value, new_path_value)
  87.                     heapq.heappush(next_node, node)
  88.  
  89.             unvisited.remove(current)
  90.  
  91.     def all_nodes(self):
  92.         yield from itertools.chain(*self.nodes)
  93.  
  94.  
  95. def print_nodes(nodes: NodeMap):
  96.     right = Vec(1, 0)
  97.     left = Vec(-1, 0)
  98.     up = Vec(0, -1)
  99.     down = Vec(0, 1)
  100.     for row in nodes.nodes:
  101.         for node in row:
  102.             if node.adj[up]:
  103.                 print("    ↑    ", end='')
  104.             else:
  105.                 print("         ", end='')
  106.         print()
  107.         for node in row:
  108.             if node.adj[left]:
  109.                 print("←──", end='')
  110.             else:
  111.                 print("   ", end='')
  112.             print(f"{node.path_value:^3}", end='')
  113.             if node.adj[right]:
  114.                 print("──→", end='')
  115.             else:
  116.                 print("   ", end='')
  117.         print()
  118.         for node in row:
  119.             if node.adj[down]:
  120.                 print("    ↓    ", end='')
  121.             else:
  122.                 print("         ", end='')
  123.         print()
  124.  
  125.  
  126. def parse_node_map(data: List[str]) -> Tuple[NodeMap, Vec, Vec]:
  127.     nodes = [[Node(Vec(0, 0), 0) for _ in line] for line in data]
  128.     start = None
  129.     end = None
  130.  
  131.     for idx_row, row in enumerate(data):
  132.         for idx_col, ch in enumerate(row):
  133.             vec = Vec(idx_col, idx_row)
  134.             match ch:
  135.                 case 'S':
  136.                     height = 0
  137.                     start = vec
  138.                 case 'E':
  139.                     height = 25
  140.                     end = vec
  141.                 case s:
  142.                     height = ord(s) - ord('a')
  143.             nodes[idx_row][idx_col] = Node(vec, height)
  144.  
  145.     node_map = NodeMap(nodes)
  146.     node_map.solve_adj()
  147.  
  148.     return node_map, start, end
  149.  
  150.  
  151. def part1(data: List[str]) -> int:
  152.     node_map, start, end = parse_node_map(data)
  153.     node_map.dijkstra(start, end)
  154.     return node_map[end].path_value
  155.  
  156.  
  157. def part2(data: List[str]) -> int:
  158.     node_map, _, end = parse_node_map(data)
  159.     best_path = math.inf
  160.     for start_point in [node.pos for node in node_map.all_nodes() if node.height == 0]:
  161.         node_map.dijkstra(start_point, end)
  162.         best_path = min(node_map[end].path_value, best_path)
  163.     return best_path
  164.  
  165.  
  166. def main():
  167.     data = read_data()
  168.     print(f"Part 1: {part1(data)}")
  169.     print(f"Part 2: {part2(data)}")
  170.  
  171.  
  172. if __name__ == "__main__":
  173.     main()
  174.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement