Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include "q-learning.h"
- QTable::QTable(Environment env, double learning_rate, double discounting_rate)
- : qtable_(env.NumberOfStates(),
- std::vector<double>(env.NumberOfActions(), 0.0)),
- learning_rate_(learning_rate),
- discounting_rate_(discounting_rate) {}
- QTable::QTable(Environment env) : QTable(env, 0.5, 0.7) {}
- void QTable::SetLearningRate(double learning_rate) {
- learning_rate_ = learning_rate;
- }
- void QTable::SetDiscountingRate(double discounting_rate) {
- discounting_rate_ = discounting_rate;
- }
- void QTable::UpdateQValue(State state, Action action, State new_state,
- Reward reward) {
- auto max_new_state_qvalue =
- *std::max_element(qtable_[new_state].begin(), qtable_[new_state].end());
- qtable_[state][static_cast<int>(action)] =
- qtable_[state][static_cast<int>(action)] +
- learning_rate_ * (reward + discounting_rate_ * max_new_state_qvalue -
- qtable_[state][static_cast<int>(action)]);
- }
- Action QTable::GetBestAction(State state) {
- auto iter = std::max_element(qtable_[state].begin(), qtable_[state].end());
- return static_cast<Action>(iter - qtable_[state].begin());
- }
- void QTable::Reset() {
- for (auto state_vector : qtable_) {
- std::fill(state_vector.begin(), state_vector.end(), 0.0);
- }
- }
- void QTable::Render(int n_cols, const Maze& maze) {
- std::vector<char> direction = {'R', 'L', 'U', 'D'};
- for (int index = 0; index < static_cast<int>(qtable_.size()); ++index) {
- if (index % n_cols == 0) {
- std::cout << "\n";
- }
- if (maze[index] == maze.kWall) {
- std::cout << "W"
- << " ";
- } else {
- std::cout << direction[static_cast<int>(GetBestAction(index))] << " ";
- }
- }
- std::cout << "\n";
- }
- void Epsilon::Update(int episode) {
- value = min_epsilon +
- (max_epsilon - min_epsilon) * std::exp(-decay_rate * episode);
- }
- QTable Train(int n_episodes, int max_steps, Epsilon epsilon) {
- Environment env;
- QTable qtable(env);
- std::mt19937 random_generator(1531413);
- std::uniform_real_distribution<> dist(0, 1);
- for (int episode = 0; episode < n_episodes; ++episode) {
- State state = env.Reset();
- int n_steps = 0;
- bool is_done = false;
- while (n_steps < max_steps && !is_done) {
- Action action = qtable.GetBestAction(state);
- if (dist(random_generator) < epsilon.value) {
- action = env.SampleAction();
- }
- Observation observation = env.Step(action);
- qtable.UpdateQValue(state, action, observation.state, observation.reward);
- is_done = observation.is_done;
- state = observation.state;
- }
- epsilon.Update(episode);
- qtable.Render(env.GetMaze().NumberOfCols(), env.GetMaze());
- }
- return qtable;
- }
- const int kSquareSize = 64;
- std::pair<sf::RectangleShape, sf::RectangleShape> UpdateAgent(Environment env,
- State old_state,
- State new_state) {
- auto old_state_coord = env.ConvertStateToCoordinate(old_state);
- std::cout << "old_state: " << old_state_coord.first << " "
- << old_state_coord.second << "\n";
- sf::RectangleShape old_rectangle;
- old_rectangle.setSize(sf::Vector2f(64, 64));
- old_rectangle.setOutlineColor(sf::Color::Black);
- if (env.GetMaze()[env.ConvertCoordinateToState(
- {old_state_coord.first, old_state_coord.second})] ==
- env.GetMaze().kStart) {
- old_rectangle.setFillColor(sf::Color(154, 205, 50));
- }
- old_rectangle.setOutlineThickness(2);
- old_rectangle.setPosition(old_state_coord.first * 64,
- old_state_coord.second * 64);
- if (env.GetMaze()[env.ConvertCoordinateToState(
- {old_state_coord.first, old_state_coord.second})] ==
- env.GetMaze().kEnd) {
- old_rectangle.setFillColor(sf::Color(220, 20, 60));
- }
- auto new_state_coord = env.ConvertStateToCoordinate(new_state);
- std::cout << "new_state: " << new_state_coord.first << " "
- << new_state_coord.second << "\n";
- sf::RectangleShape new_rectangle;
- new_rectangle.setSize(sf::Vector2f(64, 64));
- new_rectangle.setOutlineColor(sf::Color::Black);
- new_rectangle.setOutlineThickness(2);
- new_rectangle.setPosition(new_state_coord.first * 64,
- new_state_coord.second * 64);
- new_rectangle.setFillColor(sf::Color(135, 206, 250));
- return {old_rectangle, new_rectangle};
- }
- sf::RectangleShape GetRect(int size, std::pair<int, int> position,
- sf::Color fill_color, sf::Color outline_color,
- int thickness) {
- sf::RectangleShape rectangle;
- rectangle.setSize(sf::Vector2f(size, size));
- rectangle.setOutlineColor(outline_color);
- rectangle.setFillColor(fill_color);
- rectangle.setOutlineThickness(thickness);
- rectangle.setPosition(position.first, position.second);
- return rectangle;
- }
- sf::Text GetText(std::string message, const sf::Font& font, int char_size,
- std::pair<int, int> position,
- sf::Color color = sf::Color::Black) {
- sf::Text text(message, font);
- text.setCharacterSize(char_size);
- text.setFillColor(color);
- text.setPosition(position.first, position.second);
- return text;
- }
- void DrawGrid(sf::RenderWindow* window, const Maze& maze) {
- sf::Font font;
- font.loadFromFile(
- "/home/foksly/Documents/road-to-nips/feedback-learning/q-learning/"
- "q-learning/roboto.ttf");
- int thickness = 1;
- for (int row = 0; row < maze.NumberOfRows(); ++row) {
- for (int col = 0; col < maze.NumberOfCols(); ++col) {
- sf::RectangleShape rectangle = GetRect(
- kSquareSize - 2 * thickness,
- {col * kSquareSize + thickness, row * kSquareSize + thickness},
- sf::Color::White, sf::Color::Black, thickness);
- if (maze[maze.ConvertCoordinateToState({row, col})] == maze.kStart) {
- rectangle.setFillColor(sf::Color(154, 205, 50));
- sf::Text text =
- GetText("S", font, 50, {col * kSquareSize + 16, row * kSquareSize});
- window->draw(rectangle);
- window->draw(text);
- } else if (maze[maze.ConvertCoordinateToState({row, col})] == maze.kEnd) {
- rectangle.setFillColor(sf::Color(220, 20, 60));
- sf::Text text =
- GetText("E", font, 50, {col * kSquareSize + 16, row * kSquareSize});
- window->draw(rectangle);
- window->draw(text);
- } else if (maze[maze.ConvertCoordinateToState({row, col})] ==
- maze.kWall) {
- rectangle.setFillColor(sf::Color::Blue);
- window->draw(rectangle);
- } else {
- window->draw(rectangle);
- }
- }
- }
- }
- void StartVisualization(int n_episodes, int max_steps, Epsilon epsilon) {
- Maze maze = Maze(10, 10, 25, 99);
- Environment env(maze);
- QTable qtable(env);
- sf::RenderWindow window(sf::VideoMode(maze.NumberOfRows() * kSquareSize,
- maze.NumberOfCols() * kSquareSize),
- "Grid World Q-learning", sf::Style::Close);
- std::mt19937 random_generator(1531413);
- std::uniform_real_distribution<> dist(0, 1);
- while (window.isOpen()) {
- sf::Event event;
- while (window.pollEvent(event)) {
- if (event.type == sf::Event::Closed) {
- window.close();
- }
- if (event.type == sf::Event::MouseButtonPressed) {
- if (event.mouseButton.button == sf::Mouse::Left) {
- env.maze_[env.maze_.ConvertCoordinateToState(
- {event.mouseButton.y / kSquareSize,
- event.mouseButton.x / kSquareSize})] = env.maze_.kWall;
- }
- if (event.mouseButton.button == sf::Mouse::Right) {
- env.maze_[env.maze_.ConvertCoordinateToState(
- {event.mouseButton.y / kSquareSize,
- event.mouseButton.x / kSquareSize})] = env.maze_.kGrid;
- }
- }
- }
- window.clear(sf::Color::White);
- DrawGrid(&window, env.maze_);
- window.display();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement