Advertisement
Guest User

Untitled

a guest
Feb 25th, 2020
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.02 KB | None | 0 0
  1. #include "q-learning.h"
  2.  
  3. QTable::QTable(Environment env, double learning_rate, double discounting_rate)
  4.     : qtable_(env.NumberOfStates(),
  5.               std::vector<double>(env.NumberOfActions(), 0.0)),
  6.       learning_rate_(learning_rate),
  7.       discounting_rate_(discounting_rate) {}
  8.  
  9. QTable::QTable(Environment env) : QTable(env, 0.5, 0.7) {}
  10.  
  11. void QTable::SetLearningRate(double learning_rate) {
  12.   learning_rate_ = learning_rate;
  13. }
  14.  
  15. void QTable::SetDiscountingRate(double discounting_rate) {
  16.   discounting_rate_ = discounting_rate;
  17. }
  18.  
  19. void QTable::UpdateQValue(State state, Action action, State new_state,
  20.                           Reward reward) {
  21.   auto max_new_state_qvalue =
  22.       *std::max_element(qtable_[new_state].begin(), qtable_[new_state].end());
  23.  
  24.   qtable_[state][static_cast<int>(action)] =
  25.       qtable_[state][static_cast<int>(action)] +
  26.       learning_rate_ * (reward + discounting_rate_ * max_new_state_qvalue -
  27.                         qtable_[state][static_cast<int>(action)]);
  28. }
  29.  
  30. Action QTable::GetBestAction(State state) {
  31.   auto iter = std::max_element(qtable_[state].begin(), qtable_[state].end());
  32.   return static_cast<Action>(iter - qtable_[state].begin());
  33. }
  34.  
  35. void QTable::Reset() {
  36.   for (auto state_vector : qtable_) {
  37.     std::fill(state_vector.begin(), state_vector.end(), 0.0);
  38.   }
  39. }
  40.  
  41. void QTable::Render(int n_cols, const Maze& maze) {
  42.   std::vector<char> direction = {'R', 'L', 'U', 'D'};
  43.   for (int index = 0; index < static_cast<int>(qtable_.size()); ++index) {
  44.     if (index % n_cols == 0) {
  45.       std::cout << "\n";
  46.     }
  47.     if (maze[index] == maze.kWall) {
  48.       std::cout << "W"
  49.                 << " ";
  50.     } else {
  51.       std::cout << direction[static_cast<int>(GetBestAction(index))] << " ";
  52.     }
  53.   }
  54.   std::cout << "\n";
  55. }
  56.  
  57. void Epsilon::Update(int episode) {
  58.   value = min_epsilon +
  59.           (max_epsilon - min_epsilon) * std::exp(-decay_rate * episode);
  60. }
  61.  
  62. QTable Train(int n_episodes, int max_steps, Epsilon epsilon) {
  63.   Environment env;
  64.   QTable qtable(env);
  65.   std::mt19937 random_generator(1531413);
  66.   std::uniform_real_distribution<> dist(0, 1);
  67.  
  68.   for (int episode = 0; episode < n_episodes; ++episode) {
  69.     State state = env.Reset();
  70.  
  71.     int n_steps = 0;
  72.     bool is_done = false;
  73.  
  74.     while (n_steps < max_steps && !is_done) {
  75.       Action action = qtable.GetBestAction(state);
  76.       if (dist(random_generator) < epsilon.value) {
  77.         action = env.SampleAction();
  78.       }
  79.       Observation observation = env.Step(action);
  80.       qtable.UpdateQValue(state, action, observation.state, observation.reward);
  81.       is_done = observation.is_done;
  82.       state = observation.state;
  83.     }
  84.     epsilon.Update(episode);
  85.     qtable.Render(env.GetMaze().NumberOfCols(), env.GetMaze());
  86.   }
  87.  
  88.   return qtable;
  89. }
  90.  
  91. const int kSquareSize = 64;
  92.  
  93. std::pair<sf::RectangleShape, sf::RectangleShape> UpdateAgent(Environment env,
  94.                                                               State old_state,
  95.                                                               State new_state) {
  96.   auto old_state_coord = env.ConvertStateToCoordinate(old_state);
  97.   std::cout << "old_state: " << old_state_coord.first << " "
  98.             << old_state_coord.second << "\n";
  99.   sf::RectangleShape old_rectangle;
  100.   old_rectangle.setSize(sf::Vector2f(64, 64));
  101.   old_rectangle.setOutlineColor(sf::Color::Black);
  102.   if (env.GetMaze()[env.ConvertCoordinateToState(
  103.           {old_state_coord.first, old_state_coord.second})] ==
  104.       env.GetMaze().kStart) {
  105.     old_rectangle.setFillColor(sf::Color(154, 205, 50));
  106.   }
  107.  
  108.   old_rectangle.setOutlineThickness(2);
  109.   old_rectangle.setPosition(old_state_coord.first * 64,
  110.                             old_state_coord.second * 64);
  111.  
  112.   if (env.GetMaze()[env.ConvertCoordinateToState(
  113.           {old_state_coord.first, old_state_coord.second})] ==
  114.       env.GetMaze().kEnd) {
  115.     old_rectangle.setFillColor(sf::Color(220, 20, 60));
  116.   }
  117.  
  118.   auto new_state_coord = env.ConvertStateToCoordinate(new_state);
  119.   std::cout << "new_state: " << new_state_coord.first << " "
  120.             << new_state_coord.second << "\n";
  121.   sf::RectangleShape new_rectangle;
  122.   new_rectangle.setSize(sf::Vector2f(64, 64));
  123.   new_rectangle.setOutlineColor(sf::Color::Black);
  124.   new_rectangle.setOutlineThickness(2);
  125.   new_rectangle.setPosition(new_state_coord.first * 64,
  126.                             new_state_coord.second * 64);
  127.   new_rectangle.setFillColor(sf::Color(135, 206, 250));
  128.   return {old_rectangle, new_rectangle};
  129. }
  130.  
  131. sf::RectangleShape GetRect(int size, std::pair<int, int> position,
  132.                            sf::Color fill_color, sf::Color outline_color,
  133.                            int thickness) {
  134.   sf::RectangleShape rectangle;
  135.   rectangle.setSize(sf::Vector2f(size, size));
  136.   rectangle.setOutlineColor(outline_color);
  137.   rectangle.setFillColor(fill_color);
  138.   rectangle.setOutlineThickness(thickness);
  139.   rectangle.setPosition(position.first, position.second);
  140.   return rectangle;
  141. }
  142.  
  143. sf::Text GetText(std::string message, const sf::Font& font, int char_size,
  144.                  std::pair<int, int> position,
  145.                  sf::Color color = sf::Color::Black) {
  146.   sf::Text text(message, font);
  147.   text.setCharacterSize(char_size);
  148.   text.setFillColor(color);
  149.   text.setPosition(position.first, position.second);
  150.   return text;
  151. }
  152.  
  153. void DrawGrid(sf::RenderWindow* window, const Maze& maze) {
  154.   sf::Font font;
  155.   font.loadFromFile(
  156.       "/home/foksly/Documents/road-to-nips/feedback-learning/q-learning/"
  157.       "q-learning/roboto.ttf");
  158.  
  159.   int thickness = 1;
  160.   for (int row = 0; row < maze.NumberOfRows(); ++row) {
  161.     for (int col = 0; col < maze.NumberOfCols(); ++col) {
  162.       sf::RectangleShape rectangle = GetRect(
  163.           kSquareSize - 2 * thickness,
  164.           {col * kSquareSize + thickness, row * kSquareSize + thickness},
  165.           sf::Color::White, sf::Color::Black, thickness);
  166.       if (maze[maze.ConvertCoordinateToState({row, col})] == maze.kStart) {
  167.         rectangle.setFillColor(sf::Color(154, 205, 50));
  168.         sf::Text text =
  169.             GetText("S", font, 50, {col * kSquareSize + 16, row * kSquareSize});
  170.         window->draw(rectangle);
  171.         window->draw(text);
  172.       } else if (maze[maze.ConvertCoordinateToState({row, col})] == maze.kEnd) {
  173.         rectangle.setFillColor(sf::Color(220, 20, 60));
  174.         sf::Text text =
  175.             GetText("E", font, 50, {col * kSquareSize + 16, row * kSquareSize});
  176.         window->draw(rectangle);
  177.         window->draw(text);
  178.       } else if (maze[maze.ConvertCoordinateToState({row, col})] ==
  179.                  maze.kWall) {
  180.         rectangle.setFillColor(sf::Color::Blue);
  181.         window->draw(rectangle);
  182.       } else {
  183.         window->draw(rectangle);
  184.       }
  185.     }
  186.   }
  187. }
  188.  
  189. void StartVisualization(int n_episodes, int max_steps, Epsilon epsilon) {
  190.   Maze maze = Maze(10, 10, 25, 99);
  191.   Environment env(maze);
  192.  
  193.   QTable qtable(env);
  194.  
  195.   sf::RenderWindow window(sf::VideoMode(maze.NumberOfRows() * kSquareSize,
  196.                                         maze.NumberOfCols() * kSquareSize),
  197.                           "Grid World Q-learning", sf::Style::Close);
  198.  
  199.   std::mt19937 random_generator(1531413);
  200.   std::uniform_real_distribution<> dist(0, 1);
  201.  
  202.   while (window.isOpen()) {
  203.     sf::Event event;
  204.     while (window.pollEvent(event)) {
  205.       if (event.type == sf::Event::Closed) {
  206.         window.close();
  207.       }
  208.       if (event.type == sf::Event::MouseButtonPressed) {
  209.         if (event.mouseButton.button == sf::Mouse::Left) {
  210.           env.maze_[env.maze_.ConvertCoordinateToState(
  211.               {event.mouseButton.y / kSquareSize,
  212.                event.mouseButton.x / kSquareSize})] = env.maze_.kWall;
  213.         }
  214.         if (event.mouseButton.button == sf::Mouse::Right) {
  215.           env.maze_[env.maze_.ConvertCoordinateToState(
  216.               {event.mouseButton.y / kSquareSize,
  217.                event.mouseButton.x / kSquareSize})] = env.maze_.kGrid;
  218.         }
  219.       }
  220.     }
  221.  
  222.     window.clear(sf::Color::White);
  223.  
  224.     DrawGrid(&window, env.maze_);
  225.  
  226.     window.display();
  227.   }
  228. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement