Advertisement
Guest User

Untitled

a guest
Mar 25th, 2019
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.59 KB | None | 0 0
  1. # Tudor Berariu, 2016
  2. # Razvan Chitu, 2018
  3.  
  4. # Standard library imports
  5. from collections import defaultdict
  6. from argparse import ArgumentParser
  7. from random import choice, randint
  8. from time import sleep
  9. # External library imports
  10. from matplotlib import pyplot as plt
  11. import numpy as np
  12. # Local imports
  13. from mini_pacman_new import Game
  14.  
  15.  
  16. def epsilon_greedy(Q, state, legal_actions, epsilon):
  17. # TODO (2) : Epsilon greedy
  18. p = randint(1, 100)
  19.  
  20. if p < epsilon * 100:
  21. return choice(legal_actions)
  22. else:
  23. return best_action(Q, state, legal_actions)
  24.  
  25.  
  26. def best_action(Q, state, legal_actions):
  27. # TODO (3) : Best action
  28.  
  29. # # search for unexplored actions
  30. unexplored_actions = []
  31. # for action in legal_actions:
  32. # if (state, action) not in Q.keys():
  33. # unexplored_actions.append(action)
  34. #
  35. # if unexplored_actions:
  36. # return choice(unexplored_actions)
  37.  
  38. # all actions have been explored, search for the one with highest q
  39. max_q = -999999999
  40. max_action = None
  41.  
  42. # all_q_values = []
  43.  
  44. for action in legal_actions:
  45. if (state, action) not in Q.keys():
  46. unexplored_actions.append(action)
  47. elif Q[(state, action)] > max_q:
  48. max_q = Q[(state, action)]
  49. max_action = action
  50.  
  51. # # if all actions have the same max_q, pick a random one
  52. # if all_q_values and all(elem == max_q for elem in all_q_values):
  53. # return choice(legal_actions)
  54. if unexplored_actions:
  55. return choice(unexplored_actions)
  56.  
  57. return max_action
  58.  
  59.  
  60. def q_learning(map_file, learning_rate, discount, epsilon, train_episodes,
  61. eval_every, eval_episodes, verbose, plot_scores, sleep_interval,
  62. final_show):
  63. # Q will use (state, action) tuples as key.
  64. # Use Q.get(..., 0) for default values.
  65. Q = {}
  66. train_scores = []
  67. eval_scores = []
  68.  
  69. # for each episode ...
  70. for train_ep in range(1, train_episodes + 1):
  71. game = Game(map_file)
  72.  
  73. # display current state and sleep
  74. if verbose:
  75. print(game.state)
  76. sleep(sleep_interval)
  77.  
  78. # while current state is not terminal
  79. while not game.is_over():
  80. # choose one of the legal actions
  81. state, actions = game.state, game.legal_actions
  82. action = epsilon_greedy(Q, state, actions, epsilon)
  83.  
  84. # apply action and get the next state and the reward
  85. reward, msg = game.apply_action(action)
  86. next_state, next_actions = game.state, game.legal_actions
  87.  
  88. # TODO (1) : Q-Learning
  89. max_a_prime = -999999999999
  90. for a_prime in next_actions:
  91. if Q.get((next_state, a_prime), 0) > max_a_prime:
  92. max_a_prime = Q.get((next_state, a_prime), 0)
  93.  
  94. Q[(state, action)] = Q.get((state, action), 0) + learning_rate * (
  95. reward + discount * max_a_prime - Q.get((state, action), 0))
  96.  
  97. # display current state and sleep
  98. if verbose:
  99. print(msg);
  100. print(game.state);
  101. sleep(sleep_interval)
  102.  
  103. print("Episode %6d / %6d" % (train_ep, train_episodes))
  104. train_scores.append(game.score)
  105.  
  106. # evaluate the greedy policy
  107. if train_ep % eval_every == 0:
  108. avg_score = .0
  109.  
  110. # TODO (4) : Evaluate
  111. for _ in range(eval_episodes):
  112. game = Game(map_file)
  113. while not game.is_over():
  114. state, actions = game.state, game.legal_actions
  115. action = best_action(Q, state, actions)
  116. reward, msg = game.apply_action(action)
  117. avg_score += game.score
  118. avg_score /= eval_episodes
  119.  
  120. eval_scores.append(avg_score)
  121.  
  122. # --------------------------------------------------------------------------
  123. if final_show:
  124. game = Game(map_file)
  125. while not game.is_over():
  126. state, actions = game.state, game.legal_actions
  127. action = best_action(Q, state, actions)
  128. reward, msg = game.apply_action(action)
  129. print(msg)
  130. print(game.state)
  131. sleep(sleep_interval)
  132.  
  133. if plot_scores:
  134. plt.xlabel("Episode")
  135. plt.ylabel("Average score")
  136. plt.plot(
  137. np.linspace(1, train_episodes, train_episodes),
  138. np.convolve(train_scores, [0.2, 0.2, 0.2, 0.2, 0.2], "same"),
  139. linewidth=1.0, color="blue"
  140. )
  141. plt.plot(
  142. np.linspace(eval_every, train_episodes, len(eval_scores)),
  143. eval_scores, linewidth=2.0, color="red"
  144. )
  145. plt.show()
  146.  
  147.  
  148. def main():
  149. parser = ArgumentParser()
  150. # Input file
  151. parser.add_argument("--map_file", type=str, default="mini_map",
  152. help="File to read map from.")
  153. # Meta-parameters
  154. parser.add_argument("--learning_rate", type=float, default=0.1,
  155. help="Learning rate")
  156. parser.add_argument("--discount", type=float, default=0.99,
  157. help="Value for the discount factor")
  158. parser.add_argument("--epsilon", type=float, default=0.05,
  159. help="Probability to choose a random action.")
  160. # Training and evaluation episodes
  161. parser.add_argument("--train_episodes", type=int, default=1000,
  162. help="Number of episodes")
  163. parser.add_argument("--eval_every", type=int, default=10,
  164. help="Evaluate policy every ... games.")
  165. parser.add_argument("--eval_episodes", type=int, default=10,
  166. help="Number of games to play for evaluation.")
  167. # Display
  168. parser.add_argument("--verbose", dest="verbose",
  169. action="store_true", help="Print each state")
  170. parser.add_argument("--plot", dest="plot_scores", action="store_true",
  171. help="Plot scores in the end", )
  172. parser.add_argument("--sleep", type=float, default=0.1,
  173. help="Seconds to 'sleep' between moves.")
  174. parser.add_argument("--final_show", dest="final_show",
  175. action="store_true",
  176. help="Demonstrate final strategy.")
  177. args = parser.parse_args()
  178.  
  179. q_learning(
  180. args.map_file, args.learning_rate, args.discount, args.epsilon,
  181. args.train_episodes, args.eval_every, args.eval_episodes, args.verbose,
  182. args.plot_scores, args.sleep, args.final_show
  183. )
  184.  
  185.  
  186. if __name__ == "__main__":
  187. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement