Advertisement
Guest User

Untitled

a guest
Nov 11th, 2019
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.84 KB | None | 0 0
  1. from ple.games.flappybird import FlappyBird
  2. from ple import PLE
  3. import random
  4. import math
  5. import os
  6. import pickle
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import time
  10.  
  11. class MCAgent:
  12.  
  13. def __init__(self, gamma=1, eps=0.1, alpha=0.1):
  14. self.Q = {}
  15. self.G = 0
  16. self.pol = {}
  17. self.scores = []
  18. self.frames = []
  19. self.sar_triple = []
  20. self.gamma = gamma
  21. self.eps = eps
  22. self.alpha = alpha
  23. self.returns = {}
  24. for i in range(15):
  25. for j in range(15):
  26. for k in range(15):
  27. for l in range(15):
  28. state = (i, j, k, l)
  29. self.Q[state, 1] = 0
  30. self.Q[state, 0] = 0
  31. self.pol[state] = (1 - self.eps, 0)
  32. self.returns[state, 1] = []
  33. self.returns[state, 0] = []
  34.  
  35. def reward_values(self, a, b, c):
  36. """ returns the reward values used for training
  37.  
  38. Note: These are only the rewards used for training.
  39. The rewards used for evaluating the agent will always be
  40. 1 for passing through each pipe and 0 for all other state
  41. transitions.
  42. """
  43. return {"positive": a, "tick": b, "loss": c}
  44.  
  45. def observe(self, s1, a, r, s2, end):
  46. """ this function is called during training on each step of the game where
  47. the state transition is going from state s1 with action a to state s2 and
  48. yields the reward r. If s2 is a terminal state, end==True, otherwise end==False.
  49.  
  50. Unless a terminal state was reached, two subsequent calls to observe will be for
  51. subsequent steps in the same episode. That is, s1 in the second call will be s2
  52. from the first call.
  53. """
  54. if end:
  55. return
  56. self.returns[s1, a].append(r)
  57. # print("sum", self.returns[s1, a], "len", len(self.returns[s1, a]) )
  58. self.Q[s1, a] = sum(self.returns[s1, a])/len(self.returns[s1, a])
  59. argmax = 0 if self.Q[s1, 0] > self.Q[s1, 1] else 1
  60. self.pol[s1] = ((1 - self.eps) + (self.eps / 2), argmax)
  61. print("r: ", r, "Q", self.Q[s1, a], "pol", self.pol[s1][1])
  62.  
  63. def state_translate(self, state):
  64. #Discretization of environment
  65. if state["player_vel"] < -8:
  66. state["player_vel"] = -8
  67. if state["next_pipe_top_y"] < 0:
  68. state["next_pipe_top_y"] = 0
  69. if state["next_pipe_dist_to_player"] < 0:
  70. state["next_pipe_dist_to_player"] = 0
  71. if state["player_y"]:
  72. state["player_y"] = 0
  73. player_y = math.floor((state["player_y"] * (15/513)))
  74. next_pipe_top_y = math.floor(state["next_pipe_top_y"]*(15/513))
  75. next_pipe_dist_to_player = math.floor(
  76. state["next_pipe_dist_to_player"]*(15/310))
  77. player_vel = math.floor((state["player_vel"] + 8) * (15/19))
  78. state = (player_y, next_pipe_top_y,
  79. next_pipe_dist_to_player, player_vel)
  80. return state
  81.  
  82. def training_policy(self, state):
  83. """ Returns the index of the action that should be done in state while training the agent.
  84. Possible actions in Flappy Bird are 0 (flap the wing) or 1 (do nothing).
  85.  
  86. training_policy is called once per frame in the game while training
  87. """
  88. rand = random.uniform(0, 1)
  89. if rand < self.pol[self.state_translate(state)][0]:
  90. return self.pol[self.state_translate(state)][1]
  91. else:
  92. return random.randint(0, 1)
  93.  
  94. def writeToFile(self, file):
  95. if os.path.exists(file):
  96. os.remove(file)
  97. with open(file, 'wb') as f:
  98. pickle.dump(self.pol, f)
  99.  
  100. def readFromFile(self, file):
  101. with open(file, 'rb') as f:
  102. self.pol = pickle.loads(f.read())
  103.  
  104. def policy(self, state):
  105. """ Returns the index of the action that should be done in state when training is completed.
  106. Possible actions in Flappy Bird are 0 (flap the wing) or 1 (do nothing).
  107.  
  108. policy is called once per frame in the game (30 times per second in real-time)
  109. and needs to be sufficiently fast to not slow down the game.
  110. """
  111. return self.pol[self.state_translate(state)][1]
  112.  
  113. def run_game(nb_episodes, agent, a, b, c, train=True):
  114. """ Runs nb_episodes episodes of the game with agent picking the moves.
  115. An episode of FlappyBird ends with the bird crashing into a pipe or going off screen.
  116. """
  117. reward_values = agent.reward_values(a, b, c)
  118.  
  119. env = PLE(FlappyBird(), fps=30, display_screen=(not train), force_fps=train, rng=None,
  120. reward_values = reward_values)
  121. env.init()
  122. oldState = agent.state_translate(env.game.getGameState())
  123. score = 0
  124. frame = 0
  125. count = 0
  126. while nb_episodes > 0:
  127. #Training or testing
  128. frame += 1
  129. if train:
  130. action = agent.training_policy(env.game.getGameState())
  131. reward = env.act(env.getActionSet()[action])
  132. newState = agent.state_translate(env.game.getGameState())
  133. agent.sar_triple.append((oldState, action, reward))
  134. oldState = newState
  135. else:
  136. action = agent.policy(env.game.getGameState())
  137. reward = env.act(env.getActionSet()[action])
  138.  
  139.  
  140.  
  141. score += reward
  142.  
  143. # reset the environment if the game is over
  144. if env.game_over():
  145. if train:
  146. n = 0 #Loop counter
  147. agent.sar_triple.reverse() #Iterate in reverse order
  148. old_sar = (0, 0, 0) #First loop will be terminal state
  149. for sar in agent.sar_triple:
  150. agent.G += sar[2]
  151. end = True if n == 0 else False #We want to look at 2 states at a time
  152. n += 1
  153. agent.observe(old_sar[0], old_sar[1], agent.G * agent.gamma**n, sar[0], end)
  154. old_sar = sar
  155.  
  156.  
  157. agent.scores.append(score)
  158. agent.frames.append(frame)
  159. env.reset_game()
  160. nb_episodes -= 1
  161. agent.sar_triple = []
  162. score = 0
  163. count += 1
  164. agent.G = 0
  165. print("Iteration ", count)
  166.  
  167.  
  168. def run():
  169. agent = MCAgent(0.995, 0.1, 0.1)
  170. #reward structure
  171. a = 1.0
  172. b = 0.0
  173. c = -5.0
  174. run_game(5000, agent, a, b, c)
  175.  
  176. #See training results
  177. plt.plot(agent.frames, agent.scores)
  178. plt.show()
  179.  
  180. #Save the policy
  181. agent.writeToFile('mc_agent.txt')
  182.  
  183. #Test the policy
  184. agent.readFromFile('mc_agent.txt')
  185.  
  186. input("Press enter to watch the agent try his best")
  187. run_game(50, agent, a, b, c, False)
  188.  
  189. run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement