Advertisement
Guest User

Untitled

a guest
Feb 24th, 2020
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.46 KB | None | 0 0
  1. #pragma once
  2.  
  3. #include <iostream>
  4. #include <random>
  5. #include <unordered_map>
  6. #include <vector>
  7.  
  8. enum class Action { Right, Left, Up, Down, Size = 4 };
  9.  
  10. typedef int State;
  11. typedef int Reward;
  12.  
  13. class Maze {
  14.    public:
  15.     Maze(int n_rows, int n_cols, State start, State end)
  16.         : n_rows_(n_rows),
  17.           n_cols_(n_cols),
  18.           maze_(n_rows * n_cols, kGrid),
  19.           start_(start),
  20.           end_(end) {
  21.         maze_[start] = kStart;
  22.         maze_[end] = kEnd;
  23.         // assert kStart != kend
  24.  
  25.         value_to_reward[kGrid] = -1;
  26.         value_to_reward[kStart] = -1;
  27.         value_to_reward[kEnd] = 100;
  28.         value_to_reward[kKey] = 1;
  29.     }
  30.  
  31.     Maze(int size, State start, State end) : Maze(size, size, start, end) {}
  32.  
  33.     Maze(int num_rows, int num_cols, std::pair<int, int> start, std::pair<int, int> end)
  34.         : Maze(num_rows, num_cols, num_cols * start.first + start.second,
  35.                num_cols * end.first + end.second) {}
  36.  
  37.     Maze(int size, std::pair<int, int> start, std::pair<int, int> end)
  38.         : Maze(size, size, start, end) {}
  39.  
  40.     char operator[](State state) { return maze_[state]; };
  41.  
  42.     size_t Size() const { return maze_.size(); }
  43.     int NumberOfRows() const { return n_rows_; };
  44.     int NumberOfCols() const { return n_cols_; };
  45.  
  46.     std::vector<char> GetValidForStepValues() const { return {kGrid, kKey, kStart, kEnd}; }
  47.  
  48.     Reward GetRewardInState(State state) { return value_to_reward[maze_[state]]; }
  49.  
  50.     State GetStartState() const { return start_; }
  51.  
  52.     State GetEndState() const { return end_; }
  53.  
  54.     const char kGrid = 'G';
  55.     const char kStart = 'S';
  56.     const char kEnd = 'E';
  57.     const char kWall = 'W';
  58.     const char kDoor = 'D';
  59.     const char kKey = 'K';
  60.  
  61.    private:
  62.     std::vector<char> maze_;
  63.     int n_rows_;
  64.     int n_cols_;
  65.  
  66.     State start_;
  67.     State end_;
  68.  
  69.     std::unordered_map<char, Reward> value_to_reward;
  70. };
  71.  
  72. struct Observation {
  73.     State state;
  74.     Reward reward;
  75.     bool is_done;
  76. };
  77.  
  78. class Environment {
  79.    public:
  80.     explicit Environment(Maze maze)
  81.         : maze_(maze), current_state_(maze.GetStartState()), random_generator(151343) {}
  82.  
  83.     Environment() : Environment({5, 0, 24}) {}
  84.  
  85.     void Reset() { current_state_ = maze_.GetStartState(); }
  86.  
  87.     Observation Step(Action action) {
  88.         State next_state = GetNextState(current_state_, action);
  89.         Reward reward = GetRewardForAction(current_state_, action);
  90.         bool is_done = false;
  91.         if (next_state == maze_.GetEndState()) {
  92.             is_done = true;
  93.         }
  94.         return {next_state, reward, is_done};
  95.     }
  96.  
  97.     int GetNumberOfStates() { return static_cast<int>(maze_.Size()); }
  98.  
  99.     State SampleState() {
  100.         std::uniform_int_distribution<int> dist(0, static_cast<int>(maze_.Size()) - 1);
  101.         return dist(random_generator);
  102.     }
  103.  
  104.     int GetNumberOfActions() { return static_cast<int>(Action::Size); }
  105.  
  106.     Action SampleAction() {
  107.         std::uniform_int_distribution<int> dist(0, static_cast<int>(Action::Size) - 1);
  108.         return static_cast<Action>(dist(random_generator));
  109.     }
  110.  
  111.    private:
  112.     Maze maze_;
  113.     State current_state_;
  114.     std::mt19937 random_generator;
  115.  
  116.     State ConvertCoordinateToState(std::pair<int, int> coordinate) {
  117.         return maze_.NumberOfCols() * coordinate.first + coordinate.second;
  118.     };
  119.  
  120.     std::pair<int, int> ConvertStateToCoordinate(State state) {
  121.         return {state / maze_.NumberOfCols(), state % maze_.NumberOfCols()};
  122.     }
  123.  
  124.     bool IsValidForStep(State state, std::vector<char> valid_values) {
  125.         bool is_valid = false;
  126.         for (auto value : valid_values) {
  127.             is_valid |= maze_[state] == value;
  128.         }
  129.         return is_valid;
  130.     }
  131.  
  132.     State GetNextState(State state, Action action) {
  133.         std::pair<int, int> state_coordinates = ConvertStateToCoordinate(state);
  134.         if (action == Action::Right && state_coordinates.second < maze_.NumberOfCols() - 1) {
  135.             State next_state =
  136.                 ConvertCoordinateToState({state_coordinates.first, state_coordinates.second + 1});
  137.             if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
  138.                 return next_state;
  139.             }
  140.         }
  141.         if (action == Action::Left && state_coordinates.second > 0) {
  142.             State next_state =
  143.                 ConvertCoordinateToState({state_coordinates.first, state_coordinates.second - 1});
  144.             if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
  145.                 return next_state;
  146.             }
  147.         }
  148.         if (action == Action::Up && state_coordinates.first > 0) {
  149.             State next_state =
  150.                 ConvertCoordinateToState({state_coordinates.first - 1, state_coordinates.second});
  151.             if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
  152.                 return next_state;
  153.             }
  154.         }
  155.         if (action == Action::Down && state_coordinates.first < maze_.NumberOfRows() - 1) {
  156.             State next_state =
  157.                 ConvertCoordinateToState({state_coordinates.first + 1, state_coordinates.second});
  158.             if (IsValidForStep(next_state, maze_.GetValidForStepValues())) {
  159.                 return next_state;
  160.             }
  161.         }
  162.         return state;
  163.     }
  164.  
  165.     Reward GetRewardForAction(State state, Action action) {
  166.         State next_state = GetNextState(state, action);
  167.         return maze_.GetRewardInState(next_state);
  168.     }
  169. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement