Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma once
- #include <iostream>
- #include <random>
- #include <unordered_map>
- #include <vector>
- enum class Action { Right, Left, Up, Down, Size = 4 };
- typedef int State;
- typedef int Reward;
- class Maze {
- public:
- Maze(int n_rows, int n_cols, State start, State end)
- : n_rows_(n_rows),
- n_cols_(n_cols),
- maze_(n_rows * n_cols, kGrid),
- start_(start),
- end_(end) {
- maze_[start] = kStart;
- maze_[end] = kEnd;
- // assert kStart != kend
- value_to_reward[kGrid] = -1;
- value_to_reward[kStart] = -1;
- value_to_reward[kEnd] = 100;
- value_to_reward[kKey] = 1;
- }
- Maze(int size, State start, State end) : Maze(size, size, start, end) {}
- Maze(int num_rows, int num_cols, std::pair<int, int> start, std::pair<int, int> end)
- : Maze(num_rows, num_cols, num_cols * start.first + start.second,
- num_cols * end.first + end.second) {}
- Maze(int size, std::pair<int, int> start, std::pair<int, int> end)
- : Maze(size, size, start, end) {}
- char operator[](State state) { return maze_[state]; };
- size_t Size() const { return maze_.size(); }
- int NumberOfRows() const { return n_rows_; };
- int NumberOfCols() const { return n_cols_; };
- std::vector<char> GetValidForStepValues() const { return {kGrid, kKey, kStart, kEnd}; }
- Reward GetRewardInState(State state) { return value_to_reward[maze_[state]]; }
- State GetStartState() const { return start_; }
- State GetEndState() const { return end_; }
- const char kGrid = 'G';
- const char kStart = 'S';
- const char kEnd = 'E';
- const char kWall = 'W';
- const char kDoor = 'D';
- const char kKey = 'K';
- private:
- std::vector<char> maze_;
- int n_rows_;
- int n_cols_;
- State start_;
- State end_;
- std::unordered_map<char, Reward> value_to_reward;
- };
- struct Observation {
- State state;
- Reward reward;
- bool is_done;
- };
- class Environment {
- public:
- explicit Environment(Maze maze)
- : maze_(maze), current_state_(maze.GetStartState()), random_generator(151343) {}
- Environment() : Environment({5, 0, 24}) {}
- void Reset() { current_state_ = maze_.GetStartState(); }
- Observation Step(Action action) {
- State next_state = GetNextState(current_state_, action);
- Reward reward = GetRewardForAction(current_state_, action);
- bool is_done = false;
- if (next_state == maze_.GetEndState()) {
- is_done = true;
- }
- return {next_state, reward, is_done};
- }
- int GetNumberOfStates() { return static_cast<int>(maze_.Size()); }
- State SampleState() {
- std::uniform_int_distribution<int> dist(0, static_cast<int>(maze_.Size()) - 1);
- return dist(random_generator);
- }
- int GetNumberOfActions() { return static_cast<int>(Action::Size); }
- Action SampleAction() {
- std::uniform_int_distribution<int> dist(0, static_cast<int>(Action::Size) - 1);
- return static_cast<Action>(dist(random_generator));
- }
- private:
- Maze maze_;
- State current_state_;
- std::mt19937 random_generator;
- State ConvertCoordinateToState(std::pair<int, int> coordinate) {
- return maze_.NumberOfCols() * coordinate.first + coordinate.second;
- };
- std::pair<int, int> ConvertStateToCoordinate(State state) {
- return {state / maze_.NumberOfCols(), state % maze_.NumberOfCols()};
- }
- bool IsValidForStep(State state, std::vector<char> valid_values) {
- bool is_valid = false;
- for (auto value : valid_values) {
- is_valid |= maze_[state] == value;
- }
- return is_valid;
- }
- State GetNextState(State state, Action action) {
- std::pair<int, int> state_coordinates = ConvertStateToCoordinate(state);
- if (action == Action::Right && state_coordinates.second < maze_.NumberOfCols() - 1) {
- State next_state =
- ConvertCoordinateToState({state_coordinates.first, state_coordinates.second + 1});
- if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
- return next_state;
- }
- }
- if (action == Action::Left && state_coordinates.second > 0) {
- State next_state =
- ConvertCoordinateToState({state_coordinates.first, state_coordinates.second - 1});
- if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
- return next_state;
- }
- }
- if (action == Action::Up && state_coordinates.first > 0) {
- State next_state =
- ConvertCoordinateToState({state_coordinates.first - 1, state_coordinates.second});
- if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
- return next_state;
- }
- }
- if (action == Action::Down && state_coordinates.first < maze_.NumberOfRows() - 1) {
- State next_state =
- ConvertCoordinateToState({state_coordinates.first + 1, state_coordinates.second});
- if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
- return next_state;
- }
- }
- return state;
- }
- Reward GetRewardForAction(State state, Action action) {
- State next_state = GetNextState(state, action);
- return maze_.GetRewardInState(next_state);
- }
- };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement