Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- enum ReindeerDirection { NORTH, SOUTH, EAST, WEST }
- class MazeState {
- private int points;
- private final Point location;
- private final ReindeerDirection direction;
- public MazeState(Point location, ReindeerDirection direction, int points) {
- this.location = location;
- this.direction = direction;
- this.points = points;
- }
- public Point getLocation() { return location; }
- public ReindeerDirection getDirection() { return direction; }
- public int getPoints() { return points; }
- public void setPoints(int points) { this.points = points; }
- public String toString() { return this.location.toString(); }
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- MazeState mazeState = (MazeState) o;
- return Objects.equals(location, mazeState.location) && direction == mazeState.direction;
- }
- @Override
- public int hashCode() {
- return Objects.hash(location, direction);
- }
- }
- class MazePointComparator implements Comparator<MazeState> {
- @Override
- public int compare(MazeState x, MazeState y) {
- return Integer.compare(x.getPoints(), y.getPoints());
- }
- }
- // Given a maze state, return the next possible maze states. At each coordinate, the reindeer
- // can either keep moving in the same direction, or turn 90 degrees. We accumulate the total
- // points accrued in the maze state as well.
- private static List<MazeState> getNeighbors(char[][] grid, MazeState node) {
- List<MazeState> neighbors = new ArrayList<>();
- Point p = node.getLocation();
- ReindeerDirection direction = node.getDirection();
- int points = node.getPoints();
- if (direction == ReindeerDirection.NORTH) {
- if (grid[p.x-1][p.y] == '.') {
- neighbors.add(new MazeState(new Point(p.x-1, p.y), direction, points + 1));
- }
- if (grid[p.x][p.y-1] == '.') {
- neighbors.add(new MazeState(p, ReindeerDirection.WEST, points + 1000));
- }
- if (grid[p.x][p.y+1] == '.') {
- neighbors.add(new MazeState(p, ReindeerDirection.EAST, points + 1000));
- }
- } else if (direction == ReindeerDirection.EAST) {
- if (grid[p.x][p.y+1] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y+1), direction, points + 1));
- }
- if (grid[p.x-1][p.y] == '.') {
- neighbors.add(new MazeState(p, ReindeerDirection.NORTH, points + 1000));
- }
- if (grid[p.x+1][p.y] == '.') {
- neighbors.add(new MazeState(p, ReindeerDirection.SOUTH, points + 1000));
- }
- } else if (direction == ReindeerDirection.SOUTH) {
- if (grid[p.x+1][p.y] == '.') {
- neighbors.add(new MazeState(new Point(p.x+1, p.y), direction, points + 1));
- }
- if (grid[p.x][p.y-1] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y), ReindeerDirection.WEST, points + 1000));
- }
- if (grid[p.x][p.y+1] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y), ReindeerDirection.EAST, points + 1000));
- }
- } else if (direction == ReindeerDirection.WEST) {
- if (grid[p.x][p.y-1] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y-1), direction, points + 1));
- }
- if (grid[p.x-1][p.y] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y), ReindeerDirection.NORTH, points + 1000));
- }
- if (grid[p.x+1][p.y] == '.') {
- neighbors.add(new MazeState(new Point(p.x, p.y), ReindeerDirection.SOUTH, points + 1000));
- }
- }
- return neighbors;
- }
- // Part 1: Run Dijkstra's algorithm on the maze states, not the coordinates of the grid.
- private static int part1(char[][] grid, Point start, Point end) {
- Set<MazeState> visited = new HashSet<>();
- PriorityQueue<MazeState> pq = new PriorityQueue<>(new MazePointComparator());
- MazeState initial = new MazeState(start, ReindeerDirection.EAST, 0);
- pq.add(initial);
- while (!pq.isEmpty()) {
- MazeState node = pq.poll();
- Point p = node.getLocation();
- if (visited.contains(node)) continue;
- // Because we accumulate the points in the maze state, if we've reached the end,
- // then return the points of the end state.
- if (p.equals(end)) {
- return node.getPoints();
- }
- visited.add(node);
- List<MazeState> neighbors = getNeighbors(grid, node);
- for (MazeState neighbor : neighbors) {
- if (!visited.contains(neighbor)) {
- pq.add(neighbor);
- }
- }
- }
- return 0;
- }
- // Part 2: Similar to part 1, except now we don't just stop once we reach the end state.
- // We keep processing potential branching paths until all meaningful nodes have been visited.
- // We maintain a map of best known costs to reach a given maze state.
- // We also maintain a map of parent maze states, so we can reconstruct the path.
- private static int part2(char[][] grid, Point start, Point end) {
- Map<MazeState, Integer> costs = new HashMap<>(); // maps the best known costs to reach a given maze state so far
- Set<MazeState> visited = new HashSet<>();
- PriorityQueue<MazeState> pq = new PriorityQueue<>(new MazePointComparator());
- // for each maze state key, tracks a list of parent maze states that could get us to that maze state.
- Map<MazeState, List<MazeState>> parents = new HashMap<>();
- MazeState initial = new MazeState(start, ReindeerDirection.EAST, 0);
- pq.add(initial);
- costs.put(initial, 0);
- while (!pq.isEmpty()) {
- MazeState node = pq.poll();
- if (visited.contains(node)) continue;
- visited.add(node);
- List<MazeState> neighbors = getNeighbors(grid, node);
- // For each neighbor, look in our costs hashmap for the best known cost to get to that neighbor so far.
- // If we then calculate that the best known cost to get to our current node plus the cost it takes
- // to get from our current node to the neighbor is less than the best known cost to get to our neighbor,
- // then update it. In other words, we update if: cost[node] + weight[node, neighbor] < cost[neighbor]
- for (MazeState neighbor : neighbors) {
- if (!visited.contains(neighbor)) {
- // Check if our costs map contains these maze states. If not, initialize their costs to infinity.
- if (!costs.containsKey(node)) {
- costs.put(node, Integer.MAX_VALUE);
- }
- if (!costs.containsKey(neighbor)) {
- costs.put(neighbor, Integer.MAX_VALUE);
- }
- // One key thing to note here: the cost it takes to get from the current node
- // to the neighbor is (neighbor.getPoints() - node.getPoints()) because our `getNeighbors()`
- // calculation accumulates the points.
- int newCost = costs.get(node) + (neighbor.getPoints() - node.getPoints());
- // If we've found a better cost, update it.
- if (newCost <= costs.get(neighbor)) {
- costs.put(neighbor, newCost);
- pq.add(neighbor);
- // And keep track of where we've come from
- if (parents.containsKey(neighbor)) {
- parents.get(neighbor).add(node);
- } else {
- List<MazeState> parentsOfNeighbor = new ArrayList<>();
- parentsOfNeighbor.add(node);
- parents.put(neighbor, parentsOfNeighbor);
- }
- }
- }
- }
- }
- Set<Point> uniquePoints = new HashSet<>();
- // Find the min points in takes to get to the end.
- int min = Integer.MAX_VALUE;
- for (MazeState state : parents.keySet()) {
- if (state.getLocation().equals(end)) {
- min = Math.min(min, state.getPoints());
- }
- }
- // Find each end state that results in the minimum points in reaching it.
- List<MazeState> endStates = new ArrayList<>();
- for (MazeState state : parents.keySet()) {
- if (state.getLocation().equals(end) && state.getPoints() == min) {
- endStates.add(state);
- }
- }
- // Iterate through each end state and walk backwards.
- for (MazeState endState: endStates) {
- Stack<MazeState> stack = new Stack<>();
- stack.add(endState);
- // Walk backwards from the end state, visiting the parent/previous states that got us to that state
- // and add their points to a set.
- while (!stack.isEmpty()) {
- MazeState curr = stack.pop();
- uniquePoints.add(curr.getLocation());
- if (parents.containsKey(curr)) {
- for (MazeState ms : parents.get(curr)) {
- stack.push(ms);
- }
- }
- }
- }
- return uniquePoints.size();
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement